Skip to content

Commit 250585d

Browse files
author
Vikas Pandey
committed
improve auto-select logic and handle missing data
1 parent 3a13534 commit 250585d

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def build_fforms_meta_features(self, data, target_col=None, group_cols=None):
329329
if target_col not in data.columns:
330330
raise ValueError(f"Target column '{target_col}' not found in DataFrame")
331331

332+
data[target_col] = data[target_col].fillna(0)
333+
332334
# Check if group_cols are provided and valid
333335
if group_cols is not None:
334336
if not isinstance(group_cols, list):

ads/opctl/operator/lowcode/forecast/__main__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
import sys
1010
from typing import Dict, List
1111

12-
import pandas as pd
1312
import yaml
1413

1514
from ads.opctl import logger
1615
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
1716
from ads.opctl.operator.common.utils import _parse_input_args
1817

19-
from .const import AUTO_SELECT_SERIES
18+
from .const import AUTO_SELECT, AUTO_SELECT_SERIES
2019
from .model.forecast_datasets import ForecastDatasets, ForecastResults
2120
from .operator_config import ForecastOperatorConfig
2221
from .whatifserve import ModelDeploymentManager
@@ -29,8 +28,10 @@ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
2928
datasets = ForecastDatasets(operator_config)
3029
model = ForecastOperatorModelFactory.get_model(operator_config, datasets)
3130

32-
if operator_config.spec.model == AUTO_SELECT_SERIES and hasattr(
33-
operator_config.spec, "meta_features"
31+
if (
32+
operator_config.spec.model == AUTO_SELECT_SERIES
33+
and hasattr(operator_config.spec, "meta_features")
34+
and operator_config.spec.target_category_columns
3435
):
3536
# For AUTO_SELECT_SERIES, handle each series with its specific model
3637
meta_features = operator_config.spec.meta_features
@@ -64,8 +65,6 @@ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
6465
)
6566
sub_results_list.append(sub_results)
6667

67-
# results_df = pd.concat([results_df, sub_result_df], ignore_index=True, axis=0)
68-
# elapsed_time += sub_elapsed_time
6968
# Merge all sub_results into a single ForecastResults object
7069
if sub_results_list:
7170
results = sub_results_list[0]
@@ -75,6 +74,15 @@ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
7574
results = None
7675

7776
else:
77+
# When AUTO_SELECT_SERIES is specified but target_category_columns is not,
78+
# we fall back to AUTO_SELECT behavior.
79+
if (
80+
operator_config.spec.model == AUTO_SELECT_SERIES
81+
and not operator_config.spec.target_category_columns
82+
):
83+
84+
operator_config.spec.model = AUTO_SELECT
85+
model = ForecastOperatorModelFactory.get_model(operator_config, datasets)
7886
# For other cases, use the single selected model
7987
results = model.generate_report()
8088
# saving to model catalog

tests/operators/forecast/test_datasets.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pandas as pd
1313
import pytest
1414
import yaml
15+
import numpy as np
1516

1617
from ads.opctl.operator.cmd import run
1718
from ads.opctl.operator.lowcode.forecast.__main__ import operate as forecast_operate
@@ -413,5 +414,44 @@ def run_operator(
413414
# generate_train_metrics = True
414415

415416

417+
def test_missing_data_autoselect_series():
418+
"""Test case for auto-select-series with missing data."""
419+
data = {
420+
"Date": pd.to_datetime(
421+
[
422+
"2023-01-01",
423+
"2023-01-02",
424+
"2023-01-03",
425+
"2023-01-04",
426+
"2023-01-05",
427+
"2023-01-06",
428+
"2023-01-07",
429+
"2023-01-08",
430+
"2023-01-09",
431+
"2023-01-10",
432+
]
433+
),
434+
"Y": [1, 2, np.nan, 4, 5, 6, 7, 8, 9, 10],
435+
"Category": ["A", "A", "A", "A", "A", "A", "A", "A", "A", "A"],
436+
}
437+
df = pd.DataFrame(data)
438+
439+
with tempfile.TemporaryDirectory() as tmpdirname:
440+
output_data_path = f"{tmpdirname}/results"
441+
yaml_i = deepcopy(TEMPLATE_YAML)
442+
yaml_i["spec"]["model"] = "auto-select-series"
443+
yaml_i["spec"]["historical_data"].pop("url")
444+
yaml_i["spec"]["historical_data"]["data"] = df
445+
yaml_i["spec"]["target_column"] = "Y"
446+
yaml_i["spec"]["datetime_column"]["name"] = "Date"
447+
yaml_i["spec"]["target_category_columns"] = ["Category"]
448+
yaml_i["spec"]["horizon"] = 2
449+
yaml_i["spec"]["output_directory"]["url"] = output_data_path
450+
451+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
452+
forecast_operate(operator_config)
453+
check_output_for_errors(output_data_path)
454+
455+
416456
if __name__ == "__main__":
417457
pass

0 commit comments

Comments
 (0)