api-server/src/modules/rag-chat/rag-chat.service.ts

167 lines
5.6 KiB
TypeScript
Raw Normal View History

import { Injectable, NotFoundException, Logger, Optional } from '@nestjs/common';
import { PrismaService } from '../../infrastructure/database/prisma.service';
import { ContentSafetyService } from '../content-safety/content-safety.service';
import { AiGatewayService } from '../ai/gateway/ai-gateway.service';
const MAX_CONTEXT_CHARS = 4000;
@Injectable()
export class RagChatService {
private readonly logger = new Logger(RagChatService.name);
constructor(
private readonly prisma: PrismaService,
@Optional() private readonly safety?: ContentSafetyService,
@Optional() private readonly aiGateway?: AiGatewayService,
) {}
async createSession(userId: string, knowledgeBaseId: string, title?: string) {
return this.prisma.chatSession.create({
data: { userId, knowledgeBaseId, title: title || '新对话' },
});
}
async listSessions(userId: string, kbId?: string) {
return this.prisma.chatSession.findMany({
where: { userId, ...(kbId ? { knowledgeBaseId: kbId } : {}) },
orderBy: { updatedAt: 'desc' },
});
}
async getMessages(sessionId: string) {
return this.prisma.chatMessage.findMany({
where: { sessionId },
orderBy: { createdAt: 'asc' },
include: { citations: true },
});
}
async sendMessage(userId: string, sessionId: string, content: string) {
const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } });
if (!session || session.userId !== userId) throw new NotFoundException('对话不存在');
// Content safety
const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' });
if (inputCheck && !inputCheck.safe) {
return { blocked: true, message: '输入包含违规内容,请修改后重试' };
}
// Save user message
await this.prisma.chatMessage.create({
data: { sessionId, role: 'user', content },
});
// Retrieve knowledge base context
const context = await this.loadContext(session.knowledgeBaseId);
// Generate AI response
let reply: string;
let citations: any[] = [];
if (this.aiGateway && context.text) {
try {
const messages = [
{ role: 'system' as const, content: this.buildSystemPrompt(context.text) },
{ role: 'user' as const, content },
];
const resp = await this.aiGateway.generate({
feature: 'rag-chat',
userId,
tier: 'primary',
promptKey: 'rag-chat',
promptVersion: 'v1',
messages,
maxTokens: 2048,
});
reply = resp.parsed?.answer ?? String(resp.parsed?.content ?? '抱歉AI 暂时无法生成回答。');
citations = context.citations;
} catch (err: any) {
this.logger.error('AI Gateway failed, falling back', err?.message);
reply = this.fallbackReply(context.isEmpty);
}
} else {
reply = this.fallbackReply(context.isEmpty);
}
// Save AI message
const aiMsg = await this.prisma.chatMessage.create({
data: { sessionId, role: 'ai', content: reply, tokens: reply.length },
});
// Save citations
for (const c of citations.slice(0, 5)) {
await this.prisma.chatCitation.create({
data: {
messageId: aiMsg.id,
chunkId: c.id,
sourceTitle: c.title ?? null,
excerptText: c.text.slice(0, 500),
},
});
}
// Update session timestamp
await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } });
return { message: aiMsg, citations };
}
async deleteSession(sessionId: string) {
await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } });
await this.prisma.chatMessage.deleteMany({ where: { sessionId } });
await this.prisma.chatSession.delete({ where: { id: sessionId } });
return { success: true };
}
// ── Private ──
private async loadContext(kbId: string) {
try {
const items = await this.prisma.knowledgeItem.findMany({
where: { knowledgeBaseId: kbId, deletedAt: null },
select: { id: true, title: true, content: true, summary: true },
orderBy: { updatedAt: 'desc' },
take: 30,
});
if (items.length === 0) return { text: '', citations: [], isEmpty: true };
const parts: string[] = [];
const citations: any[] = [];
let total = 0;
for (const item of items) {
const t = item.content || item.summary || '';
if (!t || total >= MAX_CONTEXT_CHARS) break;
const snippet = t.slice(0, Math.min(t.length, 500));
parts.push(`${item.title}${snippet}`);
citations.push({ id: item.id, text: snippet, score: 1.0, title: item.title });
total += snippet.length;
}
return { text: parts.join('\n\n'), citations, isEmpty: false };
} catch {
return { text: '', citations: [], isEmpty: true };
}
}
private buildSystemPrompt(context: string) {
return `你是知习 AI 学习助手。基于以下知识库内容回答用户问题,回答应准确、简洁、有依据。
##
${context}
##
-
-
- xxx...`;
}
private fallbackReply(isEmpty: boolean) {
if (isEmpty) {
return '当前知识库还没有知识点内容。请先上传资料或添加知识点,我就可以基于它们回答你的问题了。';
}
return '抱歉AI 服务暂时不可用,请稍后再试。';
}
}