95 lines
2.8 KiB
Python
95 lines
2.8 KiB
Python
import httpx
|
|
from config import API_BASE_URL, RAG_WORKER_SECRET, WORKER_ID
|
|
|
|
_auth_headers = {
|
|
"Authorization": f"Bearer {RAG_WORKER_SECRET}",
|
|
"X-Worker-Id": WORKER_ID,
|
|
}
|
|
|
|
|
|
async def get_next_job() -> dict | None:
|
|
"""获取下一个 QUEUED 导入任务"""
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.get(
|
|
f"{API_BASE_URL}/internal/rag/jobs/next",
|
|
headers=_auth_headers,
|
|
)
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
return data.get("data") or data.get("job")
|
|
return None
|
|
|
|
|
|
async def claim_job(job_id: str) -> bool:
|
|
"""认领任务"""
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.post(
|
|
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
|
|
headers=_auth_headers,
|
|
)
|
|
return resp.status_code == 200
|
|
|
|
|
|
async def heartbeat(job_id: str) -> bool:
|
|
"""发送心跳"""
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
resp = await client.post(
|
|
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/heartbeat",
|
|
headers=_auth_headers,
|
|
)
|
|
return resp.status_code == 200
|
|
|
|
|
|
async def update_job_status(job_id: str, status: str, data: dict | None = None):
|
|
"""更新导入任务状态"""
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
await client.post(
|
|
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/status",
|
|
headers=_auth_headers,
|
|
json={"status": status, **(data or {})},
|
|
)
|
|
|
|
|
|
async def save_chunks(chunks: list[dict]):
|
|
"""批量保存 KnowledgeChunk"""
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
await client.post(
|
|
f"{API_BASE_URL}/internal/rag/chunks",
|
|
headers=_auth_headers,
|
|
json={"chunks": chunks},
|
|
)
|
|
|
|
|
|
async def save_candidates(
|
|
user_id: str,
|
|
kb_id: str,
|
|
source_id: str,
|
|
import_id: str,
|
|
candidates: list[dict],
|
|
):
|
|
"""保存候选知识点"""
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
await client.post(
|
|
f"{API_BASE_URL}/internal/rag/candidates",
|
|
headers=_auth_headers,
|
|
json={
|
|
"userId": user_id,
|
|
"knowledgeBaseId": kb_id,
|
|
"sourceId": source_id,
|
|
"importId": import_id,
|
|
"candidates": candidates,
|
|
},
|
|
)
|
|
|
|
|
|
async def get_job_detail(job_id: str) -> dict | None:
|
|
"""获取任务详情(含 source 信息)"""
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.get(
|
|
f"{API_BASE_URL}/internal/rag/jobs/{job_id}",
|
|
headers=_auth_headers,
|
|
)
|
|
if resp.status_code == 200:
|
|
return resp.json()
|
|
return None
|