diff --git a/ext/LinearSolveMooncakeExt.jl b/ext/LinearSolveMooncakeExt.jl index b3179cbcf..161c422aa 100644 --- a/ext/LinearSolveMooncakeExt.jl +++ b/ext/LinearSolveMooncakeExt.jl @@ -1,18 +1,19 @@ module LinearSolveMooncakeExt using Mooncake -using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!! +using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!, @is_primitive, primal, zero_fcodual, CoDual, rdata, fdata using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearProblem, - LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver, - defaultalg_adjoint_eval, solve + LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver, LinearSolveAdjoint, + defaultalg_adjoint_eval, solve, LUFactorization using LinearSolve.LinearAlgebra +using LazyArrays: @~, BroadcastArray using SciMLBase -@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve), LinearProblem, Nothing} true ReverseMode +@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve),LinearProblem,Nothing} true ReverseMode @from_chainrules MinimalCtx Tuple{ - typeof(SciMLBase.solve), LinearProblem, SciMLLinearSolveAlgorithm} true ReverseMode + typeof(SciMLBase.solve),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode @from_chainrules MinimalCtx Tuple{ - Type{<:LinearProblem}, AbstractMatrix, AbstractVector, SciMLBase.NullParameters} true ReverseMode + Type{<:LinearProblem},AbstractMatrix,AbstractVector,SciMLBase.NullParameters} true ReverseMode function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearProblem) f.data.A .+= t.A @@ -29,4 +30,104 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T} end end +function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache) + f.fields.A .+= t.A + f.fields.b .+= t.b + f.fields.u .+= t.u + + return NoRData() +end + +# rrules for LinearCache +@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode +@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode + +# rrules for solve! +# NOTE - Avoid Mooncake.prepare_gradient_cache, only use Mooncake.prepare_pullback_cache (and therefore Mooncake.value_and_pullback!!) +# calling Mooncake.prepare_gradient_cache for functions with solve! will activate unsupported Adjoint case exception for below rrules +# This because in Mooncake.prepare_gradient_cache we reset stacks + state by passing in zero gradient in the reverse pass once. +# However, if one has a valid cache then they can directly use Mooncake.value_and_gradient!!. + +@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm,Vararg} +@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing,Vararg} + +function Mooncake.rrule!!(sig::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{Nothing}, args::Vararg{Any,N}; kwargs...) where {N} + cache = primal(_cache) + assump = OperatorAssumptions() + _alg.x = defaultalg(cache.A, cache.b, assump) + Mooncake.rrule!!(sig, _cache, _alg, args...; kwargs...) +end + +function Mooncake.rrule!!(::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{<:SciMLLinearSolveAlgorithm}, args::Vararg{Any,N}; alias_A=zero_fcodual(LinearSolve.default_alias_A( + _alg.x, _cache.x.A, _cache.x.b)), kwargs...) where {N} + + cache = primal(_cache) + alg = primal(_alg) + _args = map(primal, args) + + (; A, b, sensealg) = cache + A_orig = copy(A) + b_orig = copy(b) + + @assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." + + # logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve + if sensealg.linsolve === missing + if !(alg isa LinearSolve.AbstractFactorization || alg isa LinearSolve.AbstractKrylovSubspaceMethod || + alg isa LinearSolve.DefaultLinearSolver) + A_ = alias_A ? deepcopy(A) : A + end + else + A_ = deepcopy(A) + end + + sol = zero_fcodual(solve!(cache)) + cache.A = A_orig + cache.b = b_orig + + function solve!_adjoint(::NoRData) + ∂∅ = NoRData() + cachenew = init(LinearProblem(cache.A, cache.b), LUFactorization(), _args...; kwargs...) + new_sol = solve!(cachenew) + ∂u = sol.dx.data.u + + if sensealg.linsolve === missing + λ = if cache.cacheval isa Factorization + cache.cacheval' \ ∂u + elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization + first(cache.cacheval)' \ ∂u + elseif alg isa AbstractKrylovSubspaceMethod + invprob = LinearProblem(adjoint(cache.A), ∂u) + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + elseif alg isa DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) + else + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` + solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + end + else + invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A` + λ = solve( + invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u + end + + tu = adjoint(new_sol.u) + ∂A = BroadcastArray(@~ .-(λ .* tu)) + ∂b = λ + + if (iszero(∂b) || iszero(∂A)) && !iszero(tu) + error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.") + end + + fdata(_cache.dx).fields.A .+= ∂A + fdata(_cache.dx).fields.b .+= ∂b + fdata(_cache.dx).fields.u .+= ∂u + + # rdata for cache is a struct with NoRdata field values + return (∂∅, rdata(_cache.dx), ∂∅, ntuple(_ -> ∂∅, length(args))...) + end + + return sol, solve!_adjoint +end + end \ No newline at end of file diff --git a/src/adjoint.jl b/src/adjoint.jl index 281a1ee69..4ecc5b260 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -99,3 +99,19 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...) ∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p) return prob, ∇prob end + +function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...) + assump = OperatorAssumptions(issquare(prob.A)) + alg = defaultalg(prob.A, prob.b, assump) + CRC.rrule(T, prob, alg, args...; kwargs...) +end + +function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...) + init_res = LinearSolve.init(prob, alg) + function init_adjoint(∂init) + ∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent()) + return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...) + end + + return init_res, init_adjoint +end diff --git a/test/nopre/mooncake.jl b/test/nopre/mooncake.jl index bbd9c5196..c50a89b45 100644 --- a/test/nopre/mooncake.jl +++ b/test/nopre/mooncake.jl @@ -11,9 +11,7 @@ b1 = rand(n); function f(A, b1; alg = LUFactorization()) prob = LinearProblem(A, b1) - sol1 = solve(prob, alg) - s1 = sol1.u norm(s1) end @@ -153,3 +151,146 @@ for alg in ( @test results[1] ≈ fA(A) @test mooncake_gradient ≈ fd_jac rtol = 1e-5 end + +# Tests for solve! and init rrules. +n = 4 +A = rand(n, n); +b1 = rand(n); +b2 = rand(n); + +function f_(A, b1, b2; alg=LUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f_primal = f_(copy(A), copy(b1), copy(b2)) +rule = Mooncake.build_rrule(f_, copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_pullback!!( + rule, 1.0, + f_, copy(A), copy(b1), copy(b2) +) + +dA2 = ForwardDiff.gradient(x -> f_(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x -> f_(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f_(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + +@test value == f_primal +@test gradient[2] ≈ dA2 +@test gradient[3] ≈ db12 +@test gradient[4] ≈ db22 + +function f_2(A, b1, b2; alg=RFLUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f_primal = f_2(copy(A), copy(b1), copy(b2)) +rule = Mooncake.build_rrule(f_2, copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_pullback!!( + rule, 1.0, + f_2, copy(A), copy(b1), copy(b2) +) + +dA2 = ForwardDiff.gradient(x -> f_2(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + +@test value == f_primal +@test gradient[2] ≈ dA2 +@test gradient[3] ≈ db12 +@test gradient[4] ≈ db22 + +function f_3(A, b1, b2; alg=KrylovJL_GMRES()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + s1 = copy(solve!(cache).u) + cache.b = b2 + s2 = solve!(cache).u + norm(s1 + s2) +end + +f_primal = f_3(copy(A), copy(b1), copy(b2)) +rule = Mooncake.build_rrule(f_3, copy(A), copy(b1), copy(b2)) +value, gradient = Mooncake.value_and_pullback!!( + rule, 1.0, + f_3, copy(A), copy(b1), copy(b2) +) + +dA2 = ForwardDiff.gradient(x -> f_3(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +db12 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + +@test value == f_primal +@test gradient[2] ≈ dA2 +@test gradient[3] ≈ db12 +@test gradient[4] ≈ db22 + +function f_4(A, b1, b2; alg=LUFactorization()) + prob = LinearProblem(A, b1) + cache = init(prob, alg) + solve!(cache) + s1 = copy(cache.u) + cache.b = b2 + solve!(cache) + s2 = copy(cache.u) + norm(s1 + s2) +end + +A = rand(n, n); +b1 = rand(n); +b2 = rand(n); +f_primal = f_4(copy(A), copy(b1), copy(b2)) + +rule = Mooncake.build_rrule(f_4, copy(A), copy(b1), copy(b2)) +@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!( + rule, 1.0, + f_4, copy(A), copy(b1), copy(b2) +) + +# dA2 = ForwardDiff.gradient(x -> f_4(x, eltype(x).(b1), eltype(x).(b2)), copy(A)) +# db12 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), x, eltype(x).(b2)), copy(b1)) +# db22 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), eltype(x).(b1), x), copy(b2)) + +# @test value == f_primal +# @test grad[2] ≈ dA2 +# @test grad[3] ≈ db12 +# @test grad[4] ≈ db22 + +A = rand(n, n); +b1 = rand(n); + +function fnice(A, b, alg) + prob = LinearProblem(A, b) + sol1 = solve(prob, alg) + return sum(sol1.u) +end + +@testset for alg in ( + LUFactorization(), + RFLUFactorization(), + KrylovJL_GMRES() +) + # for B + fb_closure = b -> fnice(A, b, alg) + fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec + + val, en_jac = Mooncake.value_and_gradient!!( + prepare_gradient_cache(fnice, copy(A), copy(b1), alg), + fnice, copy(A), copy(b1), alg + ) + @test en_jac[3] ≈ fd_jac_b rtol = 1e-5 + + # For A + fA_closure = A -> fnice(A, b1, alg) + fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec + A_grad = en_jac[2] |> vec + @test A_grad ≈ fd_jac_A rtol = 1e-5 +end \ No newline at end of file