44# Copyright (c) 2023 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
7- from ..const import SupportedModels
7+ from ..const import SupportedModels , AUTO_SELECT
88from ..operator_config import ForecastOperatorConfig
99from .arima import ArimaOperatorModel
1010from .automlx import AutoMLXOperatorModel
1414from .prophet import ProphetOperatorModel
1515from .forecast_datasets import ForecastDatasets
1616from .ml_forecast import MLForecastOperatorModel
17+ from ..model_evaluator import ModelEvaluator
1718
1819class UnSupportedModelError (Exception ):
1920 def __init__ (self , model_type : str ):
@@ -62,8 +63,9 @@ def get_model(
6263 In case of not supported model.
6364 """
6465 model_type = operator_config .spec .model
65- if model_type == "auto-select" :
66+ if model_type == AUTO_SELECT :
6667 model_type = cls .auto_select_model (datasets , operator_config )
68+ operator_config .spec .model_kwargs = dict ()
6769 if model_type not in cls ._MAP :
6870 raise UnSupportedModelError (model_type )
6971 return cls ._MAP [model_type ](config = operator_config , datasets = datasets )
@@ -88,7 +90,8 @@ def auto_select_model(
8890 str
8991 The type of the model.
9092 """
91- from ..model_evaluator import ModelEvaluator
92- all_models = cls ._MAP .keys ()
93- model_evaluator = ModelEvaluator (all_models )
93+ all_models = operator_config .spec .model_kwargs .get ("model_list" , cls ._MAP .keys ())
94+ num_backtests = operator_config .spec .model_kwargs .get ("num_backtests" , 5 )
95+ sample_ratio = operator_config .spec .model_kwargs .get ("sample_ratio" , 0.20 )
96+ model_evaluator = ModelEvaluator (all_models , num_backtests , sample_ratio )
9497 return model_evaluator .find_best_model (datasets , operator_config )
0 commit comments