Browse Source

AI 应用接口开发, 支持流式输出

master
hygl 12 months ago
parent
commit
f937ea2756
  1. 47
      src/main/java/com/wok/supportbot/app/AssistantApp.java
  2. 5
      src/main/java/com/wok/supportbot/app/ProductInfoApp.java
  3. 25
      src/main/java/com/wok/supportbot/config/CorsConfig.java
  4. 89
      src/main/java/com/wok/supportbot/controller/AiController.java
  5. 15
      src/main/java/com/wok/supportbot/controller/HealthController.java
  6. 2
      src/main/java/com/wok/supportbot/rag/config/QueryExpanderConfig.java
  7. 7
      src/main/java/com/wok/supportbot/rag/config/QueryTransformerConfig.java
  8. 7
      src/main/java/com/wok/supportbot/rag/load/InMemoryVectorStoreConfig.java
  9. 7
      src/main/java/com/wok/supportbot/rag/load/PgVectorStoreConfig.java
  10. 4
      src/main/java/com/wok/supportbot/rag/preretrieval/CompressionQueryRewriter.java
  11. 2
      src/main/java/com/wok/supportbot/rag/preretrieval/MultiQueryExpanderRewriter.java
  12. 2
      src/main/java/com/wok/supportbot/rag/preretrieval/RewriteQueryRewriter.java
  13. 2
      src/main/java/com/wok/supportbot/rag/preretrieval/TranslationQueryRewriter.java
  14. 1
      src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java
  15. 10
      src/test/java/com/wok/supportbot/QueryTransformerTests.java
  16. 6
      src/test/java/com/wok/supportbot/SupportBotApplicationTests.java

47
src/main/java/com/wok/supportbot/app/AssistantApp.java

@ -3,6 +3,10 @@ package com.wok.supportbot.app;
import com.wok.supportbot.advisor.MyLoggerAdvisor; import com.wok.supportbot.advisor.MyLoggerAdvisor;
import com.wok.supportbot.advisor.ReReadingAdvisor; import com.wok.supportbot.advisor.ReReadingAdvisor;
import com.wok.supportbot.chatmemory.DatabaseChatMemory; import com.wok.supportbot.chatmemory.DatabaseChatMemory;
import com.wok.supportbot.rag.preretrieval.CompressionQueryRewriter;
import com.wok.supportbot.rag.preretrieval.MultiQueryExpanderRewriter;
import com.wok.supportbot.rag.preretrieval.RewriteQueryRewriter;
import com.wok.supportbot.rag.preretrieval.TranslationQueryRewriter;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
@ -22,6 +26,7 @@ import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -41,7 +46,7 @@ import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvis
public class AssistantApp { public class AssistantApp {
@Resource @Resource
private VectorStore vectorStore;
private VectorStore pgVectorVectorStore;
private final ChatClient chatClient; private final ChatClient chatClient;
@ -90,6 +95,33 @@ public class AssistantApp {
return chatResponse.getResult().getOutput().getText(); return chatResponse.getResult().getOutput().getText();
} }
/**
* AI 基础对话支持多轮对话记忆SSE 流式传输
*
* @param message
* @param chatId
* @return
*/
public Flux<String> doChatByStream(String message, String chatId) {
return chatClient
.prompt()
.user(message)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
.stream()
.content();
}
// AI 恋爱知识库问答功能
@Resource
RewriteQueryRewriter rewriteQueryRewriter;
@Resource
CompressionQueryRewriter compressionQueryRewriter;
@Resource
MultiQueryExpanderRewriter multiQueryExpanderRewriter;
@Resource
TranslationQueryRewriter translationQueryRewriter;
/** /**
* RAG 知识库进行对话 * RAG 知识库进行对话
@ -99,13 +131,17 @@ public class AssistantApp {
* @return * @return
*/ */
public String doChatWithRag(String message, String chatId) { public String doChatWithRag(String message, String chatId) {
// 在预检索阶段系统接收用户的原始查询通过查询转换和查询扩展等方法对其进行优化输出增强的用户查询
// String rewrittenMessage = translationQueryRewriter.doQueryRewrite(message);
String rewrittenMessage = rewriteQueryRewriter.doQueryRewrite(message);
ChatResponse chatResponse = chatClient ChatResponse chatResponse = chatClient
.prompt() .prompt()
.user(message)
.user(rewrittenMessage)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId) .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
// 应用 RAG 知识库问答 // 应用 RAG 知识库问答
.advisors(QuestionAnswerAdvisor.builder(vectorStore)
.advisors(QuestionAnswerAdvisor.builder(pgVectorVectorStore)
// 相似度阈值为 0.0并返回最相关的前 4 个结果 // 相似度阈值为 0.0并返回最相关的前 4 个结果
.searchRequest(SearchRequest.builder().similarityThreshold(0.0).topK(4).build()) .searchRequest(SearchRequest.builder().similarityThreshold(0.0).topK(4).build())
.build()) .build())
@ -121,7 +157,7 @@ public class AssistantApp {
private MultiQueryExpander multiQueryExpander; private MultiQueryExpander multiQueryExpander;
/** /**
* RAG 知识库进行对话
* RAG 知识库进行对话(另外一种使用方式)
* *
* @param message * @param message
* @param chatId * @param chatId
@ -129,10 +165,11 @@ public class AssistantApp {
*/ */
public String doChatWithRagEnhance(String message, String chatId) { public String doChatWithRagEnhance(String message, String chatId) {
Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
// todo 不生效
//.queryTransformers(queryTransformers) //.queryTransformers(queryTransformers)
//.queryExpander(multiQueryExpander) //.queryExpander(multiQueryExpander)
.documentRetriever(VectorStoreDocumentRetriever.builder() .documentRetriever(VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.vectorStore(pgVectorVectorStore)
.similarityThreshold(0.5) .similarityThreshold(0.5)
.topK(4) .topK(4)
.build()) .build())

5
src/main/java/com/wok/supportbot/app/ProductInfoApp.java

@ -45,16 +45,13 @@ public class ProductInfoApp {
/** /**
* 商品信息结构化抽取 * 商品信息结构化抽取
* @param rawContent 爬取的商品网页内容 * @param rawContent 爬取的商品网页内容
* @param chatId 对话ID
* @return 结构化的商品信息对象 * @return 结构化的商品信息对象
*/ */
public ProductInfo extractProductInfo(String rawContent, String chatId) {
public ProductInfo extractProductInfo(String rawContent) {
ProductInfo productInfo = chatClient ProductInfo productInfo = chatClient
.prompt() .prompt()
.system(SYSTEM_PROMPT) .system(SYSTEM_PROMPT)
.user(rawContent) .user(rawContent)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
.call() .call()
.entity(ProductInfo.class); .entity(ProductInfo.class);
log.info("Extracted product info: {}", productInfo); log.info("Extracted product info: {}", productInfo);

25
src/main/java/com/wok/supportbot/config/CorsConfig.java

@ -0,0 +1,25 @@
package com.wok.supportbot.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* 全局跨域配置
*/
@Configuration
public class CorsConfig implements WebMvcConfigurer {
@Override
public void addCorsMappings(CorsRegistry registry) {
// 覆盖所有请求
registry.addMapping("/**")
// 允许发送 Cookie
.allowCredentials(true)
// 放行哪些域名必须用 patterns否则 * 会和 allowCredentials 冲突
.allowedOriginPatterns("*")
.allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS")
.allowedHeaders("*")
.exposedHeaders("*");
}
}

89
src/main/java/com/wok/supportbot/controller/AiController.java

@ -0,0 +1,89 @@
package com.wok.supportbot.controller;
import com.wok.supportbot.app.AssistantApp;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
import java.io.IOException;
public class AiController {
@Resource
private AssistantApp assistantApp;
/**
* 同步调用 AI 智能客服应用
*
* @param message
* @param chatId
* @return
*/
@GetMapping("/assistant_app/chat/sync")
public String doChatWithAssistantAppSync(String message, String chatId) {
return assistantApp.doChat(message, chatId);
}
/**
* SSE 流式调用 AI 智能客服应用
* 返回Flux 响应式؜对象并且添加 SSE 对应的 MediaType
*
* @param message
* @param chatId
* @return
*/
@GetMapping(value = "/assistant_app/chat/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> doChatWithLoveAppSSE(String message, String chatId) {
return assistantApp.doChatByStream(message, chatId);
}
/**
* SSE 流式调用 AI 智能客服应用
* 返回 Flux 对象并且؜设置泛型为 ServerSentEvent使用这种方式可以省略 MediaType
*
* @param message
* @param chatId
* @return
*/
@GetMapping(value = "/assistant_app/chat/server_sent_event")
public Flux<ServerSentEvent<String>> doChatWithAssistantAppServerSentEvent(String message, String chatId) {
return assistantApp.doChatByStream(message, chatId)
.map(chunk -> ServerSentEvent.<String>builder()
.data(chunk)
.build());
}
/**
* SSE 流式调用 AI 智能客服应用
* 使用 SSEEmiter؜通过 send 方法持续向 SseEmitter 发送消息
*
* @param message
* @param chatId
* @return
*/
@GetMapping(value = "/assistant_app/chat/sse_emitter")
public SseEmitter doChatWithAssistantAppServerSseEmitter(String message, String chatId) {
// 创建一个超时时间较长的 SseEmitter
SseEmitter sseEmitter = new SseEmitter(180000L); // 3 分钟超时
// 获取 Flux 响应式数据流并且直接通过订阅推送给 SseEmitter
assistantApp.doChatByStream(message, chatId)
.subscribe(chunk -> {
try {
sseEmitter.send(chunk);
} catch (IOException e) {
sseEmitter.completeWithError(e);
}
}, sseEmitter::completeWithError, sseEmitter::complete);
// 返回
return sseEmitter;
}
}

15
src/main/java/com/wok/supportbot/controller/HealthController.java

@ -1,15 +0,0 @@
package com.wok.supportbot.controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/health")
public class HealthController {
@GetMapping
public String healthCheck() {
return "ok";
}
}

2
src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java → src/main/java/com/wok/supportbot/rag/config/QueryExpanderConfig.java

@ -1,4 +1,4 @@
package com.wok.supportbot.config;
package com.wok.supportbot.rag.config;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;

7
src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java → src/main/java/com/wok/supportbot/rag/config/QueryTransformerConfig.java

@ -1,19 +1,14 @@
package com.wok.supportbot.config;
package com.wok.supportbot.rag.config;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import java.util.List;
@Configuration @Configuration
public class QueryTransformerConfig { public class QueryTransformerConfig {

7
src/main/java/com/wok/supportbot/load/InMemoryVectorStoreConfig.java → src/main/java/com/wok/supportbot/rag/load/InMemoryVectorStoreConfig.java

@ -1,16 +1,11 @@
package com.wok.supportbot.load;
package com.wok.supportbot.rag.load;
import com.wok.supportbot.extract.MarkdownDocumentLoader;
import jakarta.annotation.Resource;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SimpleVectorStore; import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import java.util.List;
/** /**
* 向量数据库配置初始化基于内存的向量数据库 Bean * 向量数据库配置初始化基于内存的向量数据库 Bean
*/ */

7
src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java → src/main/java/com/wok/supportbot/rag/load/PgVectorStoreConfig.java

@ -1,8 +1,5 @@
package com.wok.supportbot.load;
package com.wok.supportbot.rag.load;
import com.wok.supportbot.extract.MarkdownDocumentLoader;
import jakarta.annotation.Resource;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore; import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
@ -11,8 +8,6 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.JdbcTemplate;
import java.util.List;
import static org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgDistanceType.COSINE_DISTANCE; import static org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgDistanceType.COSINE_DISTANCE;
import static org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType.HNSW; import static org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType.HNSW;
/** /**

4
src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java → src/main/java/com/wok/supportbot/rag/preretrieval/CompressionQueryRewriter.java

@ -1,10 +1,8 @@
package com.wok.supportbot.preretrieval;
package com.wok.supportbot.rag.preretrieval;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query; import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;

2
src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java → src/main/java/com/wok/supportbot/rag/preretrieval/MultiQueryExpanderRewriter.java

@ -1,4 +1,4 @@
package com.wok.supportbot.preretrieval;
package com.wok.supportbot.rag.preretrieval;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;

2
src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java → src/main/java/com/wok/supportbot/rag/preretrieval/RewriteQueryRewriter.java

@ -1,4 +1,4 @@
package com.wok.supportbot.preretrieval;
package com.wok.supportbot.rag.preretrieval;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;

2
src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java → src/main/java/com/wok/supportbot/rag/preretrieval/TranslationQueryRewriter.java

@ -1,4 +1,4 @@
package com.wok.supportbot.preretrieval;
package com.wok.supportbot.rag.preretrieval;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;

1
src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java

@ -1,6 +1,5 @@
package com.wok.supportbot; package com.wok.supportbot;
import com.wok.supportbot.preretrieval.RewriteQueryRewriter;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;

10
src/test/java/com/wok/supportbot/QueryTransformerTests.java

@ -1,15 +1,13 @@
package com.wok.supportbot; package com.wok.supportbot;
import com.wok.supportbot.preretrieval.CompressionQueryRewriter;
import com.wok.supportbot.preretrieval.MultiQueryExpanderRewriter;
import com.wok.supportbot.preretrieval.RewriteQueryRewriter;
import com.wok.supportbot.preretrieval.TranslationQueryRewriter;
import jakarta.annotation.Resource;
import com.wok.supportbot.rag.preretrieval.CompressionQueryRewriter;
import com.wok.supportbot.rag.preretrieval.MultiQueryExpanderRewriter;
import com.wok.supportbot.rag.preretrieval.RewriteQueryRewriter;
import com.wok.supportbot.rag.preretrieval.TranslationQueryRewriter;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.rag.Query;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;

6
src/test/java/com/wok/supportbot/SupportBotApplicationTests.java

@ -44,12 +44,8 @@ class SupportBotApplicationTests {
String rawContent = "这是商品标题:智能手表Pro 2025," + String rawContent = "这是商品标题:智能手表Pro 2025," +
"描述:这款智能手表支持心率监测和GPS," + "描述:这款智能手表支持心率监测和GPS," +
"价格:299美元,评分:4.7星,评论数:1567,品牌:TechBrand,分类:电子产品。"; "价格:299美元,评分:4.7星,评论数:1567,品牌:TechBrand,分类:电子产品。";
// 生成随机聊天ID模拟独立会话
String chatId = UUID.randomUUID().toString();
// 调用方法 // 调用方法
ProductInfo productInfo = productInfoApp.extractProductInfo(rawContent, chatId);
ProductInfo productInfo = productInfoApp.extractProductInfo(rawContent);
// 断言结果不为空 // 断言结果不为空
Assertions.assertNotNull(productInfo); Assertions.assertNotNull(productInfo);

Loading…
Cancel
Save