Browse Source

Example for two vector fields (#228)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 years ago
parent
commit
d0880b0975
1 changed files with 90 additions and 30 deletions
  1. 90 30
      examples/main/io/milvus/GeneralExample.java

+ 90 - 30
examples/main/io/milvus/GeneralExample.java

@@ -28,6 +28,7 @@ import io.milvus.param.index.*;
 import io.milvus.param.partition.*;
 import io.milvus.param.partition.*;
 import io.milvus.Response.*;
 import io.milvus.Response.*;
 
 
+import java.nio.ByteBuffer;
 import java.util.*;
 import java.util.*;
 
 
 public class GeneralExample {
 public class GeneralExample {
@@ -46,10 +47,11 @@ public class GeneralExample {
     private static final String VECTOR_FIELD = "userFace";
     private static final String VECTOR_FIELD = "userFace";
     private static final Integer VECTOR_DIM = 64;
     private static final Integer VECTOR_DIM = 64;
     private static final String AGE_FIELD = "userAge";
     private static final String AGE_FIELD = "userAge";
+    private static final String PROFILE_FIELD = "userProfile";
+    private static final Integer BINARY_DIM = 128;
 
 
     private static final IndexType INDEX_TYPE = IndexType.IVF_FLAT;
     private static final IndexType INDEX_TYPE = IndexType.IVF_FLAT;
     private static final String INDEX_PARAM = "{\"nlist\":128}";
     private static final String INDEX_PARAM = "{\"nlist\":128}";
-    private static final MetricType METRIC_TYPE = MetricType.IP;
 
 
     private static final Integer SEARCH_K = 5;
     private static final Integer SEARCH_K = 5;
     private static final String SEARCH_PARAM = "{\"nprobe\":10}";
     private static final String SEARCH_PARAM = "{\"nprobe\":10}";
@@ -77,6 +79,13 @@ public class GeneralExample {
                 .withDataType(DataType.Int8)
                 .withDataType(DataType.Int8)
                 .build();
                 .build();
 
 
+        FieldType fieldType4 = FieldType.newBuilder()
+                .withName(PROFILE_FIELD)
+                .withDescription("user profile")
+                .withDataType(DataType.BinaryVector)
+                .withDimension(BINARY_DIM)
+                .build();
+
         CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
         CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withCollectionName(COLLECTION_NAME)
                 .withDescription("customer info")
                 .withDescription("customer info")
@@ -84,6 +93,7 @@ public class GeneralExample {
                 .addFieldType(fieldType1)
                 .addFieldType(fieldType1)
                 .addFieldType(fieldType2)
                 .addFieldType(fieldType2)
                 .addFieldType(fieldType3)
                 .addFieldType(fieldType3)
+                .addFieldType(fieldType4)
                 .build();
                 .build();
         R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
         R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
 
 
@@ -215,7 +225,7 @@ public class GeneralExample {
                 .withCollectionName(COLLECTION_NAME)
                 .withCollectionName(COLLECTION_NAME)
                 .withFieldName(VECTOR_FIELD)
                 .withFieldName(VECTOR_FIELD)
                 .withIndexType(INDEX_TYPE)
                 .withIndexType(INDEX_TYPE)
-                .withMetricType(METRIC_TYPE)
+                .withMetricType(MetricType.L2)
                 .withExtraParam(INDEX_PARAM)
                 .withExtraParam(INDEX_PARAM)
                 .withSyncMode(Boolean.TRUE)
                 .withSyncMode(Boolean.TRUE)
                 .build());
                 .build());
@@ -275,20 +285,11 @@ public class GeneralExample {
         return response;
         return response;
     }
     }
 
 
-    private R<SearchResults> search(String expr) {
-        System.out.println("========== search() ==========");
-        List<String> outFields = Collections.singletonList(ID_FIELD);
+    private R<SearchResults> searchFace(String expr) {
+        System.out.println("========== searchFace() ==========");
 
 
-        Random ran=new Random();
-        int nq = 5;
-        List<List<Float>> vectors = new ArrayList<>();
-        for (int i = 0; i < nq; ++i) {
-            List<Float> vector = new ArrayList<>();
-            for (int d = 0; d < VECTOR_DIM; ++d) {
-                vector.add(ran.nextFloat());
-            }
-            vectors.add(vector);
-        }
+        List<String> outFields = Collections.singletonList(AGE_FIELD);
+        List<List<Float>> vectors = generateFloatVectors(5);
 
 
         SearchParam searchParam = SearchParam.newBuilder()
         SearchParam searchParam = SearchParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withCollectionName(COLLECTION_NAME)
@@ -310,6 +311,39 @@ public class GeneralExample {
             List<SearchResultsWrapper.IDScore> scores = wrapper.GetIDScore(i);
             List<SearchResultsWrapper.IDScore> scores = wrapper.GetIDScore(i);
             System.out.println(scores);
             System.out.println(scores);
         }
         }
+        System.out.println(wrapper.GetFieldData(AGE_FIELD).getFieldData());
+
+        return response;
+    }
+
+    private R<SearchResults> searchProfile(String expr) {
+        System.out.println("========== searchProfile() ==========");
+
+        List<String> outFields = Collections.singletonList(AGE_FIELD);
+        List<ByteBuffer> vectors = generateBinaryVectors(5);
+
+        SearchParam searchParam = SearchParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withMetricType(MetricType.HAMMING)
+                .withOutFields(outFields)
+                .withTopK(SEARCH_K)
+                .withVectors(vectors)
+                .withVectorFieldName(PROFILE_FIELD)
+                .withExpr(expr)
+                .withParams(SEARCH_PARAM)
+                .build();
+
+
+        R<SearchResults> response = milvusClient.search(searchParam);
+
+        SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
+        for (int i = 0; i < vectors.size(); ++i) {
+            System.out.println("Search result of No." + i);
+            List<SearchResultsWrapper.IDScore> scores = wrapper.GetIDScore(i);
+            System.out.println(scores);
+        }
+
+        System.out.println(wrapper.GetFieldData(AGE_FIELD).getFieldData());
 
 
         return response;
         return response;
     }
     }
@@ -350,23 +384,20 @@ public class GeneralExample {
         return response;
         return response;
     }
     }
 
 
-    private R<MutationResult> insert(String partitionName, Long count) {
+    private R<MutationResult> insert(String partitionName, int count) {
         System.out.println("========== insert() ==========");
         System.out.println("========== insert() ==========");
-        List<List<Float>> vectors = new ArrayList<>();
-        List<Integer> ages = new ArrayList<>();
+        List<List<Float>> vectors = generateFloatVectors(count);
+        List<ByteBuffer> profiles = generateBinaryVectors(count);
 
 
-        Random ran=new Random();
+        Random ran = new Random();
+        List<Integer> ages = new ArrayList<>();
         for (long i = 0L; i < count; ++i) {
         for (long i = 0L; i < count; ++i) {
-            List<Float> vector = new ArrayList<>();
-            for (int d = 0; d < VECTOR_DIM; ++d) {
-                vector.add(ran.nextFloat());
-            }
-            vectors.add(vector);
             ages.add(ran.nextInt(99));
             ages.add(ran.nextInt(99));
         }
         }
 
 
         List<InsertParam.Field> fields = new ArrayList<>();
         List<InsertParam.Field> fields = new ArrayList<>();
         fields.add(new InsertParam.Field(VECTOR_FIELD, DataType.FloatVector, vectors));
         fields.add(new InsertParam.Field(VECTOR_FIELD, DataType.FloatVector, vectors));
+        fields.add(new InsertParam.Field(PROFILE_FIELD, DataType.BinaryVector, profiles));
         fields.add(new InsertParam.Field(AGE_FIELD, DataType.Int8, ages));
         fields.add(new InsertParam.Field(AGE_FIELD, DataType.Int8, ages));
 
 
         InsertParam insertParam = InsertParam.newBuilder()
         InsertParam insertParam = InsertParam.newBuilder()
@@ -375,10 +406,36 @@ public class GeneralExample {
                 .withFields(fields)
                 .withFields(fields)
                 .build();
                 .build();
 
 
-        R<MutationResult> response = milvusClient.insert(insertParam);
-//        System.out.println(response);
+        return milvusClient.insert(insertParam);
+    }
+
+    private List<List<Float>> generateFloatVectors(int count) {
+        Random ran = new Random();
+        List<List<Float>> vectors = new ArrayList<>();
+        for (int n = 0; n < count; ++n) {
+            List<Float> vector = new ArrayList<>();
+            for (int i = 0; i < VECTOR_DIM; ++i) {
+                vector.add(ran.nextFloat());
+            }
+            vectors.add(vector);
+        }
+
+        return vectors;
+    }
+
+    private List<ByteBuffer> generateBinaryVectors(int count) {
+        Random ran = new Random();
+        List<ByteBuffer> vectors = new ArrayList<>();
+        int byteCount = BINARY_DIM/8;
+        for (int n = 0; n < count; ++n) {
+            ByteBuffer vector = ByteBuffer.allocate(byteCount);
+            for (int i = 0; i < byteCount; ++i) {
+                vector.put((byte)ran.nextInt(Byte.MAX_VALUE));
+            }
+            vectors.add(vector);
+        }
+        return vectors;
 
 
-        return response;
     }
     }
 
 
     public static void main(String[] args) {
     public static void main(String[] args) {
@@ -396,14 +453,14 @@ public class GeneralExample {
         example.hasPartition(partitionName);
         example.hasPartition(partitionName);
         example.showPartitions();
         example.showPartitions();
 
 
-        final Long row_count = 10000L;
+        final int row_count = 10000;
         List<Long> deleteIds = new ArrayList<>();
         List<Long> deleteIds = new ArrayList<>();
         Random ran = new Random();
         Random ran = new Random();
         for (int i = 0; i < 100; ++i) {
         for (int i = 0; i < 100; ++i) {
             R<MutationResult> result = example.insert(partitionName, row_count);
             R<MutationResult> result = example.insert(partitionName, row_count);
             InsertResultWrapper wrapper = new InsertResultWrapper(result.getData());
             InsertResultWrapper wrapper = new InsertResultWrapper(result.getData());
             List<Long> ids = wrapper.getLongIDs();
             List<Long> ids = wrapper.getLongIDs();
-            deleteIds.add(ids.get(ran.nextInt(row_count.intValue())));
+            deleteIds.add(ids.get(ran.nextInt(row_count)));
         }
         }
         example.getCollectionStatistics();
         example.getCollectionStatistics();
 
 
@@ -416,7 +473,10 @@ public class GeneralExample {
         example.delete(partitionName, deleteExpr);
         example.delete(partitionName, deleteExpr);
         String queryExpr = AGE_FIELD + " == 60";
         String queryExpr = AGE_FIELD + " == 60";
         example.query(queryExpr);
         example.query(queryExpr);
-        example.search("");
+        String searchExpr = AGE_FIELD + " > 50";
+        example.searchFace(searchExpr);
+        searchExpr = AGE_FIELD + " <= 30";
+        example.searchProfile(searchExpr);
         example.calDistance();
         example.calDistance();
 
 
         example.releasePartition(partitionName);
         example.releasePartition(partitionName);