11 changed files with 371 additions and 6 deletions
-
63src/main/java/com/wok/supportbot/app/AssistantApp.java
-
20src/main/java/com/wok/supportbot/config/QueryExpanderConfig.java
-
41src/main/java/com/wok/supportbot/config/QueryTransformerConfig.java
-
5src/main/java/com/wok/supportbot/load/PgVectorStoreConfig.java
-
46src/main/java/com/wok/supportbot/preretrieval/CompressionQueryRewriter.java
-
41src/main/java/com/wok/supportbot/preretrieval/MultiQueryExpanderRewriter.java
-
39src/main/java/com/wok/supportbot/preretrieval/RewriteQueryRewriter.java
-
37src/main/java/com/wok/supportbot/preretrieval/TranslationQueryRewriter.java
-
1src/test/java/com/wok/supportbot/PgVectorVectorStoreConfigTest.java
-
76src/test/java/com/wok/supportbot/QueryTransformerTests.java
-
8src/test/java/com/wok/supportbot/SupportBotApplicationTests.java
@ -0,0 +1,20 @@ |
|||||
|
package com.wok.supportbot.config; |
||||
|
|
||||
|
import org.springframework.ai.chat.client.ChatClient; |
||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||
|
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; |
||||
|
import org.springframework.context.annotation.Bean; |
||||
|
import org.springframework.context.annotation.Configuration; |
||||
|
|
||||
|
@Configuration |
||||
|
public class QueryExpanderConfig { |
||||
|
|
||||
|
@Bean |
||||
|
public MultiQueryExpander multiQueryExpander(ChatModel dashscopeChatModel) { |
||||
|
return MultiQueryExpander.builder() |
||||
|
.chatClientBuilder(ChatClient.builder(dashscopeChatModel)) |
||||
|
.numberOfQueries(3) |
||||
|
.includeOriginal(true) |
||||
|
.build(); |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,41 @@ |
|||||
|
package com.wok.supportbot.config; |
||||
|
|
||||
|
import org.springframework.ai.chat.client.ChatClient; |
||||
|
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; |
||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||
|
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; |
||||
|
import org.springframework.ai.rag.retrieval.search.DocumentRetriever; |
||||
|
import org.springframework.context.annotation.Bean; |
||||
|
import org.springframework.context.annotation.Configuration; |
||||
|
|
||||
|
import java.util.List; |
||||
|
|
||||
|
@Configuration |
||||
|
public class QueryTransformerConfig { |
||||
|
|
||||
|
@Bean |
||||
|
public QueryTransformer rewriteQueryTransformer(ChatModel dashscopeChatModel) { |
||||
|
return RewriteQueryTransformer.builder() |
||||
|
.chatClientBuilder(ChatClient.builder(dashscopeChatModel)) |
||||
|
.build(); |
||||
|
} |
||||
|
|
||||
|
@Bean |
||||
|
public QueryTransformer translationQueryTransformer(ChatModel dashscopeChatModel) { |
||||
|
return TranslationQueryTransformer.builder() |
||||
|
.chatClientBuilder(ChatClient.builder(dashscopeChatModel)) |
||||
|
.targetLanguage("chinese") |
||||
|
.build(); |
||||
|
} |
||||
|
|
||||
|
@Bean |
||||
|
public QueryTransformer compressionQueryTransformer(ChatModel dashscopeChatModel) { |
||||
|
return CompressionQueryTransformer.builder() |
||||
|
.chatClientBuilder(ChatClient.builder(dashscopeChatModel)) |
||||
|
.build(); |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,46 @@ |
|||||
|
package com.wok.supportbot.preretrieval; |
||||
|
|
||||
|
import org.springframework.ai.chat.client.ChatClient; |
||||
|
|
||||
|
import org.springframework.ai.chat.messages.AssistantMessage; |
||||
|
import org.springframework.ai.chat.messages.Message; |
||||
|
import org.springframework.ai.chat.messages.UserMessage; |
||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||
|
import org.springframework.ai.rag.Query; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; |
||||
|
import org.springframework.stereotype.Component; |
||||
|
|
||||
|
import java.util.List; |
||||
|
|
||||
|
/** |
||||
|
* 查询压缩器 - CompressionQueryTransformer |
||||
|
*/ |
||||
|
@Component |
||||
|
public class CompressionQueryRewriter { |
||||
|
|
||||
|
private final QueryTransformer queryTransformer; |
||||
|
|
||||
|
public CompressionQueryRewriter(ChatModel dashscopeChatModel) { |
||||
|
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); |
||||
|
queryTransformer = CompressionQueryTransformer.builder() |
||||
|
.chatClientBuilder(builder) |
||||
|
.build(); |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 执行查询压缩(带对话历史) |
||||
|
* |
||||
|
* @param prompt 当前查询文本 |
||||
|
* @return 压缩后的查询文本 |
||||
|
*/ |
||||
|
public String doQueryRewrite(String prompt, List<Message> history) { |
||||
|
Query query = Query.builder() |
||||
|
.text(prompt) |
||||
|
.history(history) |
||||
|
.build(); |
||||
|
|
||||
|
Query transformed = queryTransformer.transform(query); |
||||
|
return transformed.text(); |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,41 @@ |
|||||
|
package com.wok.supportbot.preretrieval; |
||||
|
|
||||
|
import org.springframework.ai.chat.client.ChatClient; |
||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||
|
import org.springframework.ai.rag.Query; |
||||
|
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; |
||||
|
import org.springframework.stereotype.Component; |
||||
|
|
||||
|
import java.util.List; |
||||
|
import java.util.stream.Collectors; |
||||
|
|
||||
|
/** |
||||
|
* 多查询扩展器 - MultiQueryExpander |
||||
|
*/ |
||||
|
@Component |
||||
|
public class MultiQueryExpanderRewriter { |
||||
|
|
||||
|
private final MultiQueryExpander queryExpander; |
||||
|
|
||||
|
public MultiQueryExpanderRewriter(ChatModel dashscopeChatModel) { |
||||
|
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); |
||||
|
queryExpander = MultiQueryExpander.builder() |
||||
|
.chatClientBuilder(builder) |
||||
|
.numberOfQueries(3) |
||||
|
.includeOriginal(true) //在扩展查询列表中包含原始查询 |
||||
|
.build(); |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 执行查询扩展,返回多个查询文本 |
||||
|
* |
||||
|
* @param prompt 原始查询 |
||||
|
* @return 多个语义不同的查询文本列表 |
||||
|
*/ |
||||
|
public List<String> doQueryRewrite(String prompt) { |
||||
|
List<Query> queries = queryExpander.expand(new Query(prompt)); |
||||
|
return queries.stream() |
||||
|
.map(Query::text) |
||||
|
.collect(Collectors.toList()); |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,39 @@ |
|||||
|
package com.wok.supportbot.preretrieval; |
||||
|
|
||||
|
import org.springframework.ai.chat.client.ChatClient; |
||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||
|
import org.springframework.ai.rag.Query; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; |
||||
|
import org.springframework.stereotype.Component; |
||||
|
|
||||
|
/** |
||||
|
* 查询重写器 - RewriteQueryTransformer |
||||
|
*/ |
||||
|
@Component |
||||
|
public class RewriteQueryRewriter { |
||||
|
|
||||
|
private final QueryTransformer queryTransformer; |
||||
|
|
||||
|
public RewriteQueryRewriter(ChatModel dashscopeChatModel) { |
||||
|
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); |
||||
|
// 创建查询重写转换器 |
||||
|
queryTransformer = RewriteQueryTransformer.builder() |
||||
|
.chatClientBuilder(builder) |
||||
|
.build(); |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 执行查询重写 |
||||
|
* |
||||
|
* @param prompt |
||||
|
* @return |
||||
|
*/ |
||||
|
public String doQueryRewrite(String prompt) { |
||||
|
Query query = new Query(prompt); |
||||
|
// 执行查询重写 |
||||
|
Query transformedQuery = queryTransformer.transform(query); |
||||
|
// 输出重写后的查询 |
||||
|
return transformedQuery.text(); |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,37 @@ |
|||||
|
package com.wok.supportbot.preretrieval; |
||||
|
|
||||
|
import org.springframework.ai.chat.client.ChatClient; |
||||
|
import org.springframework.ai.chat.model.ChatModel; |
||||
|
import org.springframework.ai.rag.Query; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; |
||||
|
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; |
||||
|
import org.springframework.stereotype.Component; |
||||
|
|
||||
|
/** |
||||
|
* 查询翻译器 - TranslationQueryTransformer |
||||
|
*/ |
||||
|
@Component |
||||
|
public class TranslationQueryRewriter { |
||||
|
|
||||
|
private final QueryTransformer queryTransformer; |
||||
|
|
||||
|
public TranslationQueryRewriter(ChatModel dashscopeChatModel) { |
||||
|
ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel); |
||||
|
queryTransformer = TranslationQueryTransformer.builder() |
||||
|
.chatClientBuilder(builder) |
||||
|
.targetLanguage("chinese") |
||||
|
.build(); |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 执行查询翻译 |
||||
|
* |
||||
|
* @param prompt 原始查询文本 |
||||
|
* @return 翻译后的查询文本 |
||||
|
*/ |
||||
|
public String doQueryRewrite(String prompt) { |
||||
|
Query query = new Query(prompt); |
||||
|
Query transformedQuery = queryTransformer.transform(query); |
||||
|
return transformedQuery.text(); |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,76 @@ |
|||||
|
package com.wok.supportbot; |
||||
|
|
||||
|
import com.wok.supportbot.preretrieval.CompressionQueryRewriter; |
||||
|
import com.wok.supportbot.preretrieval.MultiQueryExpanderRewriter; |
||||
|
import com.wok.supportbot.preretrieval.RewriteQueryRewriter; |
||||
|
import com.wok.supportbot.preretrieval.TranslationQueryRewriter; |
||||
|
import jakarta.annotation.Resource; |
||||
|
import org.junit.jupiter.api.Test; |
||||
|
import org.springframework.ai.chat.messages.AssistantMessage; |
||||
|
import org.springframework.ai.chat.messages.Message; |
||||
|
import org.springframework.ai.chat.messages.UserMessage; |
||||
|
import org.springframework.ai.rag.Query; |
||||
|
import org.springframework.beans.factory.annotation.Autowired; |
||||
|
import org.springframework.boot.test.context.SpringBootTest; |
||||
|
|
||||
|
import java.util.List; |
||||
|
|
||||
|
@SpringBootTest |
||||
|
public class QueryTransformerTests { |
||||
|
|
||||
|
@Autowired |
||||
|
private TranslationQueryRewriter translationQueryRewriter; |
||||
|
|
||||
|
@Autowired |
||||
|
private CompressionQueryRewriter compressionQueryRewriter; |
||||
|
|
||||
|
@Autowired |
||||
|
private MultiQueryExpanderRewriter multiQueryExpanderRewriter; |
||||
|
|
||||
|
@Autowired |
||||
|
private RewriteQueryRewriter rewriteQueryRewriter; |
||||
|
|
||||
|
@Test |
||||
|
void testRewriteQueryRewriter() { |
||||
|
// 构造输入 |
||||
|
String originalQuery = "我想买一部拍照效果好的手机"; |
||||
|
// 执行 |
||||
|
String rewritten = rewriteQueryRewriter.doQueryRewrite(originalQuery); |
||||
|
// 输出结果 |
||||
|
System.out.println("重写结果:" + rewritten); |
||||
|
} |
||||
|
|
||||
|
@Test |
||||
|
public void testTranslationQueryRewriter() { |
||||
|
String prompt = "I want to buy a lightweight laptop suitable for students."; |
||||
|
String result = translationQueryRewriter.doQueryRewrite(prompt); |
||||
|
System.out.println("Translation result: " + result); |
||||
|
} |
||||
|
|
||||
|
@Test |
||||
|
public void testCompressionQueryRewriter() { |
||||
|
// 当前追问,用户说得很模糊 |
||||
|
String prompt = "那这款的电池续航如何?"; |
||||
|
|
||||
|
// 多轮上下文,用户逐渐缩小目标 |
||||
|
List<Message> history = List.of( |
||||
|
new UserMessage("我想买一台适合出差用的轻薄笔记本"), |
||||
|
new AssistantMessage("你可以看看戴尔 XPS 13,性能不错而且轻便"), |
||||
|
new UserMessage("能不能推荐一款支持长续航的?"), |
||||
|
new AssistantMessage("荣耀 MagicBook X16 电池续航表现优秀,适合长时间外出使用") |
||||
|
); |
||||
|
// 执行压缩 |
||||
|
String result = compressionQueryRewriter.doQueryRewrite(prompt, history); |
||||
|
// 输出压缩后的独立查询 |
||||
|
System.out.println("Compression result: " + result); |
||||
|
} |
||||
|
|
||||
|
|
||||
|
@Test |
||||
|
public void testMultiQueryExpanderRewriter() { |
||||
|
String prompt = "推荐一些适合夏天穿的男士T恤"; |
||||
|
List<String> expandedQueries = multiQueryExpanderRewriter.doQueryRewrite(prompt); |
||||
|
System.out.println("Expanded queries:"); |
||||
|
expandedQueries.forEach(System.out::println); |
||||
|
} |
||||
|
} |
||||
Write
Preview
Loading…
Cancel
Save
Reference in new issue