@@ -163,8 +163,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
163163
164164 linsolve = get_linear_solver (alg. descent)
165165 initialization_cache = __internal_init (prob, alg. initialization, alg, f, fu, u, p;
166- linsolve,
167- maxiters, internalnorm)
166+ linsolve, maxiters, internalnorm)
168167
169168 abstol, reltol, termination_cache = init_termination_cache (abstol, reltol, fu, u,
170169 termination_condition)
@@ -222,9 +221,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
222221 new_jacobian = true
223222 @static_timeit cache. timer " jacobian init/reinit" begin
224223 if get_nsteps (cache) == 0 # First Step is special ignore kwargs
225- J_init = __internal_solve! (cache. initialization_cache,
226- cache. fu,
227- cache. u,
224+ J_init = __internal_solve! (cache. initialization_cache, cache. fu, cache. u,
228225 Val (false ))
229226 if INV
230227 if jacobian_initialized_preinverted (cache. initialization_cache. alg)
@@ -283,54 +280,65 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
283280 @static_timeit cache. timer " descent" begin
284281 if cache. trustregion_cache != = nothing &&
285282 hasfield (typeof (cache. trustregion_cache), :trust_region )
286- δu, descent_success, descent_intermediates = __internal_solve! (
287- cache. descent_cache,
288- J, cache. fu, cache. u; new_jacobian,
289- trust_region = cache. trustregion_cache. trust_region)
283+ descent_result = __internal_solve! (cache. descent_cache, J, cache. fu, cache. u;
284+ new_jacobian, trust_region = cache. trustregion_cache. trust_region)
290285 else
291- δu, descent_success, descent_intermediates = __internal_solve! (
292- cache. descent_cache,
293- J, cache. fu, cache. u; new_jacobian)
286+ descent_result = __internal_solve! (cache. descent_cache, J, cache. fu, cache. u;
287+ new_jacobian)
294288 end
295289 end
296290
297- if descent_success
298- if GB === :LineSearch
299- @static_timeit cache. timer " linesearch" begin
300- needs_reset, α = __internal_solve! (cache. linesearch_cache, cache. u, δu)
301- end
302- if needs_reset && cache. steps_since_last_reset > 5 # Reset after a burn-in period
303- cache. force_reinit = true
304- else
305- @static_timeit cache. timer " step" begin
306- @bb axpy! (α, δu, cache. u)
307- evaluate_f! (cache, cache. u, cache. p)
308- end
309- end
310- elseif GB === :TrustRegion
311- @static_timeit cache. timer " trustregion" begin
312- tr_accepted, u_new, fu_new = __internal_solve! (cache. trustregion_cache, J,
313- cache. fu, cache. u, δu, descent_intermediates)
314- if tr_accepted
315- @bb copyto! (cache. u, u_new)
316- @bb copyto! (cache. fu, fu_new)
317- end
318- if hasfield (typeof (cache. trustregion_cache), :shrink_counter ) &&
319- cache. trustregion_cache. shrink_counter > cache. max_shrink_times
320- cache. retcode = ReturnCode. ShrinkThresholdExceeded
321- cache. force_stop = true
322- end
323- end
324- α = true
325- elseif GB === :None
291+ if descent_result. success
292+ if GB === :None
326293 @static_timeit cache. timer " step" begin
327- @bb axpy! (1 , δu, cache. u)
294+ if descent_result. u != = missing
295+ @bb copyto! (cache. u, descent_result. u)
296+ elseif descent_result. δu != = missing
297+ @bb axpy! (1 , descent_result. δu, cache. u)
298+ else
299+ error (" This shouldn't occur. `$(cache. alg. descent) ` is incorrectly \
300+ specified." )
301+ end
328302 evaluate_f! (cache, cache. u, cache. p)
329303 end
330304 α = true
331305 else
332- error (" Unknown Globalization Strategy: $(GB) . Allowed values are (:LineSearch, \
333- :TrustRegion, :None)" )
306+ δu = descent_result. δu
307+ @assert δu!= = missing " Descent Supporting LineSearch or TrustRegion must return a `δu`."
308+
309+ if GB === :LineSearch
310+ @static_timeit cache. timer " linesearch" begin
311+ needs_reset, α = __internal_solve! (cache. linesearch_cache, cache. u, δu)
312+ end
313+ if needs_reset && cache. steps_since_last_reset > 5 # Reset after a burn-in period
314+ cache. force_reinit = true
315+ else
316+ @static_timeit cache. timer " step" begin
317+ @bb axpy! (α, δu, cache. u)
318+ evaluate_f! (cache, cache. u, cache. p)
319+ end
320+ end
321+ elseif GB === :TrustRegion
322+ @static_timeit cache. timer " trustregion" begin
323+ tr_accepted, u_new, fu_new = __internal_solve! (cache. trustregion_cache,
324+ J, cache. fu, cache. u, δu, descent_result. extras)
325+ if tr_accepted
326+ @bb copyto! (cache. u, u_new)
327+ @bb copyto! (cache. fu, fu_new)
328+ α = true
329+ else
330+ α = false
331+ end
332+ if hasfield (typeof (cache. trustregion_cache), :shrink_counter ) &&
333+ cache. trustregion_cache. shrink_counter > cache. max_shrink_times
334+ cache. retcode = ReturnCode. ShrinkThresholdExceeded
335+ cache. force_stop = true
336+ end
337+ end
338+ else
339+ error (" Unknown Globalization Strategy: $(GB) . Allowed values are \
340+ (:LineSearch, :TrustRegion, :None)" )
341+ end
334342 end
335343 check_and_update! (cache, cache. fu, cache. u, cache. u_cache)
336344 else
0 commit comments