Ver código fonte

Fix struct example (#1646)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 meses atrás
pai
commit
20886c8e66
1 arquivos alterados com 36 adições e 24 exclusões
  1. 36 24
      examples/src/main/java/io/milvus/v2/StructExample.java

+ 36 - 24
examples/src/main/java/io/milvus/v2/StructExample.java

@@ -61,7 +61,7 @@ public class StructExample {
     private static final String CLIP_VECTOR_FIELD = "clip_embedding";
     private static final String DESC_FIELD = "clip_desc";
     private static final String DESC_VECTOR_FIELD = "description_embedding";
-    private static final Integer VECTOR_DIM = 4;
+    private static final Integer VECTOR_DIM = 128;
 
     private static void createCollection() {
         CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
@@ -121,16 +121,16 @@ public class StructExample {
         // struct vector uses special index/metric type
         List<IndexParam> indexParams = new ArrayList<>();
         indexParams.add(IndexParam.builder()
-                .fieldName(CLIP_VECTOR_FIELD)
+                .fieldName(String.format("%s[%s]", STRUCT_FIELD, CLIP_VECTOR_FIELD))
                 .indexName("index_1")
-                .indexType(IndexParam.IndexType.EMB_LIST_HNSW)
-                .metricType(IndexParam.MetricType.MAX_SIM)
+                .indexType(IndexParam.IndexType.HNSW)
+                .metricType(IndexParam.MetricType.MAX_SIM_L2)
                 .build());
         indexParams.add(IndexParam.builder()
-                .fieldName(DESC_VECTOR_FIELD)
+                .fieldName(String.format("%s[%s]", STRUCT_FIELD, DESC_VECTOR_FIELD))
                 .indexName("index_2")
-                .indexType(IndexParam.IndexType.EMB_LIST_HNSW)
-                .metricType(IndexParam.MetricType.MAX_SIM)
+                .indexType(IndexParam.IndexType.HNSW)
+                .metricType(IndexParam.MetricType.MAX_SIM_IP)
                 .build());
         client.createIndex(CreateIndexReq.builder()
                 .collectionName(COLLECTION_NAME)
@@ -161,7 +161,7 @@ public class StructExample {
                 JsonArray structArr = new JsonArray();
                 for (int k = 0; k < 5; k++) {
                     JsonObject struct = new JsonObject();
-                    struct.addProperty(FRAME_FIELD, ran.nextInt(1000000));
+                    struct.addProperty(FRAME_FIELD, ran.nextInt(10000));
                     struct.add(CLIP_VECTOR_FIELD, JsonUtils.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
                     struct.addProperty(DESC_FIELD, "clip_description_" + id);
                     struct.add(DESC_VECTOR_FIELD, JsonUtils.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
@@ -187,7 +187,7 @@ public class StructExample {
         System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
 
     }
-    private static void query(String filter) {
+    private static List<QueryResp.QueryResult> query(String filter) {
         System.out.println("===================================================");
         System.out.println("Query with filter expression: " + filter);
         QueryResp queryResp = client.query(QueryReq.builder()
@@ -200,29 +200,27 @@ public class StructExample {
         for (QueryResp.QueryResult result : queryResults) {
             System.out.println(result.getEntity());
         }
+        return queryResults;
     }
 
-    private static void search(String annsField, int nq, int targetVectorsPerNQ) {
+    private static void search(String annsField, List<BaseVector> searchData) {
         System.out.println("===================================================");
-        String msg = String.format("Search on field '%s' with nq=%d and vectors_per_nq=%d", annsField, nq, targetVectorsPerNQ);
+        String msg = String.format("Search on field '%s' in struct '%s' with nq=%d",
+                annsField, STRUCT_FIELD, searchData.size());
         System.out.println(msg);
-        List<BaseVector> searchData = new ArrayList<>();
-        for (int i = 0; i < nq; i++) {
-            EmbeddingList embList = new EmbeddingList();
-            for (int k = 0; k < targetVectorsPerNQ; k++) {
-                embList.add(new FloatVec(CommonUtils.generateFloatVector(VECTOR_DIM)));
-            }
-            searchData.add(embList);
-        }
 
+
+        String annFullName = String.format("%s[%s]", STRUCT_FIELD, annsField);
         int topK = 5;
         SearchResp searchResp = client.search(SearchReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .annsField(annsField)
+                .annsField(annFullName)
                 .data(searchData)
                 .limit(topK)
                 .consistencyLevel(ConsistencyLevel.BOUNDED)
-                .outputFields(Arrays.asList(NAME_FIELD, FRAME_FIELD, DESC_FIELD))
+                .outputFields(Arrays.asList(NAME_FIELD,
+                        String.format("%s[%s]", STRUCT_FIELD, FRAME_FIELD),
+                        String.format("%s[%s]", STRUCT_FIELD, DESC_FIELD)))
                 .build());
         List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
         for (int i = 0; i < searchResults.size(); i++) {
@@ -237,8 +235,22 @@ public class StructExample {
     public static void main(String[] args) {
         createCollection();
         insertData(2000);
-        query(ID_FIELD + " <= 5");
-        search(CLIP_VECTOR_FIELD, 2, 3);
-        search(DESC_VECTOR_FIELD, 1, 5);
+
+        // fetch 2 rows
+        List<QueryResp.QueryResult> results = query(ID_FIELD + " in [5, 8]");
+
+        // use the fetched data to search struct
+        for (QueryResp.QueryResult result : results) {
+            // in the insertData() method, we inserted 5 structures for each row
+            // in query results, each struct is represented as a Map
+            Map<String, Object> fetchedEntity = result.getEntity();
+            List<Map<String, Object>> structs = (List<Map<String, Object>>)fetchedEntity.get(STRUCT_FIELD);
+            EmbeddingList embList = new EmbeddingList();
+            for (Map<String, Object> struct : structs) {
+                List<Float> vector = (List<Float>)struct.get(CLIP_VECTOR_FIELD);
+                embList.add(new FloatVec(vector));
+            }
+            search(CLIP_VECTOR_FIELD, Collections.singletonList(embList));
+        }
     }
 }