feat(vector): 新增向量搜索与文本语义检索接口
- DemoController 增加 /testSearch、/tsetSearchText 端点 - QdrantVectorService 补充 searchPoint、searchText、indexText 方法 - 新增 SearchEmbedReq、TextSearchReq、QdrantSearchItem 等 DTO/VO - 调整 LLM 模型为 qwen3-embedding-0.6b 并开放对应接口免鉴权
This commit is contained in:
@@ -53,7 +53,8 @@ public class LLMConfig {
|
|||||||
this.openAiApi(),
|
this.openAiApi(),
|
||||||
MetadataMode.EMBED,
|
MetadataMode.EMBED,
|
||||||
OpenAiEmbeddingOptions.builder()
|
OpenAiEmbeddingOptions.builder()
|
||||||
.model("qwen/qwen3-embedding-8b")
|
.model("qwen/qwen3-embedding-0.6b")
|
||||||
|
.dimensions(2048)
|
||||||
.user("user-6")
|
.user("user-6")
|
||||||
.build(),
|
.build(),
|
||||||
RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||||
|
|||||||
@@ -38,7 +38,9 @@ public class SaTokenConfigure implements WebMvcConfigurer {
|
|||||||
"/demo/talk",
|
"/demo/talk",
|
||||||
"/user/appleLogin",
|
"/user/appleLogin",
|
||||||
"/demo/embed",
|
"/demo/embed",
|
||||||
"/demo/testSaveEmbed"
|
"/demo/testSaveEmbed",
|
||||||
|
"/demo/testSearch",
|
||||||
|
"/demo/tsetSearchText"
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@Bean
|
@Bean
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
package com.yolo.keyborad.controller;
|
package com.yolo.keyborad.controller;
|
||||||
|
|
||||||
import cn.hutool.json.JSON;
|
|
||||||
import cn.hutool.json.JSONUtil;
|
import cn.hutool.json.JSONUtil;
|
||||||
import com.yolo.keyborad.common.BaseResponse;
|
import com.yolo.keyborad.common.BaseResponse;
|
||||||
import com.yolo.keyborad.common.ResultUtils;
|
import com.yolo.keyborad.common.ResultUtils;
|
||||||
|
|
||||||
import com.yolo.keyborad.model.dto.EmbedSaveReq;
|
import com.yolo.keyborad.model.dto.EmbedSaveReq;
|
||||||
import com.yolo.keyborad.model.dto.IosPayVerifyReq;
|
import com.yolo.keyborad.model.dto.IosPayVerifyReq;
|
||||||
|
import com.yolo.keyborad.model.dto.SearchEmbedReq;
|
||||||
|
import com.yolo.keyborad.model.dto.TextSearchReq;
|
||||||
|
import com.yolo.keyborad.model.vo.QdrantSearchItem;
|
||||||
import com.yolo.keyborad.service.impl.QdrantVectorService;
|
import com.yolo.keyborad.service.impl.QdrantVectorService;
|
||||||
import io.qdrant.client.QdrantClient;
|
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.Parameter;
|
import io.swagger.v3.oas.annotations.Parameter;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
@@ -91,4 +92,20 @@ public class DemoController {
|
|||||||
, JSONUtil.toJsonStr(embedSaveReq.getRecordItem()));
|
, JSONUtil.toJsonStr(embedSaveReq.getRecordItem()));
|
||||||
return ResultUtils.success(true);
|
return ResultUtils.success(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@PostMapping("/testSearch")
|
||||||
|
@Operation(summary = "测试搜索向量接口", description = "测试搜索向量接口")
|
||||||
|
@Parameter(name = "userInput",required = true,description = "测试搜索向量接口")
|
||||||
|
public BaseResponse<List<QdrantSearchItem>> testSearch(@RequestBody SearchEmbedReq searchEmbedReq) {
|
||||||
|
return ResultUtils.success(qdrantVectorService.searchPoint(searchEmbedReq.getUserInputEmbed(), 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@PostMapping("/tsetSearchText")
|
||||||
|
@Operation(summary = "测试搜索语义接口", description = "测试搜索语义接口")
|
||||||
|
@Parameter(name = "userInput",required = true,description = "测试搜索语义接口")
|
||||||
|
public BaseResponse<List<QdrantSearchItem>> testSearchText(@RequestBody TextSearchReq textSearchReq) {
|
||||||
|
return ResultUtils.success(qdrantVectorService.searchText(textSearchReq.getUserInput()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.yolo.keyborad.model.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @author: ziin
|
||||||
|
* @date: 2025/11/14 18:12
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class SearchEmbedReq {
|
||||||
|
|
||||||
|
private List<Float> userInputEmbed;
|
||||||
|
}
|
||||||
12
src/main/java/com/yolo/keyborad/model/dto/TextSearchReq.java
Normal file
12
src/main/java/com/yolo/keyborad/model/dto/TextSearchReq.java
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package com.yolo.keyborad.model.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @author: ziin
|
||||||
|
* @date: 2025/11/14 19:50
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class TextSearchReq {
|
||||||
|
private String userInput;
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package com.yolo.keyborad.model.vo;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
// package 自己按项目结构放
|
||||||
|
@Data
|
||||||
|
public class QdrantSearchItem {
|
||||||
|
|
||||||
|
/** 向量 ID(你插入时用的是 long) */
|
||||||
|
private Long id;
|
||||||
|
|
||||||
|
/** 相似度得分 */
|
||||||
|
private Float score;
|
||||||
|
|
||||||
|
/** 你存进去的 payload 文本(或者 JSON 字符串) */
|
||||||
|
private String payload;
|
||||||
|
|
||||||
|
/** 完整向量(如果不想暴露可以去掉这个字段) */
|
||||||
|
private List<Float> vector;
|
||||||
|
|
||||||
|
// getter / setter 省略,也可以用 Lombok @Data
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.yolo.keyborad.model.vo;
|
||||||
|
|
||||||
|
import io.qdrant.client.grpc.JsonWithInt;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @author: ziin
|
||||||
|
* @date: 2025/11/14 18:36
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class VectorSearchResultVO {
|
||||||
|
|
||||||
|
private String id;
|
||||||
|
private double score;
|
||||||
|
private String payload;
|
||||||
|
}
|
||||||
@@ -1,13 +1,16 @@
|
|||||||
package com.yolo.keyborad.service.impl;
|
package com.yolo.keyborad.service.impl;
|
||||||
|
|
||||||
|
import com.google.common.primitives.Floats;
|
||||||
import com.yolo.keyborad.common.ErrorCode;
|
import com.yolo.keyborad.common.ErrorCode;
|
||||||
import com.yolo.keyborad.exception.BusinessException;
|
import com.yolo.keyborad.exception.BusinessException;
|
||||||
|
import com.yolo.keyborad.model.vo.QdrantSearchItem;
|
||||||
import io.qdrant.client.QdrantClient;
|
import io.qdrant.client.QdrantClient;
|
||||||
import io.qdrant.client.grpc.Collections;
|
|
||||||
import io.qdrant.client.grpc.JsonWithInt;
|
|
||||||
import io.qdrant.client.grpc.Points;
|
import io.qdrant.client.grpc.Points;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.ai.embedding.Embedding;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingModel;
|
||||||
|
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -15,8 +18,10 @@ import java.util.Map;
|
|||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
import static io.qdrant.client.PointIdFactory.id;
|
import static io.qdrant.client.PointIdFactory.id;
|
||||||
|
import static io.qdrant.client.QueryFactory.nearest;
|
||||||
import static io.qdrant.client.ValueFactory.value;
|
import static io.qdrant.client.ValueFactory.value;
|
||||||
import static io.qdrant.client.VectorsFactory.vectors;
|
import static io.qdrant.client.VectorsFactory.vectors;
|
||||||
|
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
|
||||||
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@@ -28,6 +33,9 @@ public class QdrantVectorService {
|
|||||||
|
|
||||||
private static final String COLLECTION_NAME = "test_document";
|
private static final String COLLECTION_NAME = "test_document";
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private EmbeddingModel embeddingModel;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 插入/更新一条向量数据
|
* 插入/更新一条向量数据
|
||||||
*
|
*
|
||||||
@@ -37,20 +45,6 @@ public class QdrantVectorService {
|
|||||||
*/
|
*/
|
||||||
public void upsertPoint(long id, List<Float> vector,String payload){
|
public void upsertPoint(long id, List<Float> vector,String payload){
|
||||||
|
|
||||||
// // 1. 确保 collection 存在(没有就创建一次即可)
|
|
||||||
// try {
|
|
||||||
// qdrantClient.createCollectionAsync(
|
|
||||||
// COLLECTION_NAME,
|
|
||||||
// Collections.VectorParams.newBuilder()
|
|
||||||
// .setSize(vector.size()) // 向量维度
|
|
||||||
// .setDistance(Collections.Distance.Cosine) // 相似度度量
|
|
||||||
// .build()
|
|
||||||
// ).get(); // 简单起见直接 get(),生产建议在启动时提前创建好
|
|
||||||
// } catch (InterruptedException | ExecutionException e) {
|
|
||||||
// log.error("创建 collection 失败", e);
|
|
||||||
// throw new BusinessException(ErrorCode.OPERATION_ERROR);
|
|
||||||
// }
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
qdrantClient.upsertAsync(
|
qdrantClient.upsertAsync(
|
||||||
COLLECTION_NAME,
|
COLLECTION_NAME,
|
||||||
@@ -66,6 +60,131 @@ public class QdrantVectorService {
|
|||||||
log.error("upsert point 失败", e);
|
log.error("upsert point 失败", e);
|
||||||
throw new BusinessException(ErrorCode.OPERATION_ERROR);
|
throw new BusinessException(ErrorCode.OPERATION_ERROR);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// public List<VectorSearchResultVO> searchPoint(List<Float> userInput) {
|
||||||
|
// try {
|
||||||
|
// Points.QueryPoints query = Points.QueryPoints.newBuilder()
|
||||||
|
// .setCollectionName(COLLECTION_NAME) // ★ 必须设置
|
||||||
|
// .setQuery(nearest(userInput))
|
||||||
|
// .build();
|
||||||
|
//
|
||||||
|
// List<Points.BatchResult> batchResults = qdrantClient.queryBatchAsync(
|
||||||
|
// COLLECTION_NAME,
|
||||||
|
// List.of(query)
|
||||||
|
// ).get();
|
||||||
|
//
|
||||||
|
// return batchResults.stream()
|
||||||
|
// .map(p -> {
|
||||||
|
// VectorSearchResultVO vo = new VectorSearchResultVO();
|
||||||
|
// vo.setId(String.valueOf(p.getResult(0).getId())); // 或者 p.getId().getUuid()
|
||||||
|
// vo.setScore(p.getResult(0).getScore());
|
||||||
|
// vo.setPayload(p.getResult(0).getPayloadMap());
|
||||||
|
// return vo;
|
||||||
|
// })
|
||||||
|
// .toList();
|
||||||
|
//
|
||||||
|
// } catch (InterruptedException | ExecutionException e) {
|
||||||
|
// log.error("search point 失败", e);
|
||||||
|
// throw new BusinessException(ErrorCode.OPERATION_ERROR);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
public List<QdrantSearchItem> searchPoint(List<Float> userVector, int limit) {
|
||||||
|
try {
|
||||||
|
Points.QueryPoints query = Points.QueryPoints.newBuilder()
|
||||||
|
.setCollectionName(COLLECTION_NAME) // ★ 必须
|
||||||
|
.setQuery(nearest(userVector)) // ★ 语义向量
|
||||||
|
.setLimit(limit) // TopK
|
||||||
|
.setWithPayload(enable(true)) // ★ 带上 payload
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<Points.BatchResult> batchResults = qdrantClient.queryBatchAsync(
|
||||||
|
COLLECTION_NAME,
|
||||||
|
List.of(query)
|
||||||
|
).get();
|
||||||
|
Points.BatchResult batchResult = batchResults.get(0);
|
||||||
|
|
||||||
|
// 3. 把 Protobuf 的 ScoredPoint 转成你的 DTO
|
||||||
|
return batchResult.getResultList().stream()
|
||||||
|
.map(p -> {
|
||||||
|
QdrantSearchItem item = new QdrantSearchItem();
|
||||||
|
|
||||||
|
// id:你插入时用的是 setId(id(id)),所以这里取 num
|
||||||
|
if (p.getId().hasNum()) {
|
||||||
|
item.setId(p.getId().getNum());
|
||||||
|
}
|
||||||
|
|
||||||
|
// score
|
||||||
|
item.setScore(p.getScore());
|
||||||
|
|
||||||
|
// payload:你之前是 putAllPayload(Map.of("payload", value(payload)))
|
||||||
|
// 这里从 Struct 里拿 "payload" 字段
|
||||||
|
var fieldsMap = p.getPayloadMap();
|
||||||
|
var payloadValue = fieldsMap.get("payload");
|
||||||
|
if (payloadValue != null && payloadValue.hasStringValue()) {
|
||||||
|
item.setPayload(payloadValue.getStringValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
// vector:你插入时用的是 vectors(vector),即 unnamed 单向量
|
||||||
|
// proto 结构一般是 Vectors.vector.data[]
|
||||||
|
if (p.getVectors().hasVector()) {
|
||||||
|
List<Float> vec = p.getVectors().getVector().getDataList();
|
||||||
|
item.setVector(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
return item;
|
||||||
|
})
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
|
||||||
|
} catch (InterruptedException | ExecutionException e) {
|
||||||
|
log.error("search point 失败", e);
|
||||||
|
throw new BusinessException(ErrorCode.OPERATION_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 把一段文本做 embedding 然后写入 Qdrant
|
||||||
|
*
|
||||||
|
* @param id 业务 ID(比如业务表主键)
|
||||||
|
* @param text 用来做向量的文本(一般是内容)
|
||||||
|
*/
|
||||||
|
public void indexText(long id, String text) {
|
||||||
|
// 1. 文本 → 向量
|
||||||
|
List<Float> vector = embedTextToVector(text);
|
||||||
|
|
||||||
|
// 2. 存到 Qdrant,payload 里顺便存原文
|
||||||
|
upsertPoint(id, vector, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Float> embedTextToVector(String text) {
|
||||||
|
EmbeddingResponse response = embeddingModel.embedForResponse(List.of(text));
|
||||||
|
|
||||||
|
|
||||||
|
Embedding embedding = response.getResult(); // 就一条
|
||||||
|
// Spring AI 里一般是 List<Double>
|
||||||
|
float[] output = embedding.getOutput();
|
||||||
|
// 转成 Qdrant 需要的 List<Float>
|
||||||
|
|
||||||
|
|
||||||
|
return Floats.asList(output);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<QdrantSearchItem> searchText(String userInput) {
|
||||||
|
long t0 = System.currentTimeMillis();
|
||||||
|
|
||||||
|
List<Float> floats = this.embedTextToVector(userInput);
|
||||||
|
long t1 = System.currentTimeMillis();
|
||||||
|
|
||||||
|
List<QdrantSearchItem> qdrantSearchItems = this.searchPoint(floats, 3);
|
||||||
|
long t2 = System.currentTimeMillis();
|
||||||
|
|
||||||
|
log.info("embedding = {} ms, qdrant = {} ms, total = {} ms",
|
||||||
|
(t1 - t0), (t2 - t1), (t2 - t0));
|
||||||
|
return qdrantSearchItems;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user