85 lines
2.3 KiB
Python
85 lines
2.3 KiB
Python
|
|
"""知习 RAG Worker — 文档导入主进程"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import signal
|
|||
|
|
import sys
|
|||
|
|
from config import WORKER_ID, POLL_INTERVAL, HEARTBEAT_INTERVAL
|
|||
|
|
from api_client import get_next_job, claim_job, heartbeat, update_job_status
|
|||
|
|
from pipelines.import_pipeline import run_import
|
|||
|
|
|
|||
|
|
running = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def shutdown(sig, frame):
|
|||
|
|
global running
|
|||
|
|
print(f"[{WORKER_ID}] 收到信号 {sig},正在退出...")
|
|||
|
|
running = False
|
|||
|
|
|
|||
|
|
|
|||
|
|
signal.signal(signal.SIGINT, shutdown)
|
|||
|
|
signal.signal(signal.SIGTERM, shutdown)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def heartbeat_loop():
|
|||
|
|
"""心跳循环(所有活跃任务)"""
|
|||
|
|
# 简化实现:worker 级心跳,后续可扩展到 per-job 心跳
|
|||
|
|
while running:
|
|||
|
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def work_loop():
|
|||
|
|
"""主工作循环:轮询 → 认领 → 执行"""
|
|||
|
|
print(f"[{WORKER_ID}] RAG Worker 已启动")
|
|||
|
|
|
|||
|
|
while running:
|
|||
|
|
try:
|
|||
|
|
job = await get_next_job()
|
|||
|
|
if not job:
|
|||
|
|
await asyncio.sleep(POLL_INTERVAL)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
job_id = job.get("id") or job.get("jobId")
|
|||
|
|
if not job_id:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 认领任务
|
|||
|
|
claimed = await claim_job(job_id)
|
|||
|
|
if not claimed:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"[{WORKER_ID}] 开始处理任务 {job_id}")
|
|||
|
|
|
|||
|
|
# 启动心跳(后台任务)
|
|||
|
|
hb_task = asyncio.create_task(_per_job_heartbeat(job_id))
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
await run_import(job)
|
|||
|
|
print(f"[{WORKER_ID}] 任务 {job_id} 完成")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[{WORKER_ID}] 任务 {job_id} 失败: {e}")
|
|||
|
|
await update_job_status(job_id, "FAILED_RETRYABLE", {
|
|||
|
|
"errorMessage": str(e)[:500],
|
|||
|
|
})
|
|||
|
|
finally:
|
|||
|
|
hb_task.cancel()
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[{WORKER_ID}] 轮询异常: {e}")
|
|||
|
|
await asyncio.sleep(POLL_INTERVAL)
|
|||
|
|
|
|||
|
|
print(f"[{WORKER_ID}] Worker 已停止")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _per_job_heartbeat(job_id: str):
|
|||
|
|
"""单个任务的心跳上报"""
|
|||
|
|
while running:
|
|||
|
|
try:
|
|||
|
|
await heartbeat(job_id)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(work_loop())
|