Browse Source

实现基于数据库的对话记忆

master
hygl 1 year ago
parent
commit
53132f301a
  1. 5
      pom.xml
  2. 16
      src/main/java/com/wok/supportbot/app/AssistantApp.java
  3. 60
      src/main/java/com/wok/supportbot/chatmemory/DatabaseChatMemory.java
  4. 44
      src/main/java/com/wok/supportbot/converter/MessageConverter.java
  5. 9
      src/main/java/com/wok/supportbot/dao/ChatMessageMapper.java
  6. 73
      src/main/java/com/wok/supportbot/entity/ChatMessage.java
  7. 25
      src/main/java/com/wok/supportbot/handler/MyMetaObjectHandler.java
  8. 61
      src/main/java/com/wok/supportbot/handler/PostgresJsonTypeHandler.java
  9. 11
      src/main/java/com/wok/supportbot/repository/ChatMessageRepository.java
  10. 4
      src/test/java/com/wok/supportbot/SupportBotApplicationTests.java

5
pom.xml

@ -101,6 +101,11 @@
<artifactId>spring-ai-starter-vector-store-pgvector</artifactId> <artifactId>spring-ai-starter-vector-store-pgvector</artifactId>
<version>1.0.0-M7</version> <version>1.0.0-M7</version>
</dependency>--> </dependency>-->
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-spring-boot3-starter</artifactId>
<version>3.5.12</version>
</dependency>
</dependencies> </dependencies>
<build> <build>

16
src/main/java/com/wok/supportbot/app/AssistantApp.java

@ -1,6 +1,7 @@
package com.wok.supportbot.app; package com.wok.supportbot.app;
import com.wok.supportbot.advisor.MyLoggerAdvisor; import com.wok.supportbot.advisor.MyLoggerAdvisor;
import com.wok.supportbot.chatmemory.DatabaseChatMemory;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
@ -39,12 +40,13 @@ public class AssistantApp {
* 初始化 ChatClient * 初始化 ChatClient
* @param dashscopeChatModel * @param dashscopeChatModel
*/ */
public AssistantApp(ChatModel dashscopeChatModel) {
public AssistantApp(ChatModel dashscopeChatModel, DatabaseChatMemory chatMemory) {
// 初始化基于文件的对话记忆 // 初始化基于文件的对话记忆
//String fileDir = System.getProperty("user.dir") + "/tmp/chat-memory"; //String fileDir = System.getProperty("user.dir") + "/tmp/chat-memory";
//ChatMemory chatMemory = new FileBasedChatMemory(fileDir); //ChatMemory chatMemory = new FileBasedChatMemory(fileDir);
// 初始化基于内存的对话记忆 // 初始化基于内存的对话记忆
ChatMemory chatMemory = new InMemoryChatMemory();
// ChatMemory chatMemory = new InMemoryChatMemory();
chatClient = ChatClient.builder(dashscopeChatModel) chatClient = ChatClient.builder(dashscopeChatModel)
.defaultSystem(SYSTEM_PROMPT) .defaultSystem(SYSTEM_PROMPT)
.defaultAdvisors( .defaultAdvisors(
@ -71,9 +73,7 @@ public class AssistantApp {
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
.call() .call()
.chatResponse(); .chatResponse();
String content = chatResponse.getResult().getOutput().getText();
//log.info("content: {}", content);
return content;
return chatResponse.getResult().getOutput().getText();
} }
@ -89,14 +89,10 @@ public class AssistantApp {
.user(message) .user(message)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId) .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
// 开启日志便于观察效果
.advisors(new MyLoggerAdvisor())
// 应用 RAG 知识库问答 // 应用 RAG 知识库问答
.advisors(new QuestionAnswerAdvisor(vectorStore)) .advisors(new QuestionAnswerAdvisor(vectorStore))
.call() .call()
.chatResponse(); .chatResponse();
String content = chatResponse.getResult().getOutput().getText();
log.info("content: {}", content);
return content;
return chatResponse.getResult().getOutput().getText();
} }
} }

60
src/main/java/com/wok/supportbot/chatmemory/DatabaseChatMemory.java

@ -0,0 +1,60 @@
package com.wok.supportbot.chatmemory;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.wok.supportbot.converter.MessageConverter;
import com.wok.supportbot.entity.ChatMessage;
import com.wok.supportbot.repository.ChatMessageRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Component
@RequiredArgsConstructor
public class DatabaseChatMemory implements ChatMemory {
@Autowired
private final ChatMessageRepository chatMessageRepository;
@Override
public void add(String conversationId, List<Message> messages) {
List<ChatMessage> chatMessages = messages.stream()
.map(message -> MessageConverter.toChatMessage(message, conversationId))
.collect(Collectors.toList());
chatMessageRepository.saveBatch(chatMessages, chatMessages.size());
}
@Override
public List<Message> get(String conversationId, int lastN) {
LambdaQueryWrapper<ChatMessage> queryWrapper = new LambdaQueryWrapper<>();
// 查询最近的 lastN 条消息
queryWrapper.eq(ChatMessage::getConversationId, conversationId)
.orderByDesc(ChatMessage::getCreateTime)
.last(lastN > 0, "LIMIT " + lastN);
List<ChatMessage> chatMessages = chatMessageRepository.list(queryWrapper);
// 按照时间顺序返回
if (!chatMessages.isEmpty()) {
Collections.reverse(chatMessages);
}
return chatMessages
.stream()
.map(MessageConverter::toMessage)
.collect(Collectors.toList());
}
@Override
public void clear(String conversationId) {
LambdaQueryWrapper<ChatMessage> queryWrapper = new LambdaQueryWrapper<>();
queryWrapper.eq(ChatMessage::getConversationId, conversationId);
chatMessageRepository.remove(queryWrapper);
}
}

44
src/main/java/com/wok/supportbot/converter/MessageConverter.java

@ -0,0 +1,44 @@
package com.wok.supportbot.converter;
import com.wok.supportbot.entity.ChatMessage;
import org.springframework.ai.chat.messages.*;
import java.util.List;
import java.util.Map;
/**
* @Classname MessageConverter
* @Description
* @Version 1.0.0
* @Date 2025/06/28 13:30
* @Author lyx
*/
public class MessageConverter {
/**
* Message 转换为 ChatMessage
*/
public static ChatMessage toChatMessage(Message message, String conversationId) {
return ChatMessage.builder()
.conversationId(conversationId)
.messageType(message.getMessageType())
.content(message.getText())
.metadata(message.getMetadata())
.build();
}
/**
* ChatMessage 转换为 Message
*/
public static Message toMessage(ChatMessage chatMessage) {
MessageType messageType = chatMessage.getMessageType();
String text = chatMessage.getContent();
Map<String, Object> metadata = chatMessage.getMetadata();
return switch (messageType) {
case USER -> new UserMessage(text);
case ASSISTANT -> new AssistantMessage(text, metadata);
case SYSTEM -> new SystemMessage(text);
case TOOL -> new ToolResponseMessage(List.of(), metadata);
};
}
}

9
src/main/java/com/wok/supportbot/dao/ChatMessageMapper.java

@ -0,0 +1,9 @@
package com.wok.supportbot.dao;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.wok.supportbot.entity.ChatMessage;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatMessageMapper extends BaseMapper<ChatMessage> {
}

73
src/main/java/com/wok/supportbot/entity/ChatMessage.java

@ -0,0 +1,73 @@
package com.wok.supportbot.entity;
import com.baomidou.mybatisplus.annotation.*;
import com.wok.supportbot.handler.PostgresJsonTypeHandler;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.chat.messages.MessageType;
import java.io.Serial;
import java.io.Serializable;
import java.util.Date;
import java.util.Map;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
@TableName(value = "chat_message", autoResultMap = true)
public class ChatMessage implements Serializable {
@Serial
@TableField(exist = false)
private static final long serialVersionUID = 1L;
@TableId(value = "id", type = IdType.ASSIGN_ID)
private Long id;
/**
* 会话ID
*/
@TableField("conversation_id")
private String conversationId;
/**
* 消息类型
*/
@TableField("message_type")
private MessageType messageType;
/**
* 消息内容
*/
@TableField("content")
private String content;
/**
* 元数据
*/
@TableField(value = "metadata", typeHandler = PostgresJsonTypeHandler.class)
private Map<String, Object> metadata;
/**
* 创建时间
*/
@TableField(value = "create_time", fill = FieldFill.INSERT)
private Date createTime;
/**
* 更新时间
*/
@Version
@TableField(value = "update_time", fill = FieldFill.INSERT_UPDATE)
private Date updateTime;
/**
* 是否删除 false-未删除 true-已删除
*/
@TableField("is_delete")
@TableLogic
private boolean isDelete;
}

25
src/main/java/com/wok/supportbot/handler/MyMetaObjectHandler.java

@ -0,0 +1,25 @@
package com.wok.supportbot.handler;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import org.apache.ibatis.reflection.MetaObject;
import org.springframework.stereotype.Component;
import java.util.Date;
/**
* 在PostgreSQL中@TableField(fill = FieldFill.INSERT) @TableField(fill = FieldFill.INSERT_UPDATE) 这样的注解本身并不能直接触发数据库级别的自动填充
* 这些注解是MyBatis-Plus框架的一部分它们需要配合 MetaObjectHandler 实现类才能工作这种机制是在Java应用层面实现的而非数据库层面
*/
@Component
public class MyMetaObjectHandler implements MetaObjectHandler {
@Override
public void insertFill(MetaObject metaObject) {
this.strictInsertFill(metaObject, "createTime", Date.class, new Date());
this.strictInsertFill(metaObject, "updateTime", Date.class, new Date());
}
@Override
public void updateFill(MetaObject metaObject) {
this.strictUpdateFill(metaObject, "updateTime", Date.class, new Date());
}
}

61
src/main/java/com/wok/supportbot/handler/PostgresJsonTypeHandler.java

@ -0,0 +1,61 @@
package com.wok.supportbot.handler;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;
import java.sql.*;
import java.util.HashMap;
import java.util.Map;
/**
* PostgreSQL使用JSONB类型存储JSON数据需要创建自定义类型处理器
*/
@MappedJdbcTypes(JdbcType.OTHER)
@MappedTypes({Map.class})
public class PostgresJsonTypeHandler extends BaseTypeHandler<Map<String, Object>> {
private static final ObjectMapper objectMapper = new ObjectMapper();
@Override
public void setNonNullParameter(PreparedStatement ps, int i, Map<String, Object> parameter, JdbcType jdbcType)
throws SQLException {
try {
ps.setObject(i, objectMapper.writeValueAsString(parameter), Types.OTHER);
} catch (JsonProcessingException e) {
throw new SQLException("Error converting Map to JSON", e);
}
}
@Override
public Map<String, Object> getNullableResult(ResultSet rs, String columnName) throws SQLException {
String json = rs.getString(columnName);
return parseJson(json);
}
@Override
public Map<String, Object> getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
String json = rs.getString(columnIndex);
return parseJson(json);
}
@Override
public Map<String, Object> getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
String json = cs.getString(columnIndex);
return parseJson(json);
}
private Map<String, Object> parseJson(String json) throws SQLException {
try {
if (json == null) {
return new HashMap<>();
}
return objectMapper.readValue(json, new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException e) {
throw new SQLException("Error parsing JSON to Map", e);
}
}
}

11
src/main/java/com/wok/supportbot/repository/ChatMessageRepository.java

@ -0,0 +1,11 @@
package com.wok.supportbot.repository;
import com.baomidou.mybatisplus.extension.repository.CrudRepository;
import com.wok.supportbot.dao.ChatMessageMapper;
import com.wok.supportbot.entity.ChatMessage;
import org.springframework.stereotype.Component;
@Component
public class ChatMessageRepository extends CrudRepository<ChatMessageMapper, ChatMessage> {
}

4
src/test/java/com/wok/supportbot/SupportBotApplicationTests.java

@ -70,8 +70,8 @@ class SupportBotApplicationTests {
@Test @Test
void doChatWithRag() { void doChatWithRag() {
String chatId = UUID.randomUUID().toString();
String message = "T恤怎么搭配?";
String chatId = "1069b88d-eb85-47ac-bd2e-c393d118a5aa";
String message = "我之前询问了你什么问题?";
String answer = assistantApp.doChatWithRag(message, chatId); String answer = assistantApp.doChatWithRag(message, chatId);
Assertions.assertNotNull(answer); Assertions.assertNotNull(answer);
} }

Loading…
Cancel
Save