From 39b19493e2fd5187d7ed5d3410e67e0f8096ed3f Mon Sep 17 00:00:00 2001 From: ziin Date: Mon, 8 Dec 2025 20:45:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(chat):=20=E6=96=B0=E5=A2=9E=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E6=B6=A6=E8=89=B2=E7=BB=93=E6=9E=9C=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E5=B9=B6=E9=87=8D=E6=9E=84=E5=90=91=E9=87=8F?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ChatController 新增 /save_embed 接口,接收 ChatSaveReq 保存用户选中润色句子的向量 - 统一向量参数由 List 改为 float[],降低 GC 压力 - 向量搜索增加 ≥0.9 相似度过滤,仅返回高置信结果 - 精简 DemoController 测试接口,下线冗余的 testSaveEmbed/testSearch - 调整 Embedding 模型为 qwen3-embedding-4b,降低资源占用 - 放开 /chat/save_embed 匿名访问,适配前端直调 --- .../com/yolo/keyborad/config/LLMConfig.java | 2 +- .../keyborad/config/SaTokenConfigure.java | 3 +- .../keyborad/controller/ChatController.java | 18 ++++--- .../keyborad/controller/DemoController.java | 30 +++++------ .../keyborad/model/dto/chat/ChatSaveReq.java | 16 ++++++ .../service/impl/QdrantVectorService.java | 50 +++++++------------ 6 files changed, 62 insertions(+), 57 deletions(-) create mode 100644 src/main/java/com/yolo/keyborad/model/dto/chat/ChatSaveReq.java diff --git a/src/main/java/com/yolo/keyborad/config/LLMConfig.java b/src/main/java/com/yolo/keyborad/config/LLMConfig.java index 5d1a721..eb08d6f 100644 --- a/src/main/java/com/yolo/keyborad/config/LLMConfig.java +++ b/src/main/java/com/yolo/keyborad/config/LLMConfig.java @@ -53,7 +53,7 @@ public class LLMConfig { this.openAiApi(), MetadataMode.EMBED, OpenAiEmbeddingOptions.builder() - .model("qwen/qwen3-embedding-8b") + .model("qwen/qwen3-embedding-4b") .dimensions(1536) .user("user-6") .build(), diff --git a/src/main/java/com/yolo/keyborad/config/SaTokenConfigure.java b/src/main/java/com/yolo/keyborad/config/SaTokenConfigure.java index 5a12d72..a1084b1 100644 --- a/src/main/java/com/yolo/keyborad/config/SaTokenConfigure.java +++ b/src/main/java/com/yolo/keyborad/config/SaTokenConfigure.java @@ -84,7 +84,8 @@ public class SaTokenConfigure implements WebMvcConfigurer { "/api/apple/validate-receipt", "/character/list", "/user/resetPassWord", - "/chat/talk" + "/chat/talk", + "/chat/save_embed" }; } @Bean diff --git a/src/main/java/com/yolo/keyborad/controller/ChatController.java b/src/main/java/com/yolo/keyborad/controller/ChatController.java index 11c2094..76b81ea 100644 --- a/src/main/java/com/yolo/keyborad/controller/ChatController.java +++ b/src/main/java/com/yolo/keyborad/controller/ChatController.java @@ -1,7 +1,11 @@ package com.yolo.keyborad.controller; import cn.dev33.satoken.stp.StpUtil; +import cn.hutool.core.util.IdUtil; +import com.yolo.keyborad.common.BaseResponse; +import com.yolo.keyborad.common.ResultUtils; import com.yolo.keyborad.model.dto.chat.ChatReq; +import com.yolo.keyborad.model.dto.chat.ChatSaveReq; import com.yolo.keyborad.model.dto.chat.ChatStreamMessage; import com.yolo.keyborad.model.entity.KeyboardCharacter; import com.yolo.keyborad.service.KeyboardCharacterService; @@ -12,6 +16,7 @@ import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.boot.context.properties.bind.DefaultValue; @@ -47,15 +52,11 @@ public class ChatController { @PostMapping("/talk") @Operation(summary = "聊天润色接口", description = "聊天润色接口") - @Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something") public Flux> testTalk(@RequestBody ChatReq chatReq){ - KeyboardCharacter character = keyboardCharacterService.getById(chatReq.getCharacterId()); - // 1. LLM 流式输出 Flux llmFlux = client - .prompt(character.getPrompt() + - "\nUser message: %s".formatted(chatReq.getMessage())) + .prompt(character.getPrompt()) .system(""" Format rules: - Return EXACTLY 3 replies. @@ -97,8 +98,9 @@ public class ChatController { @PostMapping("/save_embed") @Operation(summary = "保存润色后的句子", description = "保存润色后的句子") @Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something") - public Flux testTalkWithVector(@RequestBody ChatReq chatReq) { - - return null; + public BaseResponse testTalkWithVector(@RequestBody ChatSaveReq chatSaveReq) { + float[] embed = embeddingModel.embed(chatSaveReq.getUserInputMessage()); + qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, chatSaveReq.getUserSelectMessage()); + return ResultUtils.success(true); } } diff --git a/src/main/java/com/yolo/keyborad/controller/DemoController.java b/src/main/java/com/yolo/keyborad/controller/DemoController.java index fa55543..c09818c 100644 --- a/src/main/java/com/yolo/keyborad/controller/DemoController.java +++ b/src/main/java/com/yolo/keyborad/controller/DemoController.java @@ -97,23 +97,23 @@ public class DemoController { } - @PostMapping("/testSaveEmbed") - @Operation(summary = "测试存储向量接口", description = "测试存储向量接口") - @Parameter(name = "userInput",required = true,description = "测试存储向量接口") - public BaseResponse testSaveEmbed(@RequestBody EmbedSaveReq embedSaveReq) { - qdrantVectorService.upsertPoint(embedSaveReq.getRecordItem().getId() - , embedSaveReq.getVector() - , JSONUtil.toJsonStr(embedSaveReq.getRecordItem())); - return ResultUtils.success(true); - } +// @PostMapping("/testSaveEmbed") +// @Operation(summary = "测试存储向量接口", description = "测试存储向量接口") +// @Parameter(name = "userInput",required = true,description = "测试存储向量接口") +// public BaseResponse testSaveEmbed(@RequestBody EmbedSaveReq embedSaveReq) { +// qdrantVectorService.upsertPoint(embedSaveReq.getRecordItem().getId() +// , embedSaveReq.getVector() +// , JSONUtil.toJsonStr(embedSaveReq.getRecordItem())); +// return ResultUtils.success(true); +// } - @PostMapping("/testSearch") - @Operation(summary = "测试搜索向量接口", description = "测试搜索向量接口") - @Parameter(name = "userInput",required = true,description = "测试搜索向量接口") - public BaseResponse> testSearch(@RequestBody SearchEmbedReq searchEmbedReq) { - return ResultUtils.success(qdrantVectorService.searchPoint(searchEmbedReq.getUserInputEmbed(), 3)); - } +// @PostMapping("/testSearch") +// @Operation(summary = "测试搜索向量接口", description = "测试搜索向量接口") +// @Parameter(name = "userInput",required = true,description = "测试搜索向量接口") +// public BaseResponse> testSearch(@RequestBody SearchEmbedReq searchEmbedReq) { +// return ResultUtils.success(qdrantVectorService.searchPoint(searchEmbedReq.getUserInputEmbed(), 3)); +// } @PostMapping("/tsetSearchText") diff --git a/src/main/java/com/yolo/keyborad/model/dto/chat/ChatSaveReq.java b/src/main/java/com/yolo/keyborad/model/dto/chat/ChatSaveReq.java new file mode 100644 index 0000000..edcb816 --- /dev/null +++ b/src/main/java/com/yolo/keyborad/model/dto/chat/ChatSaveReq.java @@ -0,0 +1,16 @@ +package com.yolo.keyborad.model.dto.chat; + +import lombok.Data; + +/* + * @author: ziin + * @date: 2025/12/8 19:26 + */ +@Data +public class ChatSaveReq { + + private String userInputMessage; + + private String userSelectMessage; + +} diff --git a/src/main/java/com/yolo/keyborad/service/impl/QdrantVectorService.java b/src/main/java/com/yolo/keyborad/service/impl/QdrantVectorService.java index 1bc85c5..47b65d7 100644 --- a/src/main/java/com/yolo/keyborad/service/impl/QdrantVectorService.java +++ b/src/main/java/com/yolo/keyborad/service/impl/QdrantVectorService.java @@ -13,6 +13,7 @@ import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.stereotype.Service; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -43,7 +44,7 @@ public class QdrantVectorService { * @param vector 向量(和 collection 中定义的 size 一致) * @param payload 额外信息,例如原文、标题、userId 等 */ - public void upsertPoint(long id, List vector,String payload){ + public void upsertPoint(long id, float[] vector,String payload){ try { qdrantClient.upsertAsync( @@ -90,12 +91,12 @@ public class QdrantVectorService { // } // } - public List searchPoint(List userVector, int limit) { + public List searchPoint(float[] userVector, int limit) { try { Points.QueryPoints query = Points.QueryPoints.newBuilder() .setCollectionName(COLLECTION_NAME) // ★ 必须 .setQuery(nearest(userVector)) // ★ 语义向量 - .setLimit(limit) // TopK + .setLimit(limit) // 限制返回数量 .setWithPayload(enable(true)) // ★ 带上 payload .build(); @@ -107,37 +108,31 @@ public class QdrantVectorService { // 3. 把 Protobuf 的 ScoredPoint 转成你的 DTO return batchResult.getResultList().stream() + .filter(p -> p.getScore() >= 0.9) // ★ 只要相似度 ≥ 90% .map(p -> { QdrantSearchItem item = new QdrantSearchItem(); - // id:你插入时用的是 setId(id(id)),所以这里取 num if (p.getId().hasNum()) { item.setId(p.getId().getNum()); } - // score item.setScore(p.getScore()); - // payload:你之前是 putAllPayload(Map.of("payload", value(payload))) - // 这里从 Struct 里拿 "payload" 字段 var fieldsMap = p.getPayloadMap(); var payloadValue = fieldsMap.get("payload"); if (payloadValue != null && payloadValue.hasStringValue()) { item.setPayload(payloadValue.getStringValue()); } - // vector:你插入时用的是 vectors(vector),即 unnamed 单向量 - // proto 结构一般是 Vectors.vector.data[] if (p.getVectors().hasVector()) { - List vec = p.getVectors().getVector().getDataList(); - item.setVector(vec); + item.setVector(p.getVectors().getVector().getDataList()); } return item; }) .toList(); - + } catch (InterruptedException | ExecutionException e) { log.error("search point 失败", e); throw new BusinessException(ErrorCode.OPERATION_ERROR); @@ -149,38 +144,29 @@ public class QdrantVectorService { /** * 把一段文本做 embedding 然后写入 Qdrant * - * @param id 业务 ID(比如业务表主键) * @param text 用来做向量的文本(一般是内容) */ - public void indexText(long id, String text) { - // 1. 文本 → 向量 - List vector = embedTextToVector(text); +// public void indexText(long id, String text) { +// // 1. 文本 → 向量 +// embedTextToVector(text); +// +// // 2. 存到 Qdrant,payload 里顺便存原文 +// upsertPoint(id, vector, text); +// } - // 2. 存到 Qdrant,payload 里顺便存原文 - upsertPoint(id, vector, text); - } + private float[] embedTextToVector(String text) { - private List embedTextToVector(String text) { - EmbeddingResponse response = embeddingModel.embedForResponse(List.of(text)); - - - Embedding embedding = response.getResult(); // 就一条 - // Spring AI 里一般是 List - float[] output = embedding.getOutput(); - // 转成 Qdrant 需要的 List - - - return Floats.asList(output); + return embeddingModel.embed(text); } public List searchText(String userInput) { long t0 = System.currentTimeMillis(); - List floats = this.embedTextToVector(userInput); + float[] floats = this.embedTextToVector(userInput); long t1 = System.currentTimeMillis(); - List qdrantSearchItems = this.searchPoint(floats, 3); + List qdrantSearchItems = this.searchPoint(floats, 1); long t2 = System.currentTimeMillis(); log.info("embedding = {} ms, qdrant = {} ms, total = {} ms",