1818__all__ = ["Model" ]
1919
2020TRUST_REMOTE_CODE = os .getenv ("TRUST_REMOTE_CODE" , "false" ).lower () in ["true" , "1" ]
21+ DISABLE_TENSOR_CACHE = os .getenv ("DISABLE_TENSOR_CACHE" , "false" ).lower () in [
22+ "true" ,
23+ "1" ,
24+ ]
2125# Disable gradients
2226torch .set_grad_enabled (False )
2327
3236 __all__ .append (FlashBert )
3337
3438
39+ def wrap_model_if_hpu (model_handle , device ):
40+ """Wrap the model in HPU graph if the device is HPU."""
41+ if device .type == "hpu" :
42+ from habana_frameworks .torch .hpu import wrap_in_hpu_graph
43+
44+ model_handle .model = wrap_in_hpu_graph (
45+ model_handle .model , disable_tensor_cache = DISABLE_TENSOR_CACHE
46+ )
47+ return model_handle
48+
49+
50+ def create_model (model_class , model_path , device , datatype , pool = "cls" ):
51+ """Create a model instance and wrap it if needed."""
52+ model_handle = model_class (
53+ model_path ,
54+ device ,
55+ datatype ,
56+ pool ,
57+ trust_remote = TRUST_REMOTE_CODE ,
58+ )
59+ return wrap_model_if_hpu (model_handle , device )
60+
61+
3562def get_model (model_path : Path , dtype : Optional [str ], pool : str ):
3663 if dtype == "float32" :
3764 datatype = torch .float32
@@ -46,6 +73,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
4673 logger .info (f"backend device: { device } " )
4774
4875 config = AutoConfig .from_pretrained (model_path , trust_remote_code = TRUST_REMOTE_CODE )
76+
4977 if (
5078 hasattr (config , "auto_map" )
5179 and isinstance (config .auto_map , dict )
@@ -54,8 +82,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5482 == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
5583 ):
5684 # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
57- return FlashJinaBert (model_path , config , device , datatype , pool )
58- elif config .model_type == "bert" :
85+ return create_model (FlashJinaBert , model_path , device , datatype )
86+
87+ if config .model_type == "bert" :
5988 config : BertConfig
6089 if (
6190 use_ipex ()
@@ -66,98 +95,36 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6695 ):
6796 if pool != "cls" :
6897 if config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
69- return MaskedLanguageModel (
70- model_path ,
71- device ,
72- datatype ,
73- trust_remote = TRUST_REMOTE_CODE ,
98+ return create_model (
99+ MaskedLanguageModel , model_path , device , datatype , pool
74100 )
75- return DefaultModel (
76- model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
77- )
101+ return create_model (DefaultModel , model_path , device , datatype , pool )
102+
78103 try :
79- return FlashBert ( model_path , device , datatype )
80- except FileNotFoundError as e :
104+ return create_model ( FlashBert , model_path , device , datatype )
105+ except FileNotFoundError :
81106 logger .info (
82107 "Do not have safetensors file for this model, use default transformers model path instead"
83108 )
84- return DefaultModel (
85- model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
86- )
109+ return create_model (DefaultModel , model_path , device , datatype , pool )
110+
87111 if config .architectures [0 ].endswith ("Classification" ):
88- return ClassificationModel (
89- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
90- )
112+ return create_model (ClassificationModel , model_path , device , datatype )
91113 elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
92- return MaskedLanguageModel (
93- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
94- )
114+ return create_model (MaskedLanguageModel , model_path , device , datatype )
95115 else :
96- return DefaultModel (
97- model_path ,
98- device ,
99- datatype ,
100- pool ,
101- trust_remote = TRUST_REMOTE_CODE ,
102- )
103- elif config .model_type == "mistral" and device .type == "hpu" :
116+ return create_model (DefaultModel , model_path , device , datatype , pool )
117+
118+ if config .model_type == "mistral" and device .type == "hpu" :
104119 try :
105- return FlashMistral (
106- model_path ,
107- device ,
108- datatype ,
109- pool ,
110- )
111- except FileNotFoundError as e :
112- return DefaultModel (
113- model_path ,
114- device ,
115- datatype ,
116- pool ,
117- trust_remote = TRUST_REMOTE_CODE ,
118- )
120+ return create_model (FlashMistral , model_path , device , datatype , pool )
121+ except FileNotFoundError :
122+ return create_model (DefaultModel , model_path , device , datatype , pool )
123+
124+ # Default case
125+ if config .architectures [0 ].endswith ("Classification" ):
126+ return create_model (ClassificationModel , model_path , device , datatype )
127+ elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
128+ return create_model (MaskedLanguageModel , model_path , device , datatype )
119129 else :
120- if device .type == "hpu" :
121- from habana_frameworks .torch .hpu import wrap_in_hpu_graph
122-
123- if config .architectures [0 ].endswith ("Classification" ):
124- model_handle = ClassificationModel (
125- model_path ,
126- device ,
127- datatype ,
128- trust_remote = TRUST_REMOTE_CODE ,
129- )
130- elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
131- model_handle = MaskedLanguageModel (
132- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
133- )
134- else :
135- model_handle = DefaultModel (
136- model_path ,
137- device ,
138- datatype ,
139- pool ,
140- trust_remote = TRUST_REMOTE_CODE ,
141- )
142- model_handle .model = wrap_in_hpu_graph (model_handle .model )
143- return model_handle
144- elif use_ipex ():
145- if config .architectures [0 ].endswith ("Classification" ):
146- return ClassificationModel (
147- model_path ,
148- device ,
149- datatype ,
150- trust_remote = TRUST_REMOTE_CODE ,
151- )
152- elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
153- return MaskedLanguageModel (
154- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
155- )
156- else :
157- return DefaultModel (
158- model_path ,
159- device ,
160- datatype ,
161- pool ,
162- trust_remote = TRUST_REMOTE_CODE ,
163- )
130+ return create_model (DefaultModel , model_path , device , datatype , pool )
0 commit comments