Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ ur_command_list_manager::getSignalEvent(ur_event_handle_t hUserEvent,
ur_result_t ur_command_list_manager::appendKernelLaunchLocked(
ur_kernel_handle_t hKernel, ze_kernel_handle_t hZeKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t phEvent,
bool cooperative, std::vector<void *> *pKMemObj, void *pNext) {
const size_t *pLocalWorkSize, wait_list_view &waitListView,
ur_event_handle_t phEvent, bool cooperative, std::vector<void *> *pKMemObj,
void *pNext) {

ze_group_count_t zeThreadGroupDimensions{1, 1, 1};
uint32_t WG[3]{};
Expand All @@ -165,11 +165,10 @@ ur_result_t ur_command_list_manager::appendKernelLaunchLocked(
pGlobalWorkSize, pLocalWorkSize));

auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_KERNEL_LAUNCH);
auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);

UR_CALL(hKernel->prepareForSubmission(
hContext.get(), hDevice.get(), pGlobalWorkOffset, workDim, WG[0], WG[1],
WG[2], getZeCommandList(), waitListView, pKMemObj));
WG[2], getZeCommandList(), waitListView));

if (pKMemObj) {
// zeCommandListAppendLaunchKernelWithArguments
Expand Down Expand Up @@ -231,11 +230,13 @@ ur_result_t ur_command_list_manager::appendKernelLaunchUnlocked(

std::scoped_lock<ur_shared_mutex> Lock(hKernel->Mutex);

wait_list_view waitListView =
getWaitListView(phEventWaitList, numEventsInWaitList);

// last arguments: pKMemObj == nullptr and pNext == nullptr
return appendKernelLaunchLocked(hKernel, hZeKernel, workDim,
pGlobalWorkOffset, pGlobalWorkSize,
pLocalWorkSize, numEventsInWaitList,
phEventWaitList, phEvent, cooperative);
return appendKernelLaunchLocked(
hKernel, hZeKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
pLocalWorkSize, waitListView, phEvent, cooperative);
}

ur_result_t ur_command_list_manager::appendKernelLaunch(
Expand Down Expand Up @@ -1164,6 +1165,9 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpNew(
hKernel->kernelMemObj.resize(numArgs, 0);
hKernel->kernelArgs.resize(numArgs, 0);

wait_list_view waitListView =
getWaitListView(phEventWaitList, numEventsInWaitList);

for (uint32_t argIndex = 0; argIndex < numArgs; argIndex++) {
switch (pArgs[argIndex].type) {
case UR_EXP_KERNEL_ARG_TYPE_LOCAL:
Expand All @@ -1176,12 +1180,13 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpNew(
hKernel->kernelArgs[argIndex] = (void *)&pArgs[argIndex].value.pointer;
break;
case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ:
// prepareForSubmission() will save zePtr in kernelMemObj[argIndex]
// compute zePtr for the given memory handle and store it in
// hKernel->kernelMemObj[argIndex]
UR_CALL(hKernel->computeZePtr(
pArgs[argIndex].value.memObjTuple.hMem, hDevice.get(),
ur_mem_buffer_t::device_access_mode_t::read_write, getZeCommandList(),
waitListView, &hKernel->kernelMemObj[argIndex]));
hKernel->kernelArgs[argIndex] = &hKernel->kernelMemObj[argIndex];
UR_CALL(hKernel->addPendingMemoryAllocation(
{pArgs[argIndex].value.memObjTuple.hMem,
ur_mem_buffer_t::device_access_mode_t::read_write,
pArgs[argIndex].index}));
break;
case UR_EXP_KERNEL_ARG_TYPE_SAMPLER:
hKernel->kernelArgs[argIndex] = &pArgs[argIndex].value.sampler->ZeSampler;
Expand All @@ -1193,8 +1198,8 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpNew(

return appendKernelLaunchLocked(
hKernel, hZeKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent,
cooperativeKernelLaunchRequested, &hKernel->kernelMemObj, pNext);
pLocalWorkSize, waitListView, phEvent, cooperativeKernelLaunchRequested,
&hKernel->kernelMemObj, pNext);
}

ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,7 @@ struct ur_command_list_manager {
ur_kernel_handle_t hKernel, ze_kernel_handle_t hZeKernel,
uint32_t workDim, const size_t *pGlobalWorkOffset,
const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t phEvent, bool cooperative,
wait_list_view &waitListView, ur_event_handle_t phEvent, bool cooperative,
std::vector<void *> *pKMemObj = nullptr, void *pNext = nullptr);

ur_result_t appendKernelLaunchUnlocked(
Expand Down
59 changes: 30 additions & 29 deletions unified-runtime/source/adapters/level_zero/v2/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,36 @@ ur_result_t ur_kernel_handle_t_::setExecInfo(ur_kernel_exec_info_t propName,
return UR_RESULT_SUCCESS;
}

// Compute a zePtr pointer for the given memory handle and store it in *pZePtr
ur_result_t ur_kernel_handle_t_::computeZePtr(
ur_mem_handle_t hMem, ur_device_handle_t hDevice,
ur_mem_buffer_t::device_access_mode_t accessMode,
ze_command_list_handle_t zeCommandList, wait_list_view &waitListView,
void **pZePtr) {
UR_ASSERT(pZePtr, UR_RESULT_ERROR_INVALID_NULL_POINTER);

void *zePtr = nullptr;
if (hMem) {
if (!hMem->isImage()) {
auto hBuffer = hMem->getBuffer();
zePtr = hBuffer->getDevicePtr(hDevice, accessMode, 0, hBuffer->getSize(),
zeCommandList, waitListView);
} else {
auto hImage = static_cast<ur_mem_image_t *>(hMem->getImage());
zePtr = reinterpret_cast<void *>(hImage->getZeImage());
}
}

*pZePtr = zePtr;
return UR_RESULT_SUCCESS;
}

// Perform any required allocations and set the kernel arguments.
ur_result_t ur_kernel_handle_t_::prepareForSubmission(
ur_context_handle_t hContext, ur_device_handle_t hDevice,
const size_t *pGlobalWorkOffset, uint32_t workDim, uint32_t groupSizeX,
uint32_t groupSizeY, uint32_t groupSizeZ,
ze_command_list_handle_t commandList, wait_list_view &waitListView,
std::vector<void *> *kMemObj) {
ze_command_list_handle_t commandList, wait_list_view &waitListView) {
auto &deviceKernelOpt = deviceKernels[deviceIndex(hDevice)];
if (!deviceKernelOpt.has_value())
return UR_RESULT_ERROR_INVALID_KERNEL;
Expand All @@ -288,34 +311,12 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(

for (auto &pending : pending_allocations) {
void *zePtr = nullptr;
if (pending.hMem) {
if (!pending.hMem->isImage()) {
auto hBuffer = pending.hMem->getBuffer();
zePtr =
hBuffer->getDevicePtr(hDevice, pending.mode, 0, hBuffer->getSize(),
commandList, waitListView);
} else {
auto hImage = static_cast<ur_mem_image_t *>(pending.hMem->getImage());
zePtr = reinterpret_cast<void *>(hImage->getZeImage());
}
}
// Compute a zePtr pointer for the given memory handle and store it in zePtr
UR_CALL(computeZePtr(pending.hMem, hDevice, pending.mode, commandList,
waitListView, &zePtr));

// kMemObj must be non-null in the path of
// zeCommandListAppendLaunchKernelWithArguments()
if (kMemObj) {
// zeCommandListAppendLaunchKernelWithArguments()
// (==CommandListCoreFamily<gfxCoreFamily>::appendLaunchKernelWithArguments())
// calls setArgumentValue(i, argSize, argValue) for all arguments on its
// own so do not call it here, but save the zePtr pointer in kMemObj
// for this future call.
if (pending.argIndex > kMemObj->size() - 1) {
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
}
(*kMemObj)[pending.argIndex] = zePtr;
} else {
// Set the argument only on this device's kernel.
UR_CALL(deviceKernel.setArgPointer(pending.argIndex, zePtr));
}
// Set the argument only on this device's kernel.
UR_CALL(deviceKernel.setArgPointer(pending.argIndex, zePtr));
}
pending_allocations.clear();

Expand Down
9 changes: 7 additions & 2 deletions unified-runtime/source/adapters/level_zero/v2/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ struct ur_kernel_handle_t_ : ur_object {
ur_result_t addPendingPointerArgument(uint32_t argIndex,
const void *pArgValue);

// Compute a zePtr pointer for the given memory handle and store it in *pZePtr
ur_result_t computeZePtr(ur_mem_handle_t hMem, ur_device_handle_t hDevice,
ur_mem_buffer_t::device_access_mode_t accessMode,
ze_command_list_handle_t zeCommandList,
wait_list_view &waitListView, void **pZePtr);

// Set all required values for the kernel before submission (including pending
// memory allocations).
// The kMemObj argument must be a non-empty vector
Expand All @@ -104,8 +110,7 @@ struct ur_kernel_handle_t_ : ur_object {
uint32_t workDim, uint32_t groupSizeX,
uint32_t groupSizeY, uint32_t groupSizeZ,
ze_command_list_handle_t cmdList,
wait_list_view &waitListView,
std::vector<void *> *kMemObj = nullptr);
wait_list_view &waitListView);

// Get context of the kernel.
ur_context_handle_t getContext() const { return hProgram->Context; }
Expand Down
Loading