Browse Source

Use schema for index example

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
sahuang 4 years ago
parent
commit
471b2a80e8
1 changed files with 38 additions and 30 deletions
  1. 38 30
      examples/src/main/java/MilvusIndexExample.java

+ 38 - 30
examples/src/main/java/MilvusIndexExample.java

@@ -17,7 +17,6 @@
  * under the License.
  * under the License.
  */
  */
 
 
-import io.milvus.client.CollectionMapping;
 import io.milvus.client.ConnectParam;
 import io.milvus.client.ConnectParam;
 import io.milvus.client.DataType;
 import io.milvus.client.DataType;
 import io.milvus.client.Index;
 import io.milvus.client.Index;
@@ -29,6 +28,9 @@ import io.milvus.client.MilvusClient;
 import io.milvus.client.MilvusGrpcClient;
 import io.milvus.client.MilvusGrpcClient;
 import io.milvus.client.SearchParam;
 import io.milvus.client.SearchParam;
 import io.milvus.client.SearchResult;
 import io.milvus.client.SearchResult;
+import io.milvus.client.dsl.MilvusService;
+import io.milvus.client.dsl.Query;
+import io.milvus.client.dsl.Schema;
 import java.io.BufferedReader;
 import java.io.BufferedReader;
 import java.io.FileReader;
 import java.io.FileReader;
 import java.io.IOException;
 import java.io.IOException;
@@ -54,6 +56,8 @@ import org.json.JSONObject;
  */
  */
 public class MilvusIndexExample {
 public class MilvusIndexExample {
 
 
+  public static final int dimension = 8;
+
   // Helper function that generates random float vectors
   // Helper function that generates random float vectors
   private static List<List<Float>> randomFloatVectors(int vectorCount, int dimension) {
   private static List<List<Float>> randomFloatVectors(int vectorCount, int dimension) {
     SplittableRandom splitCollectionRandom = new SplittableRandom();
     SplittableRandom splitCollectionRandom = new SplittableRandom();
@@ -87,15 +91,15 @@ public class MilvusIndexExample {
       client.dropCollection(collectionName);
       client.dropCollection(collectionName);
     }
     }
 
 
-    // Create collection
-    final int dimension = 8;
-    CollectionMapping collectionMapping =
-        CollectionMapping.create(collectionName)
-            .addField("release_year", DataType.INT64)
-            .addVectorField("embedding", DataType.VECTOR_FLOAT, dimension)
-            .setParamsInJson("{\"segment_row_limit\": 4096, \"auto_id\": false}");
-
-    client.createCollection(collectionMapping);
+    /*
+     * Basic create collection:
+     *   Another way to create a collection is to predefine a schema.
+     *   The schema can later be used in querying.
+     */
+    FilmSchema filmSchema = new FilmSchema();
+    MilvusService service = new MilvusService(client, collectionName, filmSchema);
+    service.createCollection(
+        new JsonBuilder().param("auto_id", false).param("segment_row_limit", 4096).build());
 
 
     /*
     /*
      * Basic insert and create index:
      * Basic insert and create index:
@@ -110,7 +114,7 @@ public class MilvusIndexExample {
     BufferedReader csvReader = new BufferedReader(new FileReader(path));
     BufferedReader csvReader = new BufferedReader(new FileReader(path));
     List<Long> ids = new ArrayList<>();
     List<Long> ids = new ArrayList<>();
     List<String> titles = new ArrayList<>();
     List<String> titles = new ArrayList<>();
-    List<Long> releaseYears = new ArrayList<>();
+    List<Integer> releaseYears = new ArrayList<>();
     List<List<Float>> embeddings = new ArrayList<>();
     List<List<Float>> embeddings = new ArrayList<>();
     String row;
     String row;
     while ((row = csvReader.readLine()) != null) {
     while ((row = csvReader.readLine()) != null) {
@@ -118,7 +122,7 @@ public class MilvusIndexExample {
       // process four columns in order
       // process four columns in order
       ids.add(Long.parseLong(data[0]));
       ids.add(Long.parseLong(data[0]));
       titles.add(data[1]);
       titles.add(data[1]);
-      releaseYears.add(Long.parseLong(data[2]));
+      releaseYears.add(Integer.parseInt(data[2]));
       List<Float> embedding = new ArrayList<>(dimension);
       List<Float> embedding = new ArrayList<>(dimension);
       for (int i = 3; i < 11; i++) {
       for (int i = 3; i < 11; i++) {
         // 8 float values in a vector
         // 8 float values in a vector
@@ -137,7 +141,7 @@ public class MilvusIndexExample {
     // Now we can insert entities, the total row count should be 8657.
     // Now we can insert entities, the total row count should be 8657.
     InsertParam insertParam =
     InsertParam insertParam =
         InsertParam.create(collectionName)
         InsertParam.create(collectionName)
-            .addField("release_year", DataType.INT64, releaseYears)
+            .addField("release_year", DataType.INT32, releaseYears)
             .addVectorField("embedding", DataType.VECTOR_FLOAT, embeddings)
             .addVectorField("embedding", DataType.VECTOR_FLOAT, embeddings)
             .setEntityIds(ids);
             .setEntityIds(ids);
 
 
@@ -176,29 +180,27 @@ public class MilvusIndexExample {
      *
      *
      *   Based on the index you created, the available search parameters will be different. Refer to
      *   Based on the index you created, the available search parameters will be different. Refer to
      *   Milvus documentation for how to set the optimal parameters based on your needs.
      *   Milvus documentation for how to set the optimal parameters based on your needs.
+     *
+     *   Here we present a way to use predefined schema to search vectors.
      */
      */
     List<List<Float>> queryEmbedding = randomFloatVectors(1, dimension);
     List<List<Float>> queryEmbedding = randomFloatVectors(1, dimension);
     final long topK = 3;
     final long topK = 3;
-    String dsl =
-        String.format(
-            "{\"bool\": {"
-                + "\"must\": [{"
-                + "    \"term\": {"
-                + "        \"release_year\": [2002, 1995]"
-                + "    }},{"
-                + "    \"vector\": {"
-                + "        \"embedding\": {"
-                + "            \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": "
-                + "%s, \"params\": {\"nprobe\": 8}"
-                + "    }}}]}}",
-            topK, queryEmbedding.toString());
-
+    Query query =
+        Query.bool(
+            Query.must(
+                filmSchema.releaseYear.in(1995, 2002),
+                filmSchema
+                    .embedding
+                    .query(queryEmbedding)
+                    .metricType(MetricType.L2)
+                    .top((int) topK)
+                    .param("nprobe", 8)));
     SearchParam searchParam =
     SearchParam searchParam =
-        SearchParam.create(collectionName)
-            .setDsl(dsl)
+        service
+            .buildSearchParam(query)
             .setParamsInJson("{\"fields\": [\"release_year\", \"embedding\"]}");
             .setParamsInJson("{\"fields\": [\"release_year\", \"embedding\"]}");
     System.out.println("\n--------Search Result--------");
     System.out.println("\n--------Search Result--------");
-    SearchResult searchResult = client.search(searchParam);
+    SearchResult searchResult = service.search(searchParam);
     System.out.println("- ids: " + searchResult.getResultIdsList().toString());
     System.out.println("- ids: " + searchResult.getResultIdsList().toString());
     System.out.println("- distances: " + searchResult.getResultDistancesList().toString());
     System.out.println("- distances: " + searchResult.getResultDistancesList().toString());
     for (List<Map<String, Object>> singleQueryResult : searchResult.getFieldsMap()) {
     for (List<Map<String, Object>> singleQueryResult : searchResult.getFieldsMap()) {
@@ -226,4 +228,10 @@ public class MilvusIndexExample {
     // Close connection
     // Close connection
     client.close();
     client.close();
   }
   }
+
+  // Schema that defines a collection
+  private static class FilmSchema extends Schema {
+    public final Int32Field releaseYear = new Int32Field("release_year");
+    public final FloatVectorField embedding = new FloatVectorField("embedding", dimension);
+  }
 }
 }