Browse Source

Update the example code to use the simplified API

jianghua 4 years ago
parent
commit
c9369ea63f
2 changed files with 90 additions and 171 deletions
  1. 4 9
      examples/pom.xml
  2. 86 162
      examples/src/main/java/MilvusClientExample.java

+ 4 - 9
examples/pom.xml

@@ -25,7 +25,7 @@
 
 
     <groupId>io.milvus</groupId>
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java-examples</artifactId>
     <artifactId>milvus-sdk-java-examples</artifactId>
-    <version>0.9.0</version>
+    <version>0.9.0-SNAPSHOT</version>
     <build>
     <build>
         <plugins>
         <plugins>
             <plugin>
             <plugin>
@@ -66,14 +66,9 @@
             <version>0.9.0-SNAPSHOT</version>
             <version>0.9.0-SNAPSHOT</version>
         </dependency>
         </dependency>
         <dependency>
         <dependency>
-            <groupId>com.google.code.gson</groupId>
-            <artifactId>gson</artifactId>
-            <version>2.8.6</version>
-        </dependency>
-        <dependency>
-            <groupId>org.slf4j</groupId>
-            <artifactId>slf4j-api</artifactId>
-            <version>1.7.30</version>
+            <groupId>org.testcontainers</groupId>
+            <artifactId>testcontainers</artifactId>
+            <version>1.14.3</version>
         </dependency>
         </dependency>
     </dependencies>
     </dependencies>
 
 

+ 86 - 162
examples/src/main/java/MilvusClientExample.java

@@ -17,17 +17,18 @@
  * under the License.
  * under the License.
  */
  */
 
 
+import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListenableFuture;
 import io.milvus.client.*;
 import io.milvus.client.*;
+import org.testcontainers.containers.GenericContainer;
+
 import java.util.ArrayList;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.SplittableRandom;
 import java.util.SplittableRandom;
-import java.util.concurrent.ExecutionException;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
 import java.util.stream.DoubleStream;
 import java.util.stream.DoubleStream;
-import org.json.JSONObject;
+import java.util.stream.LongStream;
 
 
 // This is a simple example demonstrating how to use Milvus Java SDK v0.9.0.
 // This is a simple example demonstrating how to use Milvus Java SDK v0.9.0.
 // For detailed API documentation, please refer to
 // For detailed API documentation, please refer to
@@ -58,147 +59,88 @@ public class MilvusClientExample {
     return vector;
     return vector;
   }
   }
 
 
-  // Helper function that generates default fields list for a collection
-  // In this example, we have 3 fields with names "int64", "float" and "float_vec".
-  // Their DataType must also be defined.
-  static List<Map<String, Object>> generateDefaultFields(int dimension) {
-    List<Map<String, Object>> fieldList = new ArrayList<>();
-    Map<String, Object> intField = new HashMap<>();
-    intField.put("field", "int64");
-    intField.put("type", DataType.INT64);
-
-    Map<String, Object> floatField = new HashMap<>();
-    floatField.put("field", "float");
-    floatField.put("type", DataType.FLOAT);
-
-    Map<String, Object> vecField = new HashMap<>();
-    vecField.put("field", "float_vec");
-    vecField.put("type", DataType.VECTOR_FLOAT);
-    JSONObject jsonObject = new JSONObject();
-    jsonObject.put("dim", dimension);
-    vecField.put("params", jsonObject.toString());
-
-    fieldList.add(intField);
-    fieldList.add(floatField);
-    fieldList.add(vecField);
-    return fieldList;
-  }
-
-  // Helper function that generates entity field values for inserting into a collection
-  // This corresponds to the above function that initializes fields.
-  static List<Map<String, Object>> generateDefaultFieldValues(int vectorCount, List<List<Float>> vectors) {
-    List<Map<String, Object>> fieldList = new ArrayList<>();
-    Map<String, Object> intField = new HashMap<>();
-    intField.put("field", "int64");
-    intField.put("type", DataType.INT64);
-
-    Map<String, Object> floatField = new HashMap<>();
-    floatField.put("field", "float");
-    floatField.put("type", DataType.FLOAT);
-
-    Map<String, Object> vecField = new HashMap<>();
-    vecField.put("field", "float_vec");
-    vecField.put("type", DataType.VECTOR_FLOAT);
-
-    List<Long> intValues = new ArrayList<>(vectorCount);
-    List<Float> floatValues = new ArrayList<>(vectorCount);
-    for (int i = 0; i < vectorCount; i++) {
-      intValues.add((long) i);
-      floatValues.add((float) i);
+  public static void main(String[] args) throws InterruptedException {
+    String dockerImage = System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu");
+    try (GenericContainer milvusContainer = new GenericContainer(dockerImage).withExposedPorts(19530)) {
+      milvusContainer.start();
+      ConnectParam connectParam = new ConnectParam.Builder()
+          .withHost("localhost")
+          .withPort(milvusContainer.getFirstMappedPort())
+          .build();
+      run(connectParam);
     }
     }
-    intField.put("values", intValues);
-    floatField.put("values", floatValues);
-    vecField.put("values", vectors);
-
-    fieldList.add(intField);
-    fieldList.add(floatField);
-    fieldList.add(vecField);
-    return fieldList;
   }
   }
 
 
-  public static void main(String[] args) throws InterruptedException, ConnectFailedException {
-
-    // You may need to change the following to the host and port of your Milvus server
-    String host = "localhost";
-    int port = 19530;
-    if (args.length >= 2) {
-      host = args[0];
-      port = Integer.parseInt(args[1]);
-    }
-
+  public static void run(ConnectParam connectParam) {
     // Create Milvus client
     // Create Milvus client
-    MilvusClient client = new MilvusGrpcClient();
-
-    // Connect to Milvus server
-    ConnectParam connectParam = new ConnectParam.Builder().withHost(host).withPort(port).build();
-    try {
-      Response connectResponse = client.connect(connectParam);
-    } catch (ConnectFailedException e) {
-      System.out.println("Failed to connect to Milvus server: " + e.toString());
-      throw e;
-    }
+    MilvusClient client = new MilvusGrpcClient(connectParam);
 
 
     // Create a collection with the following collection mapping
     // Create a collection with the following collection mapping
     final String collectionName = "example"; // collection name
     final String collectionName = "example"; // collection name
     final int dimension = 128; // dimension of each vector
     final int dimension = 128; // dimension of each vector
     // we choose IP (Inner Product) as our metric type
     // we choose IP (Inner Product) as our metric type
-    CollectionMapping collectionMapping =
-        new CollectionMapping.Builder(collectionName)
-            .withFields(generateDefaultFields(dimension))
-            .withParamsInJson("{\"segment_row_limit\": 50000, \"auto_id\": true}")
-            .build();
-    Response createCollectionResponse = client.createCollection(collectionMapping);
+    CollectionMapping collectionMapping = CollectionMapping
+        .create(collectionName)
+        .addField("int64", DataType.INT64)
+        .addField("float", DataType.FLOAT)
+        .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension)
+        .setParamsInJson("{\"segment_row_limit\": 50000, \"auto_id\": true}");
+
+    client.createCollection(collectionMapping);
+
+    if (!client.hasCollection(collectionName)) {
+      throw new AssertionError("Collection not found");
+    }
 
 
-    // Check whether the collection exists
-    HasCollectionResponse hasCollectionResponse = client.hasCollection(collectionName);
+    System.out.println(collectionMapping.toString());
 
 
     // Get collection info
     // Get collection info
-    GetCollectionInfoResponse getCollectionInfoResponse = client.getCollectionInfo(collectionName);
+    CollectionMapping collectionInfo = client.getCollectionInfo(collectionName);
 
 
     // Insert randomly generated field values to collection
     // Insert randomly generated field values to collection
     final int vectorCount = 100000;
     final int vectorCount = 100000;
-    List<List<Float>> vectors = generateVectors(vectorCount, dimension);
-    vectors =
-        vectors.stream().map(MilvusClientExample::normalizeVector).collect(Collectors.toList());
-    List<Map<String, Object>> defaultFieldValues = generateDefaultFieldValues(vectorCount, vectors);
-    InsertParam insertParam =
-        new InsertParam.Builder(collectionName)
-            .withFields(defaultFieldValues)
-            .build();
-    InsertResponse insertResponse = client.insert(insertParam);
+
+    List<Long> longValues = LongStream.range(0, vectorCount).boxed().collect(Collectors.toList());
+    List<Float> floatValues = LongStream.range(0, vectorCount).boxed().map(Long::floatValue).collect(Collectors.toList());
+    List<List<Float>> vectors = generateVectors(vectorCount, dimension).stream()
+        .map(MilvusClientExample::normalizeVector)
+        .collect(Collectors.toList());
+
+    InsertParam insertParam = InsertParam
+        .create(collectionName)
+        .addField("int64", DataType.INT64, longValues)
+        .addField("float", DataType.FLOAT, floatValues)
+        .addVectorField("float_vec", DataType.VECTOR_FLOAT, vectors);
+
     // Insert returns a list of entity ids that you will be using (if you did not supply them
     // Insert returns a list of entity ids that you will be using (if you did not supply them
     // yourself) to reference the entities you just inserted
     // yourself) to reference the entities you just inserted
-    List<Long> vectorIds = insertResponse.getEntityIds();
+    List<Long> vectorIds = client.insert(insertParam);
 
 
     // Flush data in collection
     // Flush data in collection
-    Response flushResponse = client.flush(collectionName);
+    client.flush(collectionName);
 
 
     // Get current entity count of collection
     // Get current entity count of collection
-    CountEntitiesResponse countEntitiesResponse = client.countEntities(collectionName);
+    long entityCount = client.countEntities(collectionName);
 
 
     // Create index for the collection
     // Create index for the collection
     // We choose IVF_SQ8 as our index type here. Refer to Milvus documentation for a
     // We choose IVF_SQ8 as our index type here. Refer to Milvus documentation for a
     // complete explanation of different index types and their relative parameters.
     // complete explanation of different index types and their relative parameters.
-    Index index =
-        new Index.Builder(collectionName, "float_vec")
-            .withParamsInJson("{\"index_type\": \"IVF_SQ8\", \"metric_type\": \"L2\", "
-                + "\"params\": {\"nlist\": 2048}}")
-            .build();
-    Response createIndexResponse = client.createIndex(index);
+    Index index = Index
+        .create(collectionName, "float_vec")
+        .setIndexType(IndexType.IVF_SQ8)
+        .setMetricType(MetricType.L2)
+        .setParamsInJson(new JsonBuilder().param("nlist", 2048).build());
+
+    client.createIndex(index);
 
 
     // Get collection info
     // Get collection info
-    Response getCollectionStatsResponse = client.getCollectionStats(collectionName);
-    if (getCollectionStatsResponse.ok()) {
-      // Collection info is sent back with JSON type string
-      String jsonString = getCollectionStatsResponse.getMessage();
-      System.out.format("Collection Stats: %s\n", jsonString);
-    }
+    String collectionStats = client.getCollectionStats(collectionName);
+    System.out.format("Collection Stats: %s\n", collectionStats);
 
 
     // Check whether a partition exists in collection
     // Check whether a partition exists in collection
     // Obviously we do not have partition "tag" now
     // Obviously we do not have partition "tag" now
-    HasPartitionResponse testHasPartition = client.hasPartition(collectionName, "tag");
-    if (testHasPartition.ok() && testHasPartition.hasPartition()) {
-      throw new AssertionError("Wrong results!");
+    if (client.hasPartition(collectionName, "tag")) {
+      throw new AssertionError("Unexpected partition found!");
     }
     }
 
 
     // Search entities using DSL statement.
     // Search entities using DSL statement.
@@ -220,71 +162,53 @@ public class MilvusClientExample {
             + "%s, \"params\": {\"nprobe\": 50}"
             + "%s, \"params\": {\"nprobe\": 50}"
             + "    }}}]}}",
             + "    }}}]}}",
         topK, vectorsToSearch.toString());
         topK, vectorsToSearch.toString());
-    SearchParam searchParam =
-        new SearchParam.Builder(collectionName)
-            .withDSL(dsl)
-            .withParamsInJson("{\"fields\": [\"int64\", \"float\"]}")
-            .build();
-    SearchResponse searchResponse = client.search(searchParam);
-    if (searchResponse.ok()) {
-      List<List<SearchResponse.QueryResult>> queryResultsList =
-          searchResponse.getQueryResultsList();
-      final double epsilon = 0.01;
-      for (int i = 0; i < searchBatchSize; i++) {
-        // Since we are searching for vector that is already present in the collection,
-        // the first result vector should be itself and the distance (inner product) should be
-        // very close to 1 (some precision is lost during the process)
-        SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
-        if (firstQueryResult.getEntityId() != vectorIds.get(i)
-            || Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
-          throw new AssertionError("Wrong results!");
-        }
+    SearchParam searchParam = SearchParam
+        .create(collectionName)
+        .setDsl(dsl)
+        .setParamsInJson("{\"fields\": [\"int64\", \"float\"]}");
+    SearchResult searchResult = client.search(searchParam);
+    List<List<SearchResult.QueryResult>> queryResultsList = searchResult.getQueryResultsList();
+    final double epsilon = 0.01;
+    for (int i = 0; i < searchBatchSize; i++) {
+      // Since we are searching for vector that is already present in the collection,
+      // the first result vector should be itself and the distance (inner product) should be
+      // very close to 1 (some precision is lost during the process)
+      SearchResult.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      if (firstQueryResult.getEntityId() != vectorIds.get(i)
+          || Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
+        throw new AssertionError("Wrong results!");
       }
       }
     }
     }
+
     // You can also get result ids and distances separately
     // You can also get result ids and distances separately
-    List<List<Long>> resultIds = searchResponse.getResultIdsList();
-    List<List<Float>> resultDistances = searchResponse.getResultDistancesList();
+    List<List<Long>> resultIds = searchResult.getResultIdsList();
+    List<List<Float>> resultDistances = searchResult.getResultDistancesList();
 
 
     // You can send search request asynchronously, which returns a ListenableFuture object
     // You can send search request asynchronously, which returns a ListenableFuture object
-    ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
-    try {
-      // Get search response immediately. Obviously you will want to do more complicated stuff with
-      // ListenableFuture
-      searchResponseFuture.get();
-    } catch (ExecutionException e) {
-      e.printStackTrace();
-    }
+    ListenableFuture<SearchResult> searchResponseFuture = client.searchAsync(searchParam);
+    // Get search response immediately. Obviously you will want to do more complicated stuff with
+    // ListenableFuture
+    Futures.getUnchecked(searchResponseFuture);
 
 
     // Delete the first 5 entities you just searched
     // Delete the first 5 entities you just searched
-    Response deleteByIdsResponse =
-        client.deleteEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
-    flushResponse = client.flush(collectionName);
+    client.deleteEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
+    client.flush(collectionName);
 
 
     // After deleting them, we call getEntityByID and obviously all 5 entities should not be returned.
     // After deleting them, we call getEntityByID and obviously all 5 entities should not be returned.
-    GetEntityByIDResponse getEntityByIDResponse =
-        client.getEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
-    if (getEntityByIDResponse.getValidIds().size() > 0) {
-      throw new AssertionError("This can never happen!");
+    Map<Long, Map<String, Object>> entities = client.getEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
+    if (!entities.isEmpty()) {
+      throw new AssertionError("Unexpected entity count!");
     }
     }
 
 
     // Compact the collection, erase deleted data from disk and rebuild index in background (if
     // Compact the collection, erase deleted data from disk and rebuild index in background (if
     // the data size after compaction is still larger than indexFileSize). Data was only
     // the data size after compaction is still larger than indexFileSize). Data was only
     // soft-deleted until you call compact.
     // soft-deleted until you call compact.
-    Response compactResponse = client.compact(
-        new CompactParam.Builder(collectionName).withThreshold(0.2).build());
+    client.compact(CompactParam.create(collectionName).setThreshold(0.2));
 
 
     // Drop index for the collection
     // Drop index for the collection
-    Response dropIndexResponse = client.dropIndex(collectionName, "float_vec");
+    client.dropIndex(collectionName, "float_vec");
 
 
     // Drop collection
     // Drop collection
-    Response dropCollectionResponse = client.dropCollection(collectionName);
-
-    // Disconnect from Milvus server
-    try {
-      Response disconnectResponse = client.disconnect();
-    } catch (InterruptedException e) {
-      System.out.println("Failed to disconnect: " + e.toString());
-      throw e;
-    }
+    client.dropCollection(collectionName);
   }
   }
 }
 }