api-server/rag-worker/pipelines/import_pipeline.py
WangDL fbdae9078f
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 22s
feat: Python RAG Worker + NestJS 内部 API(文档解析/切片/embedding/Qdrant/候选生成)
2026-05-19 22:35:12 +08:00

128 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""导入主流程:下载 → 解析 → 清洗 → 切片 → embedding → Qdrant → AI 候选"""
import os
import uuid
from parser import download_file, parse_document
from chunker import chunk_document
from embedder import embed_batch
from indexer import upsert_points
from candidate_generator import generate_candidates
from api_client import (
heartbeat as send_heartbeat,
update_job_status,
save_chunks,
save_candidates,
get_job_detail,
)
async def run_import(job: dict):
"""执行完整的文档导入流程"""
job_id = job["id"]
source_id = job.get("sourceId") or job.get("source_id")
user_id = job["userId"] or job.get("user_id")
kb_id = job["knowledgeBaseId"] or job.get("knowledge_base_id")
file_id = job.get("fileId") or job.get("file_id")
if not source_id:
raise ValueError(f"任务 {job_id} 缺少 sourceId")
# 获取 source 详情(从 NestJS
detail = await get_job_detail(job_id)
source = (detail or {}).get("source", {}) if detail else {}
mime_type = source.get("mimeType") or source.get("mime_type") or "text/plain"
original_filename = source.get("originalFilename") or source.get("original_filename") or "unknown"
tmp_dir = f"/data/tmp/imports/{job_id}"
file_path = os.path.join(tmp_dir, original_filename)
try:
# 1. 下载文件
await update_job_status(job_id, "DOWNLOADING", {"progress": 5})
file_url = source.get("downloadUrl") or (detail or {}).get("downloadUrl", "")
if file_url:
await download_file(file_url, file_path)
# 2. 解析
await update_job_status(job_id, "PARSING", {"progress": 20})
text = await parse_document(file_path, mime_type)
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
if not text and (job.get("rawText") or source.get("rawText")):
text = job.get("rawText", "") or source.get("rawText", "")
if not text or len(text.strip()) < 10:
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
# 3. 清洗
await update_job_status(job_id, "CLEANING", {"progress": 40, "textLength": len(text)})
# 4. 切片
await update_job_status(job_id, "CHUNKING", {"progress": 50})
source_type = source.get("type") or "text"
chunks = chunk_document(text, source_type)
# 5. Embedding
await update_job_status(job_id, "EMBEDDING", {"progress": 60})
texts = [c["content"] for c in chunks]
vectors = await embed_batch(texts)
# 6. Qdrant 索引
await update_job_status(job_id, "INDEXING", {"progress": 80})
points = []
chunk_records = []
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
chunk_id = f"chunk_{source_id}_{i}"
points.append({
"id": chunk_id,
"vector": vec,
"payload": {
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"chunkId": chunk_id,
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"deleted": False,
},
})
chunk_records.append({
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"content": chunk["content"],
"chunkIndex": chunk["chunkIndex"],
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"tokenCount": len(chunk["content"]),
"externalVectorId": chunk_id,
"embeddingModel": "bge-m3",
"embeddingStatus": "COMPLETED",
"metadataJson": {"chunkType": chunk.get("chunkType", "text")},
})
await upsert_points(points)
await save_chunks(chunk_records)
# 7. 生成候选知识点
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
candidates = await generate_candidates(text)
if candidates:
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
# 8. 完成
await update_job_status(job_id, "COMPLETED", {"progress": 100})
except Exception as e:
await update_job_status(job_id, "FAILED_RETRYABLE", {
"errorCode": "WORKER_ERROR",
"errorMessage": str(e)[:500],
})
raise
finally:
# 清理临时文件
if os.path.exists(tmp_dir):
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)