feat: #71 RAG Chat SSE 流式输出 + DeepSeek V4 Pro 思考过程
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 20s
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 20s
- AiProvider 接口新增 StreamChunk 类型 + generateStream() 方法 - DeepSeekProvider 实现 generateStream():stream=true,读 reader 逐 chunk yield - AiGatewayService 新增 generateStream():透传 provider stream + 记录用量 - RagChatService 新增 sendMessageStream():流式调用 + 保存最终 AI 回复到 DB - POST /api/rag-chat/sessions/:id/stream 新 SSE endpoint - thinking chunk:DeepSeek V4 Pro reasoning_content → type: "thinking" - 流式模式下禁用 response_format json_object,不阻塞思考过程 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
f4de598d96
commit
6f77162cf8
@ -10,6 +10,7 @@ import { BaseDomainEvent } from '../../../common/events/base-domain.event';
|
|||||||
import { PrismaService } from '../../../infrastructure/database/prisma.service';
|
import { PrismaService } from '../../../infrastructure/database/prisma.service';
|
||||||
import type { AiProvider } from '../providers/ai-provider.interface';
|
import type { AiProvider } from '../providers/ai-provider.interface';
|
||||||
import type { GatewayRequest, GatewayResponse, ModelTier } from './ai-gateway.types';
|
import type { GatewayRequest, GatewayResponse, ModelTier } from './ai-gateway.types';
|
||||||
|
import type { StreamChunk } from '../providers/ai-provider.interface';
|
||||||
|
|
||||||
class AIUsageRecorded extends BaseDomainEvent {
|
class AIUsageRecorded extends BaseDomainEvent {
|
||||||
eventType = 'ai.usage.recorded';
|
eventType = 'ai.usage.recorded';
|
||||||
@ -207,6 +208,63 @@ export class AiGatewayService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async *generateStream(request: GatewayRequest, timeoutMs = this.DEFAULT_TIMEOUT_MS): AsyncGenerator<StreamChunk, void, undefined> {
|
||||||
|
const tierConfig = this.modelRouter.resolve(request.tier);
|
||||||
|
const prompt = this.promptTemplate.get(request.promptKey, request.promptVersion);
|
||||||
|
const messages = [
|
||||||
|
{ role: 'system' as const, content: prompt.systemPrompt },
|
||||||
|
...request.messages,
|
||||||
|
];
|
||||||
|
|
||||||
|
const target = tierConfig.preferred;
|
||||||
|
const provider = this.resolveProviderForTarget(target.provider);
|
||||||
|
|
||||||
|
const controller = new AbortController();
|
||||||
|
const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
|
||||||
|
|
||||||
|
let fullContent = '';
|
||||||
|
let inputTokens = 0;
|
||||||
|
let outputTokens = 0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const chunk of provider.generateStream!({
|
||||||
|
model: target.model,
|
||||||
|
messages,
|
||||||
|
temperature: 0.3,
|
||||||
|
maxTokens: request.maxTokens ?? 4096,
|
||||||
|
signal: controller.signal,
|
||||||
|
})) {
|
||||||
|
if (chunk.type === 'error') {
|
||||||
|
yield { type: 'error', content: chunk.content };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (chunk.type === 'done' && chunk.usage) {
|
||||||
|
inputTokens = chunk.usage.inputTokens;
|
||||||
|
outputTokens = chunk.usage.outputTokens;
|
||||||
|
}
|
||||||
|
if (chunk.type === 'content') {
|
||||||
|
fullContent += (chunk.content || '');
|
||||||
|
}
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record usage after stream completes
|
||||||
|
if (inputTokens > 0) {
|
||||||
|
const estimatedCost = this.costCalculator.calculate(target.provider, target.model, inputTokens, outputTokens);
|
||||||
|
this.usageLog.log({
|
||||||
|
userId: request.userId, feature: request.feature,
|
||||||
|
provider: target.provider, model: target.model, tier: request.tier,
|
||||||
|
promptKey: request.promptKey, promptVersion: prompt.version,
|
||||||
|
inputTokens, outputTokens, estimatedCost, latencyMs: 0, success: true,
|
||||||
|
}).catch(() => {});
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
yield { type: 'error', content: error.message };
|
||||||
|
} finally {
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private buildSystemPrompt(systemPrompt: string, schemaDesc: string): string {
|
private buildSystemPrompt(systemPrompt: string, schemaDesc: string): string {
|
||||||
return `${systemPrompt}\n\n请严格按照以下 JSON Schema 输出,只输出 JSON,不要包含其他内容:\n${schemaDesc}`;
|
return `${systemPrompt}\n\n请严格按照以下 JSON Schema 输出,只输出 JSON,不要包含其他内容:\n${schemaDesc}`;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,7 +17,14 @@ export interface AiGenerateOutput {
|
|||||||
latencyMs: number;
|
latencyMs: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface StreamChunk {
|
||||||
|
type: 'thinking' | 'content' | 'error' | 'done';
|
||||||
|
content?: string;
|
||||||
|
usage?: { inputTokens: number; outputTokens: number };
|
||||||
|
}
|
||||||
|
|
||||||
export interface AiProvider {
|
export interface AiProvider {
|
||||||
readonly name: string;
|
readonly name: string;
|
||||||
generate(input: AiGenerateInput): Promise<AiGenerateOutput>;
|
generate(input: AiGenerateInput): Promise<AiGenerateOutput>;
|
||||||
|
generateStream?(input: AiGenerateInput): AsyncGenerator<StreamChunk, void, undefined>;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { Injectable, Logger } from '@nestjs/common';
|
import { Injectable, Logger } from '@nestjs/common';
|
||||||
import { ConfigService } from '@nestjs/config';
|
import { ConfigService } from '@nestjs/config';
|
||||||
import type { AiProvider, AiGenerateInput, AiGenerateOutput } from './ai-provider.interface';
|
import type { AiProvider, AiGenerateInput, AiGenerateOutput, StreamChunk } from './ai-provider.interface';
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class DeepSeekProvider implements AiProvider {
|
export class DeepSeekProvider implements AiProvider {
|
||||||
@ -64,4 +64,95 @@ export class DeepSeekProvider implements AiProvider {
|
|||||||
latencyMs,
|
latencyMs,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async *generateStream(input: AiGenerateInput): AsyncGenerator<StreamChunk, void, undefined> {
|
||||||
|
if (!this.apiKey) {
|
||||||
|
yield { type: 'error', content: 'DeepSeek API key not configured' };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const body: Record<string, any> = {
|
||||||
|
model: input.model,
|
||||||
|
messages: input.messages,
|
||||||
|
temperature: input.temperature ?? 0.3,
|
||||||
|
max_tokens: input.maxTokens ?? 4096,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
// When streaming, do NOT use response_format json_object — it disables reasoning_content
|
||||||
|
const response = await fetch(`${this.baseUrl}/v1/chat/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${this.apiKey}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
signal: input.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorText = await response.text().catch(() => 'unknown');
|
||||||
|
this.logger.error(`DeepSeek stream error ${response.status}: ${errorText}`);
|
||||||
|
yield { type: 'error', content: `DeepSeek API returned ${response.status}` };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body?.getReader();
|
||||||
|
if (!reader) {
|
||||||
|
yield { type: 'error', content: 'No response body' };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffer = '';
|
||||||
|
let inputTokens = 0;
|
||||||
|
let outputTokens = 0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
const lines = buffer.split('\n');
|
||||||
|
buffer = lines.pop() || '';
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (!line.startsWith('data: ')) continue;
|
||||||
|
const data = line.slice(6).trim();
|
||||||
|
if (data === '[DONE]') {
|
||||||
|
yield { type: 'done', usage: { inputTokens, outputTokens } };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const json = JSON.parse(data);
|
||||||
|
const delta = json.choices?.[0]?.delta;
|
||||||
|
if (!delta) continue;
|
||||||
|
|
||||||
|
// Track usage
|
||||||
|
if (json.usage) {
|
||||||
|
inputTokens = json.usage.prompt_tokens ?? inputTokens;
|
||||||
|
outputTokens = json.usage.completion_tokens ?? outputTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// reasoning_content = thinking process (DeepSeek V4 Pro)
|
||||||
|
if (delta.reasoning_content) {
|
||||||
|
yield { type: 'thinking', content: delta.reasoning_content };
|
||||||
|
}
|
||||||
|
|
||||||
|
// content = actual response
|
||||||
|
if (delta.content) {
|
||||||
|
yield { type: 'content', content: delta.content };
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Skip unparseable chunks
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield { type: 'done', usage: { inputTokens, outputTokens } };
|
||||||
|
} finally {
|
||||||
|
reader.releaseLock();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import { Controller, Get, Post, Delete, Body, Param, UseGuards } from '@nestjs/common';
|
import { Controller, Get, Post, Delete, Body, Param, Res } from '@nestjs/common';
|
||||||
import { ApiTags, ApiOperation, ApiBearerAuth } from '@nestjs/swagger';
|
import { ApiTags, ApiOperation, ApiBearerAuth } from '@nestjs/swagger';
|
||||||
|
import { Response } from 'express';
|
||||||
import { RagChatService } from './rag-chat.service';
|
import { RagChatService } from './rag-chat.service';
|
||||||
import { CurrentUser } from '../../common/decorators/current-user.decorator';
|
import { CurrentUser } from '../../common/decorators/current-user.decorator';
|
||||||
import type { UserPayload } from '../../common/types';
|
import type { UserPayload } from '../../common/types';
|
||||||
@ -29,11 +30,27 @@ export class RagChatController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Post('sessions/:id/messages')
|
@Post('sessions/:id/messages')
|
||||||
@ApiOperation({ summary: '发送消息' })
|
@ApiOperation({ summary: '发送消息(同步)' })
|
||||||
async sendMessage(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }) {
|
async sendMessage(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }) {
|
||||||
return this.svc.sendMessage(String(user.id), id, dto.content);
|
return this.svc.sendMessage(String(user.id), id, dto.content);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Post('sessions/:id/stream')
|
||||||
|
@ApiOperation({ summary: '发送消息(SSE 流式,支持思考过程)' })
|
||||||
|
async sendMessageStream(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }, @Res() res: Response) {
|
||||||
|
res.setHeader('Content-Type', 'text/event-stream');
|
||||||
|
res.setHeader('Cache-Control', 'no-cache');
|
||||||
|
res.setHeader('Connection', 'keep-alive');
|
||||||
|
res.setHeader('X-Accel-Buffering', 'no');
|
||||||
|
|
||||||
|
const userId = String(user.id);
|
||||||
|
for await (const chunk of this.svc.sendMessageStream(userId, id, dto.content)) {
|
||||||
|
if (res.destroyed) break;
|
||||||
|
res.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||||
|
}
|
||||||
|
res.end();
|
||||||
|
}
|
||||||
|
|
||||||
@Delete('sessions/:id')
|
@Delete('sessions/:id')
|
||||||
@ApiOperation({ summary: '删除对话' })
|
@ApiOperation({ summary: '删除对话' })
|
||||||
async deleteSession(@Param('id') id: string) {
|
async deleteSession(@Param('id') id: string) {
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import { PrismaService } from '../../infrastructure/database/prisma.service';
|
|||||||
import { ContentSafetyService } from '../content-safety/content-safety.service';
|
import { ContentSafetyService } from '../content-safety/content-safety.service';
|
||||||
import { AiGatewayService } from '../ai/gateway/ai-gateway.service';
|
import { AiGatewayService } from '../ai/gateway/ai-gateway.service';
|
||||||
import { RagChatOutputSchema } from '../ai/prompts/schemas/rag-chat.schema';
|
import { RagChatOutputSchema } from '../ai/prompts/schemas/rag-chat.schema';
|
||||||
|
import type { StreamChunk } from '../ai/providers/ai-provider.interface';
|
||||||
|
|
||||||
const MAX_CONTEXT_CHARS = 4000;
|
const MAX_CONTEXT_CHARS = 4000;
|
||||||
|
|
||||||
@ -113,6 +114,59 @@ export class RagChatService {
|
|||||||
return { message: aiMsg, citations };
|
return { message: aiMsg, citations };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async *sendMessageStream(userId: string, sessionId: string, content: string): AsyncGenerator<StreamChunk, void, undefined> {
|
||||||
|
const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } });
|
||||||
|
if (!session || session.userId !== userId) {
|
||||||
|
yield { type: 'error', content: '对话不存在' };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' });
|
||||||
|
if (inputCheck && !inputCheck.safe) {
|
||||||
|
yield { type: 'error', content: '输入包含违规内容' };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save user message
|
||||||
|
await this.prisma.chatMessage.create({ data: { sessionId, role: 'user', content } });
|
||||||
|
|
||||||
|
// Load context
|
||||||
|
const context = await this.loadContext(session.knowledgeBaseId);
|
||||||
|
if (!context.text) {
|
||||||
|
yield { type: 'content', content: this.fallbackReply(true) };
|
||||||
|
} else {
|
||||||
|
const messages = [
|
||||||
|
{ role: 'system' as const, content: this.buildSystemPrompt(context.text) },
|
||||||
|
{ role: 'user' as const, content },
|
||||||
|
];
|
||||||
|
|
||||||
|
let fullContent = '';
|
||||||
|
for await (const chunk of this.aiGateway!.generateStream({
|
||||||
|
feature: 'rag-chat', userId, tier: 'primary',
|
||||||
|
promptKey: 'rag-chat', promptVersion: 'v1', messages, maxTokens: 2048,
|
||||||
|
})) {
|
||||||
|
if (chunk.type === 'content' && chunk.content) {
|
||||||
|
fullContent += chunk.content;
|
||||||
|
}
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save AI reply
|
||||||
|
if (fullContent) {
|
||||||
|
const aiMsg = await this.prisma.chatMessage.create({
|
||||||
|
data: { sessionId, role: 'ai', content: fullContent, tokens: fullContent.length },
|
||||||
|
});
|
||||||
|
for (const c of context.citations.slice(0, 5)) {
|
||||||
|
await this.prisma.chatCitation.create({
|
||||||
|
data: { messageId: aiMsg.id, chunkId: c.id, sourceTitle: c.title ?? null, excerptText: c.text.slice(0, 500) },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } });
|
||||||
|
}
|
||||||
|
|
||||||
async deleteSession(sessionId: string) {
|
async deleteSession(sessionId: string) {
|
||||||
await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } });
|
await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } });
|
||||||
await this.prisma.chatMessage.deleteMany({ where: { sessionId } });
|
await this.prisma.chatMessage.deleteMany({ where: { sessionId } });
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user