diff --git a/src/modules/ai/gateway/ai-gateway.service.ts b/src/modules/ai/gateway/ai-gateway.service.ts index d3ea50a..3dc08eb 100644 --- a/src/modules/ai/gateway/ai-gateway.service.ts +++ b/src/modules/ai/gateway/ai-gateway.service.ts @@ -10,6 +10,7 @@ import { BaseDomainEvent } from '../../../common/events/base-domain.event'; import { PrismaService } from '../../../infrastructure/database/prisma.service'; import type { AiProvider } from '../providers/ai-provider.interface'; import type { GatewayRequest, GatewayResponse, ModelTier } from './ai-gateway.types'; +import type { StreamChunk } from '../providers/ai-provider.interface'; class AIUsageRecorded extends BaseDomainEvent { eventType = 'ai.usage.recorded'; @@ -207,6 +208,63 @@ export class AiGatewayService { } } + async *generateStream(request: GatewayRequest, timeoutMs = this.DEFAULT_TIMEOUT_MS): AsyncGenerator { + const tierConfig = this.modelRouter.resolve(request.tier); + const prompt = this.promptTemplate.get(request.promptKey, request.promptVersion); + const messages = [ + { role: 'system' as const, content: prompt.systemPrompt }, + ...request.messages, + ]; + + const target = tierConfig.preferred; + const provider = this.resolveProviderForTarget(target.provider); + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), timeoutMs); + + let fullContent = ''; + let inputTokens = 0; + let outputTokens = 0; + + try { + for await (const chunk of provider.generateStream!({ + model: target.model, + messages, + temperature: 0.3, + maxTokens: request.maxTokens ?? 4096, + signal: controller.signal, + })) { + if (chunk.type === 'error') { + yield { type: 'error', content: chunk.content }; + return; + } + if (chunk.type === 'done' && chunk.usage) { + inputTokens = chunk.usage.inputTokens; + outputTokens = chunk.usage.outputTokens; + } + if (chunk.type === 'content') { + fullContent += (chunk.content || ''); + } + yield chunk; + } + + // Record usage after stream completes + if (inputTokens > 0) { + const estimatedCost = this.costCalculator.calculate(target.provider, target.model, inputTokens, outputTokens); + this.usageLog.log({ + userId: request.userId, feature: request.feature, + provider: target.provider, model: target.model, tier: request.tier, + promptKey: request.promptKey, promptVersion: prompt.version, + inputTokens, outputTokens, estimatedCost, latencyMs: 0, success: true, + }).catch(() => {}); + } + } catch (error: any) { + yield { type: 'error', content: error.message }; + } finally { + clearTimeout(timeoutId); + } + } + private buildSystemPrompt(systemPrompt: string, schemaDesc: string): string { return `${systemPrompt}\n\n请严格按照以下 JSON Schema 输出,只输出 JSON,不要包含其他内容:\n${schemaDesc}`; } diff --git a/src/modules/ai/providers/ai-provider.interface.ts b/src/modules/ai/providers/ai-provider.interface.ts index 9390f9b..965a67d 100644 --- a/src/modules/ai/providers/ai-provider.interface.ts +++ b/src/modules/ai/providers/ai-provider.interface.ts @@ -17,7 +17,14 @@ export interface AiGenerateOutput { latencyMs: number; } +export interface StreamChunk { + type: 'thinking' | 'content' | 'error' | 'done'; + content?: string; + usage?: { inputTokens: number; outputTokens: number }; +} + export interface AiProvider { readonly name: string; generate(input: AiGenerateInput): Promise; + generateStream?(input: AiGenerateInput): AsyncGenerator; } diff --git a/src/modules/ai/providers/deepseek.provider.ts b/src/modules/ai/providers/deepseek.provider.ts index 7be0dab..446e0ee 100644 --- a/src/modules/ai/providers/deepseek.provider.ts +++ b/src/modules/ai/providers/deepseek.provider.ts @@ -1,6 +1,6 @@ import { Injectable, Logger } from '@nestjs/common'; import { ConfigService } from '@nestjs/config'; -import type { AiProvider, AiGenerateInput, AiGenerateOutput } from './ai-provider.interface'; +import type { AiProvider, AiGenerateInput, AiGenerateOutput, StreamChunk } from './ai-provider.interface'; @Injectable() export class DeepSeekProvider implements AiProvider { @@ -64,4 +64,95 @@ export class DeepSeekProvider implements AiProvider { latencyMs, }; } + + async *generateStream(input: AiGenerateInput): AsyncGenerator { + if (!this.apiKey) { + yield { type: 'error', content: 'DeepSeek API key not configured' }; + return; + } + + const body: Record = { + model: input.model, + messages: input.messages, + temperature: input.temperature ?? 0.3, + max_tokens: input.maxTokens ?? 4096, + stream: true, + }; + + // When streaming, do NOT use response_format json_object — it disables reasoning_content + const response = await fetch(`${this.baseUrl}/v1/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify(body), + signal: input.signal, + }); + + if (!response.ok) { + const errorText = await response.text().catch(() => 'unknown'); + this.logger.error(`DeepSeek stream error ${response.status}: ${errorText}`); + yield { type: 'error', content: `DeepSeek API returned ${response.status}` }; + return; + } + + const reader = response.body?.getReader(); + if (!reader) { + yield { type: 'error', content: 'No response body' }; + return; + } + + const decoder = new TextDecoder(); + let buffer = ''; + let inputTokens = 0; + let outputTokens = 0; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + const data = line.slice(6).trim(); + if (data === '[DONE]') { + yield { type: 'done', usage: { inputTokens, outputTokens } }; + return; + } + + try { + const json = JSON.parse(data); + const delta = json.choices?.[0]?.delta; + if (!delta) continue; + + // Track usage + if (json.usage) { + inputTokens = json.usage.prompt_tokens ?? inputTokens; + outputTokens = json.usage.completion_tokens ?? outputTokens; + } + + // reasoning_content = thinking process (DeepSeek V4 Pro) + if (delta.reasoning_content) { + yield { type: 'thinking', content: delta.reasoning_content }; + } + + // content = actual response + if (delta.content) { + yield { type: 'content', content: delta.content }; + } + } catch { + // Skip unparseable chunks + } + } + } + yield { type: 'done', usage: { inputTokens, outputTokens } }; + } finally { + reader.releaseLock(); + } + } } diff --git a/src/modules/rag-chat/rag-chat.controller.ts b/src/modules/rag-chat/rag-chat.controller.ts index 588deed..4842f57 100644 --- a/src/modules/rag-chat/rag-chat.controller.ts +++ b/src/modules/rag-chat/rag-chat.controller.ts @@ -1,5 +1,6 @@ -import { Controller, Get, Post, Delete, Body, Param, UseGuards } from '@nestjs/common'; +import { Controller, Get, Post, Delete, Body, Param, Res } from '@nestjs/common'; import { ApiTags, ApiOperation, ApiBearerAuth } from '@nestjs/swagger'; +import { Response } from 'express'; import { RagChatService } from './rag-chat.service'; import { CurrentUser } from '../../common/decorators/current-user.decorator'; import type { UserPayload } from '../../common/types'; @@ -29,11 +30,27 @@ export class RagChatController { } @Post('sessions/:id/messages') - @ApiOperation({ summary: '发送消息' }) + @ApiOperation({ summary: '发送消息(同步)' }) async sendMessage(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }) { return this.svc.sendMessage(String(user.id), id, dto.content); } + @Post('sessions/:id/stream') + @ApiOperation({ summary: '发送消息(SSE 流式,支持思考过程)' }) + async sendMessageStream(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }, @Res() res: Response) { + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + res.setHeader('X-Accel-Buffering', 'no'); + + const userId = String(user.id); + for await (const chunk of this.svc.sendMessageStream(userId, id, dto.content)) { + if (res.destroyed) break; + res.write(`data: ${JSON.stringify(chunk)}\n\n`); + } + res.end(); + } + @Delete('sessions/:id') @ApiOperation({ summary: '删除对话' }) async deleteSession(@Param('id') id: string) { diff --git a/src/modules/rag-chat/rag-chat.service.ts b/src/modules/rag-chat/rag-chat.service.ts index 9b5e0fc..fcc98c0 100644 --- a/src/modules/rag-chat/rag-chat.service.ts +++ b/src/modules/rag-chat/rag-chat.service.ts @@ -3,6 +3,7 @@ import { PrismaService } from '../../infrastructure/database/prisma.service'; import { ContentSafetyService } from '../content-safety/content-safety.service'; import { AiGatewayService } from '../ai/gateway/ai-gateway.service'; import { RagChatOutputSchema } from '../ai/prompts/schemas/rag-chat.schema'; +import type { StreamChunk } from '../ai/providers/ai-provider.interface'; const MAX_CONTEXT_CHARS = 4000; @@ -113,6 +114,59 @@ export class RagChatService { return { message: aiMsg, citations }; } + async *sendMessageStream(userId: string, sessionId: string, content: string): AsyncGenerator { + const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } }); + if (!session || session.userId !== userId) { + yield { type: 'error', content: '对话不存在' }; + return; + } + + const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' }); + if (inputCheck && !inputCheck.safe) { + yield { type: 'error', content: '输入包含违规内容' }; + return; + } + + // Save user message + await this.prisma.chatMessage.create({ data: { sessionId, role: 'user', content } }); + + // Load context + const context = await this.loadContext(session.knowledgeBaseId); + if (!context.text) { + yield { type: 'content', content: this.fallbackReply(true) }; + } else { + const messages = [ + { role: 'system' as const, content: this.buildSystemPrompt(context.text) }, + { role: 'user' as const, content }, + ]; + + let fullContent = ''; + for await (const chunk of this.aiGateway!.generateStream({ + feature: 'rag-chat', userId, tier: 'primary', + promptKey: 'rag-chat', promptVersion: 'v1', messages, maxTokens: 2048, + })) { + if (chunk.type === 'content' && chunk.content) { + fullContent += chunk.content; + } + yield chunk; + } + + // Save AI reply + if (fullContent) { + const aiMsg = await this.prisma.chatMessage.create({ + data: { sessionId, role: 'ai', content: fullContent, tokens: fullContent.length }, + }); + for (const c of context.citations.slice(0, 5)) { + await this.prisma.chatCitation.create({ + data: { messageId: aiMsg.id, chunkId: c.id, sourceTitle: c.title ?? null, excerptText: c.text.slice(0, 500) }, + }); + } + } + } + + await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } }); + } + async deleteSession(sessionId: string) { await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } }); await this.prisma.chatMessage.deleteMany({ where: { sessionId } });