Browse Source

新增查询重写器、翻译器、压缩器和多查询扩展器

master
hygl 1 year ago
parent
commit
936bcbdc70
  1. 63
      src/main/java/com/wok/supportbot/app/AssistantApp.java
  2. 20
      src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java
  3. 41
      src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java
  4. 5
      src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java
  5. 46
      src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java
  6. 41
      src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java
  7. 39
      src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java
  8. 37
      src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java
  9. 1
      src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java
  10. 76
      src/test/java/com/wok/supportbot/QueryTransformerTests.java
  11. 8
      src/test/java/com/wok/supportbot/SupportBotApplicationTests.java

63
src/main/java/com/wok/supportbot/app/AssistantApp.java

@ -1,19 +1,31 @@
package com.wok.supportbot.app;
import com.wok.supportbot.advisor.MyLoggerAdvisor;
import com.wok.supportbot.advisor.ReReadingAdvisor;
import com.wok.supportbot.chatmemory.DatabaseChatMemory;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY;
@ -38,6 +50,7 @@ public class AssistantApp {
/**
* 初始化 ChatClient
*
* @param dashscopeChatModel
*/
public AssistantApp(ChatModel dashscopeChatModel, DatabaseChatMemory chatMemory) {
@ -61,6 +74,7 @@ public class AssistantApp {
/**
* AI 基础对话支持多轮对话记忆
*
* @param message
* @param chatId
* @return
@ -79,6 +93,7 @@ public class AssistantApp {
/**
* RAG 知识库进行对话
*
* @param message
* @param chatId
* @return
@ -90,7 +105,49 @@ public class AssistantApp {
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
// 应用 RAG 知识库问答
.advisors(new QuestionAnswerAdvisor(vectorStore))
.advisors(QuestionAnswerAdvisor.builder(vectorStore)
// 相似度阈值为 0.0并返回最相关的前 4 个结果
.searchRequest(SearchRequest.builder().similarityThreshold(0.0).topK(4).build())
.build())
.call()
.chatResponse();
return chatResponse.getResult().getOutput().getText();
}
@Autowired
private List<QueryTransformer> queryTransformers;
@Autowired
private MultiQueryExpander multiQueryExpander;
/**
* RAG 知识库进行对话
*
* @param message
* @param chatId
* @return
*/
public String doChatWithRagEnhance(String message, String chatId) {
Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
//.queryTransformers(queryTransformers)
//.queryExpander(multiQueryExpander)
.documentRetriever(VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.similarityThreshold(0.5)
.topK(4)
.build())
.queryAugmenter(ContextualQueryAugmenter.builder()
.allowEmptyContext(false) // 不允许模型在没有找到相关文档的情况下也生成回答
.build())
.build();
ChatResponse chatResponse = chatClient
.prompt()
.user(message)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
// 应用 RAG 知识库问答
.advisors(retrievalAugmentationAdvisor)
.call()
.chatResponse();
return chatResponse.getResult().getOutput().getText();

20
src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java

@ -0,0 +1,20 @@
package com.wok.supportbot.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class QueryExpanderConfig {
@Bean
public MultiQueryExpander multiQueryExpander(ChatModel dashscopeChatModel) {
return MultiQueryExpander.builder()
.chatClientBuilder(ChatClient.builder(dashscopeChatModel))
.numberOfQueries(3)
.includeOriginal(true)
.build();
}
}

41
src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java

@ -0,0 +1,41 @@
package com.wok.supportbot.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
@Configuration
public class QueryTransformerConfig {
@Bean
public QueryTransformer rewriteQueryTransformer(ChatModel dashscopeChatModel) {
return RewriteQueryTransformer.builder()
.chatClientBuilder(ChatClient.builder(dashscopeChatModel))
.build();
}
@Bean
public QueryTransformer translationQueryTransformer(ChatModel dashscopeChatModel) {
return TranslationQueryTransformer.builder()
.chatClientBuilder(ChatClient.builder(dashscopeChatModel))
.targetLanguage("chinese")
.build();
}
@Bean
public QueryTransformer compressionQueryTransformer(ChatModel dashscopeChatModel) {
return CompressionQueryTransformer.builder()
.chatClientBuilder(ChatClient.builder(dashscopeChatModel))
.build();
}
}

5
src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java

@ -22,9 +22,9 @@ import static org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexT
public class PgVectorStoreConfig {
@Bean
@Primary
@Primary // 默认使用pgsql储存向量
public VectorStore pgVectorVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel dashscopeEmbeddingModel) {
VectorStore vectorStore = PgVectorStore.builder(jdbcTemplate, dashscopeEmbeddingModel)
return PgVectorStore.builder(jdbcTemplate, dashscopeEmbeddingModel)
.dimensions(1536) // Optional: defaults to model dimensions or 1536
.distanceType(COSINE_DISTANCE) // Optional: defaults to COSINE_DISTANCE
.indexType(HNSW) // Optional: defaults to HNSW
@ -33,6 +33,5 @@ public class PgVectorStoreConfig {
.vectorTableName("vector_store") // Optional: defaults to "vector_store"
.maxDocumentBatchSize(10000) // Optional: defaults to 10000
.build();
return vectorStore;
}
}

46
src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java

@ -0,0 +1,46 @@
package com.wok.supportbot.preretrieval;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.stereotype.Component;
import java.util.List;
/**
* 查询压缩器 - CompressionQueryTransformer
*/
@Component
public class CompressionQueryRewriter {
private final QueryTransformer queryTransformer;
public CompressionQueryRewriter(ChatModel dashscopeChatModel) {
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel);
queryTransformer = CompressionQueryTransformer.builder()
.chatClientBuilder(builder)
.build();
}
/**
* 执行查询压缩带对话历史
*
* @param prompt 当前查询文本
* @return 压缩后的查询文本
*/
public String doQueryRewrite(String prompt, List<Message> history) {
Query query = Query.builder()
.text(prompt)
.history(history)
.build();
Query transformed = queryTransformer.transform(query);
return transformed.text();
}
}

41
src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java

@ -0,0 +1,41 @@
package com.wok.supportbot.preretrieval;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.stream.Collectors;
/**
* 多查询扩展器 - MultiQueryExpander
*/
@Component
public class MultiQueryExpanderRewriter {
private final MultiQueryExpander queryExpander;
public MultiQueryExpanderRewriter(ChatModel dashscopeChatModel) {
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel);
queryExpander = MultiQueryExpander.builder()
.chatClientBuilder(builder)
.numberOfQueries(3)
.includeOriginal(true) //在扩展查询列表中包含原始查询
.build();
}
/**
* 执行查询扩展返回多个查询文本
*
* @param prompt 原始查询
* @return 多个语义不同的查询文本列表
*/
public List<String> doQueryRewrite(String prompt) {
List<Query> queries = queryExpander.expand(new Query(prompt));
return queries.stream()
.map(Query::text)
.collect(Collectors.toList());
}
}

39
src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java

@ -0,0 +1,39 @@
package com.wok.supportbot.preretrieval;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.stereotype.Component;
/**
* 查询重写器 - RewriteQueryTransformer
*/
@Component
public class RewriteQueryRewriter {
private final QueryTransformer queryTransformer;
public RewriteQueryRewriter(ChatModel dashscopeChatModel) {
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel);
// 创建查询重写转换器
queryTransformer = RewriteQueryTransformer.builder()
.chatClientBuilder(builder)
.build();
}
/**
* 执行查询重写
*
* @param prompt
* @return
*/
public String doQueryRewrite(String prompt) {
Query query = new Query(prompt);
// 执行查询重写
Query transformedQuery = queryTransformer.transform(query);
// 输出重写后的查询
return transformedQuery.text();
}
}

37
src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java

@ -0,0 +1,37 @@
package com.wok.supportbot.preretrieval;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.stereotype.Component;
/**
* 查询翻译器 - TranslationQueryTransformer
*/
@Component
public class TranslationQueryRewriter {
private final QueryTransformer queryTransformer;
public TranslationQueryRewriter(ChatModel dashscopeChatModel) {
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel);
queryTransformer = TranslationQueryTransformer.builder()
.chatClientBuilder(builder)
.targetLanguage("chinese")
.build();
}
/**
* 执行查询翻译
*
* @param prompt 原始查询文本
* @return 翻译后的查询文本
*/
public String doQueryRewrite(String prompt) {
Query query = new Query(prompt);
Query transformedQuery = queryTransformer.transform(query);
return transformedQuery.text();
}
}

1
src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java

@ -1,5 +1,6 @@
package com.wok.supportbot;
import com.wok.supportbot.preretrieval.RewriteQueryRewriter;
import jakarta.annotation.Resource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

76
src/test/java/com/wok/supportbot/QueryTransformerTests.java

@ -0,0 +1,76 @@
package com.wok.supportbot;
import com.wok.supportbot.preretrieval.CompressionQueryRewriter;
import com.wok.supportbot.preretrieval.MultiQueryExpanderRewriter;
import com.wok.supportbot.preretrieval.RewriteQueryRewriter;
import com.wok.supportbot.preretrieval.TranslationQueryRewriter;
import jakarta.annotation.Resource;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.rag.Query;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import java.util.List;
@SpringBootTest
public class QueryTransformerTests {
@Autowired
private TranslationQueryRewriter translationQueryRewriter;
@Autowired
private CompressionQueryRewriter compressionQueryRewriter;
@Autowired
private MultiQueryExpanderRewriter multiQueryExpanderRewriter;
@Autowired
private RewriteQueryRewriter rewriteQueryRewriter;
@Test
void testRewriteQueryRewriter() {
// 构造输入
String originalQuery = "我想买一部拍照效果好的手机";
// 执行
String rewritten = rewriteQueryRewriter.doQueryRewrite(originalQuery);
// 输出结果
System.out.println("重写结果:" + rewritten);
}
@Test
public void testTranslationQueryRewriter() {
String prompt = "I want to buy a lightweight laptop suitable for students.";
String result = translationQueryRewriter.doQueryRewrite(prompt);
System.out.println("Translation result: " + result);
}
@Test
public void testCompressionQueryRewriter() {
// 当前追问用户说得很模糊
String prompt = "那这款的电池续航如何?";
// 多轮上下文用户逐渐缩小目标
List<Message> history = List.of(
new UserMessage("我想买一台适合出差用的轻薄笔记本"),
new AssistantMessage("你可以看看戴尔 XPS 13,性能不错而且轻便"),
new UserMessage("能不能推荐一款支持长续航的?"),
new AssistantMessage("荣耀 MagicBook X16 电池续航表现优秀,适合长时间外出使用")
);
// 执行压缩
String result = compressionQueryRewriter.doQueryRewrite(prompt, history);
// 输出压缩后的独立查询
System.out.println("Compression result: " + result);
}
@Test
public void testMultiQueryExpanderRewriter() {
String prompt = "推荐一些适合夏天穿的男士T恤";
List<String> expandedQueries = multiQueryExpanderRewriter.doQueryRewrite(prompt);
System.out.println("Expanded queries:");
expandedQueries.forEach(System.out::println);
}
}

8
src/test/java/com/wok/supportbot/SupportBotApplicationTests.java

@ -76,5 +76,13 @@ class SupportBotApplicationTests {
Assertions.assertNotNull(answer);
}
@Test
void doChatWithRagEnhance() {
String chatId = "1069b88d-eb85-47ac-bd2e-c393d118a5aa";
String message = "我之前询问了你什么问题?";
String answer = assistantApp.doChatWithRagEnhance(message, chatId);
Assertions.assertNotNull(answer);
}
}
Loading…
Cancel
Save