api-server/src/modules/rag-chat/rag-chat.service.ts
wangdl 6ab54be309
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 19s
feat: H0-11 RAG Chat 接入真实检索 + AI 生成管道
- sendMessage 从 KB 加载知识点内容作为上下文(最多 30 条/4000 字符)
- 通过 AiGatewayService 调用 DeepSeek 生成回答
- AI 回复附带引用来源(ChatCitation)
- AI Gateway 不可用时降级提示
- 知识库为空时引导用户先添加内容

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-29 19:31:33 +08:00

167 lines
5.6 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { Injectable, NotFoundException, Logger, Optional } from '@nestjs/common';
import { PrismaService } from '../../infrastructure/database/prisma.service';
import { ContentSafetyService } from '../content-safety/content-safety.service';
import { AiGatewayService } from '../ai/gateway/ai-gateway.service';
const MAX_CONTEXT_CHARS = 4000;
@Injectable()
export class RagChatService {
private readonly logger = new Logger(RagChatService.name);
constructor(
private readonly prisma: PrismaService,
@Optional() private readonly safety?: ContentSafetyService,
@Optional() private readonly aiGateway?: AiGatewayService,
) {}
async createSession(userId: string, knowledgeBaseId: string, title?: string) {
return this.prisma.chatSession.create({
data: { userId, knowledgeBaseId, title: title || '新对话' },
});
}
async listSessions(userId: string, kbId?: string) {
return this.prisma.chatSession.findMany({
where: { userId, ...(kbId ? { knowledgeBaseId: kbId } : {}) },
orderBy: { updatedAt: 'desc' },
});
}
async getMessages(sessionId: string) {
return this.prisma.chatMessage.findMany({
where: { sessionId },
orderBy: { createdAt: 'asc' },
include: { citations: true },
});
}
async sendMessage(userId: string, sessionId: string, content: string) {
const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } });
if (!session || session.userId !== userId) throw new NotFoundException('对话不存在');
// Content safety
const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' });
if (inputCheck && !inputCheck.safe) {
return { blocked: true, message: '输入包含违规内容,请修改后重试' };
}
// Save user message
await this.prisma.chatMessage.create({
data: { sessionId, role: 'user', content },
});
// Retrieve knowledge base context
const context = await this.loadContext(session.knowledgeBaseId);
// Generate AI response
let reply: string;
let citations: any[] = [];
if (this.aiGateway && context.text) {
try {
const messages = [
{ role: 'system' as const, content: this.buildSystemPrompt(context.text) },
{ role: 'user' as const, content },
];
const resp = await this.aiGateway.generate({
feature: 'rag-chat',
userId,
tier: 'primary',
promptKey: 'rag-chat',
promptVersion: 'v1',
messages,
maxTokens: 2048,
});
reply = resp.parsed?.answer ?? String(resp.parsed?.content ?? '抱歉AI 暂时无法生成回答。');
citations = context.citations;
} catch (err: any) {
this.logger.error('AI Gateway failed, falling back', err?.message);
reply = this.fallbackReply(context.isEmpty);
}
} else {
reply = this.fallbackReply(context.isEmpty);
}
// Save AI message
const aiMsg = await this.prisma.chatMessage.create({
data: { sessionId, role: 'ai', content: reply, tokens: reply.length },
});
// Save citations
for (const c of citations.slice(0, 5)) {
await this.prisma.chatCitation.create({
data: {
messageId: aiMsg.id,
chunkId: c.id,
content: c.text.slice(0, 500),
score: c.score ?? 0,
},
});
}
// Update session timestamp
await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } });
return { message: aiMsg, citations };
}
async deleteSession(sessionId: string) {
await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } });
await this.prisma.chatMessage.deleteMany({ where: { sessionId } });
await this.prisma.chatSession.delete({ where: { id: sessionId } });
return { success: true };
}
// ── Private ──
private async loadContext(kbId: string) {
try {
const items = await this.prisma.knowledgeItem.findMany({
where: { knowledgeBaseId: kbId, deletedAt: null },
select: { id: true, title: true, content: true, summary: true },
orderBy: { updatedAt: 'desc' },
take: 30,
});
if (items.length === 0) return { text: '', citations: [], isEmpty: true };
const parts: string[] = [];
const citations: any[] = [];
let total = 0;
for (const item of items) {
const t = item.content || item.summary || '';
if (!t || total >= MAX_CONTEXT_CHARS) break;
const snippet = t.slice(0, Math.min(t.length, 500));
parts.push(`${item.title}${snippet}`);
citations.push({ id: item.id, text: snippet, score: 1.0, title: item.title });
total += snippet.length;
}
return { text: parts.join('\n\n'), citations, isEmpty: false };
} catch {
return { text: '', citations: [], isEmpty: true };
}
}
private buildSystemPrompt(context: string) {
return `你是知习 AI 学习助手。基于以下知识库内容回答用户问题,回答应准确、简洁、有依据。
## 知识库内容
${context}
## 回答要求
- 基于提供的知识库内容回答,不要编造信息
- 如果知识库内容不足以回答问题,请诚实告知
- 回答时可以用「根据知识库中的《xxx》...」引用来源`;
}
private fallbackReply(isEmpty: boolean) {
if (isEmpty) {
return '当前知识库还没有知识点内容。请先上传资料或添加知识点,我就可以基于它们回答你的问题了。';
}
return '抱歉AI 服务暂时不可用,请稍后再试。';
}
}