feat(chat): 重构 LLM 流式输出并扩展 ChatSaveReq 字段
- 将原始整段 chunk 拆分为 3 字批次推送,降低前端卡顿 - ChatSaveReq 新增 userId、lang、liked 等 8 个字段并补充 Swagger 注解 - QdrantVectorService 改用 Map<String,JsonWithInt.Value> 载荷,新增 QdrantPayloadMapper 统一转换
This commit is contained in:
@@ -4,28 +4,32 @@ import cn.dev33.satoken.stp.StpUtil;
|
|||||||
import cn.hutool.core.util.IdUtil;
|
import cn.hutool.core.util.IdUtil;
|
||||||
import com.yolo.keyborad.common.BaseResponse;
|
import com.yolo.keyborad.common.BaseResponse;
|
||||||
import com.yolo.keyborad.common.ResultUtils;
|
import com.yolo.keyborad.common.ResultUtils;
|
||||||
|
import com.yolo.keyborad.mapper.QdrantPayloadMapper;
|
||||||
import com.yolo.keyborad.model.dto.chat.ChatReq;
|
import com.yolo.keyborad.model.dto.chat.ChatReq;
|
||||||
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
|
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
|
||||||
import com.yolo.keyborad.model.dto.chat.ChatStreamMessage;
|
import com.yolo.keyborad.model.dto.chat.ChatStreamMessage;
|
||||||
import com.yolo.keyborad.model.entity.KeyboardCharacter;
|
import com.yolo.keyborad.model.entity.KeyboardCharacter;
|
||||||
import com.yolo.keyborad.service.KeyboardCharacterService;
|
import com.yolo.keyborad.service.KeyboardCharacterService;
|
||||||
import com.yolo.keyborad.service.impl.QdrantVectorService;
|
import com.yolo.keyborad.service.impl.QdrantVectorService;
|
||||||
|
import io.qdrant.client.grpc.JsonWithInt;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.Parameter;
|
import io.swagger.v3.oas.annotations.Parameter;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.chat.client.ChatClient;
|
import org.springframework.ai.chat.client.ChatClient;
|
||||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
|
||||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
||||||
import org.springframework.boot.context.properties.bind.DefaultValue;
|
|
||||||
import org.springframework.http.codec.ServerSentEvent;
|
import org.springframework.http.codec.ServerSentEvent;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
import reactor.core.scheduler.Schedulers;
|
import reactor.core.scheduler.Schedulers;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @author: ziin
|
* @author: ziin
|
||||||
* @date: 2025/12/8 15:05
|
* @date: 2025/12/8 15:05
|
||||||
@@ -52,7 +56,7 @@ public class ChatController {
|
|||||||
|
|
||||||
@PostMapping("/talk")
|
@PostMapping("/talk")
|
||||||
@Operation(summary = "聊天润色接口", description = "聊天润色接口")
|
@Operation(summary = "聊天润色接口", description = "聊天润色接口")
|
||||||
public Flux<ServerSentEvent<ChatStreamMessage>> testTalk(@RequestBody ChatReq chatReq){
|
public Flux<ServerSentEvent<ChatStreamMessage>> talk(@RequestBody ChatReq chatReq){
|
||||||
KeyboardCharacter character = keyboardCharacterService.getById(chatReq.getCharacterId());
|
KeyboardCharacter character = keyboardCharacterService.getById(chatReq.getCharacterId());
|
||||||
// 1. LLM 流式输出
|
// 1. LLM 流式输出
|
||||||
Flux<ChatStreamMessage> llmFlux = client
|
Flux<ChatStreamMessage> llmFlux = client
|
||||||
@@ -69,7 +73,30 @@ public class ChatController {
|
|||||||
.build())
|
.build())
|
||||||
.stream()
|
.stream()
|
||||||
.content()
|
.content()
|
||||||
.map(chunk -> new ChatStreamMessage("llm_chunk", chunk));
|
.concatMap(chunk -> {
|
||||||
|
// 拆成单字符
|
||||||
|
List<String> chars = chunk.codePoints()
|
||||||
|
.mapToObj(cp -> new String(Character.toChars(cp)))
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
// 你可以在这里按 3~5 个字符再拼一拼
|
||||||
|
List<String> batched = new ArrayList<>();
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (String ch : chars) {
|
||||||
|
sb.append(ch);
|
||||||
|
if (sb.length() >= 3) { // 这里的 3 可以自己调
|
||||||
|
batched.add(sb.toString());
|
||||||
|
sb.setLength(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!sb.isEmpty()) {
|
||||||
|
batched.add(sb.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Flux.fromIterable(batched)
|
||||||
|
.map(s -> new ChatStreamMessage("llm_chunk", s));
|
||||||
|
});
|
||||||
|
// .map(chunk -> new ChatStreamMessage("llm_chunk", chunk));
|
||||||
|
|
||||||
// 2. 向量搜索Flux(一次性发送搜索结果)
|
// 2. 向量搜索Flux(一次性发送搜索结果)
|
||||||
Flux<ChatStreamMessage> searchFlux = Mono
|
Flux<ChatStreamMessage> searchFlux = Mono
|
||||||
@@ -99,8 +126,10 @@ public class ChatController {
|
|||||||
@Operation(summary = "保存润色后的句子", description = "保存润色后的句子")
|
@Operation(summary = "保存润色后的句子", description = "保存润色后的句子")
|
||||||
@Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something")
|
@Parameter(name = "userInput",required = true,description = "测试聊天接口",example = "talk to something")
|
||||||
public BaseResponse<Boolean> testTalkWithVector(@RequestBody ChatSaveReq chatSaveReq) {
|
public BaseResponse<Boolean> testTalkWithVector(@RequestBody ChatSaveReq chatSaveReq) {
|
||||||
float[] embed = embeddingModel.embed(chatSaveReq.getUserInputMessage());
|
float[] embed = embeddingModel.embed(chatSaveReq.getUserText());
|
||||||
qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, chatSaveReq.getUserSelectMessage());
|
chatSaveReq.setUserId(StpUtil.getLoginIdAsLong());
|
||||||
|
Map<String, JsonWithInt.Value> map = QdrantPayloadMapper.toQdrantPayload(chatSaveReq);
|
||||||
|
qdrantVectorService.upsertPoint(IdUtil.getSnowflakeNextId(), embed, map);
|
||||||
return ResultUtils.success(true);
|
return ResultUtils.success(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package com.yolo.keyborad.mapper;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import com.yolo.keyborad.model.dto.chat.ChatSaveReq;
|
||||||
|
import io.qdrant.client.grpc.JsonWithInt;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class QdrantPayloadMapper {
|
||||||
|
|
||||||
|
public static Map<String, JsonWithInt.Value> toQdrantPayload(ChatSaveReq p) {
|
||||||
|
Map<String, JsonWithInt.Value> map = new HashMap<>();
|
||||||
|
|
||||||
|
if (p.getUserId() != null)
|
||||||
|
map.put("userId", longValue(p.getUserId()));
|
||||||
|
|
||||||
|
if (p.getUserText() != null)
|
||||||
|
map.put("userText", stringValue(p.getUserText()));
|
||||||
|
|
||||||
|
if (p.getReplyText() != null)
|
||||||
|
map.put("replyText", stringValue(p.getReplyText()));
|
||||||
|
|
||||||
|
if (p.getCharacterId() != null)
|
||||||
|
map.put("characterId", intValue(p.getCharacterId()));
|
||||||
|
|
||||||
|
if (p.getLang() != null)
|
||||||
|
map.put("lang", stringValue(p.getLang()));
|
||||||
|
|
||||||
|
if (p.getLiked() != null)
|
||||||
|
map.put("liked", boolValue(p.getLiked()));
|
||||||
|
|
||||||
|
if (p.getCreatedAt() != null)
|
||||||
|
map.put("createdAt", longValue(p.getCreatedAt()));
|
||||||
|
|
||||||
|
if (p.getSource() != null)
|
||||||
|
map.put("source", stringValue(p.getSource()));
|
||||||
|
|
||||||
|
if (p.getAppVersion() != null)
|
||||||
|
map.put("appVersion", stringValue(p.getAppVersion()));
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonWithInt.Value stringValue(String v) {
|
||||||
|
return JsonWithInt.Value.newBuilder().setStringValue(v).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonWithInt.Value intValue(Integer v) {
|
||||||
|
return JsonWithInt.Value.newBuilder().setIntegerValue(v).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonWithInt.Value longValue(Long v) {
|
||||||
|
return JsonWithInt.Value.newBuilder().setIntegerValue(v).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonWithInt.Value boolValue(Boolean v) {
|
||||||
|
return JsonWithInt.Value.newBuilder().setBoolValue(v).build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.yolo.keyborad.model.dto.chat;
|
package com.yolo.keyborad.model.dto.chat;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -9,8 +10,39 @@ import lombok.Data;
|
|||||||
@Data
|
@Data
|
||||||
public class ChatSaveReq {
|
public class ChatSaveReq {
|
||||||
|
|
||||||
private String userInputMessage;
|
|
||||||
|
|
||||||
private String userSelectMessage;
|
/** 用户的原始输入文本 */
|
||||||
|
@Schema(description="用户的原始输入文本")
|
||||||
|
private String userText;
|
||||||
|
|
||||||
|
/** 用户选择 / 点赞的回复文本 */
|
||||||
|
@Schema(description="用户选择 / 点赞的回复文本")
|
||||||
|
private String replyText;
|
||||||
|
|
||||||
|
/** 当前使用的角色ID(比如某个键盘人格) */
|
||||||
|
@Schema(description="当前使用的角色ID(比如某个键盘人格)")
|
||||||
|
private Integer characterId;
|
||||||
|
|
||||||
|
/** 文本语言:en / zh / ja / es ... */
|
||||||
|
@Schema(description="文本语言:en / zh / ja / es ... ")
|
||||||
|
private String lang;
|
||||||
|
|
||||||
|
/** 是否是用户明确点选或点赞过的内容,用于高质量样本过滤 */
|
||||||
|
@Schema(description="是否是用户明确点选或点赞过的内容,用于高质量样本过滤")
|
||||||
|
private Boolean liked;
|
||||||
|
|
||||||
|
/** 创建时间(建议存秒级或毫秒级时间戳) */
|
||||||
|
@Schema(description="创建时间(建议存秒级或毫秒级时间戳)")
|
||||||
|
private Long createdAt;
|
||||||
|
|
||||||
|
/** 数据来源:例如 "llm" / "template" / "user" */
|
||||||
|
@Schema(description="数据来源:例如 \"llm\" / \"template\" / \"user\"")
|
||||||
|
private String source;
|
||||||
|
|
||||||
|
/** 可选:用于调试 / 分析的客户端版本号 */
|
||||||
|
@Schema(description="可选:用于调试 / 分析的客户端版本号")
|
||||||
|
private String appVersion;
|
||||||
|
|
||||||
|
@Schema(description = "用户 Id")
|
||||||
|
private Long userId;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,22 +5,19 @@ import com.yolo.keyborad.common.ErrorCode;
|
|||||||
import com.yolo.keyborad.exception.BusinessException;
|
import com.yolo.keyborad.exception.BusinessException;
|
||||||
import com.yolo.keyborad.model.vo.QdrantSearchItem;
|
import com.yolo.keyborad.model.vo.QdrantSearchItem;
|
||||||
import io.qdrant.client.QdrantClient;
|
import io.qdrant.client.QdrantClient;
|
||||||
|
import io.qdrant.client.grpc.JsonWithInt;
|
||||||
import io.qdrant.client.grpc.Points;
|
import io.qdrant.client.grpc.Points;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.embedding.Embedding;
|
|
||||||
import org.springframework.ai.embedding.EmbeddingModel;
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
import static io.qdrant.client.PointIdFactory.id;
|
import static io.qdrant.client.PointIdFactory.id;
|
||||||
import static io.qdrant.client.QueryFactory.nearest;
|
import static io.qdrant.client.QueryFactory.nearest;
|
||||||
import static io.qdrant.client.ValueFactory.value;
|
|
||||||
import static io.qdrant.client.VectorsFactory.vectors;
|
import static io.qdrant.client.VectorsFactory.vectors;
|
||||||
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
|
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
|
||||||
|
|
||||||
@@ -44,7 +41,7 @@ public class QdrantVectorService {
|
|||||||
* @param vector 向量(和 collection 中定义的 size 一致)
|
* @param vector 向量(和 collection 中定义的 size 一致)
|
||||||
* @param payload 额外信息,例如原文、标题、userId 等
|
* @param payload 额外信息,例如原文、标题、userId 等
|
||||||
*/
|
*/
|
||||||
public void upsertPoint(long id, float[] vector,String payload){
|
public void upsertPoint(long id, float[] vector,Map<String, JsonWithInt.Value> payload){
|
||||||
|
|
||||||
try {
|
try {
|
||||||
qdrantClient.upsertAsync(
|
qdrantClient.upsertAsync(
|
||||||
@@ -53,7 +50,7 @@ public class QdrantVectorService {
|
|||||||
Points.PointStruct.newBuilder()
|
Points.PointStruct.newBuilder()
|
||||||
.setId(id(id))
|
.setId(id(id))
|
||||||
.setVectors(vectors(vector))
|
.setVectors(vectors(vector))
|
||||||
.putAllPayload(Map.of("payload",value(payload)))
|
.putAllPayload(payload)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
).get();
|
).get();
|
||||||
@@ -119,7 +116,7 @@ public class QdrantVectorService {
|
|||||||
item.setScore(p.getScore());
|
item.setScore(p.getScore());
|
||||||
|
|
||||||
var fieldsMap = p.getPayloadMap();
|
var fieldsMap = p.getPayloadMap();
|
||||||
var payloadValue = fieldsMap.get("payload");
|
var payloadValue = fieldsMap.get("replyText");
|
||||||
if (payloadValue != null && payloadValue.hasStringValue()) {
|
if (payloadValue != null && payloadValue.hasStringValue()) {
|
||||||
item.setPayload(payloadValue.getStringValue());
|
item.setPayload(payloadValue.getStringValue());
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user