11"""
2- HalleyDescent(; linsolve = nothing, precs = DEFAULT_PRECS )
2+ HalleyDescent(; linsolve = nothing)
33
44Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``.
55Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
66Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.
77
8+ Note that `import TaylorDiff` is required to use this descent algorithm.
9+
810See also [`NewtonDescent`](@ref).
911"""
10- @kwdef @concrete struct HalleyDescent <: AbstractDescentAlgorithm
12+ @kwdef @concrete struct HalleyDescent <: AbstractDescentDirection
1113 linsolve = nothing
12- precs = DEFAULT_PRECS
13- end
14-
15- using TaylorDiff: derivative
16-
17- function Base. show (io:: IO , d:: HalleyDescent )
18- modifiers = String[]
19- d. linsolve != = nothing && push! (modifiers, " linsolve = $(d. linsolve) " )
20- d. precs != = DEFAULT_PRECS && push! (modifiers, " precs = $(d. precs) " )
21- print (io, " HalleyDescent($(join (modifiers, " , " )) )" )
2214end
2315
2416supports_line_search (:: HalleyDescent ) = true
2517
26- @concrete mutable struct HalleyDescentCache{pre_inverted} <: AbstractDescentCache
18+ @concrete mutable struct HalleyDescentCache <: AbstractDescentCache
2719 f
2820 p
2921 δu
3022 δus
3123 b
3224 fu
25+ hvvp
3326 lincache
3427 timer
28+ preinverted_jacobian <: Union{Val{false}, Val{true}}
3529end
3630
3731@internal_caches HalleyDescentCache :lincache
3832
39- function __internal_init (prob:: NonlinearProblem , alg:: HalleyDescent , J, fu, u; stats,
40- shared:: Val{N} = Val (1 ), pre_inverted:: Val{INV} = False,
33+ function InternalAPI. init (
34+ prob:: NonlinearProblem , alg:: HalleyDescent , J, fu, u; stats,
35+ shared = Val (1 ), pre_inverted:: Val = Val (false ),
4136 linsolve_kwargs = (;), abstol = nothing , reltol = nothing ,
42- timer = get_timer_output (), kwargs... ) where {INV, N}
37+ timer = get_timer_output (), kwargs... )
4338 @bb δu = similar (u)
4439 @bb b = similar (u)
4540 @bb fu = similar (fu)
46- δus = N ≤ 1 ? nothing : map (2 : N) do i
41+ @bb hvvp = similar (fu)
42+ δus = Utils. unwrap_val (shared) ≤ 1 ? nothing : map (2 : Utils. unwrap_val (shared)) do i
4743 @bb δu_ = similar (u)
4844 end
49- INV && return HalleyDescentCache {true} (prob. f, prob. p, δu, δus, b, nothing , timer)
50- lincache = LinearSolverCache (
51- alg, alg. linsolve, J, _vec (fu), _vec (u); stats, abstol, reltol, linsolve_kwargs... )
52- return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, fu, lincache, timer)
45+ lincache = Utils. unwrap_val (pre_inverted) ? nothing :
46+ construct_linear_solver (
47+ alg, alg. linsolve, J, Utils. safe_vec (fu), Utils. safe_vec (u);
48+ stats, abstol, reltol, linsolve_kwargs...
49+ )
50+ return HalleyDescentCache (
51+ prob. f, prob. p, δu, δus, b, fu, hvvp, lincache, timer, pre_inverted)
5352end
5453
55- function __internal_solve! (cache:: HalleyDescentCache{INV} , J, fu, u, idx:: Val = Val (1 );
56- skip_solve:: Bool = false , new_jacobian:: Bool = true , kwargs... ) where {INV}
57- δu = get_du (cache, idx)
54+ function InternalAPI. solve! (
55+ cache:: HalleyDescentCache , J, fu, u, idx:: Val = Val (1 );
56+ skip_solve:: Bool = false , new_jacobian:: Bool = true , kwargs... )
57+ δu = SciMLBase. get_du (cache, idx)
5858 skip_solve && return DescentResult (; δu)
59- if INV
59+ if preinverted_jacobian (cache)
6060 @assert J!= = nothing " `J` must be provided when `pre_inverted = Val(true)`."
6161 @bb δu = J × vec (fu)
6262 else
6363 @static_timeit cache. timer " linear solve 1" begin
6464 linres = cache. lincache (;
65- A = J, b = _vec (fu), kwargs... , linu = _vec (δu), du = _vec (δu),
65+ A = J, b = Utils. safe_vec (fu),
66+ kwargs... , linu = Utils. safe_vec (δu),
6667 reuse_A_if_factorization = ! new_jacobian || (idx != = Val (1 )))
67- δu = _restructure ( get_du (cache, idx), linres. u)
68+ δu = Utils . restructure (SciMLBase . get_du (cache, idx), linres. u)
6869 if ! linres. success
6970 set_du! (cache, δu, idx)
7071 return DescentResult (; δu, success = false , linsolve_success = false )
@@ -73,15 +74,17 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
7374 end
7475 b = cache. b
7576 # compute the hessian-vector-vector product
76- hvvp = evaluate_hvvp (cache, cache. f, cache. p, u, δu)
77+ hvvp = evaluate_hvvp (cache. hvvp, cache , cache. f, cache. p, u, δu)
7778 # second linear solve, reuse factorization if possible
78- if INV
79+ if preinverted_jacobian (cache)
7980 @bb b = J × vec (hvvp)
8081 else
8182 @static_timeit cache. timer " linear solve 2" begin
82- linres = cache. lincache (; A = J, b = _vec (hvvp), kwargs... , linu = _vec (b),
83- du = _vec (b), reuse_A_if_factorization = true )
84- b = _restructure (cache. b, linres. u)
83+ linres = cache. lincache (;
84+ A = J, b = Utils. safe_vec (hvvp),
85+ kwargs... , linu = Utils. safe_vec (b),
86+ reuse_A_if_factorization = true )
87+ b = Utils. restructure (cache. b, linres. u)
8588 if ! linres. success
8689 set_du! (cache, δu, idx)
8790 return DescentResult (; δu, success = false , linsolve_success = false )
@@ -94,13 +97,4 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
9497 return DescentResult (; δu)
9598end
9699
97- function evaluate_hvvp (
98- cache:: HalleyDescentCache , f:: NonlinearFunction{iip} , p, u, δu) where {iip}
99- if iip
100- binary_f = @closure (y, x) -> f (y, x, p)
101- derivative (binary_f, cache. fu, u, δu, Val {3} ())
102- else
103- unary_f = Base. Fix2 (f, p)
104- derivative (unary_f, u, δu, Val {3} ())
105- end
106- end
100+ evaluate_hvvp (hvvp, cache, f, p, u, δu) = error (" not implemented. please import TaylorDiff" )
0 commit comments