22# Copyright (c) 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
5+ import json
6+ import os
7+ import re
58import shutil
6- from typing import List , Union
9+ from typing import Dict , List , Optional , Tuple , Union
710
11+ from huggingface_hub import hf_hub_download
12+ from huggingface_hub .utils import HfHubHTTPError
813from pydantic import ValidationError
914from rich .table import Table
1015
1722)
1823from ads .aqua .common .utils import (
1924 build_pydantic_error_message ,
25+ format_hf_custom_error_message ,
2026 get_resource_type ,
27+ is_valid_ocid ,
2128 load_config ,
2229 load_gpu_shapes_index ,
2330)
3744 ShapeRecommendationReport ,
3845 ShapeReport ,
3946)
47+ from ads .config import COMPARTMENT_OCID
4048from ads .model .datascience_model import DataScienceModel
4149from ads .model .service .oci_datascience_model_deployment import (
4250 OCIDataScienceModelDeployment ,
@@ -91,20 +99,23 @@ def which_shapes(
9199 try :
92100 shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
93101
94- ds_model = self ._get_data_science_model (request .model_id )
95-
96- model_name = ds_model .display_name if ds_model .display_name else ""
97-
98102 if request .deployment_config :
103+ if is_valid_ocid (request .model_id ):
104+ ds_model = self ._get_data_science_model (request .model_id )
105+ model_name = ds_model .display_name
106+ else :
107+ model_name = request .model_id
108+
99109 shape_recommendation_report = (
100110 ShapeRecommendationReport .from_deployment_config (
101111 request .deployment_config , model_name , shapes
102112 )
103113 )
104114
105115 else :
106- data = self ._get_model_config (ds_model )
107-
116+ data , model_name = self ._get_model_config_and_name (
117+ model_id = request .model_id ,
118+ )
108119 llm_config = LLMConfig .from_raw_config (data )
109120
110121 shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
@@ -135,7 +146,57 @@ def which_shapes(
135146
136147 return shape_recommendation_report
137148
138- def valid_compute_shapes (self , compartment_id : str ) -> List ["ComputeShapeSummary" ]:
149+ def _get_model_config_and_name (
150+ self ,
151+ model_id : str ,
152+ ) -> Tuple [Dict , str ]:
153+ """
154+ Loads model configuration by trying OCID logic first, then falling back
155+ to treating the model_id as a Hugging Face Hub ID.
156+
157+ Parameters
158+ ----------
159+ model_id : str
160+ The model OCID or Hugging Face model ID.
161+ # compartment_id : Optional[str]
162+ # The compartment OCID, used for searching the model catalog.
163+
164+ Returns
165+ -------
166+ Tuple[Dict, str]
167+ A tuple containing:
168+ - The model configuration dictionary.
169+ - The display name for the model.
170+ """
171+ if is_valid_ocid (model_id ):
172+ logger .info (f"Detected OCID: Fetching OCI model config for '{ model_id } '." )
173+ ds_model = self ._get_data_science_model (model_id )
174+ config = self ._get_model_config (ds_model )
175+ model_name = ds_model .display_name
176+ else :
177+ logger .info (
178+ f"Assuming Hugging Face model ID: Fetching config for '{ model_id } '."
179+ )
180+ config = self ._fetch_hf_config (model_id )
181+ model_name = model_id
182+
183+ return config , model_name
184+
185+ def _fetch_hf_config (self , model_id : str ) -> Dict :
186+ """
187+ Downloads a model's config.json from Hugging Face Hub using the
188+ huggingface_hub library.
189+ """
190+ try :
191+ config_path = hf_hub_download (repo_id = model_id , filename = "config.json" )
192+ with open (config_path , "r" , encoding = "utf-8" ) as f :
193+ return json .load (f )
194+ except HfHubHTTPError as e :
195+ format_hf_custom_error_message (e )
196+
197+ def valid_compute_shapes (
198+ self , compartment_id : Optional [str ] = None
199+ ) -> List ["ComputeShapeSummary" ]:
139200 """
140201 Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
141202
@@ -151,9 +212,23 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary
151212
152213 Raises
153214 ------
154- ValueError
155- If the file cannot be opened, parsed, or the 'shapes' key is missing.
215+ AquaValueError
216+ If a compartment_id is not provided and cannot be found in the
217+ environment variables.
156218 """
219+ if not compartment_id :
220+ compartment_id = COMPARTMENT_OCID
221+ if compartment_id :
222+ logger .info (f"Using compartment_id from environment: { compartment_id } " )
223+
224+ if not compartment_id :
225+ raise AquaValueError (
226+ "A compartment OCID is required to list available shapes. "
227+ "Please specify it using the --compartment_id parameter.\n \n "
228+ "Example:\n "
229+ 'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"'
230+ )
231+
157232 oci_shapes = OCIDataScienceModelDeployment .shapes (compartment_id = compartment_id )
158233 set_user_shapes = {shape .name : shape for shape in oci_shapes }
159234
@@ -324,6 +399,7 @@ def _get_model_config(model: DataScienceModel):
324399 """
325400
326401 model_task = model .freeform_tags .get ("task" , "" ).lower ()
402+ model_task = re .sub (r"-" , "_" , model_task )
327403 model_format = model .freeform_tags .get ("model_format" , "" ).lower ()
328404
329405 logger .info (f"Current model task type: { model_task } " )
0 commit comments