"""Embedding 服务:调用硅基流动 bge-m3""" import asyncio import httpx from config import ( SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL, EMBEDDING_MODEL, EMBEDDING_DIM, ) BATCH_SIZE = 50 MAX_RETRIES = 2 async def embed_single(text: str) -> list[float]: """单条文本 embedding""" async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{SILICONFLOW_BASE_URL}/embeddings", headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}, json={ "model": EMBEDDING_MODEL, "input": [text], }, ) if resp.status_code != 200: raise RuntimeError(f"Embedding API error: {resp.status_code} {resp.text}") data = resp.json() return data["data"][0]["embedding"] async def embed_batch(texts: list[str]) -> list[list[float]]: """批量 embedding,自动分批 + 重试""" all_embeddings = [] for i in range(0, len(texts), BATCH_SIZE): batch = texts[i:i + BATCH_SIZE] for attempt in range(MAX_RETRIES + 1): try: async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{SILICONFLOW_BASE_URL}/embeddings", headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}, json={ "model": EMBEDDING_MODEL, "input": batch, }, ) if resp.status_code == 200: data = resp.json() all_embeddings.extend([d["embedding"] for d in data["data"]]) break else: err = f"Status {resp.status_code}" if attempt == MAX_RETRIES: raise RuntimeError(f"Embedding batch failed after {MAX_RETRIES} retries: {err}") except Exception as e: if attempt == MAX_RETRIES: raise RuntimeError(f"Embedding batch failed: {e}") await asyncio.sleep(2 ** attempt) return all_embeddings