api-server/rag-worker/candidate_generator.py

107 lines
3.4 KiB
Python
Raw Normal View History

"""候选知识点生成:调用 DeepSeek 分析文本,生成 ImportCandidate"""
import json
import httpx
from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL
MAX_CANDIDATES = 30
MIN_CANDIDATES = 3
CHARS_PER_CANDIDATE = 2000
_PROMPT = """你是一个学习助手。请分析以下文档内容,提取关键知识点。
对于每个知识点请提供
- title: 知识点标题简洁不超过 30
- summary: 一句话概述不超过 80
- content: 详细解释基于原文保持准确
- tags: 2-4 个标签
- recallQuestions: 1-2 个主动回忆问题
- difficulty: 难度评估easy/medium/hard
- confidence: 你对这个知识点重要性的置信度0.0-1.0
请以 JSON 数组格式返回每个元素是一个知识点
```json
[{
"title": "知识点标题",
"summary": "一句话概述",
"content": "详细解释...",
"tags": ["标签1", "标签2"],
"recallQuestions": ["问题1", "问题2"],
"difficulty": "medium",
"confidence": 0.85
}]
```
文档内容
{text}
"""
async def generate_candidates(text: str) -> list[dict]:
"""用 DeepSeek 生成候选知识点"""
# 估算生成数量
text_len = len(text)
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
prompt = _PROMPT.replace("{text}", text[:16000]) # 限制上下文长度
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"{DEEPSEEK_BASE_URL}/chat/completions",
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"},
json={
"model": DEEPSEEK_MODEL,
"messages": [
{"role": "system", "content": "你是一个专业的学习内容分析师。请始终返回有效的 JSON 数组。"},
{"role": "user", "content": prompt},
],
"temperature": 0.3,
"max_tokens": 4096,
},
)
if resp.status_code != 200:
raise RuntimeError(f"DeepSeek API error: {resp.status_code} {resp.text}")
data = resp.json()
raw = data["choices"][0]["message"]["content"]
# 提取 JSON
return _parse_json_response(raw, expected_count)
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
"""从 AI 回复中提取 JSON 数组"""
import re
# 1. 提取 ```json ... ``` 块
m = re.search(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL)
if m:
inner = m.group(1).strip()
try:
candidates = json.loads(inner)
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 2. 提取 [ ... ] 块(从第一个 [ 到最后一个 ]
start = raw.find("[")
end = raw.rfind("]")
if start != -1 and end != -1 and end > start:
try:
candidates = json.loads(raw[start:end + 1])
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 3. 直接解析整个回复
try:
candidates = json.loads(raw)
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")