diff --git a/prisma/migrations/20260525050000_add_rag_chat/migration.sql b/prisma/migrations/20260525050000_add_rag_chat/migration.sql new file mode 100644 index 0000000..861d47b --- /dev/null +++ b/prisma/migrations/20260525050000_add_rag_chat/migration.sql @@ -0,0 +1,35 @@ +CREATE TABLE IF NOT EXISTS `ChatSession` ( + `id` VARCHAR(191) NOT NULL, + `userId` VARCHAR(191) NOT NULL, + `knowledgeBaseId` VARCHAR(191) NOT NULL, + `title` VARCHAR(200) NOT NULL DEFAULT '新对话', + `createdAt` DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3), + `updatedAt` DATETIME(3) NOT NULL, + INDEX `ChatSession_userId_idx`(`userId`), + INDEX `ChatSession_knowledgeBaseId_idx`(`knowledgeBaseId`), + PRIMARY KEY (`id`) +) DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `ChatMessage` ( + `id` VARCHAR(191) NOT NULL, + `sessionId` VARCHAR(191) NOT NULL, + `role` VARCHAR(16) NOT NULL, + `content` LONGTEXT NOT NULL, + `tokens` INTEGER NOT NULL DEFAULT 0, + `createdAt` DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3), + INDEX `ChatMessage_sessionId_idx`(`sessionId`), + PRIMARY KEY (`id`) +) DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +CREATE TABLE IF NOT EXISTS `ChatCitation` ( + `id` VARCHAR(191) NOT NULL, + `messageId` VARCHAR(191) NOT NULL, + `chunkId` VARCHAR(191) NULL, + `sourceId` VARCHAR(191) NULL, + `sourceTitle` VARCHAR(255) NULL, + `excerptText` VARCHAR(2000) NULL, + `pageNumber` INT NULL, + `createdAt` DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3), + INDEX `ChatCitation_messageId_idx`(`messageId`), + PRIMARY KEY (`id`) +) DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index b297bd8..6b6c576 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -945,6 +945,49 @@ model AdminMessage { @@index([createdAt]) } +model ChatSession { + id String @id @default(cuid()) + userId String + knowledgeBaseId String + title String @default("新对话") @db.VarChar(200) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + messages ChatMessage[] + + @@index([userId]) + @@index([knowledgeBaseId]) +} + +model ChatMessage { + id String @id @default(cuid()) + sessionId String + role String @db.VarChar(16) + content String @db.LongText + tokens Int @default(0) + createdAt DateTime @default(now()) + + session ChatSession @relation(fields: [sessionId], references: [id]) + citations ChatCitation[] + + @@index([sessionId]) +} + +model ChatCitation { + id String @id @default(cuid()) + messageId String + chunkId String? + sourceId String? + sourceTitle String? @db.VarChar(255) + excerptText String? @db.VarChar(2000) + pageNumber Int? + createdAt DateTime @default(now()) + + message ChatMessage @relation(fields: [messageId], references: [id]) + + @@index([messageId]) +} + model AdminCostItem { id String @id @default(cuid()) name String @db.VarChar(100) diff --git a/src/app.module.ts b/src/app.module.ts index 1192323..c639de3 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -47,6 +47,7 @@ import { WaitlistModule } from './modules/waitlist/waitlist.module'; import { KnowledgeSourceModule } from './modules/knowledge-source/knowledge-source.module'; import { ImportCandidateModule } from './modules/import-candidate/import-candidate.module'; import { RagModule } from './modules/rag/rag.module'; +import { RagChatModule } from './modules/rag-chat/rag-chat.module'; import { VectorModule } from './modules/vector/vector.module'; import { JwtAuthGuard } from './common/guards/jwt-auth.guard'; @@ -130,6 +131,7 @@ import appleConfig from './config/apple.config'; ImportCandidateModule, DocumentImportModule, RagModule, + RagChatModule, VectorModule, LearningSessionModule, ActiveRecallModule, diff --git a/src/modules/rag-chat/admin-rag-chat.controller.ts b/src/modules/rag-chat/admin-rag-chat.controller.ts new file mode 100644 index 0000000..943487b --- /dev/null +++ b/src/modules/rag-chat/admin-rag-chat.controller.ts @@ -0,0 +1,38 @@ +import { Controller, Get, Param, Query, UseGuards } from '@nestjs/common'; +import { ApiTags, ApiBearerAuth, ApiOperation } from '@nestjs/swagger'; +import { PrismaService } from '../../infrastructure/database/prisma.service'; +import { AdminAuthGuard } from '../../common/guards/admin-auth.guard'; +import { AdminRolesGuard } from '../../common/guards/admin-roles.guard'; +import { AdminRoles } from '../../common/decorators/admin-roles.decorator'; +import type { AdminRole } from '../../common/types/admin-role.enum'; + +@ApiTags('admin-rag-chat') +@Controller('admin-api/rag-chat') +@UseGuards(AdminAuthGuard, AdminRolesGuard) +@ApiBearerAuth() +export class AdminRagChatController { + constructor(private readonly prisma: PrismaService) {} + + @Get('sessions') + @AdminRoles('ADMIN' as AdminRole) + @ApiOperation({ summary: '用户对话列表' }) + async sessions(@Query('userId') userId?: string) { + return this.prisma.chatSession.findMany({ + where: userId ? { userId } : undefined, + orderBy: { updatedAt: 'desc' }, + take: 100, + include: { _count: { select: { messages: true } } }, + }); + } + + @Get('sessions/:id/messages') + @AdminRoles('ADMIN' as AdminRole) + @ApiOperation({ summary: '对话消息详情' }) + async messages(@Param('id') id: string) { + return this.prisma.chatMessage.findMany({ + where: { sessionId: id }, + orderBy: { createdAt: 'asc' }, + include: { citations: true }, + }); + } +} diff --git a/src/modules/rag-chat/rag-chat.controller.ts b/src/modules/rag-chat/rag-chat.controller.ts new file mode 100644 index 0000000..588deed --- /dev/null +++ b/src/modules/rag-chat/rag-chat.controller.ts @@ -0,0 +1,42 @@ +import { Controller, Get, Post, Delete, Body, Param, UseGuards } from '@nestjs/common'; +import { ApiTags, ApiOperation, ApiBearerAuth } from '@nestjs/swagger'; +import { RagChatService } from './rag-chat.service'; +import { CurrentUser } from '../../common/decorators/current-user.decorator'; +import type { UserPayload } from '../../common/types'; + +@ApiTags('rag-chat') +@Controller('rag-chat') +@ApiBearerAuth() +export class RagChatController { + constructor(private readonly svc: RagChatService) {} + + @Post('sessions') + @ApiOperation({ summary: '创建对话' }) + async createSession(@CurrentUser() user: UserPayload, @Body() dto: { knowledgeBaseId: string; title?: string }) { + return this.svc.createSession(String(user.id), dto.knowledgeBaseId, dto.title); + } + + @Get('sessions') + @ApiOperation({ summary: '对话列表' }) + async listSessions(@CurrentUser() user: UserPayload, @Body('knowledgeBaseId') kbId?: string) { + return this.svc.listSessions(String(user.id), kbId); + } + + @Get('sessions/:id/messages') + @ApiOperation({ summary: '对话历史' }) + async messages(@Param('id') id: string) { + return this.svc.getMessages(id); + } + + @Post('sessions/:id/messages') + @ApiOperation({ summary: '发送消息' }) + async sendMessage(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }) { + return this.svc.sendMessage(String(user.id), id, dto.content); + } + + @Delete('sessions/:id') + @ApiOperation({ summary: '删除对话' }) + async deleteSession(@Param('id') id: string) { + return this.svc.deleteSession(id); + } +} diff --git a/src/modules/rag-chat/rag-chat.module.ts b/src/modules/rag-chat/rag-chat.module.ts new file mode 100644 index 0000000..f16a57d --- /dev/null +++ b/src/modules/rag-chat/rag-chat.module.ts @@ -0,0 +1,12 @@ +import { Module } from '@nestjs/common'; +import { RagChatController } from './rag-chat.controller'; +import { AdminRagChatController } from './admin-rag-chat.controller'; +import { RagChatService } from './rag-chat.service'; +import { PrismaService } from '../../infrastructure/database/prisma.service'; + +@Module({ + controllers: [RagChatController, AdminRagChatController], + providers: [RagChatService, PrismaService], + exports: [RagChatService], +}) +export class RagChatModule {} diff --git a/src/modules/rag-chat/rag-chat.service.ts b/src/modules/rag-chat/rag-chat.service.ts new file mode 100644 index 0000000..43bf2f8 --- /dev/null +++ b/src/modules/rag-chat/rag-chat.service.ts @@ -0,0 +1,68 @@ +import { Injectable, NotFoundException, Logger } from '@nestjs/common'; +import { PrismaService } from '../../infrastructure/database/prisma.service'; +import { ContentSafetyService } from '../content-safety/content-safety.service'; + +@Injectable() +export class RagChatService { + private readonly logger = new Logger(RagChatService.name); + + constructor( + private readonly prisma: PrismaService, + private readonly safety?: ContentSafetyService, + ) {} + + async createSession(userId: string, knowledgeBaseId: string, title?: string) { + return this.prisma.chatSession.create({ + data: { userId, knowledgeBaseId, title: title || '新对话' }, + }); + } + + async listSessions(userId: string, kbId?: string) { + return this.prisma.chatSession.findMany({ + where: { userId, ...(kbId ? { knowledgeBaseId: kbId } : {}) }, + orderBy: { updatedAt: 'desc' }, + }); + } + + async getMessages(sessionId: string) { + return this.prisma.chatMessage.findMany({ + where: { sessionId }, + orderBy: { createdAt: 'asc' }, + include: { citations: true }, + }); + } + + async sendMessage(userId: string, sessionId: string, content: string) { + const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } }); + if (!session || session.userId !== userId) throw new NotFoundException('对话不存在'); + + // Content safety check on user input + const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' }); + if (inputCheck && !inputCheck.safe) { + return { blocked: true, message: '输入包含违规内容,请修改后重试' }; + } + + // Save user message + await this.prisma.chatMessage.create({ + data: { sessionId, role: 'user', content }, + }); + + // Generate AI response (simplified — real RAG pipeline in M3) + const reply = `感谢提问。基于知识库内容,我暂时无法生成完整回答(RAG 检索管道将在后续版本完善)。`; + const aiMsg = await this.prisma.chatMessage.create({ + data: { sessionId, role: 'ai', content: reply, tokens: reply.length }, + }); + + // Update session timestamp + await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } }); + + return { message: aiMsg, citations: [] }; + } + + async deleteSession(sessionId: string) { + await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } }); + await this.prisma.chatMessage.deleteMany({ where: { sessionId } }); + await this.prisma.chatSession.delete({ where: { id: sessionId } }); + return { success: true }; + } +} diff --git a/test/m2.e2e-spec.ts b/test/m2.e2e-spec.ts index 6f0c53b..089b13a 100644 --- a/test/m2.e2e-spec.ts +++ b/test/m2.e2e-spec.ts @@ -259,4 +259,50 @@ describe('M2 E2E Tests', () => { expect(res.body.data).toHaveProperty('count'); }); }); + + // ══════════════════════════════════════════════ + // M2-07: RAG Chat + // ══════════════════════════════════════════════ + describe('M2-07 RAG Chat', () => { + let token: string; + beforeAll(async () => { token = await loginAdmin(); }); + + it('POST /api/rag-chat/sessions → 201 create session', async () => { + const res = await request(app.getHttpServer()) + .post('/api/rag-chat/sessions') + .send({ knowledgeBaseId: 'kb1', title: 'Test Chat' }) + .expect([200, 201]); + expect(res.body.data).toHaveProperty('id'); + }); + + it('GET /api/rag-chat/sessions → 200 list sessions', async () => { + const res = await request(app.getHttpServer()) + .get('/api/rag-chat/sessions') + .expect(200); + expect(Array.isArray(res.body.data)).toBe(true); + }); + + it('POST /api/rag-chat/sessions/:id/messages → send message', async () => { + const session = await request(app.getHttpServer()) + .post('/api/rag-chat/sessions') + .send({ knowledgeBaseId: 'kb1' }); + const sId = session.body?.data?.id; + if (!sId) return; + + const res = await request(app.getHttpServer()) + .post(`/api/rag-chat/sessions/${sId}/messages`) + .send({ content: '这个知识库的主要内容是什么?' }) + .expect([200, 201]); + expect(res.body.data).toHaveProperty('message'); + }); + + it('GET /admin-api/rag-chat/sessions → 200 admin sessions', async () => { + if (!token) return; + const res = await request(app.getHttpServer()) + .get('/admin-api/rag-chat/sessions') + .set('Authorization', `Bearer ${token}`) + .expect(200); + expect(Array.isArray(res.body.data)).toBe(true); + }); + }); });