77
88import os
99from copy import deepcopy
10- from unittest import TestCase , mock , SkipTest
10+ from unittest import SkipTest , TestCase , mock , skipIf
1111
12- from langchain . llms import Cohere
12+ import langchain_core
1313from langchain .chains import LLMChain
14+ from langchain .llms import Cohere
1415from langchain .prompts import PromptTemplate
1516from langchain .schema .runnable import RunnableParallel , RunnablePassthrough
1617
17- from ads .llm .serialize import load , dump
1818from ads .llm import (
1919 GenerativeAI ,
2020 GenerativeAIEmbeddings ,
2121 ModelDeploymentTGI ,
2222 ModelDeploymentVLLM ,
2323)
24+ from ads .llm .serialize import dump , load
25+
26+
27+ def version_tuple (version ):
28+ return tuple (map (int , version .split ("." )))
2429
2530
2631class ChainSerializationTest (TestCase ):
@@ -142,6 +147,10 @@ def test_llm_chain_serialization_with_oci(self):
142147 self .assertEqual (llm_chain .llm .model , "my_model" )
143148 self .assertEqual (llm_chain .input_keys , ["subject" ])
144149
150+ @skipIf (
151+ version_tuple (langchain_core .__version__ ) > (0 , 1 , 50 ),
152+ "Serialization not supported in this langchain_core version" ,
153+ )
145154 def test_oci_gen_ai_serialization (self ):
146155 """Tests serialization of OCI Gen AI LLM."""
147156 try :
@@ -157,6 +166,10 @@ def test_oci_gen_ai_serialization(self):
157166 self .assertEqual (llm .compartment_id , self .COMPARTMENT_ID )
158167 self .assertEqual (llm .client_kwargs , self .GEN_AI_KWARGS )
159168
169+ @skipIf (
170+ version_tuple (langchain_core .__version__ ) > (0 , 1 , 50 ),
171+ "Serialization not supported in this langchain_core version" ,
172+ )
160173 def test_gen_ai_embeddings_serialization (self ):
161174 """Tests serialization of OCI Gen AI embeddings."""
162175 try :
@@ -201,10 +214,27 @@ def test_runnable_sequence_serialization(self):
201214 element_3 = kwargs .get ("last" )
202215 self .assertNotIn ("_type" , element_3 )
203216 self .assertEqual (element_3 .get ("id" ), ["ads" , "llm" , "ModelDeploymentTGI" ])
204- self .assertEqual (
205- element_3 .get ("kwargs" ),
206- {"endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" },
207- )
217+
218+ if version_tuple (langchain_core .__version__ ) > (0 , 1 , 50 ):
219+ self .assertEqual (
220+ element_3 .get ("kwargs" ),
221+ {
222+ "max_tokens" : 256 ,
223+ "temperature" : 0.2 ,
224+ "p" : 0.75 ,
225+ "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
226+ "best_of" : 1 ,
227+ "do_sample" : True ,
228+ "watermark" : True ,
229+ },
230+ )
231+ else :
232+ self .assertEqual (
233+ element_3 .get ("kwargs" ),
234+ {
235+ "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
236+ },
237+ )
208238
209239 chain = load (serialized )
210240 self .assertEqual (len (chain .steps ), 3 )
0 commit comments