@@ -169,6 +169,7 @@ void MetricsLibrary::release() {
169169 api = {};
170170 callbacks = {};
171171 context = {};
172+ isWorkloadPartitionEnabled = false ;
172173 initializationState = ZE_RESULT_ERROR_UNINITIALIZED;
173174}
174175
@@ -193,13 +194,19 @@ bool MetricsLibrary::load() {
193194 return true ;
194195}
195196
197+ void MetricsLibrary::enableWorkloadPartition () {
198+ isWorkloadPartitionEnabled = true ;
199+ }
200+
196201void MetricsLibrary::getSubDeviceClientOptions (
197- NEO::Device &neoDevice,
198202 ClientOptionsData_1_0 &subDevice,
199203 ClientOptionsData_1_0 &subDeviceIndex,
200- ClientOptionsData_1_0 &subDeviceCount) {
204+ ClientOptionsData_1_0 &subDeviceCount,
205+ ClientOptionsData_1_0 &workloadPartition) {
206+
207+ const auto &deviceImp = *static_cast <DeviceImp *>(&metricContext.getDevice ());
201208
202- if (!neoDevice. isSubDevice () ) {
209+ if (!deviceImp. isSubdevice ) {
203210
204211 // Root device.
205212 subDevice.Type = ClientOptionsType::SubDevice;
@@ -209,7 +216,10 @@ void MetricsLibrary::getSubDeviceClientOptions(
209216 subDeviceIndex.SubDeviceIndex .Index = 0 ;
210217
211218 subDeviceCount.Type = ClientOptionsType::SubDeviceCount;
212- subDeviceCount.SubDeviceCount .Count = std::max (neoDevice.getNumSubDevices (), 1u );
219+ subDeviceCount.SubDeviceCount .Count = std::max (deviceImp.neoDevice ->getRootDevice ()->getNumSubDevices (), 1u );
220+
221+ workloadPartition.Type = ClientOptionsType::WorkloadPartition;
222+ workloadPartition.WorkloadPartition .Enabled = false ;
213223
214224 } else {
215225
@@ -218,10 +228,13 @@ void MetricsLibrary::getSubDeviceClientOptions(
218228 subDevice.SubDevice .Enabled = true ;
219229
220230 subDeviceIndex.Type = ClientOptionsType::SubDeviceIndex;
221- subDeviceIndex.SubDeviceIndex .Index = static_cast <NEO::SubDevice *>(& neoDevice)->getSubDeviceIndex ();
231+ subDeviceIndex.SubDeviceIndex .Index = static_cast <NEO::SubDevice *>(deviceImp. neoDevice )->getSubDeviceIndex ();
222232
223233 subDeviceCount.Type = ClientOptionsType::SubDeviceCount;
224- subDeviceCount.SubDeviceCount .Count = std::max (neoDevice.getRootDevice ()->getNumSubDevices (), 1u );
234+ subDeviceCount.SubDeviceCount .Count = std::max (deviceImp.neoDevice ->getRootDevice ()->getNumSubDevices (), 1u );
235+
236+ workloadPartition.Type = ClientOptionsType::WorkloadPartition;
237+ workloadPartition.WorkloadPartition .Enabled = isWorkloadPartitionEnabled;
225238 }
226239}
227240
@@ -230,7 +243,7 @@ bool MetricsLibrary::createContext() {
230243 const auto &hwHelper = device.getHwHelper ();
231244 const auto &asyncComputeEngines = hwHelper.getGpgpuEngineInstances (device.getHwInfo ());
232245 ContextCreateData_1_0 createData = {};
233- ClientOptionsData_1_0 clientOptions[5 ] = {};
246+ ClientOptionsData_1_0 clientOptions[6 ] = {};
234247 ClientData_1_0 clientData = {};
235248 ClientType_1_0 clientType = {};
236249 ClientDataLinuxAdapter_1_0 adapter = {};
@@ -259,7 +272,7 @@ bool MetricsLibrary::createContext() {
259272 clientOptions[1 ].Tbs .Enabled = metricContext.getMetricStreamer () != nullptr ;
260273
261274 // Sub device client options #2
262- getSubDeviceClientOptions (*device. getNEODevice (), clientOptions[2 ], clientOptions[3 ], clientOptions[4 ]);
275+ getSubDeviceClientOptions (clientOptions[2 ], clientOptions[3 ], clientOptions[4 ], clientOptions[ 5 ]);
263276
264277 clientData.Linux .Adapter = &adapter;
265278 clientData.ClientOptions = clientOptions;
@@ -422,7 +435,7 @@ ze_result_t metricQueryPoolCreate(zet_context_handle_t hContext, zet_device_hand
422435 const auto &deviceImp = *static_cast <DeviceImp *>(device);
423436 auto metricPoolImp = new MetricQueryPoolImp (device->getMetricContext (), hMetricGroup, *pDesc);
424437
425- if (!deviceImp. isSubdevice && deviceImp .isMultiDeviceCapable ()) {
438+ if (metricContext .isMultiDeviceCapable ()) {
426439
427440 auto emptyMetricGroups = std::vector<zet_metric_group_handle_t >();
428441 auto &metricGroups = hMetricGroup
@@ -436,12 +449,15 @@ ze_result_t metricQueryPoolCreate(zet_context_handle_t hContext, zet_device_hand
436449 for (size_t i = 0 ; i < deviceImp.numSubDevices ; ++i) {
437450
438451 auto &subDevice = deviceImp.subDevices [i];
452+ auto &subDeviceMetricContext = subDevice->getMetricContext ();
453+
454+ subDeviceMetricContext.getMetricsLibrary ().enableWorkloadPartition ();
439455
440456 zet_metric_group_handle_t metricGroupHandle = useMetricGroupSubDevice
441- ? metricGroups[subDevice-> getMetricContext () .getSubDeviceIndex ()]
457+ ? metricGroups[subDeviceMetricContext .getSubDeviceIndex ()]
442458 : hMetricGroup;
443459
444- auto metricPoolSubdeviceImp = new MetricQueryPoolImp (subDevice-> getMetricContext () , metricGroupHandle, *pDesc);
460+ auto metricPoolSubdeviceImp = new MetricQueryPoolImp (subDeviceMetricContext , metricGroupHandle, *pDesc);
445461
446462 // Create metric query pool.
447463 if (!metricPoolSubdeviceImp->create ()) {
@@ -534,7 +550,7 @@ bool MetricQueryPoolImp::allocateGpuMemory() {
534550 if (description.type == ZET_METRIC_QUERY_POOL_TYPE_PERFORMANCE) {
535551 // Get allocation size.
536552 const auto &deviceImp = *static_cast <DeviceImp *>(&metricContext.getDevice ());
537- allocationSize = (!deviceImp. isSubdevice && deviceImp .isMultiDeviceCapable ())
553+ allocationSize = (metricContext .isMultiDeviceCapable ())
538554 ? deviceImp.subDevices [0 ]->getMetricContext ().getMetricsLibrary ().getQueryReportGpuSize () * description.count * deviceImp.numSubDevices
539555 : metricsLibrary.getQueryReportGpuSize () * description.count ;
540556
@@ -867,7 +883,7 @@ ze_result_t MetricQuery::appendMemoryBarrier(CommandList &commandList) {
867883
868884 DeviceImp *pDeviceImp = static_cast <DeviceImp *>(commandList.device );
869885
870- if (! pDeviceImp->isSubdevice && pDeviceImp ->isMultiDeviceCapable ()) {
886+ if (pDeviceImp->metricContext ->isMultiDeviceCapable ()) {
871887 // Use one of the sub-device contexts to append to command list.
872888 pDeviceImp = static_cast <DeviceImp *>(pDeviceImp->subDevices [0 ]);
873889 }
@@ -893,9 +909,10 @@ ze_result_t MetricQuery::appendStreamerMarker(CommandList &commandList,
893909
894910 DeviceImp *pDeviceImp = static_cast <DeviceImp *>(commandList.device );
895911
896- if (! pDeviceImp->isSubdevice && pDeviceImp ->isMultiDeviceCapable ()) {
912+ if (pDeviceImp->metricContext ->isMultiDeviceCapable ()) {
897913 // Use one of the sub-device contexts to append to command list.
898914 pDeviceImp = static_cast <DeviceImp *>(pDeviceImp->subDevices [0 ]);
915+ pDeviceImp->metricContext ->getMetricsLibrary ().enableWorkloadPartition ();
899916 }
900917 auto &metricContext = pDeviceImp->getMetricContext ();
901918 auto &metricsLibrary = metricContext.getMetricsLibrary ();
0 commit comments