3636
3737@internal_caches HalleyDescentCache :lincache
3838
39- function __internal_init (
40- prob :: NonlinearProblem , alg :: HalleyDescent , J, fu, u; shared:: Val{N} = Val (1 ),
41- pre_inverted :: Val{INV} = False, linsolve_kwargs = (;), abstol = nothing ,
42- reltol = nothing , timer = get_timer_output (), kwargs... ) where {INV, N}
39+ function __internal_init (prob :: NonlinearProblem , alg :: HalleyDescent , J, fu, u; stats,
40+ shared:: Val{N} = Val (1 ), pre_inverted :: Val{INV} = False ,
41+ linsolve_kwargs = (;), abstol = nothing , reltol = nothing ,
42+ timer = get_timer_output (), kwargs... ) where {INV, N}
4343 @bb δu = similar (u)
4444 @bb b = similar (u)
4545 @bb fu = similar (fu)
@@ -48,23 +48,27 @@ function __internal_init(
4848 end
4949 INV && return HalleyDescentCache {true} (prob. f, prob. p, δu, δus, b, nothing , timer)
5050 lincache = LinearSolverCache (
51- alg, alg. linsolve, J, _vec (fu), _vec (u); abstol, reltol, linsolve_kwargs... )
51+ alg, alg. linsolve, J, _vec (fu), _vec (u); stats, abstol, reltol, linsolve_kwargs... )
5252 return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, fu, lincache, timer)
5353end
5454
5555function __internal_solve! (cache:: HalleyDescentCache{INV} , J, fu, u, idx:: Val = Val (1 );
5656 skip_solve:: Bool = false , new_jacobian:: Bool = true , kwargs... ) where {INV}
5757 δu = get_du (cache, idx)
58- skip_solve && return δu, true , (; )
58+ skip_solve && return DescentResult (; δu )
5959 if INV
6060 @assert J!= = nothing " `J` must be provided when `pre_inverted = Val(true)`."
6161 @bb δu = J × vec (fu)
6262 else
6363 @static_timeit cache. timer " linear solve 1" begin
64- δu = cache. lincache (;
64+ linres = cache. lincache (;
6565 A = J, b = _vec (fu), kwargs... , linu = _vec (δu), du = _vec (δu),
6666 reuse_A_if_factorization = ! new_jacobian || (idx != = Val (1 )))
67- δu = _restructure (get_du (cache, idx), δu)
67+ δu = _restructure (get_du (cache, idx), linres. u)
68+ if ! linres. success
69+ set_du! (cache, δu, idx)
70+ return DescentResult (; δu, success = false , linsolve_success = false )
71+ end
6872 end
6973 end
7074 b = cache. b
@@ -75,15 +79,19 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
7579 @bb b = J × vec (hvvp)
7680 else
7781 @static_timeit cache. timer " linear solve 2" begin
78- b = cache. lincache (; A = J, b = _vec (hvvp), kwargs... , linu = _vec (b),
82+ linres = cache. lincache (; A = J, b = _vec (hvvp), kwargs... , linu = _vec (b),
7983 du = _vec (b), reuse_A_if_factorization = true )
80- b = _restructure (cache. b, b)
84+ b = _restructure (cache. b, linres. u)
85+ if ! linres. success
86+ set_du! (cache, δu, idx)
87+ return DescentResult (; δu, success = false , linsolve_success = false )
88+ end
8189 end
8290 end
8391 @bb @. δu = δu * δu / (b / 2 - δu)
8492 set_du! (cache, δu, idx)
8593 cache. b = b
86- return δu, true , (; )
94+ return DescentResult (; δu )
8795end
8896
8997function evaluate_hvvp (
0 commit comments