api-server/rag-worker/pipelines/import_pipeline.py

128 lines
4.7 KiB
Python
Raw Normal View History

"""导入主流程:下载 → 解析 → 清洗 → 切片 → embedding → Qdrant → AI 候选"""
import os
import uuid
from parser import download_file, parse_document
from chunker import chunk_document
from embedder import embed_batch
from indexer import upsert_points
from candidate_generator import generate_candidates
from api_client import (
heartbeat as send_heartbeat,
update_job_status,
save_chunks,
save_candidates,
get_job_detail,
)
async def run_import(job: dict):
"""执行完整的文档导入流程"""
job_id = job["id"]
source_id = job.get("sourceId") or job.get("source_id")
user_id = job["userId"] or job.get("user_id")
kb_id = job["knowledgeBaseId"] or job.get("knowledge_base_id")
file_id = job.get("fileId") or job.get("file_id")
if not source_id:
raise ValueError(f"任务 {job_id} 缺少 sourceId")
# 获取 source 详情(从 NestJS
detail = await get_job_detail(job_id)
source = (detail or {}).get("source", {}) if detail else {}
mime_type = source.get("mimeType") or source.get("mime_type") or "text/plain"
original_filename = source.get("originalFilename") or source.get("original_filename") or "unknown"
tmp_dir = f"/data/tmp/imports/{job_id}"
file_path = os.path.join(tmp_dir, original_filename)
try:
# 1. 下载文件
await update_job_status(job_id, "DOWNLOADING", {"progress": 5})
file_url = source.get("downloadUrl") or (detail or {}).get("downloadUrl", "")
if file_url:
await download_file(file_url, file_path)
# 2. 解析
await update_job_status(job_id, "PARSING", {"progress": 20})
text = await parse_document(file_path, mime_type)
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
if not text and (job.get("rawText") or source.get("rawText")):
text = job.get("rawText", "") or source.get("rawText", "")
if not text or len(text.strip()) < 10:
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
# 3. 清洗
await update_job_status(job_id, "CLEANING", {"progress": 40, "textLength": len(text)})
# 4. 切片
await update_job_status(job_id, "CHUNKING", {"progress": 50})
source_type = source.get("type") or "text"
chunks = chunk_document(text, source_type)
# 5. Embedding
await update_job_status(job_id, "EMBEDDING", {"progress": 60})
texts = [c["content"] for c in chunks]
vectors = await embed_batch(texts)
# 6. Qdrant 索引
await update_job_status(job_id, "INDEXING", {"progress": 80})
points = []
chunk_records = []
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
chunk_id = f"chunk_{source_id}_{i}"
points.append({
"id": chunk_id,
"vector": vec,
"payload": {
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"chunkId": chunk_id,
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"deleted": False,
},
})
chunk_records.append({
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"content": chunk["content"],
"chunkIndex": chunk["chunkIndex"],
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"tokenCount": len(chunk["content"]),
"externalVectorId": chunk_id,
"embeddingModel": "bge-m3",
"embeddingStatus": "COMPLETED",
"metadataJson": {"chunkType": chunk.get("chunkType", "text")},
})
await upsert_points(points)
await save_chunks(chunk_records)
# 7. 生成候选知识点
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
candidates = await generate_candidates(text)
if candidates:
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
# 8. 完成
await update_job_status(job_id, "COMPLETED", {"progress": 100})
except Exception as e:
await update_job_status(job_id, "FAILED_RETRYABLE", {
"errorCode": "WORKER_ERROR",
"errorMessage": str(e)[:500],
})
raise
finally:
# 清理临时文件
if os.path.exists(tmp_dir):
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)