22# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
5- from typing import List , Union
5+ from typing import List , Optional , Union
66from urllib .parse import urlparse
77
88from tornado .web import HTTPError
99
10+ from ads .aqua .app import logger
11+ from ads .aqua .client .client import Client , ExtendedRequestError
1012from ads .aqua .common .decorator import handle_exceptions
13+ from ads .aqua .common .enums import PredictEndpoints
1114from ads .aqua .extension .base_handler import AquaAPIhandler
1215from ads .aqua .extension .errors import Errors
13- from ads .aqua .modeldeployment import AquaDeploymentApp , MDInferenceResponse
14- from ads .aqua .modeldeployment .entities import ModelParams
16+ from ads .aqua .modeldeployment import AquaDeploymentApp
1517from ads .config import COMPARTMENT_OCID
1618
1719
@@ -175,23 +177,107 @@ def list_shapes(self):
175177 )
176178
177179
178- class AquaDeploymentInferenceHandler (AquaAPIhandler ):
179- @staticmethod
180- def validate_predict_url (endpoint ):
181- try :
182- url = urlparse (endpoint )
183- if url .scheme != "https" :
184- return False
185- if not url .netloc :
186- return False
187- return url .path .endswith ("/predict" )
188- except Exception :
189- return False
180+ class AquaDeploymentStreamingInferenceHandler (AquaAPIhandler ):
181+ def _get_model_deployment_response (
182+ self ,
183+ model_deployment_id : str ,
184+ payload : dict ,
185+ route_override_header : Optional [str ],
186+ ):
187+ """
188+ Returns the model deployment inference response in a streaming fashion.
189+
190+ This method connects to the specified model deployment endpoint and
191+ streams the inference output back to the caller, handling both text
192+ and chat completion endpoints depending on the route override.
193+
194+ Parameters
195+ ----------
196+ model_deployment_id : str
197+ The OCID of the model deployment to invoke.
198+ Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz'
199+
200+ payload : dict
201+ Dictionary containing the model inference parameters.
202+ Same example for text completions:
203+ {
204+ "max_tokens": 1024,
205+ "temperature": 0.5,
206+ "prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?",
207+ "top_p": 0.4,
208+ "top_k": 100,
209+ "model": "odsc-llm",
210+ "frequency_penalty": 1,
211+ "presence_penalty": 1,
212+ "stream": true
213+ }
214+
215+ route_override_header : Optional[str]
216+ Optional override for the inference route, used for routing between
217+ different endpoint types (e.g., chat vs. text completions).
218+ Example: '/v1/chat/completions'
219+
220+ Returns
221+ -------
222+ Generator[str]
223+ A generator that yields strings of the model's output as they are received.
224+
225+ Raises
226+ ------
227+ HTTPError
228+ If the request to the model deployment fails or if streaming cannot be established.
229+ """
230+
231+ model_deployment = AquaDeploymentApp ().get (model_deployment_id )
232+ endpoint = model_deployment .endpoint + "/predictWithResponseStream"
233+ endpoint_type = model_deployment .environment_variables .get (
234+ "MODEL_DEPLOY_PREDICT_ENDPOINT" , PredictEndpoints .TEXT_COMPLETIONS_ENDPOINT
235+ )
236+ aqua_client = Client (endpoint = endpoint )
237+
238+ if PredictEndpoints .CHAT_COMPLETIONS_ENDPOINT in (
239+ endpoint_type ,
240+ route_override_header ,
241+ ):
242+ try :
243+ for chunk in aqua_client .chat (
244+ messages = payload .pop ("messages" ),
245+ payload = payload ,
246+ stream = True ,
247+ ):
248+ try :
249+ yield chunk ["choices" ][0 ]["delta" ]["content" ]
250+ except Exception as e :
251+ logger .debug (
252+ f"Exception occurred while parsing streaming response: { e } "
253+ )
254+ except ExtendedRequestError as ex :
255+ raise HTTPError (400 , str (ex ))
256+ except Exception as ex :
257+ raise HTTPError (500 , str (ex ))
258+
259+ elif endpoint_type == PredictEndpoints .TEXT_COMPLETIONS_ENDPOINT :
260+ try :
261+ for chunk in aqua_client .generate (
262+ prompt = payload .pop ("prompt" ),
263+ payload = payload ,
264+ stream = True ,
265+ ):
266+ try :
267+ yield chunk ["choices" ][0 ]["text" ]
268+ except Exception as e :
269+ logger .debug (
270+ f"Exception occurred while parsing streaming response: { e } "
271+ )
272+ except ExtendedRequestError as ex :
273+ raise HTTPError (400 , str (ex ))
274+ except Exception as ex :
275+ raise HTTPError (500 , str (ex ))
190276
191277 @handle_exceptions
192- def post (self , * args , ** kwargs ): # noqa: ARG002
278+ def post (self , model_deployment_id ):
193279 """
194- Handles inference request for the Active Model Deployments
280+ Handles streaming inference request for the Active Model Deployments
195281 Raises
196282 ------
197283 HTTPError
@@ -205,32 +291,29 @@ def post(self, *args, **kwargs): # noqa: ARG002
205291 if not input_data :
206292 raise HTTPError (400 , Errors .NO_INPUT_DATA )
207293
208- endpoint = input_data .get ("endpoint" )
209- if not endpoint :
210- raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("endpoint" ))
211-
212- if not self .validate_predict_url (endpoint ):
213- raise HTTPError (400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("endpoint" ))
214-
215294 prompt = input_data .get ("prompt" )
216- if not prompt :
217- raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("prompt" ))
295+ messages = input_data .get ("messages" )
218296
219- model_params = (
220- input_data .get ("model_params" ) if input_data .get ("model_params" ) else {}
221- )
222- try :
223- model_params_obj = ModelParams (** model_params )
224- except Exception as ex :
297+ if not prompt and not messages :
225298 raise HTTPError (
226- 400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("model_params" )
227- ) from ex
228-
229- return self .finish (
230- MDInferenceResponse (prompt , model_params_obj ).get_model_deployment_response (
231- endpoint
299+ 400 , Errors .MISSING_REQUIRED_PARAMETER .format ("prompt/messages" )
232300 )
301+ if not input_data .get ("model" ):
302+ raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("model" ))
303+ route_override_header = self .request .headers .get ("route" , None )
304+ self .set_header ("Content-Type" , "text/event-stream" )
305+ response_gen = self ._get_model_deployment_response (
306+ model_deployment_id , input_data , route_override_header
233307 )
308+ try :
309+ for chunk in response_gen :
310+ self .write (chunk )
311+ self .flush ()
312+ self .finish ()
313+ except Exception as ex :
314+ self .set_status (ex .status_code )
315+ self .write ({"message" : "Error occurred" , "reason" : str (ex )})
316+ self .finish ()
234317
235318
236319class AquaDeploymentParamsHandler (AquaAPIhandler ):
@@ -294,5 +377,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
294377 ("deployments/?([^/]*)" , AquaDeploymentHandler ),
295378 ("deployments/?([^/]*)/activate" , AquaDeploymentHandler ),
296379 ("deployments/?([^/]*)/deactivate" , AquaDeploymentHandler ),
297- ("inference" , AquaDeploymentInferenceHandler ),
380+ ("inference/stream/?([^/]*) " , AquaDeploymentStreamingInferenceHandler ),
298381]
0 commit comments