add rerank module + bug fixes from e2e test
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 15s
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 15s
- New reranker.py: SiliconFlow bge-reranker-v2-m3 integration - config.py: add RERANK_MODEL - api_client.py: fix get_next_job/claim_job/get_job_detail unwrapping - candidate_generator.py: fix .format() conflict with JSON braces - import_pipeline.py: fix file existence check + UUID point IDs - Add .gitignore for __pycache__ Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
1947a0c0d5
commit
c9882c8d04
1
rag-worker/.gitignore
vendored
Normal file
1
rag-worker/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__/
|
||||
@ -16,7 +16,9 @@ async def get_next_job() -> dict | None:
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
return data.get("data") or data.get("job")
|
||||
result = data.get("data") or data
|
||||
if isinstance(result, dict):
|
||||
return result.get("job")
|
||||
return None
|
||||
|
||||
|
||||
@ -26,8 +28,9 @@ async def claim_job(job_id: str) -> bool:
|
||||
resp = await client.post(
|
||||
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
|
||||
headers=_auth_headers,
|
||||
json={"workerId": WORKER_ID},
|
||||
)
|
||||
return resp.status_code == 200
|
||||
return resp.status_code in (200, 201)
|
||||
|
||||
|
||||
async def heartbeat(job_id: str) -> bool:
|
||||
@ -90,5 +93,6 @@ async def get_job_detail(job_id: str) -> dict | None:
|
||||
headers=_auth_headers,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
data = resp.json()
|
||||
return data.get("data") or data
|
||||
return None
|
||||
|
||||
@ -43,7 +43,7 @@ async def generate_candidates(text: str) -> list[dict]:
|
||||
text_len = len(text)
|
||||
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
|
||||
|
||||
prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度
|
||||
prompt = _PROMPT.replace("{text}", text[:16000]) # 限制上下文长度
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = await client.post(
|
||||
@ -71,7 +71,31 @@ async def generate_candidates(text: str) -> list[dict]:
|
||||
|
||||
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
|
||||
"""从 AI 回复中提取 JSON 数组"""
|
||||
# 尝试直接解析
|
||||
import re
|
||||
|
||||
# 1. 提取 ```json ... ``` 块
|
||||
m = re.search(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL)
|
||||
if m:
|
||||
inner = m.group(1).strip()
|
||||
try:
|
||||
candidates = json.loads(inner)
|
||||
if isinstance(candidates, list):
|
||||
return candidates[:MAX_CANDIDATES]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 2. 提取 [ ... ] 块(从第一个 [ 到最后一个 ])
|
||||
start = raw.find("[")
|
||||
end = raw.rfind("]")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
try:
|
||||
candidates = json.loads(raw[start:end + 1])
|
||||
if isinstance(candidates, list):
|
||||
return candidates[:MAX_CANDIDATES]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 3. 直接解析整个回复
|
||||
try:
|
||||
candidates = json.loads(raw)
|
||||
if isinstance(candidates, list):
|
||||
@ -79,25 +103,4 @@ def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 提取 ```json ... ``` 块
|
||||
import re
|
||||
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL)
|
||||
if m:
|
||||
try:
|
||||
candidates = json.loads(m.group(1))
|
||||
if isinstance(candidates, list):
|
||||
return candidates[:MAX_CANDIDATES]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 提取 [ ... ] 块
|
||||
m = re.search(r"\[.*\]", raw, re.DOTALL)
|
||||
if m:
|
||||
try:
|
||||
candidates = json.loads(m.group(0))
|
||||
if isinstance(candidates, list):
|
||||
return candidates[:MAX_CANDIDATES]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")
|
||||
|
||||
@ -1,4 +1,14 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 加载 .env 文件(systemd 可通过 EnvironmentFile 设置,此处作为手动运行的兜底)
|
||||
_dotenv_path = Path(__file__).resolve().parent / ".env"
|
||||
if _dotenv_path.exists():
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(_dotenv_path)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# NestJS 内部 API
|
||||
API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000")
|
||||
@ -9,6 +19,7 @@ SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
|
||||
SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3")
|
||||
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024"))
|
||||
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
||||
|
||||
@ -45,11 +45,14 @@ async def run_import(job: dict):
|
||||
|
||||
# 2. 解析
|
||||
await update_job_status(job_id, "PARSING", {"progress": 20})
|
||||
text = await parse_document(file_path, mime_type)
|
||||
text = ""
|
||||
if os.path.exists(file_path):
|
||||
text = await parse_document(file_path, mime_type)
|
||||
|
||||
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
|
||||
if not text and (job.get("rawText") or source.get("rawText")):
|
||||
text = job.get("rawText", "") or source.get("rawText", "")
|
||||
detail_job = (detail or {}).get("job", {})
|
||||
if not text:
|
||||
text = job.get("rawText") or source.get("rawText") or detail_job.get("rawText") or ""
|
||||
|
||||
if not text or len(text.strip()) < 10:
|
||||
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
|
||||
@ -72,7 +75,7 @@ async def run_import(job: dict):
|
||||
points = []
|
||||
chunk_records = []
|
||||
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
|
||||
chunk_id = f"chunk_{source_id}_{i}"
|
||||
chunk_id = str(uuid.uuid4())
|
||||
points.append({
|
||||
"id": chunk_id,
|
||||
"vector": vec,
|
||||
@ -104,11 +107,14 @@ async def run_import(job: dict):
|
||||
await upsert_points(points)
|
||||
await save_chunks(chunk_records)
|
||||
|
||||
# 7. 生成候选知识点
|
||||
# 7. 生成候选知识点(非致命:失败不影响导入完成)
|
||||
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
|
||||
candidates = await generate_candidates(text)
|
||||
if candidates:
|
||||
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
|
||||
try:
|
||||
candidates = await generate_candidates(text)
|
||||
if candidates:
|
||||
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
|
||||
except Exception as e:
|
||||
print(f"[worker] 候选知识点生成失败(非致命): {e}")
|
||||
|
||||
# 8. 完成
|
||||
await update_job_status(job_id, "COMPLETED", {"progress": 100})
|
||||
|
||||
@ -6,3 +6,5 @@ markdown>=3.5
|
||||
pandas>=2.0
|
||||
openpyxl>=3.1
|
||||
Pillow>=10.0
|
||||
qdrant-client>=1.9
|
||||
python-dotenv>=1.0
|
||||
|
||||
46
rag-worker/reranker.py
Normal file
46
rag-worker/reranker.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Rerank 服务:调用硅基流动 bge-reranker-v2-m3 对检索结果精排"""
|
||||
|
||||
import httpx
|
||||
from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL, RERANK_MODEL
|
||||
|
||||
|
||||
async def rerank(
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int = 5,
|
||||
) -> list[dict]:
|
||||
"""对候选文档重新打分排序,返回 top_n 结果。
|
||||
|
||||
每个结果包含:
|
||||
- index: 原始 documents 数组中的位置
|
||||
- score: 相关性分数 (0.0-1.0)
|
||||
- text: 文档原文
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{SILICONFLOW_BASE_URL}/rerank",
|
||||
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
||||
json={
|
||||
"model": RERANK_MODEL,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": min(top_n, len(documents)),
|
||||
"return_documents": True,
|
||||
"max_chunks_per_doc": 1024,
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"Rerank API error: {resp.status_code} {resp.text}")
|
||||
|
||||
data = resp.json()
|
||||
return [
|
||||
{
|
||||
"index": r["index"],
|
||||
"score": r["relevance_score"],
|
||||
"text": r.get("document", {}).get("text", documents[r["index"]]),
|
||||
}
|
||||
for r in data["results"]
|
||||
]
|
||||
Loading…
x
Reference in New Issue
Block a user