feat(chat): 重构 LLM 流式输出并扩展 ChatSaveReq 字段

- 将原始整段 chunk 拆分为 3 字批次推送,降低前端卡顿
- ChatSaveReq 新增 userId、lang、liked 等 8 个字段并补充 Swagger 注解
- QdrantVectorService 改用 Map<String,JsonWithInt.Value> 载荷,新增 QdrantPayloadMapper 统一转换
This commit is contained in:
2025-12-09 14:49:14 +08:00
parent 39b19493e2
commit fba6f0d729
4 changed files with 134 additions and 15 deletions

View File

@@ -4,28 +4,32 @@ import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.core.util.IdUtil;
import com.yolo.keyborad.common.BaseResponse;
import com.yolo.keyborad.common.ResultUtils;
import com.yolo.keyborad.mapper.QdrantPayloadMapper;
import com.yolo.keyborad.model.dto.chat.ChatReq;
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
import com.yolo.keyborad.model.dto.chat.ChatStreamMessage;
import com.yolo.keyborad.model.entity.KeyboardCharacter;
import com.yolo.keyborad.service.KeyboardCharacterService;
import com.yolo.keyborad.service.impl.QdrantVectorService;
import io.qdrant.client.grpc.JsonWithInt;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.boot.context.properties.bind.DefaultValue;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/*
* @author: ziin
* @date: 2025/12/8 15:05
@@ -52,7 +56,7 @@ public class ChatController {
@PostMapping("/talk")
@Operation(summary = "聊天润色接口", description = "聊天润色接口")
public Flux<ServerSentEvent<ChatStreamMessage>> testTalk(@RequestBody ChatReq chatReq){
public Flux<ServerSentEvent<ChatStreamMessage>> talk(@RequestBody ChatReq chatReq){
KeyboardCharacter character = keyboardCharacterService.getById(chatReq.getCharacterId());
// 1. LLM 流式输出
Flux<ChatStreamMessage> llmFlux = client
@@ -69,7 +73,30 @@ public class ChatController {
.build())
.stream()
.content()
.map(chunk -> new ChatStreamMessage("llm_chunk", chunk));
.concatMap(chunk -> {
// 拆成单字符
List<String> chars = chunk.codePoints()
.mapToObj(cp -> new String(Character.toChars(cp)))
.toList();
// 你可以在这里按 35 个字符再拼一拼
List<String> batched = new ArrayList<>();
StringBuilder sb = new StringBuilder();
for (String ch : chars) {
sb.append(ch);
if (sb.length() >= 3) { // 这里的 3 可以自己调
batched.add(sb.toString());
sb.setLength(0);
}
}
if (!sb.isEmpty()) {
batched.add(sb.toString());
}
return Flux.fromIterable(batched)
.map(s -> new ChatStreamMessage("llm_chunk", s));
});
// .map(chunk -> new ChatStreamMessage("llm_chunk", chunk));
// 2. 向量搜索Flux一次性发送搜索结果
Flux<ChatStreamMessage> searchFlux = Mono
@@ -99,8 +126,10 @@ public class ChatController {
@Operation(summary = "保存润色后的句子", description = "保存润色后的句子")
@Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something")
public BaseResponse<Boolean> testTalkWithVector(@RequestBody ChatSaveReq chatSaveReq) {
float[] embed = embeddingModel.embed(chatSaveReq.getUserInputMessage());
qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, chatSaveReq.getUserSelectMessage());
float[] embed = embeddingModel.embed(chatSaveReq.getUserText());
chatSaveReq.setUserId(StpUtil.getLoginIdAsLong());
Map<String, JsonWithInt.Value> map = QdrantPayloadMapper.toQdrantPayload(chatSaveReq);
qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, map);
return ResultUtils.success(true);
}
}

View File

@@ -0,0 +1,61 @@
package com.yolo.keyborad.mapper;
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
import io.qdrant.client.grpc.JsonWithInt;
import java.util.HashMap;
import java.util.Map;
public class QdrantPayloadMapper {
public static Map<String, JsonWithInt.Value> toQdrantPayload(ChatSaveReq p) {
Map<String, JsonWithInt.Value> map = new HashMap<>();
if (p.getUserId() != null)
map.put("userId", longValue(p.getUserId()));
if (p.getUserText() != null)
map.put("userText", stringValue(p.getUserText()));
if (p.getReplyText() != null)
map.put("replyText", stringValue(p.getReplyText()));
if (p.getCharacterId() != null)
map.put("characterId", intValue(p.getCharacterId()));
if (p.getLang() != null)
map.put("lang", stringValue(p.getLang()));
if (p.getLiked() != null)
map.put("liked", boolValue(p.getLiked()));
if (p.getCreatedAt() != null)
map.put("createdAt", longValue(p.getCreatedAt()));
if (p.getSource() != null)
map.put("source", stringValue(p.getSource()));
if (p.getAppVersion() != null)
map.put("appVersion", stringValue(p.getAppVersion()));
return map;
}
private static JsonWithInt.Value stringValue(String v) {
return JsonWithInt.Value.newBuilder().setStringValue(v).build();
}
private static JsonWithInt.Value intValue(Integer v) {
return JsonWithInt.Value.newBuilder().setIntegerValue(v).build();
}
private static JsonWithInt.Value longValue(Long v) {
return JsonWithInt.Value.newBuilder().setIntegerValue(v).build();
}
private static JsonWithInt.Value boolValue(Boolean v) {
return JsonWithInt.Value.newBuilder().setBoolValue(v).build();
}
}

View File

@@ -1,5 +1,6 @@
package com.yolo.keyborad.model.dto.chat;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
/*
@@ -9,8 +10,39 @@ import lombok.Data;
@Data
public class ChatSaveReq {
private String userInputMessage;
private String userSelectMessage;
/** 用户的原始输入文本 */
@Schema(description="用户的原始输入文本")
private String userText;
/** 用户选择 / 点赞的回复文本 */
@Schema(description="用户选择 / 点赞的回复文本")
private String replyText;
/** 当前使用的角色ID比如某个键盘人格 */
@Schema(description="当前使用的角色ID比如某个键盘人格")
private Integer characterId;
/** 文本语言en / zh / ja / es ... */
@Schema(description="文本语言en / zh / ja / es ... ")
private String lang;
/** 是否是用户明确点选或点赞过的内容,用于高质量样本过滤 */
@Schema(description="是否是用户明确点选或点赞过的内容,用于高质量样本过滤")
private Boolean liked;
/** 创建时间(建议存秒级或毫秒级时间戳) */
@Schema(description="创建时间(建议存秒级或毫秒级时间戳)")
private Long createdAt;
/** 数据来源:例如 "llm" / "template" / "user" */
@Schema(description="数据来源:例如 \"llm\" / \"template\" / \"user\"")
private String source;
/** 可选:用于调试 / 分析的客户端版本号 */
@Schema(description="可选:用于调试 / 分析的客户端版本号")
private String appVersion;
@Schema(description = "用户 Id")
private Long userId;
}

View File

@@ -5,22 +5,19 @@ import com.yolo.keyborad.common.ErrorCode;
import com.yolo.keyborad.exception.BusinessException;
import com.yolo.keyborad.model.vo.QdrantSearchItem;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.grpc.JsonWithInt;
import io.qdrant.client.grpc.Points;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.stereotype.Service;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import static io.qdrant.client.PointIdFactory.id;
import static io.qdrant.client.QueryFactory.nearest;
import static io.qdrant.client.ValueFactory.value;
import static io.qdrant.client.VectorsFactory.vectors;
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
@@ -44,7 +41,7 @@ public class QdrantVectorService {
* @param vector 向量(和 collection 中定义的 size 一致)
* @param payload 额外信息例如原文、标题、userId 等
*/
public void upsertPoint(long id, float[] vector,String payload){
public void upsertPoint(long id, float[] vector,Map<String, JsonWithInt.Value> payload){
try {
qdrantClient.upsertAsync(
@@ -53,7 +50,7 @@ public class QdrantVectorService {
Points.PointStruct.newBuilder()
.setId(id(id))
.setVectors(vectors(vector))
.putAllPayload(Map.of("payload",value(payload)))
.putAllPayload(payload)
.build()
)
).get();
@@ -119,7 +116,7 @@ public class QdrantVectorService {
item.setScore(p.getScore());
var fieldsMap = p.getPayloadMap();
var payloadValue = fieldsMap.get("payload");
var payloadValue = fieldsMap.get("replyText");
if (payloadValue != null && payloadValue.hasStringValue()) {
item.setPayload(payloadValue.getStringValue());
}