feat(chat): 新增保存润色结果向量接口并重构向量类型
- ChatController 新增 /save_embed 接口,接收 ChatSaveReq 保存用户选中润色句子的向量 - 统一向量参数由 List<Float> 改为 float[],降低 GC 压力 - 向量搜索增加 ≥0.9 相似度过滤,仅返回高置信结果 - 精简 DemoController 测试接口,下线冗余的 testSaveEmbed/testSearch - 调整 Embedding 模型为 qwen3-embedding-4b,降低资源占用 - 放开 /chat/save_embed 匿名访问,适配前端直调
This commit is contained in:
@@ -53,7 +53,7 @@ public class LLMConfig {
|
|||||||
this.openAiApi(),
|
this.openAiApi(),
|
||||||
MetadataMode.EMBED,
|
MetadataMode.EMBED,
|
||||||
OpenAiEmbeddingOptions.builder()
|
OpenAiEmbeddingOptions.builder()
|
||||||
.model("qwen/qwen3-embedding-8b")
|
.model("qwen/qwen3-embedding-4b")
|
||||||
.dimensions(1536)
|
.dimensions(1536)
|
||||||
.user("user-6")
|
.user("user-6")
|
||||||
.build(),
|
.build(),
|
||||||
|
|||||||
@@ -84,7 +84,8 @@ public class SaTokenConfigure implements WebMvcConfigurer {
|
|||||||
"/api/apple/validate-receipt",
|
"/api/apple/validate-receipt",
|
||||||
"/character/list",
|
"/character/list",
|
||||||
"/user/resetPassWord",
|
"/user/resetPassWord",
|
||||||
"/chat/talk"
|
"/chat/talk",
|
||||||
|
"/chat/save_embed"
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@Bean
|
@Bean
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package com.yolo.keyborad.controller;
|
package com.yolo.keyborad.controller;
|
||||||
|
|
||||||
import cn.dev33.satoken.stp.StpUtil;
|
import cn.dev33.satoken.stp.StpUtil;
|
||||||
|
import cn.hutool.core.util.IdUtil;
|
||||||
|
import com.yolo.keyborad.common.BaseResponse;
|
||||||
|
import com.yolo.keyborad.common.ResultUtils;
|
||||||
import com.yolo.keyborad.model.dto.chat.ChatReq;
|
import com.yolo.keyborad.model.dto.chat.ChatReq;
|
||||||
|
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
|
||||||
import com.yolo.keyborad.model.dto.chat.ChatStreamMessage;
|
import com.yolo.keyborad.model.dto.chat.ChatStreamMessage;
|
||||||
import com.yolo.keyborad.model.entity.KeyboardCharacter;
|
import com.yolo.keyborad.model.entity.KeyboardCharacter;
|
||||||
import com.yolo.keyborad.service.KeyboardCharacterService;
|
import com.yolo.keyborad.service.KeyboardCharacterService;
|
||||||
@@ -12,6 +16,7 @@ import io.swagger.v3.oas.annotations.tags.Tag;
|
|||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.chat.client.ChatClient;
|
import org.springframework.ai.chat.client.ChatClient;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
||||||
import org.springframework.boot.context.properties.bind.DefaultValue;
|
import org.springframework.boot.context.properties.bind.DefaultValue;
|
||||||
@@ -47,15 +52,11 @@ public class ChatController {
|
|||||||
|
|
||||||
@PostMapping("/talk")
|
@PostMapping("/talk")
|
||||||
@Operation(summary = "聊天润色接口", description = "聊天润色接口")
|
@Operation(summary = "聊天润色接口", description = "聊天润色接口")
|
||||||
@Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something")
|
|
||||||
public Flux<ServerSentEvent<ChatStreamMessage>> testTalk(@RequestBody ChatReq chatReq){
|
public Flux<ServerSentEvent<ChatStreamMessage>> testTalk(@RequestBody ChatReq chatReq){
|
||||||
|
|
||||||
KeyboardCharacter character = keyboardCharacterService.getById(chatReq.getCharacterId());
|
KeyboardCharacter character = keyboardCharacterService.getById(chatReq.getCharacterId());
|
||||||
|
|
||||||
// 1. LLM 流式输出
|
// 1. LLM 流式输出
|
||||||
Flux<ChatStreamMessage> llmFlux = client
|
Flux<ChatStreamMessage> llmFlux = client
|
||||||
.prompt(character.getPrompt() +
|
.prompt(character.getPrompt())
|
||||||
"\nUser message: %s".formatted(chatReq.getMessage()))
|
|
||||||
.system("""
|
.system("""
|
||||||
Format rules:
|
Format rules:
|
||||||
- Return EXACTLY 3 replies.
|
- Return EXACTLY 3 replies.
|
||||||
@@ -97,8 +98,9 @@ public class ChatController {
|
|||||||
@PostMapping("/save_embed")
|
@PostMapping("/save_embed")
|
||||||
@Operation(summary = "保存润色后的句子", description = "保存润色后的句子")
|
@Operation(summary = "保存润色后的句子", description = "保存润色后的句子")
|
||||||
@Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something")
|
@Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something")
|
||||||
public Flux<String> testTalkWithVector(@RequestBody ChatReq chatReq) {
|
public BaseResponse<Boolean> testTalkWithVector(@RequestBody ChatSaveReq chatSaveReq) {
|
||||||
|
float[] embed = embeddingModel.embed(chatSaveReq.getUserInputMessage());
|
||||||
return null;
|
qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, chatSaveReq.getUserSelectMessage());
|
||||||
|
return ResultUtils.success(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -97,23 +97,23 @@ public class DemoController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@PostMapping("/testSaveEmbed")
|
// @PostMapping("/testSaveEmbed")
|
||||||
@Operation(summary = "测试存储向量接口", description = "测试存储向量接口")
|
// @Operation(summary = "测试存储向量接口", description = "测试存储向量接口")
|
||||||
@Parameter(name = "userInput",required = true,description = "测试存储向量接口")
|
// @Parameter(name = "userInput",required = true,description = "测试存储向量接口")
|
||||||
public BaseResponse<Boolean> testSaveEmbed(@RequestBody EmbedSaveReq embedSaveReq) {
|
// public BaseResponse<Boolean> testSaveEmbed(@RequestBody EmbedSaveReq embedSaveReq) {
|
||||||
qdrantVectorService.upsertPoint(embedSaveReq.getRecordItem().getId()
|
// qdrantVectorService.upsertPoint(embedSaveReq.getRecordItem().getId()
|
||||||
, embedSaveReq.getVector()
|
// , embedSaveReq.getVector()
|
||||||
, JSONUtil.toJsonStr(embedSaveReq.getRecordItem()));
|
// , JSONUtil.toJsonStr(embedSaveReq.getRecordItem()));
|
||||||
return ResultUtils.success(true);
|
// return ResultUtils.success(true);
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
|
||||||
@PostMapping("/testSearch")
|
// @PostMapping("/testSearch")
|
||||||
@Operation(summary = "测试搜索向量接口", description = "测试搜索向量接口")
|
// @Operation(summary = "测试搜索向量接口", description = "测试搜索向量接口")
|
||||||
@Parameter(name = "userInput",required = true,description = "测试搜索向量接口")
|
// @Parameter(name = "userInput",required = true,description = "测试搜索向量接口")
|
||||||
public BaseResponse<List<QdrantSearchItem>> testSearch(@RequestBody SearchEmbedReq searchEmbedReq) {
|
// public BaseResponse<List<QdrantSearchItem>> testSearch(@RequestBody SearchEmbedReq searchEmbedReq) {
|
||||||
return ResultUtils.success(qdrantVectorService.searchPoint(searchEmbedReq.getUserInputEmbed(), 3));
|
// return ResultUtils.success(qdrantVectorService.searchPoint(searchEmbedReq.getUserInputEmbed(), 3));
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
|
||||||
@PostMapping("/tsetSearchText")
|
@PostMapping("/tsetSearchText")
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.yolo.keyborad.model.dto.chat;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @author: ziin
|
||||||
|
* @date: 2025/12/8 19:26
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class ChatSaveReq {
|
||||||
|
|
||||||
|
private String userInputMessage;
|
||||||
|
|
||||||
|
private String userSelectMessage;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import org.springframework.ai.embedding.EmbeddingModel;
|
|||||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
@@ -43,7 +44,7 @@ public class QdrantVectorService {
|
|||||||
* @param vector 向量(和 collection 中定义的 size 一致)
|
* @param vector 向量(和 collection 中定义的 size 一致)
|
||||||
* @param payload 额外信息,例如原文、标题、userId 等
|
* @param payload 额外信息,例如原文、标题、userId 等
|
||||||
*/
|
*/
|
||||||
public void upsertPoint(long id, List<Float> vector,String payload){
|
public void upsertPoint(long id, float[] vector,String payload){
|
||||||
|
|
||||||
try {
|
try {
|
||||||
qdrantClient.upsertAsync(
|
qdrantClient.upsertAsync(
|
||||||
@@ -90,12 +91,12 @@ public class QdrantVectorService {
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
public List<QdrantSearchItem> searchPoint(List<Float> userVector, int limit) {
|
public List<QdrantSearchItem> searchPoint(float[] userVector, int limit) {
|
||||||
try {
|
try {
|
||||||
Points.QueryPoints query = Points.QueryPoints.newBuilder()
|
Points.QueryPoints query = Points.QueryPoints.newBuilder()
|
||||||
.setCollectionName(COLLECTION_NAME) // ★ 必须
|
.setCollectionName(COLLECTION_NAME) // ★ 必须
|
||||||
.setQuery(nearest(userVector)) // ★ 语义向量
|
.setQuery(nearest(userVector)) // ★ 语义向量
|
||||||
.setLimit(limit) // TopK
|
.setLimit(limit) // 限制返回数量
|
||||||
.setWithPayload(enable(true)) // ★ 带上 payload
|
.setWithPayload(enable(true)) // ★ 带上 payload
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@@ -107,37 +108,31 @@ public class QdrantVectorService {
|
|||||||
|
|
||||||
// 3. 把 Protobuf 的 ScoredPoint 转成你的 DTO
|
// 3. 把 Protobuf 的 ScoredPoint 转成你的 DTO
|
||||||
return batchResult.getResultList().stream()
|
return batchResult.getResultList().stream()
|
||||||
|
.filter(p -> p.getScore() >= 0.9) // ★ 只要相似度 ≥ 90%
|
||||||
.map(p -> {
|
.map(p -> {
|
||||||
QdrantSearchItem item = new QdrantSearchItem();
|
QdrantSearchItem item = new QdrantSearchItem();
|
||||||
|
|
||||||
// id:你插入时用的是 setId(id(id)),所以这里取 num
|
|
||||||
if (p.getId().hasNum()) {
|
if (p.getId().hasNum()) {
|
||||||
item.setId(p.getId().getNum());
|
item.setId(p.getId().getNum());
|
||||||
}
|
}
|
||||||
|
|
||||||
// score
|
|
||||||
item.setScore(p.getScore());
|
item.setScore(p.getScore());
|
||||||
|
|
||||||
// payload:你之前是 putAllPayload(Map.of("payload", value(payload)))
|
|
||||||
// 这里从 Struct 里拿 "payload" 字段
|
|
||||||
var fieldsMap = p.getPayloadMap();
|
var fieldsMap = p.getPayloadMap();
|
||||||
var payloadValue = fieldsMap.get("payload");
|
var payloadValue = fieldsMap.get("payload");
|
||||||
if (payloadValue != null && payloadValue.hasStringValue()) {
|
if (payloadValue != null && payloadValue.hasStringValue()) {
|
||||||
item.setPayload(payloadValue.getStringValue());
|
item.setPayload(payloadValue.getStringValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// vector:你插入时用的是 vectors(vector),即 unnamed 单向量
|
|
||||||
// proto 结构一般是 Vectors.vector.data[]
|
|
||||||
if (p.getVectors().hasVector()) {
|
if (p.getVectors().hasVector()) {
|
||||||
List<Float> vec = p.getVectors().getVector().getDataList();
|
item.setVector(p.getVectors().getVector().getDataList());
|
||||||
item.setVector(vec);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return item;
|
return item;
|
||||||
})
|
})
|
||||||
.toList();
|
.toList();
|
||||||
|
|
||||||
|
|
||||||
} catch (InterruptedException | ExecutionException e) {
|
} catch (InterruptedException | ExecutionException e) {
|
||||||
log.error("search point 失败", e);
|
log.error("search point 失败", e);
|
||||||
throw new BusinessException(ErrorCode.OPERATION_ERROR);
|
throw new BusinessException(ErrorCode.OPERATION_ERROR);
|
||||||
@@ -149,38 +144,29 @@ public class QdrantVectorService {
|
|||||||
/**
|
/**
|
||||||
* 把一段文本做 embedding 然后写入 Qdrant
|
* 把一段文本做 embedding 然后写入 Qdrant
|
||||||
*
|
*
|
||||||
* @param id 业务 ID(比如业务表主键)
|
|
||||||
* @param text 用来做向量的文本(一般是内容)
|
* @param text 用来做向量的文本(一般是内容)
|
||||||
*/
|
*/
|
||||||
public void indexText(long id, String text) {
|
// public void indexText(long id, String text) {
|
||||||
// 1. 文本 → 向量
|
// // 1. 文本 → 向量
|
||||||
List<Float> vector = embedTextToVector(text);
|
// embedTextToVector(text);
|
||||||
|
//
|
||||||
|
// // 2. 存到 Qdrant,payload 里顺便存原文
|
||||||
|
// upsertPoint(id, vector, text);
|
||||||
|
// }
|
||||||
|
|
||||||
// 2. 存到 Qdrant,payload 里顺便存原文
|
private float[] embedTextToVector(String text) {
|
||||||
upsertPoint(id, vector, text);
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<Float> embedTextToVector(String text) {
|
return embeddingModel.embed(text);
|
||||||
EmbeddingResponse response = embeddingModel.embedForResponse(List.of(text));
|
|
||||||
|
|
||||||
|
|
||||||
Embedding embedding = response.getResult(); // 就一条
|
|
||||||
// Spring AI 里一般是 List<Double>
|
|
||||||
float[] output = embedding.getOutput();
|
|
||||||
// 转成 Qdrant 需要的 List<Float>
|
|
||||||
|
|
||||||
|
|
||||||
return Floats.asList(output);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<QdrantSearchItem> searchText(String userInput) {
|
public List<QdrantSearchItem> searchText(String userInput) {
|
||||||
long t0 = System.currentTimeMillis();
|
long t0 = System.currentTimeMillis();
|
||||||
|
|
||||||
List<Float> floats = this.embedTextToVector(userInput);
|
float[] floats = this.embedTextToVector(userInput);
|
||||||
long t1 = System.currentTimeMillis();
|
long t1 = System.currentTimeMillis();
|
||||||
|
|
||||||
List<QdrantSearchItem> qdrantSearchItems = this.searchPoint(floats, 3);
|
List<QdrantSearchItem> qdrantSearchItems = this.searchPoint(floats, 1);
|
||||||
long t2 = System.currentTimeMillis();
|
long t2 = System.currentTimeMillis();
|
||||||
|
|
||||||
log.info("embedding = {} ms, qdrant = {} ms, total = {} ms",
|
log.info("embedding = {} ms, qdrant = {} ms, total = {} ms",
|
||||||
|
|||||||
Reference in New Issue
Block a user