1212
1313import sklearn
1414import sklearn .datasets
15+ from sklearn .base import clone
1516from sklearn .ensemble import VotingClassifier , VotingRegressor
1617
18+ from smac .runhistory .runhistory import RunHistory
19+
1720import torch
1821
1922from autoPyTorch .api .tabular_classification import TabularClassificationTask
2326 HoldoutValTypes ,
2427)
2528from autoPyTorch .optimizer .smbo import AutoMLSMBO
29+ from autoPyTorch .pipeline .components .training .metrics .metrics import accuracy
2630
2731
2832# Fixtures
@@ -104,17 +108,20 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
104108
105109 # Search for an existing run key in disc. A individual model might have
106110 # a timeout and hence was not written to disc
111+ successful_num_run = None
112+ SUCCESS = False
107113 for i , (run_key , value ) in enumerate (estimator .run_history .data .items ()):
108- if 'SUCCESS' not in str (value .status ):
109- continue
110-
111- run_key_model_run_dir = estimator ._backend .get_numrun_directory (
112- estimator .seed , run_key .config_id + 1 , run_key .budget )
113- if os .path .exists (run_key_model_run_dir ):
114- # Runkey config id is different from the num_run
115- # more specifically num_run = config_id + 1(dummy)
114+ if 'SUCCESS' in str (value .status ):
115+ run_key_model_run_dir = estimator ._backend .get_numrun_directory (
116+ estimator .seed , run_key .config_id + 1 , run_key .budget )
116117 successful_num_run = run_key .config_id + 1
117- break
118+ if os .path .exists (run_key_model_run_dir ):
119+ # Runkey config id is different from the num_run
120+ # more specifically num_run = config_id + 1(dummy)
121+ SUCCESS = True
122+ break
123+
124+ assert SUCCESS , f"Successful run was not properly saved for num_run: { successful_num_run } "
118125
119126 if resampling_strategy == HoldoutValTypes .holdout_validation :
120127 model_file = os .path .join (run_key_model_run_dir ,
@@ -272,17 +279,20 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
272279
273280 # Search for an existing run key in disc. A individual model might have
274281 # a timeout and hence was not written to disc
282+ successful_num_run = None
283+ SUCCESS = False
275284 for i , (run_key , value ) in enumerate (estimator .run_history .data .items ()):
276- if 'SUCCESS' not in str (value .status ):
277- continue
278-
279- run_key_model_run_dir = estimator ._backend .get_numrun_directory (
280- estimator .seed , run_key .config_id + 1 , run_key .budget )
281- if os .path .exists (run_key_model_run_dir ):
282- # Runkey config id is different from the num_run
283- # more specifically num_run = config_id + 1(dummy)
285+ if 'SUCCESS' in str (value .status ):
286+ run_key_model_run_dir = estimator ._backend .get_numrun_directory (
287+ estimator .seed , run_key .config_id + 1 , run_key .budget )
284288 successful_num_run = run_key .config_id + 1
285- break
289+ if os .path .exists (run_key_model_run_dir ):
290+ # Runkey config id is different from the num_run
291+ # more specifically num_run = config_id + 1(dummy)
292+ SUCCESS = True
293+ break
294+
295+ assert SUCCESS , f"Successful run was not properly saved for num_run: { successful_num_run } "
286296
287297 if resampling_strategy == HoldoutValTypes .holdout_validation :
288298 model_file = os .path .join (run_key_model_run_dir ,
@@ -384,7 +394,7 @@ def test_tabular_input_support(openml_id, backend):
384394 estimator ._do_dummy_prediction = unittest .mock .MagicMock ()
385395
386396 with unittest .mock .patch .object (AutoMLSMBO , 'run_smbo' ) as AutoMLSMBOMock :
387- AutoMLSMBOMock .return_value = ({} , {}, 'epochs' )
397+ AutoMLSMBOMock .return_value = (RunHistory () , {}, 'epochs' )
388398 estimator .search (
389399 X_train = X_train , y_train = y_train ,
390400 X_test = X_test , y_test = y_test ,
@@ -394,3 +404,48 @@ def test_tabular_input_support(openml_id, backend):
394404 enable_traditional_pipeline = False ,
395405 load_models = False ,
396406 )
407+
408+
409+ @pytest .mark .parametrize ("fit_dictionary_tabular" , ['classification_categorical_only' ], indirect = True )
410+ def test_do_dummy_prediction (dask_client , fit_dictionary_tabular ):
411+ backend = fit_dictionary_tabular ['backend' ]
412+ estimator = TabularClassificationTask (
413+ backend = backend ,
414+ resampling_strategy = HoldoutValTypes .holdout_validation ,
415+ ensemble_size = 0 ,
416+ )
417+
418+ # Setup pre-requisites normally set by search()
419+ estimator ._create_dask_client ()
420+ estimator ._metric = accuracy
421+ estimator ._logger = estimator ._get_logger ('test' )
422+ estimator ._memory_limit = 5000
423+ estimator ._time_for_task = 60
424+ estimator ._disable_file_output = []
425+ estimator ._all_supported_metrics = False
426+
427+ estimator ._do_dummy_prediction ()
428+
429+ # Ensure that the dummy predictions are not in the current working
430+ # directory, but in the temporary directory.
431+ assert not os .path .exists (os .path .join (os .getcwd (), '.autoPyTorch' ))
432+ assert os .path .exists (os .path .join (
433+ backend .temporary_directory , '.autoPyTorch' , 'runs' , '1_1_1.0' ,
434+ 'predictions_ensemble_1_1_1.0.npy' )
435+ )
436+
437+ model_path = os .path .join (backend .temporary_directory ,
438+ '.autoPyTorch' ,
439+ 'runs' , '1_1_1.0' ,
440+ '1.1.1.0.model' )
441+
442+ # Make sure the dummy model complies with scikit learn
443+ # get/set params
444+ assert os .path .exists (model_path )
445+ with open (model_path , 'rb' ) as model_handler :
446+ clone (pickle .load (model_handler ))
447+
448+ estimator ._close_dask_client ()
449+ estimator ._clean_logger ()
450+
451+ del estimator
0 commit comments