77import json
88import os
99import shutil
10- from unittest .mock import Mock
1110import uuid
11+ from typing import Optional
1212from zipfile import ZipFile
1313
1414import pandas as pd
1515import yaml
1616from ads .catalog .summary import SummaryList
17- from ads .common import auth , oci_client , utils , logger
18- from ads .common .model_artifact import (
19- ConflictStrategy ,
20- ModelArtifact ,
21- OUTPUT_SCHEMA_FILE_NAME ,
22- )
17+ from ads .common import auth , logger , oci_client , utils
18+ from ads .common .decorator .deprecate import deprecated
19+ from ads .common .model_artifact import ConflictStrategy , ModelArtifact
2320from ads .common .model_metadata import (
24- ModelCustomMetadata ,
25- ModelTaxonomyMetadata ,
2621 METADATA_SIZE_LIMIT ,
2722 MetadataSizeTooLarge ,
23+ ModelCustomMetadata ,
24+ ModelTaxonomyMetadata ,
2825)
29- from ads .common .oci_resource import OCIResource , SEARCH_TYPE
26+ from ads .common .oci_resource import SEARCH_TYPE , OCIResource
3027from ads .config import (
31- OCI_IDENTITY_SERVICE_ENDPOINT ,
3228 NB_SESSION_COMPARTMENT_OCID ,
3329 OCI_ODSC_SERVICE_ENDPOINT ,
3430 PROJECT_OCID ,
4440from oci .exceptions import ServiceError
4541from oci .identity import IdentityClient
4642
47-
4843_UPDATE_MODEL_DETAILS_ATTRIBUTES = [
4944 "display_name" ,
5045 "description" ,
@@ -566,29 +561,27 @@ class ModelCatalog:
566561
567562 def __init__ (
568563 self ,
569- compartment_id = None ,
570- ds_client_auth = None ,
571- identity_client_auth = None ,
572- timeout : int = None ,
564+ compartment_id : Optional [ str ] = None ,
565+ ds_client_auth : Optional [ dict ] = None ,
566+ identity_client_auth : Optional [ dict ] = None ,
567+ timeout : Optional [ int ] = None ,
573568 ):
574569 """Initializes model catalog instance.
575570
576571 Parameters
577572 ----------
578- compartment_id : str, optional
579- OCID of model's compartment
580- If None, the default compartment ID `config.NB_SESSION_COMPARTMENT_OCID` would be used
581- ds_client_auth : dict
582- Default is None. The default authetication is set using `ads.set_auth` API. If you need to override the
573+ compartment_id : (str, optional). Defaults to None.
574+ Model compartment OCID. If `None`, the `config.NB_SESSION_COMPARTMENT_OCID` would be used.
575+ ds_client_auth : (dict, optional). Defaults to None.
576+ The default authetication is set using `ads.set_auth` API. If you need to override the
583577 default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
584578 authentication signer and kwargs required to instantiate DataScienceClient object.
585- identity_client_auth : dict
586- Default is None. The default authetication is set using `ads.set_auth` API. If you need to override the
579+ identity_client_auth : ( dict, optional). Defaults to None.
580+ The default authetication is set using `ads.set_auth` API. If you need to override the
587581 default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
588582 authentication signer and kwargs required to instantiate IdentityClient object.
589- timeout: int, optional
583+ timeout: ( int, optional). Defaults to 10 seconds.
590584 The connection timeout in seconds for the client.
591- The default value for connection timeout is 10 seconds.
592585
593586 Raises
594587 ------
@@ -864,16 +857,11 @@ def delete_model(self, model, **kwargs):
864857 logger .error ("Failed to delete the Model." )
865858 return False
866859
867- def download_model (
868- self ,
869- model_id : str ,
870- target_dir : str ,
871- force_overwrite : bool = False ,
872- install_libs : bool = False ,
873- conflict_strategy = ConflictStrategy .IGNORE ,
874- ):
860+ def _download_artifacts (
861+ self , model_id : str , target_dir : str , force_overwrite : Optional [bool ] = False
862+ ) -> None :
875863 """
876- Downloads the model from model_dir to target_dir based on model_id.
864+ Downloads the model artifacts from model catalog to target_dir based on model_id.
877865
878866 Parameters
879867 ----------
@@ -883,46 +871,89 @@ def download_model(
883871 The target location of model after download.
884872 force_overwrite: bool
885873 Overwrite target_dir if exists.
886- install_libs: bool, default: False
887- Install the libraries specified in ds-requirements.txt which are missing in the current environment.
888- conflict_strategy: ConflictStrategy, default: IGNORE
889- Determines how to handle version conflicts between the current environment and requirements of
890- model artifact.
891- Valid values: "IGNORE", "UPDATE" or ConflictStrategy.
892- IGNORE: Use the installed version in case of conflict
893- UPDATE: Force update dependency to the version required by model artifact in case of conflict
874+
875+ Raises
876+ ------
877+ ValueError
878+ If targed dir not exists.
879+ KeyError
880+ If model id not found.
894881
895882 Returns
896883 -------
897- ModelArtifact
898- A ModelArtifact instance.
884+ None
885+ Nothing
899886 """
900887 if os .path .exists (target_dir ) and os .listdir (target_dir ):
901888 if not force_overwrite :
902889 raise ValueError (
903- "Target directory already exists. Set 'force_overwrite' to overwrite."
890+ "Target directory already exists. "
891+ "Set `force_overwrite` to overwrite."
904892 )
905893 shutil .rmtree (target_dir )
906894
907895 try :
908896 zip_contents = self .ds_client .get_model_artifact_content (
909897 model_id
910898 ).data .content
911- except ServiceError as se :
912- if se .status == 404 :
913- raise KeyError (se .message ) from se
899+ except ServiceError as ex :
900+ if ex .status == 404 :
901+ raise KeyError (ex .message ) from ex
914902 else :
915903 raise
916904 zip_file_path = os .path .join (
917905 "/tmp" , "saved_model_" + str (uuid .uuid4 ()) + ".zip"
918906 )
907+
919908 # write contents to zip file
920909 with open (zip_file_path , "wb" ) as zip_file :
921910 zip_file .write (zip_contents )
911+
922912 # Extract all the contents of zip file in target directory
923913 with ZipFile (zip_file_path ) as zip_file :
924914 zip_file .extractall (target_dir )
915+
925916 os .remove (zip_file_path )
917+
918+ @deprecated (
919+ "2.5.9" ,
920+ details = "Instead use `ads.common.model_artifact.ModelArtifact.from_model_catalog()`." ,
921+ )
922+ def download_model (
923+ self ,
924+ model_id : str ,
925+ target_dir : str ,
926+ force_overwrite : bool = False ,
927+ install_libs : bool = False ,
928+ conflict_strategy = ConflictStrategy .IGNORE ,
929+ ):
930+ """
931+ Downloads the model from model_dir to target_dir based on model_id.
932+
933+ Parameters
934+ ----------
935+ model_id: str
936+ The OCID of the model to download.
937+ target_dir: str
938+ The target location of model after download.
939+ force_overwrite: bool
940+ Overwrite target_dir if exists.
941+ install_libs: bool, default: False
942+ Install the libraries specified in ds-requirements.txt which are missing in the current environment.
943+ conflict_strategy: ConflictStrategy, default: IGNORE
944+ Determines how to handle version conflicts between the current environment and requirements of
945+ model artifact.
946+ Valid values: "IGNORE", "UPDATE" or ConflictStrategy.
947+ IGNORE: Use the installed version in case of conflict
948+ UPDATE: Force update dependency to the version required by model artifact in case of conflict
949+
950+ Returns
951+ -------
952+ ModelArtifact
953+ A ModelArtifact instance.
954+ """
955+ self ._download_artifacts (model_id , target_dir , force_overwrite )
956+
926957 result = ModelArtifact (
927958 target_dir ,
928959 conflict_strategy = conflict_strategy ,
0 commit comments