1515 InvalidParameterError ,
1616)
1717from ads .opctl .operator .lowcode .common .utils import merge_category_columns
18+ from ads .opctl .operator .lowcode .forecast .operator_config import ForecastOperatorSpec
1819
1920
2021class Transformations (ABC ):
@@ -34,6 +35,7 @@ def __init__(self, dataset_info, name="historical_data"):
3435 self .dataset_info = dataset_info
3536 self .target_category_columns = dataset_info .target_category_columns
3637 self .target_column_name = dataset_info .target_column
38+ self .raw_column_names = None
3739 self .dt_column_name = (
3840 dataset_info .datetime_column .name if dataset_info .datetime_column else None
3941 )
@@ -60,7 +62,8 @@ def run(self, data):
6062
6163 """
6264 clean_df = self ._remove_trailing_whitespace (data )
63- # clean_df = self._normalize_column_names(clean_df)
65+ if isinstance (self .dataset_info , ForecastOperatorSpec ):
66+ clean_df = self ._clean_column_names (clean_df )
6467 if self .name == "historical_data" :
6568 self ._check_historical_dataset (clean_df )
6669 clean_df = self ._set_series_id_column (clean_df )
@@ -98,8 +101,36 @@ def run(self, data):
98101 def _remove_trailing_whitespace (self , df ):
99102 return df .apply (lambda x : x .str .strip () if x .dtype == "object" else x )
100103
101- # def _normalize_column_names(self, df):
102- # return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
104+ def _clean_column_names (self , df ):
105+ """
106+ Remove all whitespaces from column names in a DataFrame and store the original names.
107+
108+ Parameters:
109+ df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
110+
111+ Returns:
112+ pd.DataFrame: The DataFrame with cleaned column names.
113+ """
114+
115+ self .raw_column_names = {
116+ col : col .replace (" " , "" ) for col in df .columns if " " in col
117+ }
118+ df .columns = [self .raw_column_names .get (col , col ) for col in df .columns ]
119+
120+ if self .target_column_name :
121+ self .target_column_name = self .raw_column_names .get (
122+ self .target_column_name , self .target_column_name
123+ )
124+ self .dt_column_name = self .raw_column_names .get (
125+ self .dt_column_name , self .dt_column_name
126+ )
127+
128+ if self .target_category_columns :
129+ self .target_category_columns = [
130+ self .raw_column_names .get (col , col )
131+ for col in self .target_category_columns
132+ ]
133+ return df
103134
104135 def _set_series_id_column (self , df ):
105136 self ._target_category_columns_map = {}
@@ -233,6 +264,10 @@ def _check_historical_dataset(self, df):
233264 expected_names = [self .target_column_name , self .dt_column_name ] + (
234265 self .target_category_columns if self .target_category_columns else []
235266 )
267+
268+ if self .raw_column_names :
269+ expected_names .extend (list (self .raw_column_names .values ()))
270+
236271 if set (df .columns ) != set (expected_names ):
237272 raise DataMismatchError (
238273 f"Expected { self .name } to have columns: { expected_names } , but instead found column names: { df .columns } . Is the { self .name } path correct?"
0 commit comments