@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
2828from libc.stdint cimport uint32_t
2929
3030from dpctl._backend cimport ( # noqa: E211, E402;
31+ DPCTLBuildOptionList_Append,
32+ DPCTLBuildOptionList_Create,
33+ DPCTLBuildOptionList_Delete,
34+ DPCTLBuildOptionListRef,
3135 DPCTLKernel_Copy,
3236 DPCTLKernel_Delete,
3337 DPCTLKernel_GetCompileNumSubGroups,
@@ -41,13 +45,24 @@ from dpctl._backend cimport ( # noqa: E211, E402;
4145 DPCTLKernelBundle_Copy,
4246 DPCTLKernelBundle_CreateFromOCLSource,
4347 DPCTLKernelBundle_CreateFromSpirv,
48+ DPCTLKernelBundle_CreateFromSYCLSource,
4449 DPCTLKernelBundle_Delete,
4550 DPCTLKernelBundle_GetKernel,
51+ DPCTLKernelBundle_GetSyclKernel,
4652 DPCTLKernelBundle_HasKernel,
53+ DPCTLKernelBundle_HasSyclKernel,
54+ DPCTLKernelNameList_Append,
55+ DPCTLKernelNameList_Create,
56+ DPCTLKernelNameList_Delete,
57+ DPCTLKernelNameListRef,
4758 DPCTLSyclContextRef,
4859 DPCTLSyclDeviceRef,
4960 DPCTLSyclKernelBundleRef,
5061 DPCTLSyclKernelRef,
62+ DPCTLVirtualHeaderList_Append,
63+ DPCTLVirtualHeaderList_Create,
64+ DPCTLVirtualHeaderList_Delete,
65+ DPCTLVirtualHeaderListRef,
5166)
5267
5368__all__ = [
@@ -196,9 +211,10 @@ cdef class SyclProgram:
196211 """
197212
198213 @staticmethod
199- cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
214+ cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef, bint is_sycl_source ):
200215 cdef SyclProgram ret = SyclProgram.__new__ (SyclProgram)
201216 ret._program_ref = KBRef
217+ ret._is_sycl_source = is_sycl_source
202218 return ret
203219
204220 def __dealloc__ (self ):
@@ -209,13 +225,19 @@ cdef class SyclProgram:
209225
210226 cpdef SyclKernel get_sycl_kernel(self , str kernel_name):
211227 name = kernel_name.encode(" utf8" )
228+ if self ._is_sycl_source:
229+ return SyclKernel._create(
230+ DPCTLKernelBundle_GetSyclKernel(self ._program_ref, name),
231+ kernel_name)
212232 return SyclKernel._create(
213233 DPCTLKernelBundle_GetKernel(self ._program_ref, name),
214234 kernel_name
215235 )
216236
217237 def has_sycl_kernel (self , str kernel_name ):
218238 name = kernel_name.encode(" utf8" )
239+ if self ._is_sycl_source:
240+ return DPCTLKernelBundle_HasSyclKernel(self ._program_ref, name)
219241 return DPCTLKernelBundle_HasKernel(self ._program_ref, name)
220242
221243 def addressof_ref (self ):
@@ -271,7 +293,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
271293 if KBref is NULL :
272294 raise SyclProgramCompilationError()
273295
274- return SyclProgram._create(KBref)
296+ return SyclProgram._create(KBref, False )
275297
276298
277299cpdef create_program_from_spirv(SyclQueue q, const unsigned char [:] IL,
@@ -317,7 +339,107 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
317339 if KBref is NULL :
318340 raise SyclProgramCompilationError()
319341
320- return SyclProgram._create(KBref)
342+ return SyclProgram._create(KBref, False )
343+
344+
345+ cpdef create_program_from_sycl_source(SyclQueue q, unicode source, list headers = [], list registered_names = [], list copts = []):
346+ """
347+ Creates an executable SYCL kernel_bundle from SYCL source code.
348+
349+ This uses the DPC++ ``kernel_compiler`` extension to create a
350+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
351+ SYCL source code.
352+
353+ Parameters:
354+ q (:class:`dpctl.SyclQueue`)
355+ The :class:`dpctl.SyclQueue` for which the
356+ :class:`.SyclProgram` is going to be built.
357+ source (unicode)
358+ SYCL source code string.
359+ headers (list)
360+ Optional list of virtual headers, where each entry in the list
361+ needs to be a tuple of header name and header content. See the
362+ documentation of the ``include_files`` property in the DPC++
363+ ``kernel_compiler`` extension for more information.
364+ Default: []
365+ registered_names (list, optional)
366+ Optional list of kernel names to register. See the
367+ documentation of the ``registered_names`` property in the DPC++
368+ ``kernel_compiler`` extension for more information.
369+ Default: []
370+ copts (list)
371+ Optional list of compilation flags that will be used
372+ when compiling the program. Default: ``""``.
373+
374+ Returns:
375+ program (:class:`.SyclProgram`)
376+ A :class:`.SyclProgram` object wrapping the
377+ ``sycl::kernel_bundle<sycl::bundle_state::executable>``
378+ returned by the C API.
379+
380+ Raises:
381+ SyclProgramCompilationError
382+ If a SYCL kernel bundle could not be created.
383+ """
384+ cdef DPCTLSyclKernelBundleRef KBref
385+ cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
386+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
387+ cdef bytes bSrc = source.encode(' utf8' )
388+ cdef const char * Src = < const char * > bSrc
389+ cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
390+ cdef bytes bOpt
391+ cdef const char * sOpt
392+ cdef bytes bName
393+ cdef const char * sName
394+ cdef bytes bContent
395+ cdef const char * sContent
396+ for opt in copts:
397+ if not isinstance (opt, unicode ):
398+ DPCTLBuildOptionList_Delete(BuildOpts)
399+ raise SyclProgramCompilationError()
400+ bOpt = opt.encode(' utf8' )
401+ sOpt = < const char * > bOpt
402+ DPCTLBuildOptionList_Append(BuildOpts, sOpt)
403+
404+ cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
405+ for name in registered_names:
406+ if not isinstance (name, unicode ):
407+ DPCTLBuildOptionList_Delete(BuildOpts)
408+ DPCTLKernelNameList_Delete(KernelNames)
409+ raise SyclProgramCompilationError()
410+ bName = name.encode(' utf8' )
411+ sName = < const char * > bName
412+ DPCTLKernelNameList_Append(KernelNames, sName)
413+
414+
415+ cdef DPCTLVirtualHeaderListRef VirtualHeaders = DPCTLVirtualHeaderList_Create()
416+ for name, content in headers:
417+ if not isinstance (name, unicode ) or not isinstance (content, unicode ):
418+ DPCTLBuildOptionList_Delete(BuildOpts)
419+ DPCTLKernelNameList_Delete(KernelNames)
420+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
421+ raise SyclProgramCompilationError()
422+ bName = name.encode(' utf8' )
423+ sName = < const char * > bName
424+ bContent = content.encode(' utf8' )
425+ sContent = < const char * > bContent
426+ DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
427+
428+ KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
429+ VirtualHeaders, KernelNames,
430+ BuildOpts)
431+
432+ if KBref is NULL :
433+ DPCTLBuildOptionList_Delete(BuildOpts)
434+ DPCTLKernelNameList_Delete(KernelNames)
435+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
436+ raise SyclProgramCompilationError()
437+
438+ DPCTLBuildOptionList_Delete(BuildOpts)
439+ DPCTLKernelNameList_Delete(KernelNames)
440+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
441+
442+ return SyclProgram._create(KBref, True )
321443
322444
323445cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(
@@ -336,4 +458,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
336458 reference.
337459 """
338460 cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
339- return SyclProgram._create(copied_KBRef)
461+ return SyclProgram._create(copied_KBRef, False )
0 commit comments