@@ -44,9 +44,20 @@ def get_configs(path: Path) -> List[str]:
4444 default = 'configs/config_example.json' ,
4545 help = 'The path to a configuration file or '
4646 'a directory that contains configuration files' )
47+ parser .add_argument ('--device' , '--devices' , default = 'host cpu gpu none' , type = str , nargs = '+' ,
48+ choices = ('host' , 'cpu' , 'gpu' , 'none' ),
49+ help = 'Availible execution context devices. '
50+ 'This parameter only marks devices as available, '
51+ 'make sure to add the device to the config file '
52+ 'to run it on a specific device' )
4753 parser .add_argument ('--dummy-run' , default = False , action = 'store_true' ,
4854 help = 'Run configuration parser and datasets generation '
4955 'without benchmarks running' )
56+ parser .add_argument ('--dtype' , '--dtypes' , type = str , default = "float32 float64" , nargs = '+' ,
57+ choices = ("float32" , "float64" ),
58+ help = 'Available floating point data types'
59+ 'This parameter only marks dtype as available, '
60+ 'make sure to add the dtype parameter to the config file ' )
5061 parser .add_argument ('--no-intel-optimized' , default = False , action = 'store_true' ,
5162 help = 'Use Scikit-learn without Intel optimizations' )
5263 parser .add_argument ('--output-file' , default = 'results.json' ,
@@ -93,6 +104,28 @@ def get_configs(path: Path) -> List[str]:
93104 for params_set in config ['cases' ]:
94105 params = common_params .copy ()
95106 params .update (params_set .copy ())
107+
108+ device = []
109+ if 'device' not in params :
110+ if 'sklearn' in params ['lib' ]:
111+ logging .info ('The device parameter value is not defined in config, '
112+ 'none is used' )
113+ device = ['none' ]
114+ elif not isinstance (params ['device' ], list ):
115+ device = [params ['device' ]]
116+ else :
117+ device = params ['device' ]
118+ params ["device" ] = [dv for dv in device if dv in args .device ]
119+
120+ dtype = []
121+ if 'dtype' not in params :
122+ dtype = ['float64' ]
123+ elif not isinstance (params ['dtype' ], list ):
124+ dtype = [params ['dtype' ]]
125+ else :
126+ dtype = params ['dtype' ]
127+ params ['dtype' ] = [dt for dt in dtype if dt in args .dtype ]
128+
96129 algorithm = params ['algorithm' ]
97130 libs = params ['lib' ]
98131 if not isinstance (libs , list ):
0 commit comments