@@ -19,42 +19,37 @@ for higher order derivatives partial can be any iterable, i.e.
1919```
2020"""
2121struct DiffPt{Dim}
22- pos # the actual position
23- partial
22+ pos # the actual position
23+ partial
2424end
2525
26- DiffPt (x;partial= ()) = DiffPt {length(x)} (x, partial) # convenience constructor
26+ DiffPt (x; partial= ()) = DiffPt {length(x)} (x, partial) # convenience constructor
2727
2828"""
2929Take the partial derivative of a function `fun` with input dimesion `dim`.
3030If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned.
3131"""
3232function partial (fun, dim, partials= ())
33- if ! isnothing (local next = iterate (partials))
34- idx, state = next
35- return partial (
36- x -> FD. derivative (0 ) do dx
37- fun (x .+ dx * OneHotVector (idx, dim))
38- end ,
39- dim,
40- Base. rest (partials, state),
41- )
42- end
43- return fun
33+ if ! isnothing (local next = iterate (partials))
34+ idx, state = next
35+ return partial (
36+ x -> FD. derivative (0 ) do dx
37+ fun (x .+ dx * OneHotVector (idx, dim))
38+ end , dim, Base. rest (partials, state)
39+ )
40+ end
41+ return fun
4442end
4543
4644"""
4745Take the partial derivative of a function with two dim-dimensional inputs,
4846i.e. 2*dim dimensional input
4947"""
5048function partial (k, dim; partials_x= (), partials_y= ())
51- local f (x,y) = partial (t -> k (t,y), dim, partials_x)(x)
52- return (x,y) -> partial (t -> f (x,t), dim, partials_y)(y)
49+ local f (x, y) = partial (t -> k (t, y), dim, partials_x)(x)
50+ return (x, y) -> partial (t -> f (x, t), dim, partials_y)(y)
5351end
5452
55-
56-
57-
5853"""
5954 _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
6055
@@ -65,15 +60,10 @@ redirection over `_evaluate` is necessary
6560unboxes the partial instructions from DiffPt and applies them to k,
6661evaluates them at the positions of DiffPt
6762"""
68- function _evaluate (k:: T , x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim, T<: Kernel }
69- return partial (
70- k, Dim,
71- partials_x= x. partial, partials_y= y. partial
72- )(x. pos, y. pos)
63+ function _evaluate (k:: T , x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim,T<: Kernel }
64+ return partial (k, Dim; partials_x= x. partial, partials_y= y. partial)(x. pos, y. pos)
7365end
7466
75-
76-
7767#=
7868This is a hack to work around the fact that the `where {T<:Kernel}` clause is
7969not allowed for the `(::T)(x,y)` syntax. If we were to only implement
@@ -85,8 +75,7 @@ then julia would not know whether to use
8575```
8676=#
8777for T in [SimpleKernel, Kernel] # subtypes(Kernel)
88- (k:: T )(x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, x, y)
89- (k:: T )(x:: DiffPt{Dim} , y) where {Dim} = _evaluate (k, x, DiffPt (y))
90- (k:: T )(x, y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, DiffPt (x), y)
78+ (k:: T )(x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, x, y)
79+ (k:: T )(x:: DiffPt{Dim} , y) where {Dim} = _evaluate (k, x, DiffPt (y))
80+ (k:: T )(x, y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, DiffPt (x), y)
9181end
92-
0 commit comments