@@ -74,7 +74,7 @@ def get_estimator(library_name: str, estimator_name: str):
7474def get_estimator_methods (bench_case : BenchCase ) -> Dict [str , List [str ]]:
7575 # default estimator methods
7676 estimator_methods = {
77- "training" : ["fit" ],
77+ "training" : ["partial_fit" , " fit" ],
7878 "inference" : ["predict" , "predict_proba" , "transform" ],
7979 }
8080 for stage in estimator_methods .keys ():
@@ -334,7 +334,9 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
334334 return acceleration_lines > 0 and fallback_lines == 0
335335
336336
337- def create_online_function (method_instance , data_args , batch_size ):
337+ def create_online_function (
338+ estimator_instance , method_instance , data_args , num_batches , batch_size
339+ ):
338340 n_batches = data_args [0 ].shape [0 ] // batch_size
339341
340342 if "y" in list (inspect .signature (method_instance ).parameters ):
@@ -345,23 +347,27 @@ def ndarray_function(x, y):
345347 x [i * batch_size : (i + 1 ) * batch_size ],
346348 y [i * batch_size : (i + 1 ) * batch_size ],
347349 )
350+ estimator_instance ._onedal_finalize_fit ()
348351
349352 def dataframe_function (x , y ):
350353 for i in range (n_batches ):
351354 method_instance (
352355 x .iloc [i * batch_size : (i + 1 ) * batch_size ],
353356 y .iloc [i * batch_size : (i + 1 ) * batch_size ],
354357 )
358+ estimator_instance ._onedal_finalize_fit ()
355359
356360 else :
357361
358362 def ndarray_function (x ):
359363 for i in range (n_batches ):
360364 method_instance (x [i * batch_size : (i + 1 ) * batch_size ])
365+ estimator_instance ._onedal_finalize_fit ()
361366
362367 def dataframe_function (x ):
363368 for i in range (n_batches ):
364369 method_instance (x .iloc [i * batch_size : (i + 1 ) * batch_size ])
370+ estimator_instance ._onedal_finalize_fit ()
365371
366372 if "ndarray" in str (type (data_args [0 ])):
367373 return ndarray_function
@@ -414,12 +420,28 @@ def measure_sklearn_estimator(
414420 data_args = (x_train ,)
415421 else :
416422 data_args = (x_test ,)
417- batch_size = get_bench_case_value (
418- bench_case , f"algorithm:batch_size:{ stage } "
419- )
420- if batch_size is not None :
423+
424+ if method == "partial_fit" :
425+ num_batches = get_bench_case_value (bench_case , "data:num_batches" )
426+ batch_size = get_bench_case_value (bench_case , "data:batch_size" )
427+
428+ if batch_size is None :
429+ if num_batches is None :
430+ num_batches = 5
431+ batch_size = (
432+ data_args [0 ].shape [0 ] + num_batches - 1
433+ ) // num_batches
434+ if num_batches is None :
435+ num_batches = (
436+ data_args [0 ].shape [0 ] + batch_size - 1
437+ ) // batch_size
438+
421439 method_instance = create_online_function (
422- method_instance , data_args , batch_size
440+ estimator_instance ,
441+ method_instance ,
442+ data_args ,
443+ num_batches ,
444+ batch_size ,
423445 )
424446 # daal4py model builders enabling branch
425447 if enable_modelbuilders and stage == "inference" :
0 commit comments