1010from huggingface_hub .hf_api import HfApi , ModelInfo
1111from huggingface_hub .utils import GatedRepoError
1212from notebook .base .handlers import IPythonHandler
13+ from parameterized import parameterized
1314
1415from ads .aqua .common .errors import AquaRuntimeError
1516from ads .aqua .common .utils import get_hf_model_info
@@ -129,9 +130,25 @@ def test_list(self, mock_list):
129130 compartment_id = None , project_id = None , model_type = None
130131 )
131132
133+ @parameterized .expand (
134+ [
135+ (None , None , False , None ),
136+ ("odsc-llm-fine-tuning" , None , False , None ),
137+ (None , "test.gguf" , True , None ),
138+ (None , None , True , "iad.ocir.io/<namespace>/<image>:<tag>" ),
139+ ],
140+ )
132141 @patch ("notebook.base.handlers.APIHandler.finish" )
133142 @patch ("ads.aqua.model.AquaModelApp.register" )
134- def test_register (self , mock_register , mock_finish ):
143+ def test_register (
144+ self ,
145+ finetuning_container ,
146+ model_file ,
147+ download_from_hf ,
148+ inference_container_uri ,
149+ mock_register ,
150+ mock_finish ,
151+ ):
135152 mock_register .return_value = AquaModel (
136153 id = "test_id" ,
137154 inference_container = "odsc-tgi-serving" ,
@@ -144,18 +161,23 @@ def test_register(self, mock_register, mock_finish):
144161 model = "test_model_name" ,
145162 os_path = "test_os_path" ,
146163 inference_container = "odsc-tgi-serving" ,
164+ finetuning_container = finetuning_container ,
165+ model_file = model_file ,
166+ download_from_hf = download_from_hf ,
167+ inference_container_uri = inference_container_uri ,
147168 )
148169 )
149170 result = self .model_handler .post ()
150171 mock_register .assert_called_with (
151172 model = "test_model_name" ,
152173 os_path = "test_os_path" ,
153174 inference_container = "odsc-tgi-serving" ,
154- finetuning_container = None ,
175+ finetuning_container = finetuning_container ,
155176 compartment_id = None ,
156177 project_id = None ,
157- model_file = None ,
158- download_from_hf = False ,
178+ model_file = model_file ,
179+ download_from_hf = download_from_hf ,
180+ inference_container_uri = inference_container_uri ,
159181 )
160182 assert result ["id" ] == "test_id"
161183 assert result ["inference_container" ] == "odsc-tgi-serving"
0 commit comments