33# Copyright (c) 2023 Oracle and/or its affiliates.
44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
6- import os
6+ from mock import patch
77from typing import Tuple
8+ import os
89import pandas as pd
910import pytest
11+
12+ from ads .common import utils
1013from ads .dataset .classification_dataset import BinaryClassificationDataset
1114from ads .dataset .dataset_with_target import ADSDatasetWithTarget
1215from ads .dataset .pipeline import TransformerPipeline
1316from ads .dataset .target import TargetVariable
1417
1518
1619class TestADSDatasetTarget :
20+ def get_data_path (self ):
21+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
22+ return os .path .join (current_dir , "data" , "orcl_attrition.csv" )
23+
1724 def test_initialize_dataset_target (self ):
1825 employees = ADSDatasetWithTarget (
1926 df = pd .read_csv (self .get_data_path ()),
2027 target = "Attrition" ,
2128 name = "test_dataset" ,
2229 description = "test_description" ,
23- storage_options = {' config' : {},' region' : ' us-ashburn-1' }
30+ storage_options = {" config" : {}, " region" : " us-ashburn-1" },
2431 )
2532
2633 assert isinstance (employees , ADSDatasetWithTarget )
@@ -32,8 +39,8 @@ def test_dataset_target_from_dataframe(self):
3239 employees = ADSDatasetWithTarget .from_dataframe (
3340 df = pd .read_csv (self .get_data_path ()),
3441 target = "Attrition" ,
35- storage_options = {' config' : {},' region' : ' us-ashburn-1' }
36- ).set_positive_class (' Yes' )
42+ storage_options = {" config" : {}, " region" : " us-ashburn-1" },
43+ ).set_positive_class (" Yes" )
3744
3845 assert isinstance (employees , BinaryClassificationDataset )
3946 self .assert_dataset (employees )
@@ -65,6 +72,45 @@ def assert_dataset(self, dataset):
6572 assert "type_discovery" in dataset .init_kwargs
6673 assert isinstance (dataset .transformer_pipeline , TransformerPipeline )
6774
68- def get_data_path (self ):
69- current_dir = os .path .dirname (os .path .abspath (__file__ ))
70- return os .path .join (current_dir , "data" , "orcl_attrition.csv" )
75+ def test_seggested_sampling_for_imbalanced_dataset (self ):
76+ employees = ADSDatasetWithTarget .from_dataframe (
77+ df = pd .read_csv (self .get_data_path ()),
78+ target = "Attrition" ,
79+ ).set_positive_class ("Yes" )
80+
81+ rt = employees ._get_recommendations_transformer (
82+ fix_imbalance = True , correlation_threshold = 1
83+ )
84+ rt .fit (employees )
85+
86+ ## Assert with default setup for thresholds MAX_LEN_FOR_UP_SAMPLING and MIN_RATIO_FOR_DOWN_SAMPLING
87+ assert utils .MAX_LEN_FOR_UP_SAMPLING == 5000
88+ assert utils .MIN_RATIO_FOR_DOWN_SAMPLING == 1 / 20
89+
90+ assert (
91+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Message" ]
92+ == "Imbalanced Target(33.33%)"
93+ )
94+ # up-sample if length of dataframe is less than or equal to MAX_LEN_FOR_UP_SAMPLING
95+ assert len (employees ) < utils .MAX_LEN_FOR_UP_SAMPLING
96+ assert (
97+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Selected Action" ]
98+ == "Up-sample"
99+ )
100+
101+ # manipulate MAX_LEN_FOR_UP_SAMPLING, MIN_RATIO_FOR_DOWN_SAMPLING to get other recommendations
102+ with patch ("ads.common.utils.MAX_LEN_FOR_UP_SAMPLING" , 5 ):
103+ assert utils .MAX_LEN_FOR_UP_SAMPLING == 5
104+ rt .fit (employees )
105+ # expect down-sample suggested, because minor_majority_ratio is greater than MIN_RATIO_FOR_DOWN_SAMPLING
106+ assert (
107+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Selected Action" ]
108+ == "Down-sample"
109+ )
110+ with patch ("ads.common.utils.MIN_RATIO_FOR_DOWN_SAMPLING" , 0.35 ):
111+ rt .fit (employees )
112+ # expect "Do nothing" with both MAX_LEN_FOR_UP_SAMPLING, MIN_RATIO_FOR_DOWN_SAMPLING tweaked for sampled_df
113+ assert (
114+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Selected Action" ]
115+ == "Do nothing"
116+ )
0 commit comments