diff --git a/backend/llm_model/embeddings.py b/backend/llm_model/embeddings.py index e0bbd1a..4fb17cb 100644 --- a/backend/llm_model/embeddings.py +++ b/backend/llm_model/embeddings.py @@ -8,6 +8,19 @@ from config.db_config import db, db_session_manager from config.logging_config import logger +# OpenAI / Azure-OpenAI allows up to 300 000 tokens per embedding request. +# Leave some headroom to reduce the chance of repeated retries at the limit. +_MAX_TOKENS_PER_REQ = 290_000 + + +def _estimate_tokens(text: str) -> int: + """ + Roughly estimate how many tokens `text` consumes. + For English models, on average 1 token ≈ 4 characters; for Chinese, 1 token ≈ 1.3–2 characters. + We use a compromise value of 3.5 characters per token to ensure a safer upper-bound estimate. + """ + return max(1, int(len(text) / 3.5)) + class EmbeddingManager: """Embedding Manager""" @@ -144,9 +157,31 @@ async def _get_embeddings_with_context(text: Union[str, List[str]], model_name: if isinstance(text, str): embedding = await embedding_model.aembed_query(text[:8192]) else: - text = [t[:8192] for t in text] - embedding = await embedding_model.aembed_documents(text) + # First, trim each text to 8 192 characters + texts = [t[:8192] for t in text] + + # —— Batching logic —— # + batches, cur_batch, cur_tokens = [], [], 0 + for t in texts: + tok = _estimate_tokens(t) + # If adding `t` would exceed the per-request token limit, finalize the current batch + if cur_batch and cur_tokens + tok > _MAX_TOKENS_PER_REQ: + batches.append(cur_batch) + cur_batch, cur_tokens = [], 0 + cur_batch.append(t) + cur_tokens += tok + if cur_batch: # Process the last batch + batches.append(cur_batch) + + # Send requests sequentially to preserve output order + embedding = [] + for bt in batches: + bt_emb = await embedding_model.aembed_documents(bt) + embedding.extend(bt_emb) + return np.array(embedding) + except Exception as e: logger.error(f"Failed to generate Embedding: {str(e)}") raise +