feat: #71 RAG Chat SSE 流式输出 + DeepSeek V4 Pro 思考过程
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 20s

- AiProvider 接口新增 StreamChunk 类型 + generateStream() 方法
- DeepSeekProvider 实现 generateStream():stream=true,读 reader 逐 chunk yield
- AiGatewayService 新增 generateStream():透传 provider stream + 记录用量
- RagChatService 新增 sendMessageStream():流式调用 + 保存最终 AI 回复到 DB
- POST /api/rag-chat/sessions/:id/stream 新 SSE endpoint
- thinking chunk:DeepSeek V4 Pro reasoning_content → type: "thinking"
- 流式模式下禁用 response_format json_object,不阻塞思考过程

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wangdl 2026-06-06 14:49:20 +08:00
parent f4de598d96
commit 6f77162cf8
5 changed files with 230 additions and 3 deletions

View File

@ -10,6 +10,7 @@ import { BaseDomainEvent } from '../../../common/events/base-domain.event';
import { PrismaService } from '../../../infrastructure/database/prisma.service';
import type { AiProvider } from '../providers/ai-provider.interface';
import type { GatewayRequest, GatewayResponse, ModelTier } from './ai-gateway.types';
import type { StreamChunk } from '../providers/ai-provider.interface';
class AIUsageRecorded extends BaseDomainEvent {
eventType = 'ai.usage.recorded';
@ -207,6 +208,63 @@ export class AiGatewayService {
}
}
async *generateStream(request: GatewayRequest, timeoutMs = this.DEFAULT_TIMEOUT_MS): AsyncGenerator<StreamChunk, void, undefined> {
const tierConfig = this.modelRouter.resolve(request.tier);
const prompt = this.promptTemplate.get(request.promptKey, request.promptVersion);
const messages = [
{ role: 'system' as const, content: prompt.systemPrompt },
...request.messages,
];
const target = tierConfig.preferred;
const provider = this.resolveProviderForTarget(target.provider);
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
let fullContent = '';
let inputTokens = 0;
let outputTokens = 0;
try {
for await (const chunk of provider.generateStream!({
model: target.model,
messages,
temperature: 0.3,
maxTokens: request.maxTokens ?? 4096,
signal: controller.signal,
})) {
if (chunk.type === 'error') {
yield { type: 'error', content: chunk.content };
return;
}
if (chunk.type === 'done' && chunk.usage) {
inputTokens = chunk.usage.inputTokens;
outputTokens = chunk.usage.outputTokens;
}
if (chunk.type === 'content') {
fullContent += (chunk.content || '');
}
yield chunk;
}
// Record usage after stream completes
if (inputTokens > 0) {
const estimatedCost = this.costCalculator.calculate(target.provider, target.model, inputTokens, outputTokens);
this.usageLog.log({
userId: request.userId, feature: request.feature,
provider: target.provider, model: target.model, tier: request.tier,
promptKey: request.promptKey, promptVersion: prompt.version,
inputTokens, outputTokens, estimatedCost, latencyMs: 0, success: true,
}).catch(() => {});
}
} catch (error: any) {
yield { type: 'error', content: error.message };
} finally {
clearTimeout(timeoutId);
}
}
private buildSystemPrompt(systemPrompt: string, schemaDesc: string): string {
return `${systemPrompt}\n\n请严格按照以下 JSON Schema 输出,只输出 JSON不要包含其他内容\n${schemaDesc}`;
}

View File

@ -17,7 +17,14 @@ export interface AiGenerateOutput {
latencyMs: number;
}
export interface StreamChunk {
type: 'thinking' | 'content' | 'error' | 'done';
content?: string;
usage?: { inputTokens: number; outputTokens: number };
}
export interface AiProvider {
readonly name: string;
generate(input: AiGenerateInput): Promise<AiGenerateOutput>;
generateStream?(input: AiGenerateInput): AsyncGenerator<StreamChunk, void, undefined>;
}

View File

@ -1,6 +1,6 @@
import { Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import type { AiProvider, AiGenerateInput, AiGenerateOutput } from './ai-provider.interface';
import type { AiProvider, AiGenerateInput, AiGenerateOutput, StreamChunk } from './ai-provider.interface';
@Injectable()
export class DeepSeekProvider implements AiProvider {
@ -64,4 +64,95 @@ export class DeepSeekProvider implements AiProvider {
latencyMs,
};
}
async *generateStream(input: AiGenerateInput): AsyncGenerator<StreamChunk, void, undefined> {
if (!this.apiKey) {
yield { type: 'error', content: 'DeepSeek API key not configured' };
return;
}
const body: Record<string, any> = {
model: input.model,
messages: input.messages,
temperature: input.temperature ?? 0.3,
max_tokens: input.maxTokens ?? 4096,
stream: true,
};
// When streaming, do NOT use response_format json_object — it disables reasoning_content
const response = await fetch(`${this.baseUrl}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(body),
signal: input.signal,
});
if (!response.ok) {
const errorText = await response.text().catch(() => 'unknown');
this.logger.error(`DeepSeek stream error ${response.status}: ${errorText}`);
yield { type: 'error', content: `DeepSeek API returned ${response.status}` };
return;
}
const reader = response.body?.getReader();
if (!reader) {
yield { type: 'error', content: 'No response body' };
return;
}
const decoder = new TextDecoder();
let buffer = '';
let inputTokens = 0;
let outputTokens = 0;
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop() || '';
for (const line of lines) {
if (!line.startsWith('data: ')) continue;
const data = line.slice(6).trim();
if (data === '[DONE]') {
yield { type: 'done', usage: { inputTokens, outputTokens } };
return;
}
try {
const json = JSON.parse(data);
const delta = json.choices?.[0]?.delta;
if (!delta) continue;
// Track usage
if (json.usage) {
inputTokens = json.usage.prompt_tokens ?? inputTokens;
outputTokens = json.usage.completion_tokens ?? outputTokens;
}
// reasoning_content = thinking process (DeepSeek V4 Pro)
if (delta.reasoning_content) {
yield { type: 'thinking', content: delta.reasoning_content };
}
// content = actual response
if (delta.content) {
yield { type: 'content', content: delta.content };
}
} catch {
// Skip unparseable chunks
}
}
}
yield { type: 'done', usage: { inputTokens, outputTokens } };
} finally {
reader.releaseLock();
}
}
}

View File

@ -1,5 +1,6 @@
import { Controller, Get, Post, Delete, Body, Param, UseGuards } from '@nestjs/common';
import { Controller, Get, Post, Delete, Body, Param, Res } from '@nestjs/common';
import { ApiTags, ApiOperation, ApiBearerAuth } from '@nestjs/swagger';
import { Response } from 'express';
import { RagChatService } from './rag-chat.service';
import { CurrentUser } from '../../common/decorators/current-user.decorator';
import type { UserPayload } from '../../common/types';
@ -29,11 +30,27 @@ export class RagChatController {
}
@Post('sessions/:id/messages')
@ApiOperation({ summary: '发送消息' })
@ApiOperation({ summary: '发送消息(同步)' })
async sendMessage(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }) {
return this.svc.sendMessage(String(user.id), id, dto.content);
}
@Post('sessions/:id/stream')
@ApiOperation({ summary: '发送消息SSE 流式,支持思考过程)' })
async sendMessageStream(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }, @Res() res: Response) {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.setHeader('X-Accel-Buffering', 'no');
const userId = String(user.id);
for await (const chunk of this.svc.sendMessageStream(userId, id, dto.content)) {
if (res.destroyed) break;
res.write(`data: ${JSON.stringify(chunk)}\n\n`);
}
res.end();
}
@Delete('sessions/:id')
@ApiOperation({ summary: '删除对话' })
async deleteSession(@Param('id') id: string) {

View File

@ -3,6 +3,7 @@ import { PrismaService } from '../../infrastructure/database/prisma.service';
import { ContentSafetyService } from '../content-safety/content-safety.service';
import { AiGatewayService } from '../ai/gateway/ai-gateway.service';
import { RagChatOutputSchema } from '../ai/prompts/schemas/rag-chat.schema';
import type { StreamChunk } from '../ai/providers/ai-provider.interface';
const MAX_CONTEXT_CHARS = 4000;
@ -113,6 +114,59 @@ export class RagChatService {
return { message: aiMsg, citations };
}
async *sendMessageStream(userId: string, sessionId: string, content: string): AsyncGenerator<StreamChunk, void, undefined> {
const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } });
if (!session || session.userId !== userId) {
yield { type: 'error', content: '对话不存在' };
return;
}
const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' });
if (inputCheck && !inputCheck.safe) {
yield { type: 'error', content: '输入包含违规内容' };
return;
}
// Save user message
await this.prisma.chatMessage.create({ data: { sessionId, role: 'user', content } });
// Load context
const context = await this.loadContext(session.knowledgeBaseId);
if (!context.text) {
yield { type: 'content', content: this.fallbackReply(true) };
} else {
const messages = [
{ role: 'system' as const, content: this.buildSystemPrompt(context.text) },
{ role: 'user' as const, content },
];
let fullContent = '';
for await (const chunk of this.aiGateway!.generateStream({
feature: 'rag-chat', userId, tier: 'primary',
promptKey: 'rag-chat', promptVersion: 'v1', messages, maxTokens: 2048,
})) {
if (chunk.type === 'content' && chunk.content) {
fullContent += chunk.content;
}
yield chunk;
}
// Save AI reply
if (fullContent) {
const aiMsg = await this.prisma.chatMessage.create({
data: { sessionId, role: 'ai', content: fullContent, tokens: fullContent.length },
});
for (const c of context.citations.slice(0, 5)) {
await this.prisma.chatCitation.create({
data: { messageId: aiMsg.id, chunkId: c.id, sourceTitle: c.title ?? null, excerptText: c.text.slice(0, 500) },
});
}
}
}
await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } });
}
async deleteSession(sessionId: string) {
await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } });
await this.prisma.chatMessage.deleteMany({ where: { sessionId } });