From 936bcbdc7005be3851745e9655239bc34c4c8c89 Mon Sep 17 00:00:00 2001 From: hygl <3154803225@qq.com> Date: Sat, 28 Jun 2025 17:39:07 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=9F=A5=E8=AF=A2=E9=87=8D?= =?UTF-8?q?=E5=86=99=E5=99=A8=E3=80=81=E7=BF=BB=E8=AF=91=E5=99=A8=E3=80=81?= =?UTF-8?q?=E5=8E=8B=E7=BC=A9=E5=99=A8=E5=92=8C=E5=A4=9A=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E6=89=A9=E5=B1=95=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/wok/supportbot/app/AssistantApp.java | 63 ++++++++++++++- .../config/QueryExpanderConfig.java | 20 +++++ .../config/QueryTransformerConfig.java | 41 ++++++++++ .../supportbot/load/PgVectorStoreConfig.java | 5 +- .../CompressionQueryRewriter.java | 46 +++++++++++ .../MultiQueryExpanderRewriter.java | 41 ++++++++++ .../preretrieval/RewriteQueryRewriter.java | 39 ++++++++++ .../TranslationQueryRewriter.java | 37 +++++++++ .../PgVectorVectorStoreConfigTest.java | 1 + .../wok/supportbot/QueryTransformerTests.java | 76 +++++++++++++++++++ .../SupportBotApplicationTests.java | 8 ++ 11 files changed, 371 insertions(+), 6 deletions(-) create mode 100644 src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java create mode 100644 src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java create mode 100644 src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java create mode 100644 src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java create mode 100644 src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java create mode 100644 src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java create mode 100644 src/test/java/com/wok/supportbot/QueryTransformerTests.java diff --git a/src/main/java/com/wok/supportbot/app/AssistantApp.java b/src/main/java/com/wok/supportbot/app/AssistantApp.java index cfd0467..9d7ee12 100644 --- a/src/main/java/com/wok/supportbot/app/AssistantApp.java +++ b/src/main/java/com/wok/supportbot/app/AssistantApp.java @@ -1,19 +1,31 @@ package com.wok.supportbot.app; import com.wok.supportbot.advisor.MyLoggerAdvisor; +import com.wok.supportbot.advisor.ReReadingAdvisor; import com.wok.supportbot.chatmemory.DatabaseChatMemory; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; -import org.springframework.ai.chat.memory.ChatMemory; -import org.springframework.ai.chat.memory.InMemoryChatMemory; +import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; +import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; +import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; +import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import java.util.ArrayList; +import java.util.List; + import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY; import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY; @@ -38,6 +50,7 @@ public class AssistantApp { /** * 初始化 ChatClient + * * @param dashscopeChatModel */ public AssistantApp(ChatModel dashscopeChatModel, DatabaseChatMemory chatMemory) { @@ -61,6 +74,7 @@ public class AssistantApp { /** * AI 基础对话(支持多轮对话记忆) + * * @param message * @param chatId * @return @@ -79,6 +93,7 @@ public class AssistantApp { /** * 和 RAG 知识库进行对话 + * * @param message * @param chatId * @return @@ -90,7 +105,49 @@ public class AssistantApp { .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId) .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) // 应用 RAG 知识库问答 - .advisors(new QuestionAnswerAdvisor(vectorStore)) + .advisors(QuestionAnswerAdvisor.builder(vectorStore) + // 相似度阈值为 0.0,并返回最相关的前 4 个结果 + .searchRequest(SearchRequest.builder().similarityThreshold(0.0).topK(4).build()) + .build()) + .call() + .chatResponse(); + return chatResponse.getResult().getOutput().getText(); + } + + + @Autowired + private List queryTransformers; + @Autowired + private MultiQueryExpander multiQueryExpander; + + /** + * 和 RAG 知识库进行对话 + * + * @param message + * @param chatId + * @return + */ + public String doChatWithRagEnhance(String message, String chatId) { + Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + //.queryTransformers(queryTransformers) + //.queryExpander(multiQueryExpander) + .documentRetriever(VectorStoreDocumentRetriever.builder() + .vectorStore(vectorStore) + .similarityThreshold(0.5) + .topK(4) + .build()) + .queryAugmenter(ContextualQueryAugmenter.builder() + .allowEmptyContext(false) // 不允许模型在没有找到相关文档的情况下也生成回答 + .build()) + .build(); + + ChatResponse chatResponse = chatClient + .prompt() + .user(message) + .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId) + .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) + // 应用 RAG 知识库问答 + .advisors(retrievalAugmentationAdvisor) .call() .chatResponse(); return chatResponse.getResult().getOutput().getText(); diff --git a/src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java b/src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java new file mode 100644 index 0000000..de35e86 --- /dev/null +++ b/src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java @@ -0,0 +1,20 @@ +package com.wok.supportbot.config; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class QueryExpanderConfig { + + @Bean + public MultiQueryExpander multiQueryExpander(ChatModel dashscopeChatModel) { + return MultiQueryExpander.builder() + .chatClientBuilder(ChatClient.builder(dashscopeChatModel)) + .numberOfQueries(3) + .includeOriginal(true) + .build(); + } +} diff --git a/src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java b/src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java new file mode 100644 index 0000000..d02526f --- /dev/null +++ b/src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java @@ -0,0 +1,41 @@ +package com.wok.supportbot.config; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; +import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; +import org.springframework.ai.rag.retrieval.search.DocumentRetriever; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import java.util.List; + +@Configuration +public class QueryTransformerConfig { + + @Bean + public QueryTransformer rewriteQueryTransformer(ChatModel dashscopeChatModel) { + return RewriteQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(dashscopeChatModel)) + .build(); + } + + @Bean + public QueryTransformer translationQueryTransformer(ChatModel dashscopeChatModel) { + return TranslationQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(dashscopeChatModel)) + .targetLanguage("chinese") + .build(); + } + + @Bean + public QueryTransformer compressionQueryTransformer(ChatModel dashscopeChatModel) { + return CompressionQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(dashscopeChatModel)) + .build(); + } +} diff --git a/src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java b/src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java index 64822e6..d44466f 100644 --- a/src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java +++ b/src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java @@ -22,9 +22,9 @@ import static org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexT public class PgVectorStoreConfig { @Bean - @Primary + @Primary // 默认使用pgsql储存向量 public VectorStore pgVectorVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel dashscopeEmbeddingModel) { - VectorStore vectorStore = PgVectorStore.builder(jdbcTemplate, dashscopeEmbeddingModel) + return PgVectorStore.builder(jdbcTemplate, dashscopeEmbeddingModel) .dimensions(1536) // Optional: defaults to model dimensions or 1536 .distanceType(COSINE_DISTANCE) // Optional: defaults to COSINE_DISTANCE .indexType(HNSW) // Optional: defaults to HNSW @@ -33,6 +33,5 @@ public class PgVectorStoreConfig { .vectorTableName("vector_store") // Optional: defaults to "vector_store" .maxDocumentBatchSize(10000) // Optional: defaults to 10000 .build(); - return vectorStore; } } diff --git a/src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java b/src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java new file mode 100644 index 0000000..160a01a --- /dev/null +++ b/src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java @@ -0,0 +1,46 @@ +package com.wok.supportbot.preretrieval; + +import org.springframework.ai.chat.client.ChatClient; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.stereotype.Component; + +import java.util.List; + +/** + * 查询压缩器 - CompressionQueryTransformer + */ +@Component +public class CompressionQueryRewriter { + + private final QueryTransformer queryTransformer; + + public CompressionQueryRewriter(ChatModel dashscopeChatModel) { + ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); + queryTransformer = CompressionQueryTransformer.builder() + .chatClientBuilder(builder) + .build(); + } + + /** + * 执行查询压缩(带对话历史) + * + * @param prompt 当前查询文本 + * @return 压缩后的查询文本 + */ + public String doQueryRewrite(String prompt, List history) { + Query query = Query.builder() + .text(prompt) + .history(history) + .build(); + + Query transformed = queryTransformer.transform(query); + return transformed.text(); + } +} diff --git a/src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java b/src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java new file mode 100644 index 0000000..fcc86ba --- /dev/null +++ b/src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java @@ -0,0 +1,41 @@ +package com.wok.supportbot.preretrieval; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * 多查询扩展器 - MultiQueryExpander + */ +@Component +public class MultiQueryExpanderRewriter { + + private final MultiQueryExpander queryExpander; + + public MultiQueryExpanderRewriter(ChatModel dashscopeChatModel) { + ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); + queryExpander = MultiQueryExpander.builder() + .chatClientBuilder(builder) + .numberOfQueries(3) + .includeOriginal(true) //在扩展查询列表中包含原始查询 + .build(); + } + + /** + * 执行查询扩展,返回多个查询文本 + * + * @param prompt 原始查询 + * @return 多个语义不同的查询文本列表 + */ + public List doQueryRewrite(String prompt) { + List queries = queryExpander.expand(new Query(prompt)); + return queries.stream() + .map(Query::text) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java b/src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java new file mode 100644 index 0000000..66f4de1 --- /dev/null +++ b/src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java @@ -0,0 +1,39 @@ +package com.wok.supportbot.preretrieval; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; +import org.springframework.stereotype.Component; + +/** + * 查询重写器 - RewriteQueryTransformer + */ +@Component +public class RewriteQueryRewriter { + + private final QueryTransformer queryTransformer; + + public RewriteQueryRewriter(ChatModel dashscopeChatModel) { + ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); + // 创建查询重写转换器 + queryTransformer = RewriteQueryTransformer.builder() + .chatClientBuilder(builder) + .build(); + } + + /** + * 执行查询重写 + * + * @param prompt + * @return + */ + public String doQueryRewrite(String prompt) { + Query query = new Query(prompt); + // 执行查询重写 + Query transformedQuery = queryTransformer.transform(query); + // 输出重写后的查询 + return transformedQuery.text(); + } +} diff --git a/src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java b/src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java new file mode 100644 index 0000000..84df263 --- /dev/null +++ b/src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java @@ -0,0 +1,37 @@ +package com.wok.supportbot.preretrieval; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; +import org.springframework.stereotype.Component; + +/** + * 查询翻译器 - TranslationQueryTransformer + */ +@Component +public class TranslationQueryRewriter { + + private final QueryTransformer queryTransformer; + + public TranslationQueryRewriter(ChatModel dashscopeChatModel) { + ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); + queryTransformer = TranslationQueryTransformer.builder() + .chatClientBuilder(builder) + .targetLanguage("chinese") + .build(); + } + + /** + * 执行查询翻译 + * + * @param prompt 原始查询文本 + * @return 翻译后的查询文本 + */ + public String doQueryRewrite(String prompt) { + Query query = new Query(prompt); + Query transformedQuery = queryTransformer.transform(query); + return transformedQuery.text(); + } +} diff --git a/src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java b/src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java index 4cbbb2a..5038fdc 100644 --- a/src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java +++ b/src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java @@ -1,5 +1,6 @@ package com.wok.supportbot; +import com.wok.supportbot.preretrieval.RewriteQueryRewriter; import jakarta.annotation.Resource; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; diff --git a/src/test/java/com/wok/supportbot/QueryTransformerTests.java b/src/test/java/com/wok/supportbot/QueryTransformerTests.java new file mode 100644 index 0000000..e40eb6f --- /dev/null +++ b/src/test/java/com/wok/supportbot/QueryTransformerTests.java @@ -0,0 +1,76 @@ +package com.wok.supportbot; + +import com.wok.supportbot.preretrieval.CompressionQueryRewriter; +import com.wok.supportbot.preretrieval.MultiQueryExpanderRewriter; +import com.wok.supportbot.preretrieval.RewriteQueryRewriter; +import com.wok.supportbot.preretrieval.TranslationQueryRewriter; +import jakarta.annotation.Resource; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.rag.Query; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.List; + +@SpringBootTest +public class QueryTransformerTests { + + @Autowired + private TranslationQueryRewriter translationQueryRewriter; + + @Autowired + private CompressionQueryRewriter compressionQueryRewriter; + + @Autowired + private MultiQueryExpanderRewriter multiQueryExpanderRewriter; + + @Autowired + private RewriteQueryRewriter rewriteQueryRewriter; + + @Test + void testRewriteQueryRewriter() { + // 构造输入 + String originalQuery = "我想买一部拍照效果好的手机"; + // 执行 + String rewritten = rewriteQueryRewriter.doQueryRewrite(originalQuery); + // 输出结果 + System.out.println("重写结果:" + rewritten); + } + + @Test + public void testTranslationQueryRewriter() { + String prompt = "I want to buy a lightweight laptop suitable for students."; + String result = translationQueryRewriter.doQueryRewrite(prompt); + System.out.println("Translation result: " + result); + } + + @Test + public void testCompressionQueryRewriter() { + // 当前追问,用户说得很模糊 + String prompt = "那这款的电池续航如何?"; + + // 多轮上下文,用户逐渐缩小目标 + List history = List.of( + new UserMessage("我想买一台适合出差用的轻薄笔记本"), + new AssistantMessage("你可以看看戴尔 XPS 13,性能不错而且轻便"), + new UserMessage("能不能推荐一款支持长续航的?"), + new AssistantMessage("荣耀 MagicBook X16 电池续航表现优秀,适合长时间外出使用") + ); + // 执行压缩 + String result = compressionQueryRewriter.doQueryRewrite(prompt, history); + // 输出压缩后的独立查询 + System.out.println("Compression result: " + result); + } + + + @Test + public void testMultiQueryExpanderRewriter() { + String prompt = "推荐一些适合夏天穿的男士T恤"; + List expandedQueries = multiQueryExpanderRewriter.doQueryRewrite(prompt); + System.out.println("Expanded queries:"); + expandedQueries.forEach(System.out::println); + } +} diff --git a/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java b/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java index 22ea0cf..87bae76 100644 --- a/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java +++ b/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java @@ -76,5 +76,13 @@ class SupportBotApplicationTests { Assertions.assertNotNull(answer); } + @Test + void doChatWithRagEnhance() { + String chatId = "1069b88d-eb85-47ac-bd2e-c393d118a5aa"; + String message = "我之前询问了你什么问题?"; + String answer = assistantApp.doChatWithRagEnhance(message, chatId); + Assertions.assertNotNull(answer); + } + }