diff --git a/pom.xml b/pom.xml index 074a073..a5c01d4 100644 --- a/pom.xml +++ b/pom.xml @@ -101,6 +101,11 @@ spring-ai-starter-vector-store-pgvector 1.0.0-M7 --> + + com.baomidou + mybatis-plus-spring-boot3-starter + 3.5.12 + diff --git a/src/main/java/com/wok/supportbot/app/AssistantApp.java b/src/main/java/com/wok/supportbot/app/AssistantApp.java index c9f3824..cfd0467 100644 --- a/src/main/java/com/wok/supportbot/app/AssistantApp.java +++ b/src/main/java/com/wok/supportbot/app/AssistantApp.java @@ -1,6 +1,7 @@ package com.wok.supportbot.app; import com.wok.supportbot.advisor.MyLoggerAdvisor; +import com.wok.supportbot.chatmemory.DatabaseChatMemory; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; @@ -39,12 +40,13 @@ public class AssistantApp { * 初始化 ChatClient * @param dashscopeChatModel */ - public AssistantApp(ChatModel dashscopeChatModel) { + public AssistantApp(ChatModel dashscopeChatModel, DatabaseChatMemory chatMemory) { // 初始化基于文件的对话记忆 //String fileDir = System.getProperty("user.dir") + "/tmp/chat-memory"; //ChatMemory chatMemory = new FileBasedChatMemory(fileDir); // 初始化基于内存的对话记忆 - ChatMemory chatMemory = new InMemoryChatMemory(); + // ChatMemory chatMemory = new InMemoryChatMemory(); + chatClient = ChatClient.builder(dashscopeChatModel) .defaultSystem(SYSTEM_PROMPT) .defaultAdvisors( @@ -71,9 +73,7 @@ public class AssistantApp { .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) .call() .chatResponse(); - String content = chatResponse.getResult().getOutput().getText(); - //log.info("content: {}", content); - return content; + return chatResponse.getResult().getOutput().getText(); } @@ -89,14 +89,10 @@ public class AssistantApp { .user(message) .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId) .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)) - // 开启日志,便于观察效果 - .advisors(new MyLoggerAdvisor()) // 应用 RAG 知识库问答 .advisors(new QuestionAnswerAdvisor(vectorStore)) .call() .chatResponse(); - String content = chatResponse.getResult().getOutput().getText(); - log.info("content: {}", content); - return content; + return chatResponse.getResult().getOutput().getText(); } } diff --git a/src/main/java/com/wok/supportbot/chatmemory/DatabaseChatMemory.java b/src/main/java/com/wok/supportbot/chatmemory/DatabaseChatMemory.java new file mode 100644 index 0000000..145c334 --- /dev/null +++ b/src/main/java/com/wok/supportbot/chatmemory/DatabaseChatMemory.java @@ -0,0 +1,60 @@ +package com.wok.supportbot.chatmemory; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.wok.supportbot.converter.MessageConverter; +import com.wok.supportbot.entity.ChatMessage; +import com.wok.supportbot.repository.ChatMessageRepository; +import lombok.RequiredArgsConstructor; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +@Component +@RequiredArgsConstructor +public class DatabaseChatMemory implements ChatMemory { + + @Autowired + private final ChatMessageRepository chatMessageRepository; + + @Override + public void add(String conversationId, List messages) { + List chatMessages = messages.stream() + .map(message -> MessageConverter.toChatMessage(message, conversationId)) + .collect(Collectors.toList()); + + chatMessageRepository.saveBatch(chatMessages, chatMessages.size()); + } + + @Override + public List get(String conversationId, int lastN) { + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); + // 查询最近的 lastN 条消息 + queryWrapper.eq(ChatMessage::getConversationId, conversationId) + .orderByDesc(ChatMessage::getCreateTime) + .last(lastN > 0, "LIMIT " + lastN); + + List chatMessages = chatMessageRepository.list(queryWrapper); + + // 按照时间顺序返回 + if (!chatMessages.isEmpty()) { + Collections.reverse(chatMessages); + } + + return chatMessages + .stream() + .map(MessageConverter::toMessage) + .collect(Collectors.toList()); + } + + @Override + public void clear(String conversationId) { + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); + queryWrapper.eq(ChatMessage::getConversationId, conversationId); + chatMessageRepository.remove(queryWrapper); + } +} diff --git a/src/main/java/com/wok/supportbot/converter/MessageConverter.java b/src/main/java/com/wok/supportbot/converter/MessageConverter.java new file mode 100644 index 0000000..84e3a22 --- /dev/null +++ b/src/main/java/com/wok/supportbot/converter/MessageConverter.java @@ -0,0 +1,44 @@ +package com.wok.supportbot.converter; + +import com.wok.supportbot.entity.ChatMessage; +import org.springframework.ai.chat.messages.*; + +import java.util.List; +import java.util.Map; + +/** + * @Classname MessageConverter + * @Description + * @Version 1.0.0 + * @Date 2025/06/28 13:30 + * @Author lyx + */ +public class MessageConverter { + + /** + * 将 Message 转换为 ChatMessage + */ + public static ChatMessage toChatMessage(Message message, String conversationId) { + return ChatMessage.builder() + .conversationId(conversationId) + .messageType(message.getMessageType()) + .content(message.getText()) + .metadata(message.getMetadata()) + .build(); + } + + /** + * 将 ChatMessage 转换为 Message + */ + public static Message toMessage(ChatMessage chatMessage) { + MessageType messageType = chatMessage.getMessageType(); + String text = chatMessage.getContent(); + Map metadata = chatMessage.getMetadata(); + return switch (messageType) { + case USER -> new UserMessage(text); + case ASSISTANT -> new AssistantMessage(text, metadata); + case SYSTEM -> new SystemMessage(text); + case TOOL -> new ToolResponseMessage(List.of(), metadata); + }; + } +} diff --git a/src/main/java/com/wok/supportbot/dao/ChatMessageMapper.java b/src/main/java/com/wok/supportbot/dao/ChatMessageMapper.java new file mode 100644 index 0000000..64e8431 --- /dev/null +++ b/src/main/java/com/wok/supportbot/dao/ChatMessageMapper.java @@ -0,0 +1,9 @@ +package com.wok.supportbot.dao; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.wok.supportbot.entity.ChatMessage; +import org.apache.ibatis.annotations.Mapper; + +@Mapper +public interface ChatMessageMapper extends BaseMapper { +} diff --git a/src/main/java/com/wok/supportbot/entity/ChatMessage.java b/src/main/java/com/wok/supportbot/entity/ChatMessage.java new file mode 100644 index 0000000..37ad2e0 --- /dev/null +++ b/src/main/java/com/wok/supportbot/entity/ChatMessage.java @@ -0,0 +1,73 @@ +package com.wok.supportbot.entity; + +import com.baomidou.mybatisplus.annotation.*; +import com.wok.supportbot.handler.PostgresJsonTypeHandler; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.ai.chat.messages.MessageType; + +import java.io.Serial; +import java.io.Serializable; +import java.util.Date; +import java.util.Map; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +@TableName(value = "chat_message", autoResultMap = true) +public class ChatMessage implements Serializable { + + @Serial + @TableField(exist = false) + private static final long serialVersionUID = 1L; + + @TableId(value = "id", type = IdType.ASSIGN_ID) + private Long id; + + /** + * 会话ID + */ + @TableField("conversation_id") + private String conversationId; + + /** + * 消息类型 + */ + @TableField("message_type") + private MessageType messageType; + + /** + * 消息内容 + */ + @TableField("content") + private String content; + + /** + * 元数据 + */ + @TableField(value = "metadata", typeHandler = PostgresJsonTypeHandler.class) + private Map metadata; + + /** + * 创建时间 + */ + @TableField(value = "create_time", fill = FieldFill.INSERT) + private Date createTime; + + /** + * 更新时间 + */ + @Version + @TableField(value = "update_time", fill = FieldFill.INSERT_UPDATE) + private Date updateTime; + + /** + * 是否删除 false-未删除 true-已删除 + */ + @TableField("is_delete") + @TableLogic + private boolean isDelete; +} \ No newline at end of file diff --git a/src/main/java/com/wok/supportbot/handler/MyMetaObjectHandler.java b/src/main/java/com/wok/supportbot/handler/MyMetaObjectHandler.java new file mode 100644 index 0000000..c6b5a77 --- /dev/null +++ b/src/main/java/com/wok/supportbot/handler/MyMetaObjectHandler.java @@ -0,0 +1,25 @@ +package com.wok.supportbot.handler; + +import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler; +import org.apache.ibatis.reflection.MetaObject; +import org.springframework.stereotype.Component; + +import java.util.Date; + +/** + * 在PostgreSQL中,@TableField(fill = FieldFill.INSERT) 和 @TableField(fill = FieldFill.INSERT_UPDATE) 这样的注解本身并不能直接触发数据库级别的自动填充。 + * 这些注解是MyBatis-Plus框架的一部分,它们需要配合 MetaObjectHandler 实现类才能工作。这种机制是在Java应用层面实现的,而非数据库层面。 + */ +@Component +public class MyMetaObjectHandler implements MetaObjectHandler { + @Override + public void insertFill(MetaObject metaObject) { + this.strictInsertFill(metaObject, "createTime", Date.class, new Date()); + this.strictInsertFill(metaObject, "updateTime", Date.class, new Date()); + } + + @Override + public void updateFill(MetaObject metaObject) { + this.strictUpdateFill(metaObject, "updateTime", Date.class, new Date()); + } +} \ No newline at end of file diff --git a/src/main/java/com/wok/supportbot/handler/PostgresJsonTypeHandler.java b/src/main/java/com/wok/supportbot/handler/PostgresJsonTypeHandler.java new file mode 100644 index 0000000..0495a60 --- /dev/null +++ b/src/main/java/com/wok/supportbot/handler/PostgresJsonTypeHandler.java @@ -0,0 +1,61 @@ +package com.wok.supportbot.handler; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.ibatis.type.BaseTypeHandler; +import org.apache.ibatis.type.JdbcType; +import org.apache.ibatis.type.MappedJdbcTypes; +import org.apache.ibatis.type.MappedTypes; + +import java.sql.*; +import java.util.HashMap; +import java.util.Map; + +/** + * PostgreSQL使用JSONB类型存储JSON数据,需要创建自定义类型处理器: + */ +@MappedJdbcTypes(JdbcType.OTHER) +@MappedTypes({Map.class}) +public class PostgresJsonTypeHandler extends BaseTypeHandler> { + private static final ObjectMapper objectMapper = new ObjectMapper(); + + @Override + public void setNonNullParameter(PreparedStatement ps, int i, Map parameter, JdbcType jdbcType) + throws SQLException { + try { + ps.setObject(i, objectMapper.writeValueAsString(parameter), Types.OTHER); + } catch (JsonProcessingException e) { + throw new SQLException("Error converting Map to JSON", e); + } + } + + @Override + public Map getNullableResult(ResultSet rs, String columnName) throws SQLException { + String json = rs.getString(columnName); + return parseJson(json); + } + + @Override + public Map getNullableResult(ResultSet rs, int columnIndex) throws SQLException { + String json = rs.getString(columnIndex); + return parseJson(json); + } + + @Override + public Map getNullableResult(CallableStatement cs, int columnIndex) throws SQLException { + String json = cs.getString(columnIndex); + return parseJson(json); + } + + private Map parseJson(String json) throws SQLException { + try { + if (json == null) { + return new HashMap<>(); + } + return objectMapper.readValue(json, new TypeReference>() {}); + } catch (JsonProcessingException e) { + throw new SQLException("Error parsing JSON to Map", e); + } + } +} \ No newline at end of file diff --git a/src/main/java/com/wok/supportbot/repository/ChatMessageRepository.java b/src/main/java/com/wok/supportbot/repository/ChatMessageRepository.java new file mode 100644 index 0000000..a7efbd3 --- /dev/null +++ b/src/main/java/com/wok/supportbot/repository/ChatMessageRepository.java @@ -0,0 +1,11 @@ +package com.wok.supportbot.repository; + +import com.baomidou.mybatisplus.extension.repository.CrudRepository; +import com.wok.supportbot.dao.ChatMessageMapper; +import com.wok.supportbot.entity.ChatMessage; +import org.springframework.stereotype.Component; + +@Component +public class ChatMessageRepository extends CrudRepository { + +} diff --git a/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java b/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java index c8aeca4..22ea0cf 100644 --- a/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java +++ b/src/test/java/com/wok/supportbot/SupportBotApplicationTests.java @@ -70,8 +70,8 @@ class SupportBotApplicationTests { @Test void doChatWithRag() { - String chatId = UUID.randomUUID().toString(); - String message = "T恤怎么搭配?"; + String chatId = "1069b88d-eb85-47ac-bd2e-c393d118a5aa"; + String message = "我之前询问了你什么问题?"; String answer = assistantApp.doChatWithRag(message, chatId); Assertions.assertNotNull(answer); }