From c1492772e6beb5ebb1d77cd4f224478dd8258b94 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 17 Oct 2025 22:25:54 +0530 Subject: [PATCH 1/4] rrules for solve!, init --- ext/LinearSolveMooncakeExt.jl | 16 ++++ src/adjoint.jl | 75 +++++++++++++++ test/nopre/mooncake.jl | 176 ++++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+) diff --git a/ext/LinearSolveMooncakeExt.jl b/ext/LinearSolveMooncakeExt.jl index b3179cbcf..cf3c58167 100644 --- a/ext/LinearSolveMooncakeExt.jl +++ b/ext/LinearSolveMooncakeExt.jl @@ -29,4 +29,20 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T} end end +function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache) + println("inside increment and get rdata 2") + f.fields.A .+= t.A + f.fields.b .+= t.b + + return NoRData() +end + +# rrules for LinearCache +@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode +@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode + +# rrule for solve! +@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm} true ReverseMode +@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing} true ReverseMode + end \ No newline at end of file diff --git a/src/adjoint.jl b/src/adjoint.jl index 281a1ee69..af04793bf 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -99,3 +99,78 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) ∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p) return prob, ∇prob end + +function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...) + assump = OperatorAssumptions(issquare(prob.A)) + alg = defaultalg(prob.A, prob.b, assump) + CRC.rrule(T, prob, alg, args...; kwargs...) +end + +function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...) + init_res = LinearSolve.init(prob, alg) + function init_adjoint(∂init) + ∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent()) + return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...) + end + + return init_res, init_adjoint +end + +function CRC.rrule(T::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::Nothing, args...; kwargs...) + assump = OperatorAssumptions() + alg = defaultalg(cache.A, cache.b, assump) + CRC.rrule(T, cache, alg, args...; kwargs) +end + +function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A( + alg, cache.A, cache.b), kwargs...) + _cache = deepcopy(cache) + (; A, sensealg) = _cache + @assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." + + # logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve + if sensealg.linsolve === missing + if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || + alg isa DefaultLinearSolver) + A_ = alias_A ? deepcopy(A) : A + end + else + A_ = deepcopy(A) + end + + sol = solve!(_cache) + + function solve!_adjoint(∂sol) + ∂∅ = NoTangent() + ∂u = ∂sol.u + + if sensealg.linsolve === missing + λ = if _cache.cacheval isa Factorization + _cache.cacheval' \ ∂u + elseif _cache.cacheval isa Tuple && _cache.cacheval[1] isa Factorization + first(_cache.cacheval)' \ ∂u + elseif alg isa AbstractKrylovSubspaceMethod + invprob = LinearProblem(adjoint(_cache.A), ∂u) + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + elseif alg isa DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(_cache, ∂u) + else + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + end + else + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` + λ = solve( + invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u + end + + tu = adjoint(sol.u) + ∂A = BroadcastArray(@~ .-(λ .* tu)) + ∂b = λ + ∂prob = LinearProblem(∂A, ∂b, ∂∅) + ∂cache = LinearSolve.init(∂prob) + return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) + end + + return sol, solve!_adjoint +end \ No newline at end of file diff --git a/test/nopre/mooncake.jl b/test/nopre/mooncake.jl index bbd9c5196..5555ab776 100644 --- a/test/nopre/mooncake.jl +++ b/test/nopre/mooncake.jl @@ -153,3 +153,179 @@ for alg in ( @test results[1] ≈ fA(A) @test mooncake_gradient ≈ fd_jac rtol = 1e-5 end + +# Tests for solve! and init rrules. + +n = 4 +A = rand(n, n); +b1 = rand(n); +b2 = rand(n); + +function f(A, b1, b2; alg=LUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f_primal = f(copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_gradient!!( + prepare_gradient_cache(f, copy(A), copy(b1), copy(b2)), + f, copy(A), copy(b1), copy(b2) +) + +dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + +@test value == f_primal +@test gradient[2] ≈ dA2 +@test gradient[3] ≈ db12 +@test gradient[4] ≈ db22 + +function f2(A, b1, b2; alg=RFLUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f_primal = f2(copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_gradient!!( + prepare_gradient_cache(f2, copy(A), copy(b1), copy(b2)), + f2, copy(A), copy(b1), copy(b2) +) + +@test value == f_primal +@test gradient[2] ≈ dA2 +@test gradient[3] ≈ db12 +@test gradient[4] ≈ db22 + +function f3(A, b1, b2; alg=LUFactorization()) + # alg = KrylovJL_GMRES()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f_primal = f3(copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_gradient!!( + prepare_gradient_cache(f3, copy(A), copy(b1), copy(b2)), + f3, copy(A), copy(b1), copy(b2) +) + +@test value == f_primal +@test gradient[2] ≈ dA2 atol = 5e-5 +@test gradient[3] ≈ db12 +@test gradient[4] ≈ db22 + +A = rand(n, n); +b1 = rand(n); + +function fnice(A, b, alg) + prob = LinearProblem(A, b) + sol1 = solve(prob, alg) + return sum(sol1.u) +end + +@testset for alg in ( + LUFactorization(), + RFLUFactorization(), + KrylovJL_GMRES() +) + # for B + fb_closure = b -> fnice(A, b, alg) + fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec + + val, en_jac = Mooncake.value_and_gradient!!( + prepare_gradient_cache(fnice, copy(A), copy(b1), alg), + fnice, copy(A), copy(b1), alg + ) + @test en_jac[3] ≈ fd_jac_b rtol = 1e-5 + + # For A + fA_closure = A -> fnice(A, b1, alg) + fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec + A_grad = en_jac[2] |> vec + @test A_grad ≈ fd_jac_A rtol = 1e-5 +end + +# The below test function cases fails ! +# AVOID Adjoint case in code as : `solve!(cache); s1 = copy(cache.u)`. +# Instead stick to code like : `sol = solve!(cache); s1 = copy(sol.u)`. + +function f4(A, b1, b2; alg=LUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + solve!(cache) + s1 = copy(cache.u) + cache.b = b2 + solve!(cache) + s2 = copy(cache.u) + norm(s1 + s2) +end + +# value, grad = Mooncake.value_and_gradient!!( +# prepare_gradient_cache(f4, copy(A), copy(b1), copy(b2)), +# f4, copy(A), copy(b1), copy(b2) +# ) +# (0.0, (Mooncake.NoTangent(), [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0])) + +# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +# db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + +# @test value == f_primal +# @test grad[2] ≈ dA2 +# @test grad[3] ≈ db12 +# @test grad[4] ≈ db22 + +function testls(A, b, u) + oa = OperatorAssumptions( + true, condition=LinearSolve.OperatorCondition.WellConditioned) + prob = LinearProblem(A, b) + linsolve = init(prob, LUFactorization(), assumptions=oa) + cache = solve!(linsolve) + sum(cache.u) +end + +# A = [1.0 2.0; 3.0 4.0] +# b = [1.0, 2.0] +# u = zero(b) +# value, gradient = Mooncake.value_and_gradient!!( +# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)), +# testls, copy(A), copy(b), copy(u) +# ) + +# dA = gradient[2] +# db = gradient[3] +# du = gradient[4] + +function testls(A, b, u) + oa = OperatorAssumptions( + true, condition=LinearSolve.OperatorCondition.WellConditioned) + prob = LinearProblem(A, b) + linsolve = init(prob, LUFactorization(), assumptions=oa) + solve!(linsolve) + sum(linsolve.u) +end + +# value, gradient = Mooncake.value_and_gradient!!( +# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)), +# testls, copy(A), copy(b), copy(u) +# ) + +# dA2 = gradient[2] +# db2 = gradient[3] +# du2 = gradient[4] + +# @test dA == dA2 +# @test db == db2 +# @test du == du2 From 80b0cb643381810374d6a4b3b3099897b70bf97b Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Mon, 20 Oct 2025 21:20:16 +0530 Subject: [PATCH 2/4] fix erroring out when using solve! as is. --- ext/LinearSolveMooncakeExt.jl | 2 +- src/adjoint.jl | 25 ++++++++------ test/nopre/mooncake.jl | 64 ++++++----------------------------- 3 files changed, 26 insertions(+), 65 deletions(-) diff --git a/ext/LinearSolveMooncakeExt.jl b/ext/LinearSolveMooncakeExt.jl index cf3c58167..c52bc67b4 100644 --- a/ext/LinearSolveMooncakeExt.jl +++ b/ext/LinearSolveMooncakeExt.jl @@ -30,9 +30,9 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T} end function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache) - println("inside increment and get rdata 2") f.fields.A .+= t.A f.fields.b .+= t.b + f.fields.u .+= t.u return NoRData() end diff --git a/src/adjoint.jl b/src/adjoint.jl index af04793bf..cb2fbe305 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -124,8 +124,7 @@ end function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A( alg, cache.A, cache.b), kwargs...) - _cache = deepcopy(cache) - (; A, sensealg) = _cache + (; A, sensealg) = cache @assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." # logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve @@ -138,22 +137,21 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, a A_ = deepcopy(A) end - sol = solve!(_cache) - + sol = solve!(cache) function solve!_adjoint(∂sol) ∂∅ = NoTangent() ∂u = ∂sol.u if sensealg.linsolve === missing - λ = if _cache.cacheval isa Factorization - _cache.cacheval' \ ∂u - elseif _cache.cacheval isa Tuple && _cache.cacheval[1] isa Factorization - first(_cache.cacheval)' \ ∂u + λ = if cache.cacheval isa Factorization + cache.cacheval' \ ∂u + elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization + first(cache.cacheval)' \ ∂u elseif alg isa AbstractKrylovSubspaceMethod - invprob = LinearProblem(adjoint(_cache.A), ∂u) + invprob = LinearProblem(adjoint(cache.A), ∂u) solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u elseif alg isa DefaultLinearSolver - LinearSolve.defaultalg_adjoint_eval(_cache, ∂u) + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) else invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u @@ -167,8 +165,13 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, a tu = adjoint(sol.u) ∂A = BroadcastArray(@~ .-(λ .* tu)) ∂b = λ + + if (iszero(∂b) || iszero(∂A)) && !iszero(tu) + error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.") + end + ∂prob = LinearProblem(∂A, ∂b, ∂∅) - ∂cache = LinearSolve.init(∂prob) + ∂cache = LinearSolve.init(∂prob, u=∂u) return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) end diff --git a/test/nopre/mooncake.jl b/test/nopre/mooncake.jl index 5555ab776..90eecd1fb 100644 --- a/test/nopre/mooncake.jl +++ b/test/nopre/mooncake.jl @@ -257,10 +257,6 @@ end @test A_grad ≈ fd_jac_A rtol = 1e-5 end -# The below test function cases fails ! -# AVOID Adjoint case in code as : `solve!(cache); s1 = copy(cache.u)`. -# Instead stick to code like : `sol = solve!(cache); s1 = copy(sol.u)`. - function f4(A, b1, b2; alg=LUFactorization()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -272,11 +268,16 @@ function f4(A, b1, b2; alg=LUFactorization()) norm(s1 + s2) end -# value, grad = Mooncake.value_and_gradient!!( -# prepare_gradient_cache(f4, copy(A), copy(b1), copy(b2)), -# f4, copy(A), copy(b1), copy(b2) -# ) -# (0.0, (Mooncake.NoTangent(), [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0])) +A = rand(n, n); +b1 = rand(n); +b2 = rand(n); +# f_primal = f4(copy(A), copy(b1), copy(b2)) + +rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2)) +@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!( + rule, 1.0, + f4, copy(A), copy(b1), copy(b2) +) # dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) # db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) @@ -285,47 +286,4 @@ end # @test value == f_primal # @test grad[2] ≈ dA2 # @test grad[3] ≈ db12 -# @test grad[4] ≈ db22 - -function testls(A, b, u) - oa = OperatorAssumptions( - true, condition=LinearSolve.OperatorCondition.WellConditioned) - prob = LinearProblem(A, b) - linsolve = init(prob, LUFactorization(), assumptions=oa) - cache = solve!(linsolve) - sum(cache.u) -end - -# A = [1.0 2.0; 3.0 4.0] -# b = [1.0, 2.0] -# u = zero(b) -# value, gradient = Mooncake.value_and_gradient!!( -# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)), -# testls, copy(A), copy(b), copy(u) -# ) - -# dA = gradient[2] -# db = gradient[3] -# du = gradient[4] - -function testls(A, b, u) - oa = OperatorAssumptions( - true, condition=LinearSolve.OperatorCondition.WellConditioned) - prob = LinearProblem(A, b) - linsolve = init(prob, LUFactorization(), assumptions=oa) - solve!(linsolve) - sum(linsolve.u) -end - -# value, gradient = Mooncake.value_and_gradient!!( -# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)), -# testls, copy(A), copy(b), copy(u) -# ) - -# dA2 = gradient[2] -# db2 = gradient[3] -# du2 = gradient[4] - -# @test dA == dA2 -# @test db == db2 -# @test du == du2 +# @test grad[4] ≈ db22 \ No newline at end of file From 79fbb058ebf587e51c22fb0f29d972d92b5a3e3d Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Mon, 20 Oct 2025 21:26:20 +0530 Subject: [PATCH 3/4] some code formatting --- test/nopre/mooncake.jl | 68 ++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/test/nopre/mooncake.jl b/test/nopre/mooncake.jl index 90eecd1fb..da33c7725 100644 --- a/test/nopre/mooncake.jl +++ b/test/nopre/mooncake.jl @@ -155,7 +155,6 @@ for alg in ( end # Tests for solve! and init rrules. - n = 4 A = rand(n, n); b1 = rand(n); @@ -205,8 +204,7 @@ value, gradient = Mooncake.value_and_gradient!!( @test gradient[3] ≈ db12 @test gradient[4] ≈ db22 -function f3(A, b1, b2; alg=LUFactorization()) - # alg = KrylovJL_GMRES()) +function f3(A, b1, b2; alg=KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) s1 = copy(solve!(cache).u) @@ -226,37 +224,6 @@ value, gradient = Mooncake.value_and_gradient!!( @test gradient[3] ≈ db12 @test gradient[4] ≈ db22 -A = rand(n, n); -b1 = rand(n); - -function fnice(A, b, alg) - prob = LinearProblem(A, b) - sol1 = solve(prob, alg) - return sum(sol1.u) -end - -@testset for alg in ( - LUFactorization(), - RFLUFactorization(), - KrylovJL_GMRES() -) - # for B - fb_closure = b -> fnice(A, b, alg) - fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec - - val, en_jac = Mooncake.value_and_gradient!!( - prepare_gradient_cache(fnice, copy(A), copy(b1), alg), - fnice, copy(A), copy(b1), alg - ) - @test en_jac[3] ≈ fd_jac_b rtol = 1e-5 - - # For A - fA_closure = A -> fnice(A, b1, alg) - fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec - A_grad = en_jac[2] |> vec - @test A_grad ≈ fd_jac_A rtol = 1e-5 -end - function f4(A, b1, b2; alg=LUFactorization()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -286,4 +253,35 @@ rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2)) # @test value == f_primal # @test grad[2] ≈ dA2 # @test grad[3] ≈ db12 -# @test grad[4] ≈ db22 \ No newline at end of file +# @test grad[4] ≈ db22 + +A = rand(n, n); +b1 = rand(n); + +function fnice(A, b, alg) + prob = LinearProblem(A, b) + sol1 = solve(prob, alg) + return sum(sol1.u) +end + +@testset for alg in ( + LUFactorization(), + RFLUFactorization(), + KrylovJL_GMRES() +) + # for B + fb_closure = b -> fnice(A, b, alg) + fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec + + val, en_jac = Mooncake.value_and_gradient!!( + prepare_gradient_cache(fnice, copy(A), copy(b1), alg), + fnice, copy(A), copy(b1), alg + ) + @test en_jac[3] ≈ fd_jac_b rtol = 1e-5 + + # For A + fA_closure = A -> fnice(A, b1, alg) + fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec + A_grad = en_jac[2] |> vec + @test A_grad ≈ fd_jac_A rtol = 1e-5 +end \ No newline at end of file From 36de150a8fe7b2e2516e58cae7f985ba049f6636 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sat, 8 Nov 2025 03:18:33 +0530 Subject: [PATCH 4/4] Mutation handling, nopre lts tests pass --- ext/LinearSolveMooncakeExt.jl | 103 +++++++++++++++++++++++++++++++--- src/adjoint.jl | 62 -------------------- test/nopre/mooncake.jl | 65 ++++++++++++--------- 3 files changed, 131 insertions(+), 99 deletions(-) diff --git a/ext/LinearSolveMooncakeExt.jl b/ext/LinearSolveMooncakeExt.jl index c52bc67b4..161c422aa 100644 --- a/ext/LinearSolveMooncakeExt.jl +++ b/ext/LinearSolveMooncakeExt.jl @@ -1,18 +1,19 @@ module LinearSolveMooncakeExt using Mooncake -using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!! +using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!, @is_primitive, primal, zero_fcodual, CoDual, rdata, fdata using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearProblem, - LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver, - defaultalg_adjoint_eval, solve + LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver, LinearSolveAdjoint, + defaultalg_adjoint_eval, solve, LUFactorization using LinearSolve.LinearAlgebra +using LazyArrays: @~, BroadcastArray using SciMLBase -@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve), LinearProblem, Nothing} true ReverseMode +@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve),LinearProblem,Nothing} true ReverseMode @from_chainrules MinimalCtx Tuple{ - typeof(SciMLBase.solve), LinearProblem, SciMLLinearSolveAlgorithm} true ReverseMode + typeof(SciMLBase.solve),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode @from_chainrules MinimalCtx Tuple{ - Type{<:LinearProblem}, AbstractMatrix, AbstractVector, SciMLBase.NullParameters} true ReverseMode + Type{<:LinearProblem},AbstractMatrix,AbstractVector,SciMLBase.NullParameters} true ReverseMode function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearProblem) f.data.A .+= t.A @@ -41,8 +42,92 @@ end @from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode @from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode -# rrule for solve! -@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm} true ReverseMode -@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing} true ReverseMode +# rrules for solve! +# NOTE - Avoid Mooncake.prepare_gradient_cache, only use Mooncake.prepare_pullback_cache (and therefore Mooncake.value_and_pullback!!) +# calling Mooncake.prepare_gradient_cache for functions with solve! will activate unsupported Adjoint case exception for below rrules +# This because in Mooncake.prepare_gradient_cache we reset stacks + state by passing in zero gradient in the reverse pass once. +# However, if one has a valid cache then they can directly use Mooncake.value_and_gradient!!. + +@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm,Vararg} +@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing,Vararg} + +function Mooncake.rrule!!(sig::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{Nothing}, args::Vararg{Any,N}; kwargs...) where {N} + cache = primal(_cache) + assump = OperatorAssumptions() + _alg.x = defaultalg(cache.A, cache.b, assump) + Mooncake.rrule!!(sig, _cache, _alg, args...; kwargs...) +end + +function Mooncake.rrule!!(::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{<:SciMLLinearSolveAlgorithm}, args::Vararg{Any,N}; alias_A=zero_fcodual(LinearSolve.default_alias_A( + _alg.x, _cache.x.A, _cache.x.b)), kwargs...) where {N} + + cache = primal(_cache) + alg = primal(_alg) + _args = map(primal, args) + + (; A, b, sensealg) = cache + A_orig = copy(A) + b_orig = copy(b) + + @assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." + + # logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve + if sensealg.linsolve === missing + if !(alg isa LinearSolve.AbstractFactorization || alg isa LinearSolve.AbstractKrylovSubspaceMethod || + alg isa LinearSolve.DefaultLinearSolver) + A_ = alias_A ? deepcopy(A) : A + end + else + A_ = deepcopy(A) + end + + sol = zero_fcodual(solve!(cache)) + cache.A = A_orig + cache.b = b_orig + + function solve!_adjoint(::NoRData) + ∂∅ = NoRData() + cachenew = init(LinearProblem(cache.A, cache.b), LUFactorization(), _args...; kwargs...) + new_sol = solve!(cachenew) + ∂u = sol.dx.data.u + + if sensealg.linsolve === missing + λ = if cache.cacheval isa Factorization + cache.cacheval' \ ∂u + elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization + first(cache.cacheval)' \ ∂u + elseif alg isa AbstractKrylovSubspaceMethod + invprob = LinearProblem(adjoint(cache.A), ∂u) + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + elseif alg isa DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) + else + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + end + else + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` + λ = solve( + invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u + end + + tu = adjoint(new_sol.u) + ∂A = BroadcastArray(@~ .-(λ .* tu)) + ∂b = λ + + if (iszero(∂b) || iszero(∂A)) && !iszero(tu) + error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.") + end + + fdata(_cache.dx).fields.A .+= ∂A + fdata(_cache.dx).fields.b .+= ∂b + fdata(_cache.dx).fields.u .+= ∂u + + # rdata for cache is a struct with NoRdata field values + return (∂∅, rdata(_cache.dx), ∂∅, ntuple(_ -> ∂∅, length(args))...) + end + + return sol, solve!_adjoint +end end \ No newline at end of file diff --git a/src/adjoint.jl b/src/adjoint.jl index cb2fbe305..4ecc5b260 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -115,65 +115,3 @@ function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, return init_res, init_adjoint end - -function CRC.rrule(T::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::Nothing, args...; kwargs...) - assump = OperatorAssumptions() - alg = defaultalg(cache.A, cache.b, assump) - CRC.rrule(T, cache, alg, args...; kwargs) -end - -function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A( - alg, cache.A, cache.b), kwargs...) - (; A, sensealg) = cache - @assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." - - # logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve - if sensealg.linsolve === missing - if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod || - alg isa DefaultLinearSolver) - A_ = alias_A ? deepcopy(A) : A - end - else - A_ = deepcopy(A) - end - - sol = solve!(cache) - function solve!_adjoint(∂sol) - ∂∅ = NoTangent() - ∂u = ∂sol.u - - if sensealg.linsolve === missing - λ = if cache.cacheval isa Factorization - cache.cacheval' \ ∂u - elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization - first(cache.cacheval)' \ ∂u - elseif alg isa AbstractKrylovSubspaceMethod - invprob = LinearProblem(adjoint(cache.A), ∂u) - solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u - elseif alg isa DefaultLinearSolver - LinearSolve.defaultalg_adjoint_eval(cache, ∂u) - else - invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` - solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u - end - else - invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` - λ = solve( - invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u - end - - tu = adjoint(sol.u) - ∂A = BroadcastArray(@~ .-(λ .* tu)) - ∂b = λ - - if (iszero(∂b) || iszero(∂A)) && !iszero(tu) - error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.") - end - - ∂prob = LinearProblem(∂A, ∂b, ∂∅) - ∂cache = LinearSolve.init(∂prob, u=∂u) - return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...) - end - - return sol, solve!_adjoint -end \ No newline at end of file diff --git a/test/nopre/mooncake.jl b/test/nopre/mooncake.jl index da33c7725..c50a89b45 100644 --- a/test/nopre/mooncake.jl +++ b/test/nopre/mooncake.jl @@ -11,9 +11,7 @@ b1 = rand(n); function f(A, b1; alg = LUFactorization()) prob = LinearProblem(A, b1) - sol1 = solve(prob, alg) - s1 = sol1.u norm(s1) end @@ -160,7 +158,7 @@ A = rand(n, n); b1 = rand(n); b2 = rand(n); -function f(A, b1, b2; alg=LUFactorization()) +function f_(A, b1, b2; alg=LUFactorization()) prob = LinearProblem(A, b1) cache = init(prob, alg) s1 = copy(solve!(cache).u) @@ -169,22 +167,23 @@ function f(A, b1, b2; alg=LUFactorization()) norm(s1 + s2) end -f_primal = f(copy(A), copy(b1), copy(b2)) -value, gradient = Mooncake.value_and_gradient!!( - prepare_gradient_cache(f, copy(A), copy(b1), copy(b2)), - f, copy(A), copy(b1), copy(b2) +f_primal = f_(copy(A), copy(b1), copy(b2)) +rule = Mooncake.build_rrule(f_, copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_pullback!!( + rule, 1.0, + f_, copy(A), copy(b1), copy(b2) ) -dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) -db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) -db22 = ForwardDiff.gradient(x -> f(eltype(x).(A), eltype(x).(b1), x), copy(b2)) +dA2 = ForwardDiff.gradient(x -> f_(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x -> f_(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f_(eltype(x).(A), eltype(x).(b1), x), copy(b2)) @test value == f_primal @test gradient[2] ≈ dA2 @test gradient[3] ≈ db12 @test gradient[4] ≈ db22 -function f2(A, b1, b2; alg=RFLUFactorization()) +function f_2(A, b1, b2; alg=RFLUFactorization()) prob = LinearProblem(A, b1) cache = init(prob, alg) s1 = copy(solve!(cache).u) @@ -193,18 +192,23 @@ function f2(A, b1, b2; alg=RFLUFactorization()) norm(s1 + s2) end -f_primal = f2(copy(A), copy(b1), copy(b2)) -value, gradient = Mooncake.value_and_gradient!!( - prepare_gradient_cache(f2, copy(A), copy(b1), copy(b2)), - f2, copy(A), copy(b1), copy(b2) +f_primal = f_2(copy(A), copy(b1), copy(b2)) +rule = Mooncake.build_rrule(f_2, copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_pullback!!( + rule, 1.0, + f_2, copy(A), copy(b1), copy(b2) ) +dA2 = ForwardDiff.gradient(x -> f_2(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + @test value == f_primal @test gradient[2] ≈ dA2 @test gradient[3] ≈ db12 @test gradient[4] ≈ db22 -function f3(A, b1, b2; alg=KrylovJL_GMRES()) +function f_3(A, b1, b2; alg=KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) s1 = copy(solve!(cache).u) @@ -213,18 +217,23 @@ function f3(A, b1, b2; alg=KrylovJL_GMRES()) norm(s1 + s2) end -f_primal = f3(copy(A), copy(b1), copy(b2)) -value, gradient = Mooncake.value_and_gradient!!( - prepare_gradient_cache(f3, copy(A), copy(b1), copy(b2)), - f3, copy(A), copy(b1), copy(b2) +f_primal = f_3(copy(A), copy(b1), copy(b2)) +rule = Mooncake.build_rrule(f_3, copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_pullback!!( + rule, 1.0, + f_3, copy(A), copy(b1), copy(b2) ) +dA2 = ForwardDiff.gradient(x -> f_3(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + @test value == f_primal -@test gradient[2] ≈ dA2 atol = 5e-5 +@test gradient[2] ≈ dA2 @test gradient[3] ≈ db12 @test gradient[4] ≈ db22 -function f4(A, b1, b2; alg=LUFactorization()) +function f_4(A, b1, b2; alg=LUFactorization()) prob = LinearProblem(A, b1) cache = init(prob, alg) solve!(cache) @@ -238,17 +247,17 @@ end A = rand(n, n); b1 = rand(n); b2 = rand(n); -# f_primal = f4(copy(A), copy(b1), copy(b2)) +f_primal = f_4(copy(A), copy(b1), copy(b2)) -rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2)) +rule = Mooncake.build_rrule(f_4, copy(A), copy(b1), copy(b2)) @test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!( rule, 1.0, - f4, copy(A), copy(b1), copy(b2) + f_4, copy(A), copy(b1), copy(b2) ) -# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) -# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) -# db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2)) +# dA2 = ForwardDiff.gradient(x -> f_4(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +# db12 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +# db22 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), eltype(x).(b1), x), copy(b2)) # @test value == f_primal # @test grad[2] ≈ dA2