From 549b38119dbf22e9c7081f1783ceb51219e54b2a Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Wed, 29 Oct 2025 21:07:24 +0530 Subject: [PATCH 1/2] fix broken broken tests for Mooncake. --- .../NonlinearSolveBaseChainRulesCoreExt.jl | 18 +++++++-- .../ext/NonlinearSolveBaseMooncakeExt.jl | 39 ++++++++----------- test/adjoint_tests.jl | 2 +- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl index d60be6211..f4659305f 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl @@ -19,13 +19,23 @@ function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob, end function ChainRulesCore.rrule(::typeof(NonlinearSolveBase.solve_up), prob::AbstractNonlinearProblem, - sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), - kwargs...) - NonlinearSolveBase._solve_adjoint( + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + primal, inner_thunking_pb = NonlinearSolveBase._solve_adjoint( prob, sensealg, u0, p, originator, args...; kwargs...) + + # when using mooncake ∂sol would be a NamedTuple Tangent with cotangents of all the solution struct's fields. + # However the pullback for this rule - "steadystatebackpass" as defined in SciMLSensitivity/src/concrete_solve.jl/ + # handles AD only when ∂sol is a ChainRulesCore.AbstractThunk object or a sol.u vector and similar data structures (not namedtuples). + # When using Mooncake, we pass in sol.u to inner_thunking_pb directly as this is the only field relevant to the solution's cotangent (given solve_up, AbstractNonlinearProblem setting). + + function solve_up_adjoint(∂sol) + return inner_thunking_pb(∂sol isa Tangent{Any,<:NamedTuple} ? ∂sol.u : ∂sol) + end + return primal, solve_up_adjoint end end \ No newline at end of file diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl index bed496361..4e90f2907 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl @@ -2,30 +2,25 @@ module NonlinearSolveBaseMooncakeExt using NonlinearSolveBase, Mooncake using SciMLBase: SciMLBase -import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, - @from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx, - NoPullback +using Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, + @from_chainrules, @zero_adjoint, @mooncake_overlay, MinimalCtx, + NoPullback -@from_rrule(MinimalCtx, - Tuple{ - typeof(NonlinearSolveBase.solve_up), - SciMLBase.AbstractNonlinearProblem, - Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm}, - Any, - Any, - Any - }, - true,) +@from_chainrules MinimalCtx Tuple{typeof(NonlinearSolveBase.solve_up), + SciMLBase.AbstractNonlinearProblem, + Union{Nothing,SciMLBase.AbstractSensitivityAlgorithm}, + Any, + Any, + Any +} true # Dispatch for auto-alg -@from_rrule(MinimalCtx, - Tuple{ - typeof(NonlinearSolveBase.solve_up), - SciMLBase.AbstractNonlinearProblem, - Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm}, - Any, - Any - }, - true,) +@from_chainrules MinimalCtx Tuple{ + typeof(NonlinearSolveBase.solve_up), + SciMLBase.AbstractNonlinearProblem, + Union{Nothing,SciMLBase.AbstractSensitivityAlgorithm}, + Any, + Any +} true end diff --git a/test/adjoint_tests.jl b/test/adjoint_tests.jl index 8882c1916..a765cc159 100644 --- a/test/adjoint_tests.jl +++ b/test/adjoint_tests.jl @@ -23,5 +23,5 @@ @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme - @test_broken ∂p_forwarddiff ≈ ∂p_mooncake + @test ∂p_forwarddiff ≈ ∂p_mooncake end From ad70695db6aaeef744db03b0bd393746ae435300 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Wed, 29 Oct 2025 21:12:34 +0530 Subject: [PATCH 2/2] minor fix. --- .../ext/NonlinearSolveBaseChainRulesCoreExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl index f4659305f..869d24c57 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl @@ -6,7 +6,7 @@ using SciMLBase using SciMLBase: AbstractSensitivityAlgorithm import ChainRulesCore -import ChainRulesCore: NoTangent +import ChainRulesCore: NoTangent, Tangent function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob, sensealg::Union{Nothing, AbstractSensitivityAlgorithm},