@@ -222,25 +222,52 @@ To manually specify an inverse, call
222222function pushfwd end
223223export pushfwd
224224
225- @inline pushfwd (f, μ) = _pushfwd_impl (f, μ, AdaptRootMeasure ())
226- @inline pushfwd (f, μ, style:: AdaptRootMeasure ) = _pushfwd_impl (f, μ, style)
227- @inline pushfwd (f, μ, style:: PushfwdRootMeasure ) = _pushfwd_impl (f, μ, style)
225+ @inline pushfwd (f, μ) = _pushfwd_impl1 (f, μ, AdaptRootMeasure ())
226+ @inline pushfwd (f, μ, style:: AdaptRootMeasure ) = _pushfwd_impl1 (f, μ, style)
227+ @inline pushfwd (f, μ, style:: PushfwdRootMeasure ) = _pushfwd_impl1 (f, μ, style)
228228
229- _pushfwd_impl (f, μ, style) = PushforwardMeasure (f, inverse (f), μ, style)
229+ _pushfwd_impl1 (f, μ, style:: PushFwdStyle ) = _pushfwd_impl2 (f, inverse (f), μ, style)
230+ _pushfwd_impl1 (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
231+ _pushfwd_impl1 (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
230232
231- function _pushfwd_impl (
233+ _pushfwd_impl2 (f, finv, μ, style:: PushFwdStyle ) = PushforwardMeasure (f, finv, μ, style)
234+
235+ function _pushfwd_impl2 (
232236 f,
237+ finv,
233238 μ:: PushforwardMeasure{F,I,M,S} ,
234239 style:: S ,
235240) where {F,I,M,S<: PushFwdStyle }
236241 orig_μ = μ. origin
237242 new_f = fcomp (f, μ. f)
238- new_f_inv = fcomp (μ. finv, inverse (f) )
243+ new_f_inv = fcomp (μ. finv, finv )
239244 PushforwardMeasure (new_f, new_f_inv, orig_μ, style)
240245end
241246
242- _pushfwd_impl (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
243- _pushfwd_impl (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
247+ struct _CurriedPushfwd{F,I,S<: PushFwdStyle } <: Function
248+ f:: F
249+ finv:: I
250+ style:: S
251+
252+ function _CurriedPushfwd {F,I,S} (f:: F , finv:: I , style:: S ) where {F,I,S<: PushFwdStyle }
253+ new {F,I,S} (f, finv, style)
254+ end
255+
256+ function _CurriedPushfwd (f, finv, style:: S ) where {S<: PushFwdStyle }
257+ new {Core.Typeof(f),Core.Typeof(finv),S} (f, finv, style)
258+ end
259+ end
260+
261+ @inline (cf:: _CurriedPushfwd{F,FI} )(μ) where {F,FI} =
262+ _pushfwd_impl2 (cf. f, cf. finv, μ, cf. style)
263+
264+ @inline pushfwd (f) = _curried_pushfwd_impl (f, AdaptRootMeasure ())
265+ @inline pushfwd (f, style:: AdaptRootMeasure ) = _curried_pushfwd_impl (f, style)
266+ @inline pushfwd (f, style:: PushfwdRootMeasure ) = _curried_pushfwd_impl (f, style)
267+
268+ _curried_pushfwd_impl (f, style:: PushFwdStyle ) = _CurriedPushfwd (f, inverse (f), style)
269+ @inline _curried_pushfwd_impl (:: typeof (identity), :: AdaptRootMeasure ) = identity
270+ @inline _curried_pushfwd_impl (:: typeof (identity), :: PushfwdRootMeasure ) = identity
244271
245272# ##############################################################################
246273# pullback
@@ -267,8 +294,16 @@ export pullbck
267294@inline pullbck (f, μ, style:: AdaptRootMeasure ) = _pullback_impl (f, μ, style)
268295@inline pullbck (f, μ, style:: PushfwdRootMeasure ) = _pullback_impl (f, μ, style)
269296
270- function _pullback_impl (f, μ, style = AdaptRootMeasure ())
271- pushfwd (inverse (f), μ, style)
272- end
297+ _pullback_impl (f, μ, style:: PushFwdStyle ) = _pushfwd_impl2 (inverse (f), f, μ, style)
298+ _pullback_impl (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
299+ _pullback_impl (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
300+
301+ @inline pullbck (f) = _curried_pullbck_impl (f, AdaptRootMeasure ())
302+ @inline pullbck (f, style:: AdaptRootMeasure ) = _curried_pullbck_impl (f, style)
303+ @inline pullbck (f, style:: PushfwdRootMeasure ) = _curried_pullbck_impl (f, style)
304+
305+ _curried_pullbck_impl (f, style:: PushFwdStyle ) = _CurriedPushfwd (inverse (f), f, style)
306+ @inline _curried_pullbck_impl (:: typeof (identity), :: AdaptRootMeasure ) = identity
307+ @inline _curried_pullbck_impl (:: typeof (identity), :: PushfwdRootMeasure ) = identity
273308
274309@deprecate pullback (f, μ, style:: PushFwdStyle = AdaptRootMeasure ()) pullbck (f, μ, style)
0 commit comments