add rerank module + bug fixes from e2e test
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 15s
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 15s
- New reranker.py: SiliconFlow bge-reranker-v2-m3 integration - config.py: add RERANK_MODEL - api_client.py: fix get_next_job/claim_job/get_job_detail unwrapping - candidate_generator.py: fix .format() conflict with JSON braces - import_pipeline.py: fix file existence check + UUID point IDs - Add .gitignore for __pycache__ Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
1947a0c0d5
commit
c9882c8d04
1
rag-worker/.gitignore
vendored
Normal file
1
rag-worker/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__/
|
||||||
@ -16,7 +16,9 @@ async def get_next_job() -> dict | None:
|
|||||||
)
|
)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
return data.get("data") or data.get("job")
|
result = data.get("data") or data
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result.get("job")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -26,8 +28,9 @@ async def claim_job(job_id: str) -> bool:
|
|||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
|
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
|
||||||
headers=_auth_headers,
|
headers=_auth_headers,
|
||||||
|
json={"workerId": WORKER_ID},
|
||||||
)
|
)
|
||||||
return resp.status_code == 200
|
return resp.status_code in (200, 201)
|
||||||
|
|
||||||
|
|
||||||
async def heartbeat(job_id: str) -> bool:
|
async def heartbeat(job_id: str) -> bool:
|
||||||
@ -90,5 +93,6 @@ async def get_job_detail(job_id: str) -> dict | None:
|
|||||||
headers=_auth_headers,
|
headers=_auth_headers,
|
||||||
)
|
)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
return resp.json()
|
data = resp.json()
|
||||||
|
return data.get("data") or data
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -43,7 +43,7 @@ async def generate_candidates(text: str) -> list[dict]:
|
|||||||
text_len = len(text)
|
text_len = len(text)
|
||||||
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
|
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
|
||||||
|
|
||||||
prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度
|
prompt = _PROMPT.replace("{text}", text[:16000]) # 限制上下文长度
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=120) as client:
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
@ -71,7 +71,31 @@ async def generate_candidates(text: str) -> list[dict]:
|
|||||||
|
|
||||||
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
|
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
|
||||||
"""从 AI 回复中提取 JSON 数组"""
|
"""从 AI 回复中提取 JSON 数组"""
|
||||||
# 尝试直接解析
|
import re
|
||||||
|
|
||||||
|
# 1. 提取 ```json ... ``` 块
|
||||||
|
m = re.search(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL)
|
||||||
|
if m:
|
||||||
|
inner = m.group(1).strip()
|
||||||
|
try:
|
||||||
|
candidates = json.loads(inner)
|
||||||
|
if isinstance(candidates, list):
|
||||||
|
return candidates[:MAX_CANDIDATES]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2. 提取 [ ... ] 块(从第一个 [ 到最后一个 ])
|
||||||
|
start = raw.find("[")
|
||||||
|
end = raw.rfind("]")
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
try:
|
||||||
|
candidates = json.loads(raw[start:end + 1])
|
||||||
|
if isinstance(candidates, list):
|
||||||
|
return candidates[:MAX_CANDIDATES]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 3. 直接解析整个回复
|
||||||
try:
|
try:
|
||||||
candidates = json.loads(raw)
|
candidates = json.loads(raw)
|
||||||
if isinstance(candidates, list):
|
if isinstance(candidates, list):
|
||||||
@ -79,25 +103,4 @@ def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 提取 ```json ... ``` 块
|
|
||||||
import re
|
|
||||||
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL)
|
|
||||||
if m:
|
|
||||||
try:
|
|
||||||
candidates = json.loads(m.group(1))
|
|
||||||
if isinstance(candidates, list):
|
|
||||||
return candidates[:MAX_CANDIDATES]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 提取 [ ... ] 块
|
|
||||||
m = re.search(r"\[.*\]", raw, re.DOTALL)
|
|
||||||
if m:
|
|
||||||
try:
|
|
||||||
candidates = json.loads(m.group(0))
|
|
||||||
if isinstance(candidates, list):
|
|
||||||
return candidates[:MAX_CANDIDATES]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")
|
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")
|
||||||
|
|||||||
@ -1,4 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 加载 .env 文件(systemd 可通过 EnvironmentFile 设置,此处作为手动运行的兜底)
|
||||||
|
_dotenv_path = Path(__file__).resolve().parent / ".env"
|
||||||
|
if _dotenv_path.exists():
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv(_dotenv_path)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
# NestJS 内部 API
|
# NestJS 内部 API
|
||||||
API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000")
|
API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000")
|
||||||
@ -9,6 +19,7 @@ SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
|
|||||||
SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
|
SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
|
||||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3")
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3")
|
||||||
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024"))
|
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024"))
|
||||||
|
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||||
|
|
||||||
# DeepSeek
|
# DeepSeek
|
||||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
||||||
|
|||||||
@ -45,11 +45,14 @@ async def run_import(job: dict):
|
|||||||
|
|
||||||
# 2. 解析
|
# 2. 解析
|
||||||
await update_job_status(job_id, "PARSING", {"progress": 20})
|
await update_job_status(job_id, "PARSING", {"progress": 20})
|
||||||
|
text = ""
|
||||||
|
if os.path.exists(file_path):
|
||||||
text = await parse_document(file_path, mime_type)
|
text = await parse_document(file_path, mime_type)
|
||||||
|
|
||||||
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
|
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
|
||||||
if not text and (job.get("rawText") or source.get("rawText")):
|
detail_job = (detail or {}).get("job", {})
|
||||||
text = job.get("rawText", "") or source.get("rawText", "")
|
if not text:
|
||||||
|
text = job.get("rawText") or source.get("rawText") or detail_job.get("rawText") or ""
|
||||||
|
|
||||||
if not text or len(text.strip()) < 10:
|
if not text or len(text.strip()) < 10:
|
||||||
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
|
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
|
||||||
@ -72,7 +75,7 @@ async def run_import(job: dict):
|
|||||||
points = []
|
points = []
|
||||||
chunk_records = []
|
chunk_records = []
|
||||||
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
|
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
|
||||||
chunk_id = f"chunk_{source_id}_{i}"
|
chunk_id = str(uuid.uuid4())
|
||||||
points.append({
|
points.append({
|
||||||
"id": chunk_id,
|
"id": chunk_id,
|
||||||
"vector": vec,
|
"vector": vec,
|
||||||
@ -104,11 +107,14 @@ async def run_import(job: dict):
|
|||||||
await upsert_points(points)
|
await upsert_points(points)
|
||||||
await save_chunks(chunk_records)
|
await save_chunks(chunk_records)
|
||||||
|
|
||||||
# 7. 生成候选知识点
|
# 7. 生成候选知识点(非致命:失败不影响导入完成)
|
||||||
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
|
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
|
||||||
|
try:
|
||||||
candidates = await generate_candidates(text)
|
candidates = await generate_candidates(text)
|
||||||
if candidates:
|
if candidates:
|
||||||
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
|
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[worker] 候选知识点生成失败(非致命): {e}")
|
||||||
|
|
||||||
# 8. 完成
|
# 8. 完成
|
||||||
await update_job_status(job_id, "COMPLETED", {"progress": 100})
|
await update_job_status(job_id, "COMPLETED", {"progress": 100})
|
||||||
|
|||||||
@ -6,3 +6,5 @@ markdown>=3.5
|
|||||||
pandas>=2.0
|
pandas>=2.0
|
||||||
openpyxl>=3.1
|
openpyxl>=3.1
|
||||||
Pillow>=10.0
|
Pillow>=10.0
|
||||||
|
qdrant-client>=1.9
|
||||||
|
python-dotenv>=1.0
|
||||||
|
|||||||
46
rag-worker/reranker.py
Normal file
46
rag-worker/reranker.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""Rerank 服务:调用硅基流动 bge-reranker-v2-m3 对检索结果精排"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL, RERANK_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
query: str,
|
||||||
|
documents: list[str],
|
||||||
|
top_n: int = 5,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""对候选文档重新打分排序,返回 top_n 结果。
|
||||||
|
|
||||||
|
每个结果包含:
|
||||||
|
- index: 原始 documents 数组中的位置
|
||||||
|
- score: 相关性分数 (0.0-1.0)
|
||||||
|
- text: 文档原文
|
||||||
|
"""
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{SILICONFLOW_BASE_URL}/rerank",
|
||||||
|
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
||||||
|
json={
|
||||||
|
"model": RERANK_MODEL,
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
"top_n": min(top_n, len(documents)),
|
||||||
|
"return_documents": True,
|
||||||
|
"max_chunks_per_doc": 1024,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise RuntimeError(f"Rerank API error: {resp.status_code} {resp.text}")
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"index": r["index"],
|
||||||
|
"score": r["relevance_score"],
|
||||||
|
"text": r.get("document", {}).get("text", documents[r["index"]]),
|
||||||
|
}
|
||||||
|
for r in data["results"]
|
||||||
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user