11# -*- coding: utf-8 -*-
2- import time
3- import uvicorn
42import asyncio
5- import logging
6- import configparser
3+ from contextlib import asynccontextmanager
4+ import uvicorn
75import json
8- from fastapi import FastAPI , Request , HTTPException
9- from pydantic import BaseModel
10- from concurrent .futures import ThreadPoolExecutor
11- from starlette .responses import PlainTextResponse
12- import functools
13-
14- from modelcache import cache
15- from modelcache .adapter import adapter
16- from modelcache .manager import CacheBase , VectorBase , get_data_manager
17- from modelcache .similarity_evaluation .distance import SearchDistanceEvaluation
18- from modelcache .processor .pre import query_multi_splicing
19- from modelcache .processor .pre import insert_multi_splicing
20- from modelcache .utils .model_filter import model_blacklist_filter
21- from modelcache .embedding import Data2VecAudio
22-
23- #创建一个FastAPI实例
24- app = FastAPI ()
25-
26- class RequestData (BaseModel ):
27- type : str
28- scope : dict = None
29- query : str = None
30- chat_info : dict = None
31- remove_type : str = None
32- id_list : list = []
33-
34- data2vec = Data2VecAudio ()
35- mysql_config = configparser .ConfigParser ()
36- mysql_config .read ('modelcache/config/mysql_config.ini' )
37-
38- milvus_config = configparser .ConfigParser ()
39- milvus_config .read ('modelcache/config/milvus_config.ini' )
40-
41- # redis_config = configparser.ConfigParser()
42- # redis_config.read('modelcache/config/redis_config.ini')
43-
44- # 初始化datamanager
45- data_manager = get_data_manager (
46- CacheBase ("mysql" , config = mysql_config ),
47- VectorBase ("milvus" , dimension = data2vec .dimension , milvus_config = milvus_config )
48- )
49-
50- # # 使用redis初始化datamanager
51- # data_manager = get_data_manager(
52- # CacheBase("mysql", config=mysql_config),
53- # VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
54- # )
55-
56- cache .init (
57- embedding_func = data2vec .to_embeddings ,
58- data_manager = data_manager ,
59- similarity_evaluation = SearchDistanceEvaluation (),
60- query_pre_embedding_func = query_multi_splicing ,
61- insert_pre_embedding_func = insert_multi_splicing ,
62- )
63-
64- executor = ThreadPoolExecutor (max_workers = 6 )
65-
66- # 异步保存查询信息
67- async def save_query_info (result , model , query , delta_time_log ):
68- loop = asyncio .get_running_loop ()
69- func = functools .partial (cache .data_manager .save_query_resp , result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log )
70- await loop .run_in_executor (None , func )
71-
72-
73-
74- @app .get ("/welcome" , response_class = PlainTextResponse )
6+ from fastapi .responses import JSONResponse
7+ from fastapi import FastAPI , Request
8+ from modelcache .cache import Cache
9+ from modelcache .embedding import EmbeddingModel
10+
11+ @asynccontextmanager
12+ async def lifespan (app : FastAPI ):
13+ global cache
14+ cache , _ = await Cache .init (
15+ sql_storage = "mysql" ,
16+ vector_storage = "milvus" ,
17+ embedding_model = EmbeddingModel .HUGGINGFACE_ALL_MPNET_BASE_V2 ,
18+ embedding_workers_num = 2
19+ )
20+ yield
21+
22+ app = FastAPI (lifespan = lifespan )
23+ cache : Cache = None
24+
25+ @app .get ("/welcome" )
7526async def first_fastapi ():
7627 return "hello, modelcache!"
7728
7829@app .post ("/modelcache" )
7930async def user_backend (request : Request ):
80- try :
81- raw_body = await request .body ()
82- # 解析字符串为JSON对象
83- if isinstance (raw_body , bytes ):
84- raw_body = raw_body .decode ("utf-8" )
85- if isinstance (raw_body , str ):
86- try :
87- # 尝试将字符串解析为JSON对象
88- request_data = json .loads (raw_body )
89- except json .JSONDecodeError as e :
90- # 如果无法解析,返回格式错误
91- result = {"errorCode" : 101 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' ,
92- "answer" : '' }
93- asyncio .create_task (save_query_info (result , model = '' , query = '' , delta_time_log = 0 ))
94- raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
95- else :
96- request_data = raw_body
97-
98- # 确保request_data是字典对象
99- if isinstance (request_data , str ):
100- try :
101- request_data = json .loads (request_data )
102- except json .JSONDecodeError :
103- raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
104-
105- request_type = request_data .get ('type' )
106- model = None
107- if 'scope' in request_data :
108- model = request_data ['scope' ].get ('model' , '' ).replace ('-' , '_' ).replace ('.' , '_' )
109- query = request_data .get ('query' )
110- chat_info = request_data .get ('chat_info' )
11131
112- if not request_type or request_type not in ['query' , 'insert' , 'remove' , 'register' ]:
113- result = {"errorCode" : 102 ,
114- "errorDesc" : "type exception, should one of ['query', 'insert', 'remove', 'register']" ,
115- "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
116- asyncio .create_task (save_query_info (result , model = model , query = '' , delta_time_log = 0 ))
117- raise HTTPException (status_code = 102 , detail = "Type exception, should be one of ['query', 'insert', 'remove', 'register']" )
32+ try :
33+ request_data = await request .json ()
34+ except Exception :
35+ result = {"errorCode" : 400 , "errorDesc" : "bad request" , "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
36+ return JSONResponse (status_code = 400 , content = result )
11837
38+ try :
39+ return await cache .handle_request (request_data )
11940 except Exception as e :
120- request_data = raw_body if 'raw_body' in locals () else None
121- result = {
122- "errorCode" : 103 ,
123- "errorDesc" : str (e ),
124- "cacheHit" : False ,
125- "delta_time" : 0 ,
126- "hit_query" : '' ,
127- "answer" : '' ,
128- "para_dict" : request_data
129- }
130- return result
131-
132-
133- # model filter
134- filter_resp = model_blacklist_filter (model , request_type )
135- if isinstance (filter_resp , dict ):
136- return filter_resp
137-
138- if request_type == 'query' :
139- try :
140- start_time = time .time ()
141- response = adapter .ChatCompletion .create_query (scope = {"model" : model }, query = query )
142- delta_time = f"{ round (time .time () - start_time , 2 )} s"
143-
144- if response is None :
145- result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : False , "delta_time" : delta_time , "hit_query" : '' , "answer" : '' }
146- elif response in ['adapt_query_exception' ]:
147- result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
148- "hit_query" : '' , "answer" : '' }
149- else :
150- answer = response ['data' ]
151- hit_query = response ['hitQuery' ]
152- result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time , "hit_query" : hit_query , "answer" : answer }
153-
154- delta_time_log = round (time .time () - start_time , 2 )
155- asyncio .create_task (save_query_info (result , model , query , delta_time_log ))
156- return result
157- except Exception as e :
158- result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
159- "hit_query" : '' , "answer" : '' }
160- logging .info (f'result: { str (result )} ' )
161- return result
162-
163- if request_type == 'insert' :
164- try :
165- response = adapter .ChatCompletion .create_insert (model = model , chat_info = chat_info )
166- if response == 'success' :
167- return {"errorCode" : 0 , "errorDesc" : "" , "writeStatus" : "success" }
168- else :
169- return {"errorCode" : 301 , "errorDesc" : response , "writeStatus" : "exception" }
170- except Exception as e :
171- return {"errorCode" : 303 , "errorDesc" : str (e ), "writeStatus" : "exception" }
172-
173- if request_type == 'remove' :
174- response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .get ("remove_type" ), id_list = request_data .get ("id_list" ))
175- if not isinstance (response , dict ):
176- return {"errorCode" : 401 , "errorDesc" : "" , "response" : response , "removeStatus" : "exception" }
177-
178- state = response .get ('status' )
179- if state == 'success' :
180- return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
181- else :
182- return {"errorCode" : 402 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
183-
184- if request_type == 'register' :
185- response = adapter .ChatCompletion .create_register (model = model )
186- if response in ['create_success' , 'already_exists' ]:
187- return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
188- else :
189- return {"errorCode" : 502 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
41+ result = {"errorCode" : 500 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
42+ cache .save_query_resp (result , model = '' , query = '' , delta_time = 0 )
43+ return JSONResponse (status_code = 500 , content = result )
19044
191- # TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
19245if __name__ == '__main__' :
193- uvicorn .run (app , host = '0.0.0.0' , port = 5000 )
46+ uvicorn .run (app , host = '0.0.0.0' , port = 5000 , loop = "asyncio" , http = "httptools" )
0 commit comments