@@ -17,13 +17,15 @@ using StaticArrays
1717using Adapt
1818
1919"""
20- @kernel function f(args) end
20+ @kernel [N] function f(args) end
2121
2222Takes a function definition and generates a [`Kernel`](@ref) constructor from it.
2323The enclosed function is allowed to contain kernel language constructs.
2424In order to call it the kernel has first to be specialized on the backend
2525and then invoked on the arguments.
2626
27+ The optional `N` parameter can be used to fix the number of dimensions used for the ndrange.
28+
2729# Kernel language
2830
2931- [`@Const`](@ref)
@@ -55,7 +57,7 @@ macro kernel(expr)
5557end
5658
5759"""
58- @kernel config function f(args) end
60+ @kernel [N] config function f(args) end
5961
6062This allows for two different configurations:
6163
@@ -585,17 +587,17 @@ in a workgroup.
585587 ```
586588 As well as the on-device functionality.
587589"""
588- struct Kernel{Backend, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
590+ struct Kernel{Backend, N, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
589591 backend:: Backend
590592 f:: Fun
591593end
592594
593- function Base. similar (kernel:: Kernel{D, WS, ND} , f:: F ) where {D, WS, ND, F}
594- Kernel {D, WS, ND, F} (kernel. backend, f)
595+ function Base. similar (kernel:: Kernel{D, N, WS, ND} , f:: F ) where {D, N , WS, ND, F}
596+ Kernel {D, N, WS, ND, F} (kernel. backend, f)
595597end
596598
597- workgroupsize (:: Kernel{D, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
598- ndrange (:: Kernel{D, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
599+ workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
600+ ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
599601backend (kernel:: Kernel ) = kernel. backend
600602
601603"""
@@ -658,8 +660,8 @@ Partition a kernel for the given ndrange and workgroupsize.
658660 return iterspace, dynamic
659661end
660662
661- function construct (backend:: Backend , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , S <: _Size , NDRange <: _Size , XPUName}
662- return Kernel {Backend, S, NDRange, XPUName} (backend, xpu_name)
663+ function construct (backend:: Backend , :: Val{N} , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N , S <: _Size , NDRange <: _Size , XPUName}
664+ return Kernel {Backend, N, S, NDRange, XPUName} (backend, xpu_name)
663665end
664666
665667# ##
0 commit comments