64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
|
|
"""Embedding 服务:调用硅基流动 bge-m3"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import httpx
|
|||
|
|
from config import (
|
|||
|
|
SILICONFLOW_API_KEY,
|
|||
|
|
SILICONFLOW_BASE_URL,
|
|||
|
|
EMBEDDING_MODEL,
|
|||
|
|
EMBEDDING_DIM,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
BATCH_SIZE = 50
|
|||
|
|
MAX_RETRIES = 2
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def embed_single(text: str) -> list[float]:
|
|||
|
|
"""单条文本 embedding"""
|
|||
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|||
|
|
resp = await client.post(
|
|||
|
|
f"{SILICONFLOW_BASE_URL}/embeddings",
|
|||
|
|
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
|||
|
|
json={
|
|||
|
|
"model": EMBEDDING_MODEL,
|
|||
|
|
"input": [text],
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
if resp.status_code != 200:
|
|||
|
|
raise RuntimeError(f"Embedding API error: {resp.status_code} {resp.text}")
|
|||
|
|
data = resp.json()
|
|||
|
|
return data["data"][0]["embedding"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def embed_batch(texts: list[str]) -> list[list[float]]:
|
|||
|
|
"""批量 embedding,自动分批 + 重试"""
|
|||
|
|
all_embeddings = []
|
|||
|
|
|
|||
|
|
for i in range(0, len(texts), BATCH_SIZE):
|
|||
|
|
batch = texts[i:i + BATCH_SIZE]
|
|||
|
|
for attempt in range(MAX_RETRIES + 1):
|
|||
|
|
try:
|
|||
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|||
|
|
resp = await client.post(
|
|||
|
|
f"{SILICONFLOW_BASE_URL}/embeddings",
|
|||
|
|
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
|||
|
|
json={
|
|||
|
|
"model": EMBEDDING_MODEL,
|
|||
|
|
"input": batch,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
if resp.status_code == 200:
|
|||
|
|
data = resp.json()
|
|||
|
|
all_embeddings.extend([d["embedding"] for d in data["data"]])
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
err = f"Status {resp.status_code}"
|
|||
|
|
if attempt == MAX_RETRIES:
|
|||
|
|
raise RuntimeError(f"Embedding batch failed after {MAX_RETRIES} retries: {err}")
|
|||
|
|
except Exception as e:
|
|||
|
|
if attempt == MAX_RETRIES:
|
|||
|
|
raise RuntimeError(f"Embedding batch failed: {e}")
|
|||
|
|
await asyncio.sleep(2 ** attempt)
|
|||
|
|
|
|||
|
|
return all_embeddings
|