@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
2828from libc.stdint cimport uint32_t
2929
3030from dpctl._backend cimport ( # noqa: E211, E402;
31+ DPCTLBuildOptionList_Append,
32+ DPCTLBuildOptionList_Create,
33+ DPCTLBuildOptionList_Delete,
34+ DPCTLBuildOptionListRef,
3135 DPCTLCString_Delete,
3236 DPCTLKernel_Copy,
3337 DPCTLKernel_Delete,
@@ -42,13 +46,24 @@ from dpctl._backend cimport ( # noqa: E211, E402;
4246 DPCTLKernelBundle_Copy,
4347 DPCTLKernelBundle_CreateFromOCLSource,
4448 DPCTLKernelBundle_CreateFromSpirv,
49+ DPCTLKernelBundle_CreateFromSYCLSource,
4550 DPCTLKernelBundle_Delete,
4651 DPCTLKernelBundle_GetKernel,
52+ DPCTLKernelBundle_GetSyclKernel,
4753 DPCTLKernelBundle_HasKernel,
54+ DPCTLKernelBundle_HasSyclKernel,
55+ DPCTLKernelNameList_Append,
56+ DPCTLKernelNameList_Create,
57+ DPCTLKernelNameList_Delete,
58+ DPCTLKernelNameListRef,
4859 DPCTLSyclContextRef,
4960 DPCTLSyclDeviceRef,
5061 DPCTLSyclKernelBundleRef,
5162 DPCTLSyclKernelRef,
63+ DPCTLVirtualHeaderList_Append,
64+ DPCTLVirtualHeaderList_Create,
65+ DPCTLVirtualHeaderList_Delete,
66+ DPCTLVirtualHeaderListRef,
5267)
5368
5469__all__ = [
@@ -197,9 +212,10 @@ cdef class SyclProgram:
197212 """
198213
199214 @staticmethod
200- cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
215+ cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef, bint is_sycl_source ):
201216 cdef SyclProgram ret = SyclProgram.__new__ (SyclProgram)
202217 ret._program_ref = KBRef
218+ ret._is_sycl_source = is_sycl_source
203219 return ret
204220
205221 def __dealloc__ (self ):
@@ -210,13 +226,19 @@ cdef class SyclProgram:
210226
211227 cpdef SyclKernel get_sycl_kernel(self , str kernel_name):
212228 name = kernel_name.encode(' utf8' )
229+ if self ._is_sycl_source:
230+ return SyclKernel._create(
231+ DPCTLKernelBundle_GetSyclKernel(self ._program_ref, name),
232+ kernel_name)
213233 return SyclKernel._create(
214234 DPCTLKernelBundle_GetKernel(self ._program_ref, name),
215235 kernel_name
216236 )
217237
218238 def has_sycl_kernel (self , str kernel_name ):
219239 name = kernel_name.encode(' utf8' )
240+ if self ._is_sycl_source:
241+ return DPCTLKernelBundle_HasSyclKernel(self ._program_ref, name)
220242 return DPCTLKernelBundle_HasKernel(self ._program_ref, name)
221243
222244 def addressof_ref (self ):
@@ -272,7 +294,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
272294 if KBref is NULL :
273295 raise SyclProgramCompilationError()
274296
275- return SyclProgram._create(KBref)
297+ return SyclProgram._create(KBref, False )
276298
277299
278300cpdef create_program_from_spirv(SyclQueue q, const unsigned char [:] IL,
@@ -318,7 +340,107 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
318340 if KBref is NULL :
319341 raise SyclProgramCompilationError()
320342
321- return SyclProgram._create(KBref)
343+ return SyclProgram._create(KBref, False )
344+
345+
346+ cpdef create_program_from_sycl_source(SyclQueue q, unicode source, list headers = [], list registered_names = [], list copts = []):
347+ """
348+ Creates an executable SYCL kernel_bundle from SYCL source code.
349+
350+ This uses the DPC++ ``kernel_compiler`` extension to create a
351+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
352+ SYCL source code.
353+
354+ Parameters:
355+ q (:class:`dpctl.SyclQueue`)
356+ The :class:`dpctl.SyclQueue` for which the
357+ :class:`.SyclProgram` is going to be built.
358+ source (unicode)
359+ SYCL source code string.
360+ headers (list)
361+ Optional list of virtual headers, where each entry in the list
362+ needs to be a tuple of header name and header content. See the
363+ documentation of the ``include_files`` property in the DPC++
364+ ``kernel_compiler`` extension for more information.
365+ Default: []
366+ registered_names (list, optional)
367+ Optional list of kernel names to register. See the
368+ documentation of the ``registered_names`` property in the DPC++
369+ ``kernel_compiler`` extension for more information.
370+ Default: []
371+ copts (list)
372+ Optional list of compilation flags that will be used
373+ when compiling the program. Default: ``""``.
374+
375+ Returns:
376+ program (:class:`.SyclProgram`)
377+ A :class:`.SyclProgram` object wrapping the
378+ ``sycl::kernel_bundle<sycl::bundle_state::executable>``
379+ returned by the C API.
380+
381+ Raises:
382+ SyclProgramCompilationError
383+ If a SYCL kernel bundle could not be created.
384+ """
385+ cdef DPCTLSyclKernelBundleRef KBref
386+ cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
387+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
388+ cdef bytes bSrc = source.encode(' utf8' )
389+ cdef const char * Src = < const char * > bSrc
390+ cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
391+ cdef bytes bOpt
392+ cdef const char * sOpt
393+ cdef bytes bName
394+ cdef const char * sName
395+ cdef bytes bContent
396+ cdef const char * sContent
397+ for opt in copts:
398+ if not isinstance (opt, unicode ):
399+ DPCTLBuildOptionList_Delete(BuildOpts)
400+ raise SyclProgramCompilationError()
401+ bOpt = opt.encode(' utf8' )
402+ sOpt = < const char * > bOpt
403+ DPCTLBuildOptionList_Append(BuildOpts, sOpt)
404+
405+ cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
406+ for name in registered_names:
407+ if not isinstance (name, unicode ):
408+ DPCTLBuildOptionList_Delete(BuildOpts)
409+ DPCTLKernelNameList_Delete(KernelNames)
410+ raise SyclProgramCompilationError()
411+ bName = name.encode(' utf8' )
412+ sName = < const char * > bName
413+ DPCTLKernelNameList_Append(KernelNames, sName)
414+
415+
416+ cdef DPCTLVirtualHeaderListRef VirtualHeaders = DPCTLVirtualHeaderList_Create()
417+ for name, content in headers:
418+ if not isinstance (name, unicode ) or not isinstance (content, unicode ):
419+ DPCTLBuildOptionList_Delete(BuildOpts)
420+ DPCTLKernelNameList_Delete(KernelNames)
421+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
422+ raise SyclProgramCompilationError()
423+ bName = name.encode(' utf8' )
424+ sName = < const char * > bName
425+ bContent = content.encode(' utf8' )
426+ sContent = < const char * > bContent
427+ DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
428+
429+ KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
430+ VirtualHeaders, KernelNames,
431+ BuildOpts)
432+
433+ if KBref is NULL :
434+ DPCTLBuildOptionList_Delete(BuildOpts)
435+ DPCTLKernelNameList_Delete(KernelNames)
436+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
437+ raise SyclProgramCompilationError()
438+
439+ DPCTLBuildOptionList_Delete(BuildOpts)
440+ DPCTLKernelNameList_Delete(KernelNames)
441+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
442+
443+ return SyclProgram._create(KBref, True )
322444
323445
324446cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(SyclProgram pro):
@@ -335,4 +457,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
335457 reference.
336458 """
337459 cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
338- return SyclProgram._create(copied_KBRef)
460+ return SyclProgram._create(copied_KBRef, False )
0 commit comments