@@ -2303,14 +2303,22 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303 }
23042304}
23052305
2306- // Sets arguments for a given kernel and device based on the argument type.
2307- // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2308- // extension.
2309- static void SetArgBasedOnType (
2310- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2306+ // Gets UR argument struct for a given kernel and device based on the argument
2307+ // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2308+ // the graphs extension (LaunchWithArgs for graphs is planned future work).
2309+ static void GetUrArgsBasedOnType (
23112310 device_image_impl *DeviceImageImpl,
23122311 const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2313- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2312+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2313+ std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2314+ // UrArg.size == 0 indicates uninitialized structure
2315+ ur_exp_kernel_arg_properties_t UrArg = {
2316+ UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2317+ nullptr ,
2318+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2319+ static_cast <uint32_t >(NextTrueIndex),
2320+ 0 ,
2321+ {}};
23142322 switch (Arg.MType ) {
23152323 case kernel_param_kind_t ::kind_dynamic_work_group_memory:
23162324 break ;
@@ -2330,52 +2338,56 @@ static void SetArgBasedOnType(
23302338 getMemAllocationFunc
23312339 ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
23322340 : nullptr ;
2333- ur_kernel_arg_mem_obj_properties_t MemObjData {};
2334- MemObjData. stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES ;
2335- MemObjData. memoryAccess = AccessModeToUr (Req-> MAccessMode ) ;
2336- Adapter. call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2337- &MemObjData, MemArg) ;
2341+ ur_exp_kernel_arg_value_t Value = {};
2342+ Value. memObjTuple = {MemArg, AccessModeToUr (Req-> MAccessMode )} ;
2343+ UrArg. type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ ;
2344+ UrArg. size = sizeof (MemArg);
2345+ UrArg. value = Value ;
23382346 break ;
23392347 }
23402348 case kernel_param_kind_t ::kind_std_layout: {
2349+ ur_exp_kernel_arg_type_t Type;
23412350 if (Arg.MPtr ) {
2342- Adapter.call <UrApiKind::urKernelSetArgValue>(
2343- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2351+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23442352 } else {
2345- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2346- Arg.MSize , nullptr );
2353+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23472354 }
2348-
2355+ ur_exp_kernel_arg_value_t Value = {};
2356+ Value.value = {Arg.MPtr };
2357+ UrArg.type = Type;
2358+ UrArg.size = static_cast <size_t >(Arg.MSize );
2359+ UrArg.value = Value;
23492360 break ;
23502361 }
23512362 case kernel_param_kind_t ::kind_sampler: {
23522363 sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2353- ur_sampler_handle_t Sampler =
2354- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2355- ->getOrCreateSampler (ContextImpl);
2356- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2357- nullptr , Sampler);
2364+ ur_exp_kernel_arg_value_t Value = {};
2365+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2366+ ->getOrCreateSampler (ContextImpl);
2367+ UrArg.type = UR_EXP_KERNEL_ARG_TYPE_SAMPLER;
2368+ UrArg.size = sizeof (ur_sampler_handle_t );
2369+ UrArg.value = Value;
23582370 break ;
23592371 }
23602372 case kernel_param_kind_t ::kind_pointer: {
2361- // We need to de-rerence this to get the actual USM allocation - that's the
2373+ ur_exp_kernel_arg_value_t Value = {};
2374+ // We need to de-rerence to get the actual USM allocation - that's the
23622375 // pointer UR is expecting.
2363- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2364- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2365- nullptr , Ptr);
2376+ Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2377+ UrArg.type = UR_EXP_KERNEL_ARG_TYPE_POINTER;
2378+ UrArg.size = sizeof (Arg.MPtr );
2379+ UrArg.value = Value;
23662380 break ;
23672381 }
23682382 case kernel_param_kind_t ::kind_specialization_constants_buffer: {
23692383 assert (DeviceImageImpl != nullptr );
23702384 ur_mem_handle_t SpecConstsBuffer =
23712385 DeviceImageImpl->get_spec_const_buffer_ref ();
2372-
2373- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2374- MemObjProps.pNext = nullptr ;
2375- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2376- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2377- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2378- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2386+ ur_exp_kernel_arg_value_t Value = {};
2387+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2388+ UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ;
2389+ UrArg.size = sizeof (SpecConstsBuffer);
2390+ UrArg.value = Value;
23792391 break ;
23802392 }
23812393 case kernel_param_kind_t ::kind_invalid:
@@ -2384,6 +2396,10 @@ static void SetArgBasedOnType(
23842396 codeToString (UR_RESULT_ERROR_INVALID_VALUE));
23852397 break ;
23862398 }
2399+
2400+ if (UrArg.size ) {
2401+ UrArgs.push_back (UrArg);
2402+ }
23872403}
23882404
23892405static ur_result_t SetKernelParamsAndLaunch (
@@ -2404,22 +2420,33 @@ static ur_result_t SetKernelParamsAndLaunch(
24042420 DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
24052421 }
24062422
2423+ // just a performance optimization - avoid heap allocations
2424+ static thread_local std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2425+ UrArgs.clear ();
2426+ UrArgs.reserve (Args.size ());
2427+
24072428 if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures ) {
2408- auto setFunc = [&Adapter, Kernel,
2409- KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2429+ auto setFunc = [KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24102430 size_t NextTrueIndex) {
24112431 const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
24122432 switch (ParamDesc.kind ) {
24132433 case kernel_param_kind_t ::kind_std_layout: {
24142434 int Size = ParamDesc.info ;
2415- Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2416- Size, nullptr , ArgPtr);
2435+ ur_exp_kernel_arg_value_t Value = {};
2436+ Value.value = ArgPtr;
2437+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2438+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2439+ static_cast <uint32_t >(NextTrueIndex),
2440+ static_cast <size_t >(Size), Value});
24172441 break ;
24182442 }
24192443 case kernel_param_kind_t ::kind_pointer: {
2420- const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2421- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2422- nullptr , Ptr);
2444+ ur_exp_kernel_arg_value_t Value = {};
2445+ Value.pointer = *static_cast <const void *const *>(ArgPtr);
2446+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2447+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2448+ static_cast <uint32_t >(NextTrueIndex),
2449+ sizeof (Value.pointer ), Value});
24232450 break ;
24242451 }
24252452 default :
@@ -2429,23 +2456,28 @@ static ur_result_t SetKernelParamsAndLaunch(
24292456 applyFuncOnFilteredArgs (EliminatedArgMask, DeviceKernelInfo.NumParams ,
24302457 DeviceKernelInfo.ParamDescGetter , setFunc);
24312458 } else {
2432- auto setFunc = [&Adapter, Kernel, & DeviceImageImpl, &getMemAllocationFunc,
2459+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc,
24332460 &Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2434- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2435- Queue.getContextImpl (), Arg, NextTrueIndex);
2461+ GetUrArgsBasedOnType ( DeviceImageImpl, getMemAllocationFunc,
2462+ Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs );
24362463 };
24372464 applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
24382465 }
24392466
2440- const std::optional<int > &ImplicitLocalArg =
2441- DeviceKernelInfo.getImplicitLocalArgPos ();
2467+ std::optional<int > ImplicitLocalArg =
2468+ ProgramManager::getInstance ().kernelImplicitLocalArgPos (
2469+ DeviceKernelInfo.Name );
24422470 // Set the implicit local memory buffer to support
24432471 // get_work_group_scratch_memory. This is for backend not supporting
24442472 // CUDA-style local memory setting. Note that we may have -1 as a position,
24452473 // this indicates the buffer is actually unused and was elided.
24462474 if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2447- Adapter.call <UrApiKind::urKernelSetArgLocal>(
2448- Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2475+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2476+ nullptr ,
2477+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2478+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2479+ WorkGroupMemorySize,
2480+ {nullptr }});
24492481 }
24502482
24512483 adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2468,16 +2500,14 @@ static ur_result_t SetKernelParamsAndLaunch(
24682500 /* pPropSizeRet = */ nullptr );
24692501
24702502 const bool EnforcedLocalSize =
2471- (RequiredWGSize[0 ] != 0 &&
2472- (NDRDesc.Dims < 2 || RequiredWGSize[1 ] != 0 ) &&
2473- (NDRDesc.Dims < 3 || RequiredWGSize[2 ] != 0 ));
2503+ (RequiredWGSize[0 ] != 0 || RequiredWGSize[1 ] != 0 ||
2504+ RequiredWGSize[2 ] != 0 );
24742505 if (EnforcedLocalSize)
24752506 LocalSize = RequiredWGSize;
24762507 }
2477-
2478- const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 &&
2479- (NDRDesc.Dims < 2 || NDRDesc.GlobalOffset [1 ] != 0 ) &&
2480- (NDRDesc.Dims < 3 || NDRDesc.GlobalOffset [2 ] != 0 );
2508+ const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 ||
2509+ NDRDesc.GlobalOffset [1 ] != 0 ||
2510+ NDRDesc.GlobalOffset [2 ] != 0 ;
24812511
24822512 std::vector<ur_kernel_launch_property_t > property_list;
24832513
@@ -2505,20 +2535,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25052535 {{WorkGroupMemorySize}}});
25062536 }
25072537 ur_event_handle_t UREvent = nullptr ;
2508- ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2509- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2510- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2511- LocalSize, property_list.size (),
2512- property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2513- RawEvents.empty () ? nullptr : &RawEvents[0 ],
2514- OutEventImpl ? &UREvent : nullptr );
2538+ ur_result_t Error =
2539+ Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2540+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2541+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2542+ &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2543+ property_list.size (),
2544+ property_list.empty () ? nullptr : property_list.data (),
2545+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2546+ OutEventImpl ? &UREvent : nullptr );
25152547 if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25162548 OutEventImpl->setHandle (UREvent);
25172549 }
25182550
25192551 return Error;
25202552}
25212553
2554+ // Sets arguments for a given kernel and device based on the argument type.
2555+ // This is a legacy path which the graphs extension still uses.
2556+ static void SetArgBasedOnType (
2557+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2558+ device_image_impl *DeviceImageImpl,
2559+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2560+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2561+ switch (Arg.MType ) {
2562+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2563+ break ;
2564+ case kernel_param_kind_t ::kind_work_group_memory:
2565+ break ;
2566+ case kernel_param_kind_t ::kind_stream:
2567+ break ;
2568+ case kernel_param_kind_t ::kind_dynamic_accessor:
2569+ case kernel_param_kind_t ::kind_accessor: {
2570+ Requirement *Req = (Requirement *)(Arg.MPtr );
2571+
2572+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2573+ // we may pass default constructed accessors to a command, which don't add
2574+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2575+ // valid case, so we need to properly handle it.
2576+ ur_mem_handle_t MemArg =
2577+ getMemAllocationFunc
2578+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2579+ : nullptr ;
2580+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2581+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2582+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2583+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2584+ &MemObjData, MemArg);
2585+ break ;
2586+ }
2587+ case kernel_param_kind_t ::kind_std_layout: {
2588+ if (Arg.MPtr ) {
2589+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2590+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2591+ } else {
2592+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2593+ Arg.MSize , nullptr );
2594+ }
2595+
2596+ break ;
2597+ }
2598+ case kernel_param_kind_t ::kind_sampler: {
2599+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2600+ ur_sampler_handle_t Sampler =
2601+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2602+ ->getOrCreateSampler (ContextImpl);
2603+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2604+ nullptr , Sampler);
2605+ break ;
2606+ }
2607+ case kernel_param_kind_t ::kind_pointer: {
2608+ // We need to de-rerence this to get the actual USM allocation - that's the
2609+ // pointer UR is expecting.
2610+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2611+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2612+ nullptr , Ptr);
2613+ break ;
2614+ }
2615+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2616+ assert (DeviceImageImpl != nullptr );
2617+ ur_mem_handle_t SpecConstsBuffer =
2618+ DeviceImageImpl->get_spec_const_buffer_ref ();
2619+
2620+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2621+ MemObjProps.pNext = nullptr ;
2622+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2623+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2624+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2625+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2626+ break ;
2627+ }
2628+ case kernel_param_kind_t ::kind_invalid:
2629+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2630+ " Invalid kernel param kind " +
2631+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2632+ break ;
2633+ }
2634+ }
2635+
25222636static std::tuple<ur_kernel_handle_t , device_image_impl *,
25232637 const KernelArgMask *>
25242638getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments