feat(chat): 重构 LLM 流式输出并扩展 ChatSaveReq 字段
- 将原始整段 chunk 拆分为 3 字批次推送,降低前端卡顿 - ChatSaveReq 新增 userId、lang、liked 等 8 个字段并补充 Swagger 注解 - QdrantVectorService 改用 Map<String,JsonWithInt.Value> 载荷,新增 QdrantPayloadMapper 统一转换
This commit is contained in:
@@ -4,28 +4,32 @@ import cn.dev33.satoken.stp.StpUtil;
|
||||
import cn.hutool.core.util.IdUtil;
|
||||
import com.yolo.keyborad.common.BaseResponse;
|
||||
import com.yolo.keyborad.common.ResultUtils;
|
||||
import com.yolo.keyborad.mapper.QdrantPayloadMapper;
|
||||
import com.yolo.keyborad.model.dto.chat.ChatReq;
|
||||
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
|
||||
import com.yolo.keyborad.model.dto.chat.ChatStreamMessage;
|
||||
import com.yolo.keyborad.model.entity.KeyboardCharacter;
|
||||
import com.yolo.keyborad.service.KeyboardCharacterService;
|
||||
import com.yolo.keyborad.service.impl.QdrantVectorService;
|
||||
import io.qdrant.client.grpc.JsonWithInt;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
||||
import org.springframework.boot.context.properties.bind.DefaultValue;
|
||||
import org.springframework.http.codec.ServerSentEvent;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/*
|
||||
* @author: ziin
|
||||
* @date: 2025/12/8 15:05
|
||||
@@ -52,7 +56,7 @@ public class ChatController {
|
||||
|
||||
@PostMapping("/talk")
|
||||
@Operation(summary = "聊天润色接口", description = "聊天润色接口")
|
||||
public Flux<ServerSentEvent<ChatStreamMessage>> testTalk(@RequestBody ChatReq chatReq){
|
||||
public Flux<ServerSentEvent<ChatStreamMessage>> talk(@RequestBody ChatReq chatReq){
|
||||
KeyboardCharacter character = keyboardCharacterService.getById(chatReq.getCharacterId());
|
||||
// 1. LLM 流式输出
|
||||
Flux<ChatStreamMessage> llmFlux = client
|
||||
@@ -69,7 +73,30 @@ public class ChatController {
|
||||
.build())
|
||||
.stream()
|
||||
.content()
|
||||
.map(chunk -> new ChatStreamMessage("llm_chunk", chunk));
|
||||
.concatMap(chunk -> {
|
||||
// 拆成单字符
|
||||
List<String> chars = chunk.codePoints()
|
||||
.mapToObj(cp -> new String(Character.toChars(cp)))
|
||||
.toList();
|
||||
|
||||
// 你可以在这里按 3~5 个字符再拼一拼
|
||||
List<String> batched = new ArrayList<>();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (String ch : chars) {
|
||||
sb.append(ch);
|
||||
if (sb.length() >= 3) { // 这里的 3 可以自己调
|
||||
batched.add(sb.toString());
|
||||
sb.setLength(0);
|
||||
}
|
||||
}
|
||||
if (!sb.isEmpty()) {
|
||||
batched.add(sb.toString());
|
||||
}
|
||||
|
||||
return Flux.fromIterable(batched)
|
||||
.map(s -> new ChatStreamMessage("llm_chunk", s));
|
||||
});
|
||||
// .map(chunk -> new ChatStreamMessage("llm_chunk", chunk));
|
||||
|
||||
// 2. 向量搜索Flux(一次性发送搜索结果)
|
||||
Flux<ChatStreamMessage> searchFlux = Mono
|
||||
@@ -99,8 +126,10 @@ public class ChatController {
|
||||
@Operation(summary = "保存润色后的句子", description = "保存润色后的句子")
|
||||
@Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something")
|
||||
public BaseResponse<Boolean> testTalkWithVector(@RequestBody ChatSaveReq chatSaveReq) {
|
||||
float[] embed = embeddingModel.embed(chatSaveReq.getUserInputMessage());
|
||||
qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, chatSaveReq.getUserSelectMessage());
|
||||
float[] embed = embeddingModel.embed(chatSaveReq.getUserText());
|
||||
chatSaveReq.setUserId(StpUtil.getLoginIdAsLong());
|
||||
Map<String, JsonWithInt.Value> map = QdrantPayloadMapper.toQdrantPayload(chatSaveReq);
|
||||
qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, map);
|
||||
return ResultUtils.success(true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package com.yolo.keyborad.mapper;
|
||||
|
||||
|
||||
|
||||
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
|
||||
import io.qdrant.client.grpc.JsonWithInt;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class QdrantPayloadMapper {
|
||||
|
||||
public static Map<String, JsonWithInt.Value> toQdrantPayload(ChatSaveReq p) {
|
||||
Map<String, JsonWithInt.Value> map = new HashMap<>();
|
||||
|
||||
if (p.getUserId() != null)
|
||||
map.put("userId", longValue(p.getUserId()));
|
||||
|
||||
if (p.getUserText() != null)
|
||||
map.put("userText", stringValue(p.getUserText()));
|
||||
|
||||
if (p.getReplyText() != null)
|
||||
map.put("replyText", stringValue(p.getReplyText()));
|
||||
|
||||
if (p.getCharacterId() != null)
|
||||
map.put("characterId", intValue(p.getCharacterId()));
|
||||
|
||||
if (p.getLang() != null)
|
||||
map.put("lang", stringValue(p.getLang()));
|
||||
|
||||
if (p.getLiked() != null)
|
||||
map.put("liked", boolValue(p.getLiked()));
|
||||
|
||||
if (p.getCreatedAt() != null)
|
||||
map.put("createdAt", longValue(p.getCreatedAt()));
|
||||
|
||||
if (p.getSource() != null)
|
||||
map.put("source", stringValue(p.getSource()));
|
||||
|
||||
if (p.getAppVersion() != null)
|
||||
map.put("appVersion", stringValue(p.getAppVersion()));
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
private static JsonWithInt.Value stringValue(String v) {
|
||||
return JsonWithInt.Value.newBuilder().setStringValue(v).build();
|
||||
}
|
||||
|
||||
private static JsonWithInt.Value intValue(Integer v) {
|
||||
return JsonWithInt.Value.newBuilder().setIntegerValue(v).build();
|
||||
}
|
||||
|
||||
private static JsonWithInt.Value longValue(Long v) {
|
||||
return JsonWithInt.Value.newBuilder().setIntegerValue(v).build();
|
||||
}
|
||||
|
||||
private static JsonWithInt.Value boolValue(Boolean v) {
|
||||
return JsonWithInt.Value.newBuilder().setBoolValue(v).build();
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.yolo.keyborad.model.dto.chat;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
/*
|
||||
@@ -9,8 +10,39 @@ import lombok.Data;
|
||||
@Data
|
||||
public class ChatSaveReq {
|
||||
|
||||
private String userInputMessage;
|
||||
|
||||
private String userSelectMessage;
|
||||
/** 用户的原始输入文本 */
|
||||
@Schema(description="用户的原始输入文本")
|
||||
private String userText;
|
||||
|
||||
/** 用户选择 / 点赞的回复文本 */
|
||||
@Schema(description="用户选择 / 点赞的回复文本")
|
||||
private String replyText;
|
||||
|
||||
/** 当前使用的角色ID(比如某个键盘人格) */
|
||||
@Schema(description="当前使用的角色ID(比如某个键盘人格)")
|
||||
private Integer characterId;
|
||||
|
||||
/** 文本语言:en / zh / ja / es ... */
|
||||
@Schema(description="文本语言:en / zh / ja / es ... ")
|
||||
private String lang;
|
||||
|
||||
/** 是否是用户明确点选或点赞过的内容,用于高质量样本过滤 */
|
||||
@Schema(description="是否是用户明确点选或点赞过的内容,用于高质量样本过滤")
|
||||
private Boolean liked;
|
||||
|
||||
/** 创建时间(建议存秒级或毫秒级时间戳) */
|
||||
@Schema(description="创建时间(建议存秒级或毫秒级时间戳)")
|
||||
private Long createdAt;
|
||||
|
||||
/** 数据来源:例如 "llm" / "template" / "user" */
|
||||
@Schema(description="数据来源:例如 \"llm\" / \"template\" / \"user\"")
|
||||
private String source;
|
||||
|
||||
/** 可选:用于调试 / 分析的客户端版本号 */
|
||||
@Schema(description="可选:用于调试 / 分析的客户端版本号")
|
||||
private String appVersion;
|
||||
|
||||
@Schema(description = "用户 Id")
|
||||
private Long userId;
|
||||
}
|
||||
|
||||
@@ -5,22 +5,19 @@ import com.yolo.keyborad.common.ErrorCode;
|
||||
import com.yolo.keyborad.exception.BusinessException;
|
||||
import com.yolo.keyborad.model.vo.QdrantSearchItem;
|
||||
import io.qdrant.client.QdrantClient;
|
||||
import io.qdrant.client.grpc.JsonWithInt;
|
||||
import io.qdrant.client.grpc.Points;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static io.qdrant.client.PointIdFactory.id;
|
||||
import static io.qdrant.client.QueryFactory.nearest;
|
||||
import static io.qdrant.client.ValueFactory.value;
|
||||
import static io.qdrant.client.VectorsFactory.vectors;
|
||||
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
|
||||
|
||||
@@ -44,7 +41,7 @@ public class QdrantVectorService {
|
||||
* @param vector 向量(和 collection 中定义的 size 一致)
|
||||
* @param payload 额外信息,例如原文、标题、userId 等
|
||||
*/
|
||||
public void upsertPoint(long id, float[] vector,String payload){
|
||||
public void upsertPoint(long id, float[] vector,Map<String, JsonWithInt.Value> payload){
|
||||
|
||||
try {
|
||||
qdrantClient.upsertAsync(
|
||||
@@ -53,7 +50,7 @@ public class QdrantVectorService {
|
||||
Points.PointStruct.newBuilder()
|
||||
.setId(id(id))
|
||||
.setVectors(vectors(vector))
|
||||
.putAllPayload(Map.of("payload",value(payload)))
|
||||
.putAllPayload(payload)
|
||||
.build()
|
||||
)
|
||||
).get();
|
||||
@@ -119,7 +116,7 @@ public class QdrantVectorService {
|
||||
item.setScore(p.getScore());
|
||||
|
||||
var fieldsMap = p.getPayloadMap();
|
||||
var payloadValue = fieldsMap.get("payload");
|
||||
var payloadValue = fieldsMap.get("replyText");
|
||||
if (payloadValue != null && payloadValue.hasStringValue()) {
|
||||
item.setPayload(payloadValue.getStringValue());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user