diff --git a/src/modules/ai/ai.module.ts b/src/modules/ai/ai.module.ts index 8de815d..6377369 100644 --- a/src/modules/ai/ai.module.ts +++ b/src/modules/ai/ai.module.ts @@ -13,6 +13,7 @@ import { LearningTrendWorkflow } from './workflows/learning-trend.workflow'; import { AdminAiGatewayController } from './ai.controller'; import { MockAiProvider } from './providers/mock-ai.provider'; import { DeepSeekProvider } from './providers/deepseek.provider'; +import { SiliconFlowProvider } from './providers/siliconflow.provider'; import { MiniMaxProvider } from './providers/minimax.provider'; import type { AiProvider } from './providers/ai-provider.interface'; diff --git a/src/modules/ai/gateway/ai-gateway.service.ts b/src/modules/ai/gateway/ai-gateway.service.ts index afe018a..3f4e5e1 100644 --- a/src/modules/ai/gateway/ai-gateway.service.ts +++ b/src/modules/ai/gateway/ai-gateway.service.ts @@ -4,6 +4,7 @@ import { ModelRouter } from '../model-router'; import { PromptTemplateService } from '../prompts/prompt-template.service'; import { AiCostCalculatorService } from '../usage/ai-cost-calculator.service'; import { AiUsageLogService } from '../usage/ai-usage-log.service'; +import { ContentSafetyService } from '../../content-safety/content-safety.service'; import type { AiProvider } from '../providers/ai-provider.interface'; import type { GatewayRequest, GatewayResponse, ModelTier } from './ai-gateway.types'; @@ -17,7 +18,9 @@ export class AiGatewayService { private readonly promptTemplate: PromptTemplateService, private readonly costCalculator: AiCostCalculatorService, private readonly usageLog: AiUsageLogService, + private readonly contentSafety?: ContentSafetyService, private readonly providers: Map, + private readonly contentSafety?: ContentSafetyService, ) {} async generate(request: GatewayRequest, timeoutMs = this.DEFAULT_TIMEOUT_MS): Promise { @@ -47,6 +50,9 @@ export class AiGatewayService { signal: controller.signal, }); + const safetyCheck = await this.contentSafety?.check(output.rawText, { contentType: 'ai_output' }).catch(() => ({ safe: true })); + if (!safetyCheck.safe) throw new Error('AI output blocked by content safety'); + const parsed = this.parseJson(output.rawText, request.outputSchema); const estimatedCost = this.costCalculator.calculate( target.provider, diff --git a/src/modules/ai/providers/deepseek.provider.ts b/src/modules/ai/providers/deepseek.provider.ts index 7be0dab..647b138 100644 --- a/src/modules/ai/providers/deepseek.provider.ts +++ b/src/modules/ai/providers/deepseek.provider.ts @@ -4,6 +4,8 @@ import type { AiProvider, AiGenerateInput, AiGenerateOutput } from './ai-provide @Injectable() export class DeepSeekProvider implements AiProvider { + private baseUrl = process.env.DEEPSEEK_BASE_URL || 'https://api.deepseek.com/v1'; + private apiKey = process.env.DEEPSEEK_API_KEY || ''; readonly name = 'deepseek'; private readonly logger = new Logger(DeepSeekProvider.name); private readonly apiKey: string; diff --git a/src/modules/ai/providers/siliconflow.provider.ts b/src/modules/ai/providers/siliconflow.provider.ts new file mode 100644 index 0000000..d518080 --- /dev/null +++ b/src/modules/ai/providers/siliconflow.provider.ts @@ -0,0 +1,40 @@ +import { Injectable } from '@nestjs/common'; +import type { AiProvider, AiGenerateInput, AiGenerateOutput } from './ai-provider.interface'; + +@Injectable() +export class SiliconFlowProvider implements AiProvider { + readonly name = 'siliconflow'; + + private getConfig() { + return { + baseUrl: process.env.SILICONFLOW_BASE_URL || 'https://api.siliconflow.cn/v1', + apiKey: process.env.SILICONFLOW_API_KEY || '', + }; + } + + async generate(input: AiGenerateInput): Promise { + const { baseUrl, apiKey } = this.getConfig(); + const start = Date.now(); + + const resp = await fetch(`${baseUrl}/chat/completions`, { + method: 'POST', + headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${apiKey}` }, + body: JSON.stringify({ + model: input.model, + messages: input.messages, + temperature: input.temperature ?? 0.3, + max_tokens: input.maxTokens ?? 4096, + }), + signal: input.signal ?? AbortSignal.timeout(60_000), + }); + + if (!resp.ok) throw new Error(`SiliconFlow ${resp.status}: ${await resp.text().slice(0, 200)}`); + const data = await resp.json(); + const choice = data.choices?.[0]; + return { + rawText: choice?.message?.content || '', + usage: { inputTokens: data.usage?.prompt_tokens || 0, outputTokens: data.usage?.completion_tokens || 0 }, + latencyMs: Date.now() - start, + }; + } +}