Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ext/LinearSolveMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,20 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
end
end

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache)
f.fields.A .+= t.A
f.fields.b .+= t.b
f.fields.u .+= t.u

return NoRData()
end

# rrules for LinearCache
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode

# rrule for solve!
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing} true ReverseMode
Comment on lines +40 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are mutating, so there are no chain rules.


end
78 changes: 78 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,81 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end

function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions(issquare(prob.A))
alg = defaultalg(prob.A, prob.b, assump)
CRC.rrule(T, prob, alg, args...; kwargs...)
end

function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...)
init_res = LinearSolve.init(prob, alg)
function init_adjoint(∂init)
∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent())
return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...)
end

return init_res, init_adjoint
end

function CRC.rrule(T::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions()
alg = defaultalg(cache.A, cache.b, assump)
CRC.rrule(T, cache, alg, args...; kwargs)
end

function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a mutating rule... is it okay to do this with CRC?

Copy link
Member Author

@AstitvaAggarwal AstitvaAggarwal Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the chainrules rrules for Mooncake are actually only used in Mooncake's derived rruleset context. So i dint face any issues with the required tests.

But yes it would be safer to write Mooncake rrules here as chainrules in general don't support mutation in reverse mode. I'll try adding the changes (avoided till now due to Mooncake.fdata mutation for structs)

alg, cache.A, cache.b), kwargs...)
(; A, sensealg) = cache
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
if sensealg.linsolve === missing
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
alg isa DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A
end
else
A_ = deepcopy(A)
end

sol = solve!(cache)
function solve!_adjoint(∂sol)
∂∅ = NoTangent()
∂u = ∂sol.u

if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(adjoint(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

tu = adjoint(sol.u)
∂A = BroadcastArray(@~ .-(λ .* tu))
∂b = λ

if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
end

∂prob = LinearProblem(∂A, ∂b, ∂∅)
∂cache = LinearSolve.init(∂prob, u=∂u)
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
end

return sol, solve!_adjoint
end
132 changes: 132 additions & 0 deletions test/nopre/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,135 @@ for alg in (
@test results[1] ≈ fA(A)
@test mooncake_gradient ≈ fd_jac rtol = 1e-5
end

# Tests for solve! and init rrules.
n = 4
A = rand(n, n);
b1 = rand(n);
b2 = rand(n);

function f(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f, copy(A), copy(b1), copy(b2)),
f, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f2(A, b1, b2; alg=RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f2(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f2, copy(A), copy(b1), copy(b2)),
f2, copy(A), copy(b1), copy(b2)
)

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f3(A, b1, b2; alg=KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f3(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f3, copy(A), copy(b1), copy(b2)),
f3, copy(A), copy(b1), copy(b2)
)

@test value == f_primal
@test gradient[2] ≈ dA2 atol = 5e-5
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f4(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
solve!(cache)
s1 = copy(cache.u)
cache.b = b2
solve!(cache)
s2 = copy(cache.u)
norm(s1 + s2)
end

A = rand(n, n);
b1 = rand(n);
b2 = rand(n);
# f_primal = f4(copy(A), copy(b1), copy(b2))

rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2))
@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!(
rule, 1.0,
f4, copy(A), copy(b1), copy(b2)
)

# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
# db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2))

# @test value == f_primal
# @test grad[2] ≈ dA2
# @test grad[3] ≈ db12
# @test grad[4] ≈ db22

A = rand(n, n);
b1 = rand(n);

function fnice(A, b, alg)
prob = LinearProblem(A, b)
sol1 = solve(prob, alg)
return sum(sol1.u)
end

@testset for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES()
)
# for B
fb_closure = b -> fnice(A, b, alg)
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec

val, en_jac = Mooncake.value_and_gradient!!(
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
fnice, copy(A), copy(b1), alg
)
@test en_jac[3] ≈ fd_jac_b rtol = 1e-5

# For A
fA_closure = A -> fnice(A, b1, alg)
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
A_grad = en_jac[2] |> vec
@test A_grad ≈ fd_jac_A rtol = 1e-5
end
Loading