47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
|
|
"""Rerank 服务:调用硅基流动 bge-reranker-v2-m3 对检索结果精排"""
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL, RERANK_MODEL
|
||
|
|
|
||
|
|
|
||
|
|
async def rerank(
|
||
|
|
query: str,
|
||
|
|
documents: list[str],
|
||
|
|
top_n: int = 5,
|
||
|
|
) -> list[dict]:
|
||
|
|
"""对候选文档重新打分排序,返回 top_n 结果。
|
||
|
|
|
||
|
|
每个结果包含:
|
||
|
|
- index: 原始 documents 数组中的位置
|
||
|
|
- score: 相关性分数 (0.0-1.0)
|
||
|
|
- text: 文档原文
|
||
|
|
"""
|
||
|
|
if not documents:
|
||
|
|
return []
|
||
|
|
|
||
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
||
|
|
resp = await client.post(
|
||
|
|
f"{SILICONFLOW_BASE_URL}/rerank",
|
||
|
|
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
||
|
|
json={
|
||
|
|
"model": RERANK_MODEL,
|
||
|
|
"query": query,
|
||
|
|
"documents": documents,
|
||
|
|
"top_n": min(top_n, len(documents)),
|
||
|
|
"return_documents": True,
|
||
|
|
"max_chunks_per_doc": 1024,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
if resp.status_code != 200:
|
||
|
|
raise RuntimeError(f"Rerank API error: {resp.status_code} {resp.text}")
|
||
|
|
|
||
|
|
data = resp.json()
|
||
|
|
return [
|
||
|
|
{
|
||
|
|
"index": r["index"],
|
||
|
|
"score": r["relevance_score"],
|
||
|
|
"text": r.get("document", {}).get("text", documents[r["index"]]),
|
||
|
|
}
|
||
|
|
for r in data["results"]
|
||
|
|
]
|