@@ -29,6 +29,7 @@ supports_line_search(::HalleyDescent) = true
2929 δu
3030 δus
3131 b
32+ fu
3233 lincache
3334 timer
3435end
@@ -41,13 +42,14 @@ function __internal_init(
4142 reltol = nothing , timer = get_timer_output (), kwargs... ) where {INV, N}
4243 @bb δu = similar (u)
4344 @bb b = similar (u)
45+ @bb fu = similar (fu)
4446 δus = N ≤ 1 ? nothing : map (2 : N) do i
4547 @bb δu_ = similar (u)
4648 end
4749 INV && return HalleyDescentCache {true} (prob. f, prob. p, δu, δus, b, nothing , timer)
4850 lincache = LinearSolverCache (
4951 alg, alg. linsolve, J, _vec (fu), _vec (u); abstol, reltol, linsolve_kwargs... )
50- return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, lincache, timer)
52+ return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, fu, lincache, timer)
5153end
5254
5355function __internal_solve! (cache:: HalleyDescentCache{INV} , J, fu, u, idx:: Val = Val (1 );
@@ -67,7 +69,7 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
6769 end
6870 b = cache. b
6971 # compute the hessian-vector-vector product
70- hvvp = derivative (Base . Fix2 ( cache. f, cache. p) , u, δu, 2 )
72+ hvvp = evaluate_hvvp ( cache, cache . f, cache. p, u, δu)
7173 # second linear solve, reuse factorization if possible
7274 if INV
7375 @bb b = J × vec (hvvp)
@@ -83,3 +85,14 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
8385 cache. b = b
8486 return δu, true , (;)
8587end
88+
89+ function evaluate_hvvp (
90+ cache:: HalleyDescentCache , f:: NonlinearFunction{iip} , p, u, δu) where {iip}
91+ if iip
92+ binary_f = (y, x) -> f (y, x, p)
93+ derivative (binary_f, cache. fu, u, δu, Val {3} ())
94+ else
95+ unary_f = Base. Fix2 (f, p)
96+ derivative (unary_f, u, δu, Val {3} ())
97+ end
98+ end
0 commit comments