@@ -18,12 +18,12 @@ for higher order derivatives partial can be any iterable, i.e.
1818 k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))
1919```
2020"""
21- struct DiffPt{Dim}
21+ struct DiffPt
2222 pos # the actual position
2323 partial
2424end
2525
26- DiffPt (x; partial= ()) = DiffPt {length(x)} (x, partial) # convenience constructor
26+ DiffPt (x; partial= ()) = DiffPt (x, partial) # convenience constructor
2727
2828"""
2929 partial(fun, idx)
@@ -34,10 +34,8 @@ Return ∂ᵢf where
3434"""
3535function partial (fun, idx)
3636 return x -> FD. derivative (0 ) do dx
37- y = similar (x)
38- y = copyto! (y, x)
39- y[idx] += dx
40- fun (y)
37+ dim = length (x)
38+ fun (x .+ dx * OneHotVector (idx, dim))
4139 end
4240end
4341
5856Take the partial derivative of a function with two dim-dimensional inputs,
5957i.e. 2*dim dimensional input
6058"""
61- function partial (k, dim ; partials_x= (), partials_y= ())
62- local f (x, y) = partial (t -> k (t, y), dim, partials_x)(x)
63- return (x, y) -> partial (t -> f (x, t), dim, partials_y)(y)
59+ function partial (k; partials_x= (), partials_y= ())
60+ local f (x, y) = partial (t -> k (t, y), partials_x... )(x)
61+ return (x, y) -> partial (t -> f (x, t), partials_y... )(y)
6462end
6563
6664"""
67- _evaluate(k::T, x::DiffPt{Dim} , y::DiffPt{Dim} ) where {Dim, T<:Kernel}
65+ _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel}
6866
69- implements `(k::T)(x::DiffPt{Dim} , y::DiffPt{Dim} )` for all kernel types. But since
67+ implements `(k::T)(x::DiffPt, y::DiffPt)` for all kernel types. But since
7068generics are not allowed in the syntax above by the dispatch system, this
7169redirection over `_evaluate` is necessary
7270
7371unboxes the partial instructions from DiffPt and applies them to k,
7472evaluates them at the positions of DiffPt
7573"""
76- function _evaluate (k:: T , x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim, T<: Kernel }
77- return partial (k, Dim; partials_x= x. partial, partials_y= y. partial)(x. pos, y. pos)
74+ function _evaluate (k:: T , x:: DiffPt , y:: DiffPt ) where {T<: Kernel }
75+ return partial (k, partials_x= x. partial, partials_y= y. partial)(x. pos, y. pos)
7876end
7977
8078#=
@@ -101,7 +99,7 @@ for T in [
10199 NormalizedKernel,
102100 KernelTensorProduct
103101 ] # subtypes(Kernel)
104- (k:: T )(x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, x, y)
105- (k:: T )(x:: DiffPt{Dim} , y) where {Dim} = _evaluate (k, x, DiffPt (y))
106- (k:: T )(x, y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, DiffPt (x), y)
102+ (k:: T )(x:: DiffPt , y:: DiffPt ) = _evaluate (k, x, y)
103+ (k:: T )(x:: DiffPt , y) = _evaluate (k, x, DiffPt (y))
104+ (k:: T )(x, y:: DiffPt ) = _evaluate (k, DiffPt (x), y)
107105end
0 commit comments