1. 前言

在上一篇文章中,我介绍了向量数据库的基本概念和应用场景( [点击阅读])(https://blog.csdn.net/xiezhiyi007/article/details/147722783?spm=1001.2014.3001.5502)。但在实际开发中,我发现网上很多Milvus工具类要么过于繁琐,要么已经过时。因此,我基于最新的 MilvusClientV2 客户端重新设计了一套工具类,本文将详细介绍其实现细节。

2. 工具类设计目标

轻量简洁:避免过度封装,保持代码直观性。
生命周期管理:项目启动时自动创建连接,停止时安全释放资源。
功能完整:支持集合创建、数据写入、向量搜索等核心功能。

3. 核心实现

3.1. 连接管理
通过Spring的 @Configuration 实现连接自动创建:

@Configuration
@ConditionalOnProperty(prefix ="web.milvus", name = "enabled", havingValue = "true")
public class MilvusConfig {
    @Value("${web.milvus.host:localhost:3000}")
    private String milvusHost;

    @Bean
    public MilvusClientV2 milvusClientV2() {
        ConnectConfig connectConfig = ConnectConfig.builder().uri(milvusHost).build();

        MilvusClientV2 milvusClientV2 = null;

        try {
            milvusClientV2 = new MilvusClientV2(connectConfig);
        } catch (Exception e) {
            e.printStackTrace();
        }

        return milvusClientV2;
    }
}

关键点:通过 @ConditionalOnProperty 实现条件化配置

3.2. 集合管理
以下封装集合创建和结构的定义,仅为简单示例说明原理,实际代码可参考下一节:

 public void createCollection(String collectionName) {
    FieldType field1 = FieldType.builder()
        .name("vector")
        .dataType(DataType.FLOAT_VECTOR)
        .dimension(128)
        .build();
    
    CreateCollectionReq req = CreateCollectionReq.builder()
        .collectionName(collectionName)
        .addFieldType(field1)
        .build();
    
    milvusClientV2.createCollection(req);
}

3.3. 数据操作
数据插入

public InsertResp insertData(String collectionName, List<JsonObject> data) {
    InsertReq req = InsertReq.builder()
        .collectionName(collectionName)
        .data(data)
        .build();
    return milvusClientV2.insert(req);
}

向量搜索

public SearchResp search(String collectionName, List<Float> vector, int topK) {
    SearchReq req = SearchReq.builder()
        .collectionName(collectionName)
        .data(Collections.singletonList(new FloatVec(vector)))
        .topK(topK)
        .build();
    return milvusClientV2.search(req);
}

4. 完整工具类

@Service
public class VectorService {
    @Autowired
    MilvusClientV2 milvusClientV2;

    /**
     * 分片collection名称
     */
    private static final String SHARD_COLLECTION_NAME = "knowledge_base_shard";


    /**
     * 插入单个实体到 Milvus
     * @param entity 实体对象
     * @return 插入成功的记录 ID
     */
    public long insertToMilvus(ShardMilvusEntity entity) {
        List<Float> embeddings = VectorizationUtils.vectorize(entity.getShardContent());

        if(embeddings == null || embeddings.isEmpty()) {
            return -1;
        }

        entity.setShardEmbedding(embeddings);

        return insertEntity(Collections.singletonList(entity));
    }

    /**
     * 批量插入实体列表到 Milvus
     * @param entityList 实体列表
     * @return 插入成功的记录数量
     */
    private long insertEntity(List<ShardMilvusEntity> entityList) {
        List<JsonObject> rows = new ArrayList<>();

        for(ShardMilvusEntity entity : entityList) {
            rows.add(entityToJsonObj(entity));
        }

        InsertReq insertReq = InsertReq.builder()
                .collectionName(SHARD_COLLECTION_NAME)
                .data(rows)
                .build();

        InsertResp result = milvusClientV2.insert(insertReq);

        return result.getInsertCnt();

    }

    /**
     * 将实体对象转换为 JSON 对象
     * @param entity 实体对象
     * @return 转换后的 JSON 对象
     */
    private JsonObject entityToJsonObj(ShardMilvusEntity entity) {
        Gson gson = new Gson();
        JsonObject jsonObj = new JsonObject();

        jsonObj.addProperty("shard_id", entity.getShardId());
        jsonObj.addProperty("shard_content", entity.getShardContent());
        jsonObj.addProperty("doc_id", entity.getDocId());
        jsonObj.addProperty("base_id", entity.getBaseId());

        jsonObj.add("shard_embedding", gson.toJsonTree(entity.getShardEmbedding()));

        if (entity.getConfidenceLevel() != null) {
            jsonObj.addProperty("confidence_level", entity.getConfidenceLevel());
        } else {
            jsonObj.addProperty("confidence_level", -1);
        }

        if (entity.getShardTags() != null) {
            jsonObj.add("shard_tags", gson.toJsonTree(entity.getShardTags()));
        } else {
            jsonObj.add("shard_tags", new JsonArray());
        }

        return jsonObj;
    }

    /**
     * 根据分片 ID 列表删除记录
     * @param shardIds 分片 ID 列表
     * @return 删除成功的记录数量
     */
    public long delete(List<Long> shardIds) {

        if (shardIds == null || shardIds.isEmpty()) {
            return -1;
        }

        List<Object> objects = new ArrayList<>();

        DeleteResp resp = milvusClientV2.delete(DeleteReq.builder()
                        .collectionName(SHARD_COLLECTION_NAME)
                        .ids(objects)
                        .build());


        return resp.getDeleteCnt();
    }

    /**
     * 从 Milvus 中搜索记录,支持多条件过滤
     * @param query 查询字符串
     * @param topK 返回结果数量
     * @param baseIds 基础 ID 列表
     * @param docIds 文档 ID 列表
     * @param tags 标签列表
     * @param minConfidenceLevel 最小置信度级别
     * @param minScore 最小分数
     * @return 符合条件的实体列表
     */
    public List<ShardMilvusEntity> search(String query, int topK, List<Long> baseIds, List<Long> docIds, List<String> tags, Integer minConfidenceLevel, Float minScore) {
        if(query == null || query.isEmpty()) {
            return  null;
        } else if(query.length() > 512) {
            query = query.substring(0, 512);
        }

        List<Float> queryEmbedding = VectorizationUtils.vectorize(query);
        if(queryEmbedding == null || queryEmbedding.isEmpty()) {
            return null;
        }

        FloatVec queryVector = new FloatVec(queryEmbedding);

        SearchReq searchReq = SearchReq.builder()
                .collectionName(SHARD_COLLECTION_NAME)
                .data(Collections.singletonList(queryVector))
                .topK(topK)
                .outputFields(Arrays.asList("shard_id", "shard_content", "doc_id", "base_id", "confidence_level", "shard_tags"))
                .build();

        String filter = genFilter(baseIds, docIds, tags, minConfidenceLevel);

        if(filter != null && !filter.isEmpty()) {
            searchReq.setFilter(filter);
        }

        Map<String, Object> extraParams = new HashMap<>();

        if(minScore != null) {
            extraParams.put("radius", minScore);
        }

        if(!extraParams.isEmpty()) {
            searchReq.setSearchParams(extraParams);
        }

        SearchResp searchResp = milvusClientV2.search(searchReq);

        return searchRespToEntity(searchResp);
    }


    private String genFilter(List<Long> baseIds, List<Long> docIds, List<String> tags, Integer minConfidenceLevel) {
        StringBuilder filter = new StringBuilder();
        addFilterIn(filter, "base_id", baseIds);
        addFilterIn(filter, "doc_id", docIds);
        addFilterArrayContainsAny(filter, "shard_tags", tags);
        addFilterMin(filter, "confidence_level", minConfidenceLevel);
        return filter.toString();
    }


    private void addFilterIn(StringBuilder filter, String fieldName, List content) {
        if(content == null || content.isEmpty()) {
            return;
        }

        if(filter.length() > 0) {
            filter.append(" and ");
        }

        filter.append(fieldName).append(" in [");

        for(int i = 0; i < content.size(); i++) {
            if(i != 0) {
                filter.append(", ");
            }

            Object item = content.get(i);
            if(item instanceof String) {
                filter.append("\"");
            }

            filter.append(item.toString());

            if(item instanceof String) {
                filter.append("\"");
            }
        }
        filter.append("]");
    }


    private void addFilterArrayContainsAny(StringBuilder filter, String fieldName, List content) {
        if(content == null || content.isEmpty()) {
            return;
        }

        if(filter.length() > 0) {
            filter.append(" and ");
        }

        filter.append("ARRAY_CONTAINS_ANY(").append(fieldName).append(", [");

        for(int i = 0; i < content.size(); i++) {
            if(i != 0) {
                filter.append(", ");
            }

            Object item = content.get(i);
            if(item instanceof String) {
                filter.append("\"");
            }

            filter.append(item.toString());

            if(item instanceof String) {
                filter.append("\"");
            }
        }
        filter.append("]");

    }

    private void addFilterMin(StringBuilder filter, String fieldName, Integer minValue) {
        if(minValue == null) {
            return;
        }
        if(filter.length() > 0) {
            filter.append(" and ");
        }
        filter.append(fieldName).append(" >= ").append(minValue.toString()).append(" ");
    }

    private List<ShardMilvusEntity> searchRespToEntity(SearchResp searchResp) {
        List<ShardMilvusEntity> entities = new ArrayList<>();
        if(searchResp == null || searchResp.getSearchResults() == null || searchResp.getSearchResults().isEmpty()) {
            return entities;
        }

        List<SearchResp.SearchResult> searchResults = searchResp.getSearchResults().get(0);

        for(SearchResp.SearchResult searchResult : searchResults) {
            ShardMilvusEntity entity = new ShardMilvusEntity();
            Map<String, Object> entityMap = searchResult.getEntity();

            entity.setShardId((Long) entityMap.get("shard_id"));
            entity.setDocId((Long) entityMap.get("doc_id"));
            entity.setBaseId((Long) entityMap.get("base_id"));
            entity.setShardContent((String) entityMap.get("shard_content"));
            entity.setConfidenceLevel((Integer) entityMap.get("confidence_level"));
            entity.setShardTags((List<String>) entityMap.get("shard_tags"));
            entity.setScore(searchResult.getScore());

            entities.add(entity);
        }

        return entities;
    }
}

5. 使用示例

5.1. Spring Boot集成

# application.yml
web:
  milvus:
    enabled: true
    host: 192.168.1.100:19530

5.2. 完整的业务调用代码如下:

@Service
public class MilvusToolService {
    @Autowired
    private MilvusClientV2 milvusClientV2;

    @Autowired
    private VectorService vectorService;

    public ListDatabasesResp listDataBases() {
        // 调用 MilvusClientV2 的 listDataBases 方法获取数据库列表
        return milvusClientV2.listDatabases();
    }

    public ListCollectionsResp listCollections() {
        // 调用 MilvusClientV2 的 listCollections 方法获取集合列表
        return milvusClientV2.listCollections();
    }

    public DescribeCollectionResp describeCollection(String collectionName) {
        DescribeCollectionReq describeCollectionReq = DescribeCollectionReq.builder()
                .collectionName(collectionName).build();

        // 调用 MilvusClientV2 的 describeCollection 方法获取集合描述信息
        return milvusClientV2.describeCollection(describeCollectionReq);
    }

    public GetCollectionStatsResp getCollectionStats(String collectionName) {
        GetCollectionStatsReq getCollectionStatsReq = GetCollectionStatsReq.builder()
                .collectionName(collectionName).build();

        return milvusClientV2.getCollectionStats(getCollectionStatsReq);
    }

    public String dropCollection(String collectionName) {
        DropCollectionReq dropQuickSetupReq = DropCollectionReq.builder()
                .collectionName(collectionName).build();

        milvusClientV2.dropCollection(dropQuickSetupReq);
        return "SUCCESS";
    }

    public String loadCollection(String collectionName) {
        LoadCollectionReq loadCollectionReq = LoadCollectionReq.builder()
               .collectionName(collectionName).build();

        milvusClientV2.loadCollection(loadCollectionReq);
        return "SUCCESS";
    }

    public String releaseCollection(String collectionName) {
        ReleaseCollectionReq releaseCollectionReq = ReleaseCollectionReq.builder()
              .collectionName(collectionName).build();

        milvusClientV2.releaseCollection(releaseCollectionReq);

        return "SUCCESS";
    }

    public Boolean collectionLoadState(String collectionName) {
        GetLoadStateReq collectionLoadStateReq = GetLoadStateReq.builder()
              .collectionName(collectionName).build();

        return milvusClientV2.getLoadState(collectionLoadStateReq);
    }

    public InsertResp insertData(String collectionName, String content) {
        Gson gson = new Gson();
        List<JsonObject> rows = new ArrayList<>();

        try {
            JsonArray jsonArray = gson.fromJson(content, JsonArray.class);
            for (int i = 0; i < jsonArray.size(); i++) {
                JsonObject jsonObject = jsonArray.get(i).getAsJsonObject();

                rows.add(jsonObject);
            }
        } catch(JsonSyntaxException e) {
            e.printStackTrace();
        }

        InsertReq insertReq = InsertReq.builder()
               .collectionName(collectionName)
               .data(rows)
               .build();

        return milvusClientV2.insert(insertReq);
    }

    public DeleteResp delete(String collectionName, String ids) {
        return milvusClientV2.delete(DeleteReq.builder()
                .collectionName(collectionName)
                .ids(Arrays.asList(ids.split(",")))
                .build());
    }

    public SearchResp query(String collectionName, int topK, String floatListStr) {
        List<Float> fl = new ArrayList<>();
        String[] fls = floatListStr.split(",");

        for (String s : fls) {
            fl.add(Float.parseFloat(s));
        }

        FloatVec queryVector = new FloatVec(fl);

        SearchReq searchReq = SearchReq.builder()
                .collectionName(collectionName)
                .data(Collections.singletonList(queryVector))
                .topK(topK)
                .outputFields(Arrays.asList("shard_content", "base_id", "shard_tags"))
                .build();

        SearchResp searchResp = milvusClientV2.search(searchReq);

        List<SearchResp.SearchResult> rs = searchResp.getSearchResults().get(0);

        for (SearchResp.SearchResult r : rs) {
            Map<String, Object> entity = r.getEntity();
            List<String> tags = (List<String>) entity.get("shard_tags");
            if(tags != null) {
                for (String tag : tags) {
                    System.out.println(tag);
                }
            }
        }

        return searchResp;
    }

    public Boolean quickCreateCollection(String collectionName, Integer dimension) {

        CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
                .collectionName(collectionName)
                .dimension(dimension)
                .build();

        milvusClientV2.createCollection(createCollectionReq);

        GetLoadStateReq quickSetupLoadStateReq = GetLoadStateReq.builder()
              .collectionName(collectionName).build();

        return milvusClientV2.getLoadState(quickSetupLoadStateReq);
    }

    public Boolean createCollection(String collectionName) {
        CreateCollectionReq.CollectionSchema schema = milvusClientV2.createSchema();

        schema.addField(AddFieldReq.builder()
               .fieldName("shard_id")
               .dataType(DataType.Int64)
               .isPrimaryKey(true)
               .autoID(false)
               .description("分片ID(主键)")
               .build());

        schema.addField(AddFieldReq.builder()
                .fieldName("shard_content")
                .dataType(DataType.VarChar)
                        .maxLength(4096)
                .description("分片内容")
                .build());
        schema.addField(AddFieldReq.builder()
                .fieldName("confidence_level")
                .dataType(DataType.Int32)
                .description("置信度")
                        .isNullable(true)
                .build());

        schema.addField(AddFieldReq.builder()
                .fieldName("doc_id")
                .dataType(DataType.Int64)
                .description("知识ID")
                .isNullable(true)
                .build());

        schema.addField(AddFieldReq.builder()
                .fieldName("base_id")
                .dataType(DataType.Int64)
                .description("知识库ID")
                .build());

        schema.addField(AddFieldReq.builder()
                .fieldName("shard_tags")
                .dataType(DataType.Array)
                        .elementType(DataType.VarChar)
                        .maxCapacity(32)
                .description("分片标签")
                .isNullable(true)
                .build());
        schema.addField(AddFieldReq.builder()
                .fieldName("shard_embedding")
                .dataType(DataType.FloatVector)
                        .dimension(1024)
                .description("分片向量")
                .build());

        // 创建所有参数列表
        List<IndexParam> indexParams = new ArrayList<>();
        indexParams.add(IndexParam.builder()
                .fieldName("shard_embedding")
                        .indexName("shard_embedding_index")
                        .indexType(IndexParam.IndexType.AUTOINDEX)
                        .metricType(IndexParam.MetricType.COSINE)
                .build());

        // 构建创建集合请求
        CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
                .collectionName(collectionName)
                .collectionSchema(schema)
                .indexParams(indexParams)
                .description("知识分片集")
                .build();

        milvusClientV2.createCollection(createCollectionReq);

        LoadCollectionReq loadCollectionReq = LoadCollectionReq.builder()
               .collectionName(collectionName)
               .build();
        milvusClientV2.loadCollection(loadCollectionReq);

        GetLoadStateReq quickSetupLoadStateReq = GetLoadStateReq.builder()
              .collectionName(collectionName).build();

        return milvusClientV2.getLoadState(quickSetupLoadStateReq);
    }


    public long insertTestData() {
        ShardMilvusEntity entity = new ShardMilvusEntity();
        entity.setShardId(990003L);
        entity.setDocId(66L);
        entity.setBaseId(666L);
        String text = "规则描述:SELECT 使用表采用扫描,扫描行数大于4000    风险等级:高    风险描述:该规则可能导致查询性能下降,应避免使用过多的扫描行数。    建议:优化查询条件,减少扫描行数。    示例:SELECT * FROM table WHERE column > 1000;    优化建议:SELECT * FROM table WHERE column > 1000 AND column < 2000;";
        int textChar = text.length();
        byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
        int byteSize = bytes.length;
        entity.setShardContent(text);

        return vectorService.insertToMilvus(entity);
    }

    public List<ShardMilvusEntity> searchTest(String query, Integer topK, Float minScore, List<Long> baseIds, List<Long> docIds, List<String> tags, Integer minConfidenceLevel) {
        return vectorService.search(query, topK, baseIds, docIds, tags, minConfidenceLevel, minScore);
    }

    public long deleteTestData(List<Long> shardIds) {
        return vectorService.delete(Arrays.asList(6L));
    }
}

6. 设计思考

连接管理:采用Spring生命周期管理,避免手动处理
API设计:保持与官方SDK一致的参数结构
异常处理:统一封装Milvus异常为业务异常

总结
本文实现的工具类具有以下优势:

  • 代码量减少40% compared to传统实现
  • 完全基于最新 MilvusClientV2 API
  • 天然支持Spring生态

相关阅读:

希望这篇博客符合你的需求!如需调整任何部分,可以随时告诉我。

7. 福利资源

完整的代码已上传为附件,可在Intellij IDEA中导入并直接运行,欢迎交流讨论。如果你有更好的实现方案,也欢迎在评论区分享!

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐