Browse Source

release v0.1.2: some api changes

Zhiru Zhu 5 years ago
parent
commit
b584d5416e

+ 1 - 1
examples/pom.xml

@@ -46,7 +46,7 @@
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>0.1.1</version>
+            <version>0.1.2</version>
         </dependency>
     </dependencies>
 

+ 47 - 28
examples/src/main/java/MilvusClientExample.java

@@ -27,7 +27,7 @@ import java.util.stream.DoubleStream;
 public class MilvusClientExample {
 
   // Helper function that generates random vectors
-  static List<List<Float>> generateRandomVectors(long vectorCount, long dimension) {
+  static List<List<Float>> generateVectors(long vectorCount, long dimension) {
     SplittableRandom splittableRandom = new SplittableRandom();
     List<List<Float>> vectors = new ArrayList<>();
     for (int i = 0; i < vectorCount; ++i) {
@@ -41,15 +41,16 @@ public class MilvusClientExample {
 
   // Helper function that normalizes a vector if you are using IP (Inner product) as your metric
   // type
-  static List<Float> normalize(List<Float> vector) {
+  static List<Float> normalizeVector(List<Float> vector) {
     float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
     final float norm = (float) Math.sqrt(squareSum);
     vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
     return vector;
   }
 
-  public static void main(String[] args) throws InterruptedException {
+  public static void main(String[] args) throws InterruptedException, ConnectFailedException {
 
+    // You may need to change the following to the host and port of your Milvus server
     final String host = "localhost";
     final String port = "19530";
 
@@ -58,18 +59,22 @@ public class MilvusClientExample {
 
     // Connect to Milvus server
     ConnectParam connectParam = new ConnectParam.Builder().withHost(host).withPort(port).build();
-    Response connectResponse = client.connect(connectParam);
-    System.out.println(connectResponse);
+    try {
+      Response connectResponse = client.connect(connectParam);
+    } catch (ConnectFailedException e) {
+      System.out.println(e.toString());
+      throw e;
+    }
 
     // Check whether we are connected
-    boolean connected = client.connected();
+    boolean connected = client.isConnected();
     System.out.println("Connected = " + connected);
 
     // Create a table with the following table schema
-    final String tableName = "example";
-    final long dimension = 128;
-    final long indexFileSize = 1024;
-    final MetricType metricType = MetricType.IP;
+    final String tableName = "example"; // table name
+    final long dimension = 128; // dimension of each vector
+    final long indexFileSize = 1024; // maximum size (in MB) of each index file
+    final MetricType metricType = MetricType.IP; // we choose IP (Inner project) as our metric type
     TableSchema tableSchema =
         new TableSchema.Builder(tableName, dimension)
             .withIndexFileSize(indexFileSize)
@@ -92,8 +97,8 @@ public class MilvusClientExample {
 
     // Insert randomly generated vectors to table
     final int vectorCount = 100000;
-    List<List<Float>> vectors = generateRandomVectors(vectorCount, dimension);
-    vectors.forEach(MilvusClientExample::normalize);
+    List<List<Float>> vectors = generateVectors(vectorCount, dimension);
+    vectors.forEach(MilvusClientExample::normalizeVector);
     InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(10).build();
     InsertResponse insertResponse = client.insert(insertParam);
     System.out.println(insertResponse);
@@ -101,7 +106,8 @@ public class MilvusClientExample {
     // yourself) to reference the vectors you just inserted
     List<Long> vectorIds = insertResponse.getVectorIds();
 
-    // Sleep for 1 second
+    // The data we just inserted won't be serialized and written to meta until the next second
+    // wait 1 second here
     TimeUnit.SECONDS.sleep(1);
 
     // Get current row count of table
@@ -111,7 +117,11 @@ public class MilvusClientExample {
     System.out.println(getTableRowCountResponse);
 
     // Create index for the table
-    final IndexType indexType = IndexType.IVF_SQ8;
+    // We choose IVF_SQ8 as our index type here. Refer to IndexType javadoc for a
+    // complete explanation of different index types
+    final IndexType indexType =
+        IndexType
+            .IVF_SQ8;
     Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8).build();
     CreateIndexParam createIndexParam =
         new CreateIndexParam.Builder(tableName).withIndex(index).withTimeout(10).build();
@@ -124,24 +134,27 @@ public class MilvusClientExample {
     System.out.println(describeIndexResponse);
 
     // Search vectors
-    final int searchSize = 5;
     // Searching the first 5 vectors of the vectors we just inserted
-    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
+    final int searchBatchSize = 5;
+    List<List<Float>> vectorsToSearch = vectors.subList(0, searchBatchSize);
     final long topK = 10;
     SearchParam searchParam =
         new SearchParam.Builder(tableName, vectorsToSearch).withTopK(topK).withTimeout(10).build();
     SearchResponse searchResponse = client.search(searchParam);
     System.out.println(searchResponse);
-    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
-    final double epsilon = 0.001;
-    for (int i = 0; i < searchSize; i++) {
-      // Since we are searching for vector that is already present in the table,
-      // the first result vector should be itself and the distance (inner product) should be
-      // very close to 1 (some precision is lost during the process)
-      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
-      if (firstQueryResult.getVectorId() != vectorIds.get(i)
-          || firstQueryResult.getDistance() <= (1 - epsilon)) {
-        throw new AssertionError();
+    if (searchResponse.getResponse().ok()) {
+      List<List<SearchResponse.QueryResult>> queryResultsList =
+          searchResponse.getQueryResultsList();
+      final double epsilon = 0.001;
+      for (int i = 0; i < searchBatchSize; i++) {
+        // Since we are searching for vector that is already present in the table,
+        // the first result vector should be itself and the distance (inner product) should be
+        // very close to 1 (some precision is lost during the process)
+        SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+        if (firstQueryResult.getVectorId() != vectorIds.get(i)
+            || firstQueryResult.getDistance() <= (1 - epsilon)) {
+          throw new AssertionError("Wrong results!");
+        }
       }
     }
 
@@ -156,7 +169,13 @@ public class MilvusClientExample {
     System.out.println(dropTableResponse);
 
     // Disconnect from Milvus server
-    Response disconnectResponse = client.disconnect();
-    System.out.println(disconnectResponse);
+    try {
+      Response disconnectResponse = client.disconnect();
+    } catch (InterruptedException e) {
+      System.out.println("Failed to disconnect: " + e.toString());
+      throw e;
+    }
+
+    return;
   }
 }

+ 0 - 75
milvus-sdk-java.iml

@@ -1,75 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" type="JAVA_MODULE" version="4">
-  <component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
-    <output url="file://$MODULE_DIR$/target/classes" />
-    <output-test url="file://$MODULE_DIR$/target/test-classes" />
-    <content url="file://$MODULE_DIR$">
-      <sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
-      <sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
-      <sourceFolder url="file://$MODULE_DIR$/target/generated-sources/protobuf/grpc-java" isTestSource="false" generated="true" />
-      <sourceFolder url="file://$MODULE_DIR$/target/generated-sources/protobuf/java" isTestSource="false" generated="true" />
-      <sourceFolder url="file://$MODULE_DIR$/src/main/proto" type="java-resource" />
-      <excludeFolder url="file://$MODULE_DIR$/target" />
-    </content>
-    <orderEntry type="inheritedJdk" />
-    <orderEntry type="sourceFolder" forTests="false" />
-    <orderEntry type="library" name="Bundled Protobuf Distribution" level="application" />
-    <orderEntry type="library" name="Maven: org.apache.maven.plugins:maven-gpg-plugin:1.6" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-api:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-project:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-settings:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-profile:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-artifact-manager:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven.wagon:wagon-provider-api:1.0-beta-6" level="project" />
-    <orderEntry type="library" name="Maven: backport-util-concurrent:backport-util-concurrent:3.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-registry:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-interpolation:1.11" level="project" />
-    <orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-container-default:1.0-alpha-9-stable-1" level="project" />
-    <orderEntry type="library" name="Maven: classworlds:classworlds:1.1-alpha-2" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-artifact:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-repository-metadata:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.maven:maven-model:2.2.1" level="project" />
-    <orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-utils:3.0.20" level="project" />
-    <orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-sec-dispatcher:1.4" level="project" />
-    <orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-cipher:1.4" level="project" />
-    <orderEntry type="library" name="Maven: com.google.googlejavaformat:google-java-format:1.7" level="project" />
-    <orderEntry type="library" name="Maven: com.google.guava:guava:27.0.1-jre" level="project" />
-    <orderEntry type="library" name="Maven: com.google.guava:failureaccess:1.0.1" level="project" />
-    <orderEntry type="library" name="Maven: com.google.guava:listenablefuture:9999.0-empty-to-avoid-conflict-with-guava" level="project" />
-    <orderEntry type="library" name="Maven: com.google.code.findbugs:jsr305:3.0.2" level="project" />
-    <orderEntry type="library" name="Maven: org.checkerframework:checker-qual:2.5.2" level="project" />
-    <orderEntry type="library" name="Maven: com.google.j2objc:j2objc-annotations:1.1" level="project" />
-    <orderEntry type="library" name="Maven: org.codehaus.mojo:animal-sniffer-annotations:1.17" level="project" />
-    <orderEntry type="library" name="Maven: com.google.errorprone:javac-shaded:9+181-r4173-1" level="project" />
-    <orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.24.0" level="project" />
-    <orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.24.0" level="project" />
-    <orderEntry type="library" scope="RUNTIME" name="Maven: com.google.android:annotations:4.1.1.4" level="project" />
-    <orderEntry type="library" scope="RUNTIME" name="Maven: io.perfmark:perfmark-api:0.17.0" level="project" />
-    <orderEntry type="library" scope="RUNTIME" name="Maven: io.opencensus:opencensus-api:0.21.0" level="project" />
-    <orderEntry type="library" scope="RUNTIME" name="Maven: io.opencensus:opencensus-contrib-grpc-metrics:0.21.0" level="project" />
-    <orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.24.0" level="project" />
-    <orderEntry type="library" name="Maven: io.grpc:grpc-api:1.24.0" level="project" />
-    <orderEntry type="library" name="Maven: io.grpc:grpc-context:1.24.0" level="project" />
-    <orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.9.0" level="project" />
-    <orderEntry type="library" name="Maven: com.google.api.grpc:proto-google-common-protos:1.12.0" level="project" />
-    <orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.24.0" level="project" />
-    <orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.24.0" level="project" />
-    <orderEntry type="library" scope="PROVIDED" name="Maven: javax.annotation:javax.annotation-api:1.2" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: io.grpc:grpc-testing:1.24.0" level="project" />
-    <orderEntry type="library" name="Maven: junit:junit:4.12" level="project" />
-    <orderEntry type="library" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" />
-    <orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java-util:3.10.0" level="project" />
-    <orderEntry type="library" name="Maven: com.google.errorprone:error_prone_annotations:2.3.2" level="project" />
-    <orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.5" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter:5.5.2" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.5.2" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.apiguardian:apiguardian-api:1.1.0" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.2.0" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-commons:1.5.2" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.5.2" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.5.2" level="project" />
-    <orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.5.2" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.6" level="project" />
-    <orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" />
-  </component>
-</module>

+ 4 - 1
pom.xml

@@ -23,7 +23,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.1.1</version>
+    <version>0.1.2</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>
@@ -156,6 +156,9 @@
                         <groupId>org.apache.maven.plugins</groupId>
                         <artifactId>maven-javadoc-plugin</artifactId>
                         <version>3.1.1</version>
+                        <configuration>
+                            <javadocExecutable>${java.home}/bin/javadoc</javadocExecutable>
+                        </configuration>
                         <executions>
                             <execution>
                                 <id>attach-javadocs</id>

+ 9 - 0
src/main/java/io/milvus/client/ConnectFailedException.java

@@ -0,0 +1,9 @@
+package io.milvus.client;
+
+/** Thrown when client failed to connect to server */
+public class ConnectFailedException extends Exception {
+
+  public ConnectFailedException(String message) {
+    super(message);
+  }
+}

+ 24 - 3
src/main/java/io/milvus/client/IndexType.java

@@ -20,14 +20,35 @@ package io.milvus.client;
 import java.util.Arrays;
 import java.util.Optional;
 
-/** Represents available index types */
+/**
+ * Represents different types of indexing method to query the table:
+ * <pre>
+ *
+ * 1. FLAT - Provides 100% accuracy for recalls. However, performance might be downgraded due to huge computation effort;
+ *
+ * 2. IVFLAT - K-means based similarity search which is balanced between accuracy and performance;
+ *
+ * 3. IVF_SQ8 - Vector indexing that adopts a scalar quantization strategy that significantly reduces the size of a
+ * vector (by about 3/4), thus improving the overall throughput of vector processing;
+ *
+ * 4. NSG - NSG (Navigating Spreading-out Graph) is a graph-base search algorithm that a) lowers the average
+ * out-degree of the graph for fast traversal; b) shortens the search path; c) reduces the index
+ * size; d) lowers the indexing complexity. Extensive tests show that NSG can achieve very high
+ * search performance at high precision, and needs much less memory. Compared to non-graph-based
+ * algorithms, it is faster to achieve the same search precision.
+ *
+ * 5. IVF_SQ8H - An enhanced index algorithm of IVF_SQ8. It supports hybrid computation on both CPU and GPU,
+ * which significantly improves the search performance. To use this index type, make sure both cpu and gpu are added as
+ * resources for search usage in the Milvus configuration file.
+ * </pre>
+ */
 public enum IndexType {
   INVALID(0),
   FLAT(1),
   IVFLAT(2),
   IVF_SQ8(3),
-  MIX_NSG(4),
-  IVF_SQ8_H(5),
+  NSG(4),
+  IVF_SQ8H(5),
 
   UNKNOWN(-1);
 

+ 2 - 0
src/main/java/io/milvus/client/MetricType.java

@@ -22,7 +22,9 @@ import java.util.Optional;
 
 /** Represents available metric types */
 public enum MetricType {
+  /** Euclidean distance */
   L2(1),
+  /** Inner product */
   IP(2),
 
   UNKNOWN(-1);

+ 8 - 6
src/main/java/io/milvus/client/MilvusClient.java

@@ -20,10 +20,10 @@ package io.milvus.client;
 /** The Milvus Client Interface */
 public interface MilvusClient {
 
-  String clientVersion = "0.1.1";
+  String clientVersion = "0.1.2";
 
   /** @return the current Milvus client version */
-  default String clientVersion() {
+  default String getClientVersion() {
     return clientVersion;
   }
 
@@ -42,13 +42,15 @@ public interface MilvusClient {
    * </pre>
    *
    * @return <code>Response</code>
+   * @throws ConnectFailedException if client failed to connect
    * @see ConnectParam
    * @see Response
+   * @see ConnectFailedException
    */
-  Response connect(ConnectParam connectParam);
+  Response connect(ConnectParam connectParam) throws ConnectFailedException;
 
   /** @return <code>true</code> if the client is connected to Milvus server */
-  boolean connected();
+  boolean isConnected();
 
   /**
    * Disconnects from Milvus server
@@ -277,7 +279,7 @@ public interface MilvusClient {
    * @return <code>Response</code>
    * @see Response
    */
-  Response serverStatus();
+  Response getServerStatus();
 
   /**
    * Prints server version
@@ -285,7 +287,7 @@ public interface MilvusClient {
    * @return <code>Response</code>
    * @see Response
    */
-  Response serverVersion();
+  Response getServerVersion();
 
   /**
    * Deletes vectors by date range, specified by <code>deleteByRangeParam</code>

+ 39 - 33
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -46,17 +46,17 @@ public class MilvusGrpcClient implements MilvusClient {
   /////////////////////// Client Calls///////////////////////
 
   @Override
-  public Response connect(ConnectParam connectParam) {
-    if (channel != null) {
-      logWarning("You have already connected!");
-      return new Response(Response.Status.CONNECT_FAILED, "You have already connected!");
+  public Response connect(ConnectParam connectParam) throws ConnectFailedException {
+    if (channel != null && !(channel.isShutdown() || channel.isTerminated())) {
+      logWarning("Channel is not shutdown or terminated");
+      throw new ConnectFailedException("Channel is not shutdown or terminated");
     }
 
     try {
       int port = Integer.parseInt(connectParam.getPort());
       if (port < 0 || port > 0xFFFF) {
         logSevere("Connect failed! Port {0} out of range", connectParam.getPort());
-        return new Response(Response.Status.CONNECT_FAILED, "Port " + port + " out of range");
+        throw new ConnectFailedException("Port " + port + " out of range");
       }
 
       channel =
@@ -73,16 +73,17 @@ public class MilvusGrpcClient implements MilvusClient {
 
       connectivityState = channel.getState(false);
       if (connectivityState != ConnectivityState.READY) {
-        logSevere("Connect failed! {0}", connectParam.toString());
-        return new Response(
-            Response.Status.CONNECT_FAILED, "connectivity state = " + connectivityState);
+        logSevere(
+            "Connect failed! {0}\nConnectivity state = {1}",
+            connectParam.toString(), connectivityState);
+        throw new ConnectFailedException("Connectivity state = " + connectivityState);
       }
 
       blockingStub = io.milvus.grpc.MilvusServiceGrpc.newBlockingStub(channel);
 
     } catch (Exception e) {
       logSevere("Connect failed! {0}\n{1}", connectParam.toString(), e.toString());
-      return new Response(Response.Status.CONNECT_FAILED, e.toString());
+      throw new ConnectFailedException("Exception occurred: " + e.toString());
     }
 
     logInfo("Connected successfully!\n{0}", connectParam.toString());
@@ -90,7 +91,7 @@ public class MilvusGrpcClient implements MilvusClient {
   }
 
   @Override
-  public boolean connected() {
+  public boolean isConnected() {
     if (channel == null) {
       return false;
     }
@@ -100,15 +101,20 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response disconnect() throws InterruptedException {
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     } else {
-      if (channel.shutdown().awaitTermination(60, TimeUnit.SECONDS)) {
-        logInfo("Channel terminated");
-      } else {
-        logSevere("Encountered error when terminating channel");
-        return new Response(Response.Status.RPC_ERROR);
+      try {
+        if (channel.shutdown().awaitTermination(60, TimeUnit.SECONDS)) {
+          logInfo("Channel terminated");
+        } else {
+          logSevere("Encountered error when terminating channel");
+          return new Response(Response.Status.RPC_ERROR);
+        }
+      } catch (InterruptedException e) {
+        logSevere("Exception thrown when terminating channel: {0}", e.toString());
+        throw e;
       }
     }
     return new Response(Response.Status.SUCCESS);
@@ -117,7 +123,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createTable(@Nonnull TableSchemaParam tableSchemaParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -160,7 +166,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public HasTableResponse hasTable(@Nonnull TableParam tableParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new HasTableResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
     }
@@ -196,7 +202,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropTable(@Nonnull TableParam tableParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -229,7 +235,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createIndex(@Nonnull CreateIndexParam createIndexParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -271,7 +277,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public InsertResponse insert(@Nonnull InsertParam insertParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new InsertResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
@@ -323,7 +329,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public SearchResponse search(@Nonnull SearchParam searchParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new SearchResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
@@ -374,7 +380,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public SearchResponse searchInFiles(@Nonnull SearchInFilesParam searchInFilesParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new SearchResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
@@ -434,7 +440,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public DescribeTableResponse describeTable(@Nonnull TableParam tableParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new DescribeTableResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
     }
@@ -475,7 +481,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ShowTablesResponse showTables() {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new ShowTablesResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
@@ -509,7 +515,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public GetTableRowCountResponse getTableRowCount(@Nonnull TableParam tableParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new GetTableRowCountResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), 0);
     }
@@ -544,20 +550,20 @@ public class MilvusGrpcClient implements MilvusClient {
   }
 
   @Override
-  public Response serverStatus() {
+  public Response getServerStatus() {
     CommandParam commandParam = new CommandParam.Builder("OK").build();
     return command(commandParam);
   }
 
   @Override
-  public Response serverVersion() {
+  public Response getServerVersion() {
     CommandParam commandParam = new CommandParam.Builder("version").build();
     return command(commandParam);
   }
 
   private Response command(@Nonnull CommandParam commandParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -587,7 +593,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   public Response deleteByRange(@Nonnull DeleteByRangeParam deleteByRangeParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -628,7 +634,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response preloadTable(@Nonnull TableParam tableParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -661,7 +667,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public DescribeIndexResponse describeIndex(@Nonnull TableParam tableParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new DescribeIndexResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
     }
@@ -702,7 +708,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropIndex(@Nonnull TableParam tableParam) {
 
-    if (!connected()) {
+    if (!isConnected()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }

+ 47 - 44
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -22,6 +22,7 @@ import org.apache.commons.text.RandomStringGenerator;
 import java.util.*;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
 
 import static org.junit.jupiter.api.Assertions.*;
 
@@ -35,7 +36,28 @@ class MilvusGrpcClientTest {
   private long size;
   private long dimension;
   private TableParam tableParam;
-  private TableSchema tableSchema;
+
+  // Helper function that generates random vectors
+  static List<List<Float>> generateVectors(long vectorCount, long dimension) {
+    SplittableRandom splittableRandom = new SplittableRandom();
+    List<List<Float>> vectors = new ArrayList<>();
+    for (int i = 0; i < vectorCount; ++i) {
+      DoubleStream doubleStream = splittableRandom.doubles(dimension);
+      List<Float> vector =
+          doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
+      vectors.add(vector);
+    }
+    return vectors;
+  }
+
+  // Helper function that normalizes a vector if you are using IP (Inner product) as your metric
+  // type
+  static List<Float> normalizeVector(List<Float> vector) {
+    float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
+    final float norm = (float) Math.sqrt(squareSum);
+    vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
+    return vector;
+  }
 
   @org.junit.jupiter.api.BeforeEach
   void setUp() throws Exception {
@@ -47,11 +69,10 @@ class MilvusGrpcClientTest {
 
     generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
     randomTableName = generator.generate(10);
-    size = 100;
+    size = 100000;
     dimension = 128;
     tableParam = new TableParam.Builder(randomTableName).build();
-    tableSchema =
-        new TableSchema.Builder(randomTableName, dimension)
+    TableSchema tableSchema = new TableSchema.Builder(randomTableName, dimension)
             .withIndexFileSize(1024)
             .withMetricType(MetricType.IP)
             .build();
@@ -67,8 +88,8 @@ class MilvusGrpcClientTest {
   }
 
   @org.junit.jupiter.api.Test
-  void connected() {
-    assertTrue(client.connected());
+  void isConnected() {
+    assertTrue(client.isConnected());
   }
 
   @org.junit.jupiter.api.Test
@@ -108,54 +129,28 @@ class MilvusGrpcClientTest {
 
   @org.junit.jupiter.api.Test
   void insert() {
-    Random random = new Random();
-    List<List<Float>> vectors = new ArrayList<>();
-    for (int i = 0; i < size; ++i) {
-      List<Float> vector = new ArrayList<>();
-      for (int j = 0; j < dimension; ++j) {
-        vector.add(random.nextFloat());
-      }
-      vectors.add(vector);
-    }
+    List<List<Float>> vectors = generateVectors(size, dimension);
     InsertParam insertParam = new InsertParam.Builder(randomTableName, vectors).build();
     InsertResponse insertResponse = client.insert(insertParam);
     assertTrue(insertResponse.getResponse().ok());
     assertEquals(size, insertResponse.getVectorIds().size());
   }
 
-  List<Float> normalize(List<Float> vector) {
-    float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
-    final float norm = (float) Math.sqrt(squareSum);
-    vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
-    return vector;
-  }
-
   @org.junit.jupiter.api.Test
   void search() throws InterruptedException {
-    Random random = new Random();
-    List<List<Float>> vectors = new ArrayList<>();
-    List<List<Float>> vectorsToSearch = new ArrayList<>();
-    int searchSize = 5;
-    for (int i = 0; i < size; ++i) {
-      List<Float> vector = new ArrayList<>();
-      for (int j = 0; j < dimension; ++j) {
-        vector.add(random.nextFloat());
-      }
-      if (tableSchema.getMetricType() == MetricType.IP) {
-        vector = normalize(vector);
-      }
-      vectors.add(vector);
-      if (i < searchSize) {
-        vectorsToSearch.add(vector);
-      }
-    }
+    List<List<Float>> vectors = generateVectors(size, dimension);
+    vectors.forEach(MilvusGrpcClientTest::normalizeVector);
     InsertParam insertParam = new InsertParam.Builder(randomTableName, vectors).build();
     InsertResponse insertResponse = client.insert(insertParam);
     assertTrue(insertResponse.getResponse().ok());
-    assertEquals(size, insertResponse.getVectorIds().size());
+    List<Long> vectorIds = insertResponse.getVectorIds();
+    assertEquals(size, vectorIds.size());
 
     TimeUnit.SECONDS.sleep(1);
 
+    final int searchSize = 5;
+    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
+
     List<DateRange> queryRanges = new ArrayList<>();
     Date today = new Date();
     Calendar c = Calendar.getInstance();
@@ -167,16 +162,24 @@ class MilvusGrpcClientTest {
     Date tomorrow = c.getTime();
     queryRanges.add(new DateRange(yesterday, tomorrow));
     System.out.println(queryRanges);
+    final long topK = 1000;
     SearchParam searchParam =
         new SearchParam.Builder(randomTableName, vectorsToSearch)
-            .withTopK(1)
+            .withTopK(topK)
             .withNProbe(20)
             .withDateRanges(queryRanges)
             .build();
     SearchResponse searchResponse = client.search(searchParam);
     assertTrue(searchResponse.getResponse().ok());
     System.out.println(searchResponse);
-    assertEquals(searchSize, searchResponse.getQueryResultsList().size());
+    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+    assertEquals(searchSize, queryResultsList.size());
+    final double epsilon = 0.001;
+    for (int i = 0; i < searchSize; i++) {
+      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      assertEquals(vectorIds.get(i), firstQueryResult.getVectorId());
+      assertTrue(firstQueryResult.getDistance() > (1 - epsilon));
+    }
   }
 
   //    @org.junit.jupiter.api.Test
@@ -204,13 +207,13 @@ class MilvusGrpcClientTest {
 
   @org.junit.jupiter.api.Test
   void serverStatus() {
-    Response serverStatusResponse = client.serverStatus();
+    Response serverStatusResponse = client.getServerStatus();
     assertTrue(serverStatusResponse.ok());
   }
 
   @org.junit.jupiter.api.Test
   void serverVersion() {
-    Response serverVersionResponse = client.serverVersion();
+    Response serverVersionResponse = client.getServerVersion();
     assertTrue(serverVersionResponse.ok());
   }