Browse Source

Simplify `createIndex`

jianghua 4 years ago
parent
commit
dca9c869d0

+ 62 - 65
src/main/java/io/milvus/client/Index.java

@@ -19,98 +19,95 @@
 
 
 package io.milvus.client;
 package io.milvus.client;
 
 
+import io.milvus.grpc.IndexParam;
+import io.milvus.grpc.KeyValuePair;
+import org.json.JSONObject;
+
 import javax.annotation.Nonnull;
 import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
 
 
 /** Represents an index containing <code>fieldName</code>, <code>indexName</code> and
 /** Represents an index containing <code>fieldName</code>, <code>indexName</code> and
  * <code>paramsInJson</code>, which contains index_type, params etc.
  * <code>paramsInJson</code>, which contains index_type, params etc.
  */
  */
 public class Index {
 public class Index {
-  private final String collectionName;
-  private final String fieldName;
-  private final String indexName;
-  private final String paramsInJson;
-
-  private Index(@Nonnull Builder builder) {
-    this.collectionName = builder.collectionName;
-    this.fieldName = builder.fieldName;
-    this.indexName = builder.indexName;
-    this.paramsInJson = builder.paramsInJson;
+  private final IndexParam.Builder builder;
+
+  public static Index create(@Nonnull String collectionName, @Nonnull String fieldName) {
+    return new Index(collectionName, fieldName);
+  }
+
+  private Index(String collectionName, String fieldName) {
+    this.builder = IndexParam.newBuilder()
+        .setCollectionName(collectionName)
+        .setFieldName(fieldName);
   }
   }
 
 
   public String getCollectionName() {
   public String getCollectionName() {
-    return collectionName;
+    return builder.getCollectionName();
+  }
+
+  public Index setCollectionName(@Nonnull String collectionName) {
+    builder.setCollectionName(collectionName);
+    return this;
   }
   }
 
 
   public String getFieldName() {
   public String getFieldName() {
-    return fieldName;
+    return builder.getFieldName();
+  }
+
+  public Index setFieldName(@Nonnull String collectionName) {
+    builder.setFieldName(collectionName);
+    return this;
   }
   }
 
 
   public String getIndexName() {
   public String getIndexName() {
-    return indexName;
+    return builder.getIndexName();
+  }
+
+  public Map<String, String> getExtraParams() {
+    return toMap(builder.getExtraParamsList());
+  }
+
+  public Index setIndexType(IndexType indexType) {
+    return addParam("index_type", indexType.name());
+  }
+
+  public Index setMetricType(MetricType metricType) {
+    return addParam("metric_type", metricType.name());
   }
   }
 
 
-  public String getParamsInJson() {
-    return paramsInJson;
+  public Index setParamsInJson(String paramsInJson) {
+    return addParam(MilvusClient.extraParamKey, paramsInJson);
+  }
+
+  private Index addParam(String key, Object value) {
+    builder.addExtraParams(
+        KeyValuePair.newBuilder()
+            .setKey(key)
+            .setValue(String.valueOf(value))
+            .build());
+    return this;
   }
   }
 
 
   @Override
   @Override
   public String toString() {
   public String toString() {
     return "Index {"
     return "Index {"
         + "collectionName="
         + "collectionName="
-        + collectionName
+        + getCollectionName()
         + ", fieldName="
         + ", fieldName="
-        + fieldName
+        + getFieldName()
         + ", params="
         + ", params="
-        + paramsInJson
+        + getExtraParams()
         + '}';
         + '}';
   }
   }
 
 
-  /** Builder for <code>Index</code> */
-  public static class Builder {
-    // Required parameters
-    private final String collectionName;
-    private final String fieldName;
-
-    // Optional parameters - initialized to default values
-    private String paramsInJson = "{}";
-    private String indexName = "";
-
-    /**
-     * @param collectionName collection to create index for
-     * @param fieldName name of the field on which index is built.
-     */
-    public Builder(@Nonnull String collectionName, @Nonnull String fieldName) {
-      this.collectionName = collectionName;
-      this.fieldName = fieldName;
-    }
-
-    /**
-     * Optional. The parameters for building an index. Index parameters are different for different
-     * index types. Refer to Milvus documentation for more information.
-     * <pre>
-     * "index_type": one of the values: FLAT, IVF_FLAT, IVF_SQ8, NSG, IVF_SQ8_HYBRID, IVF_PQ,
-     *                                  HNSW, RHNSW_FLAT, RHNSW_PQ, RHNSW_SQ, ANNOY
-     * "metric_type": one of the values: L2, IP, HAMMING, JACCARD, TANIMOTO,
-     *                                   SUBSTRUCTURE, SUPERSTRUCTURE
-     * "params": optional parameters for index, including <code>nlist</code>
-     *
-     * Example param:
-     * <code>
-     *   {"index_type": "IVF_FLAT", "metric_type": "IP", "params": {nlist": 2048}}
-     * </code>
-     * </pre>
-     *
-     * @param paramsInJson extra parameters in JSON format
-     * @see JsonBuilder
-     * @return <code>Builder</code>
-     */
-    public Builder withParamsInJson(@Nonnull String paramsInJson) {
-      this.paramsInJson = paramsInJson;
-      return this;
-    }
-
-    public Index build() {
-      return new Index(this);
-    }
+  IndexParam grpc() {
+    return builder.build();
+  }
+
+  private Map<String, String> toMap(List<KeyValuePair> extraParams) {
+    return extraParams.stream().collect(Collectors.toMap(KeyValuePair::getKey, KeyValuePair::getValue));
   }
   }
 }
 }

+ 16 - 0
src/main/java/io/milvus/client/IndexType.java

@@ -0,0 +1,16 @@
+package io.milvus.client;
+
+public enum IndexType {
+  ANNOY,
+  BIN_IVF_FLAT,
+  FLAT,
+  HNSW,
+  IVF_FLAT,
+  IVF_PQ,
+  IVF_SQ8,
+  IVF_SQ8_HYBRID,
+  NSG,
+  RHNSW_FLAT,
+  RHNSW_PQ,
+  RHNSW_SQ
+}

+ 11 - 0
src/main/java/io/milvus/client/MetricType.java

@@ -0,0 +1,11 @@
+package io.milvus.client;
+
+public enum MetricType {
+  HAMMING,
+  IP,
+  JACCARD,
+  L2,
+  SUBSTRUCTURE,
+  SUPERSTRUCTURE,
+  TANIMOTO
+}

+ 2 - 5
src/main/java/io/milvus/client/MilvusClient.java

@@ -121,11 +121,9 @@ public interface MilvusClient {
    * </code>
    * </code>
    * </pre>
    * </pre>
    *
    *
-   * @return <code>Response</code>
    * @see Index
    * @see Index
-   * @see Response
    */
    */
-  Response createIndex(Index index);
+  void createIndex(Index index);
 
 
   /**
   /**
    * Creates index specified by <code>index</code> asynchronously
    * Creates index specified by <code>index</code> asynchronously
@@ -144,10 +142,9 @@ public interface MilvusClient {
    *
    *
    * @return a <code>ListenableFuture</code> object which holds the <code>Response</code>
    * @return a <code>ListenableFuture</code> object which holds the <code>Response</code>
    * @see Index
    * @see Index
-   * @see Response
    * @see ListenableFuture
    * @see ListenableFuture
    */
    */
-  ListenableFuture<Response> createIndexAsync(Index index);
+  ListenableFuture<Void> createIndexAsync(Index index);
 
 
   /**
   /**
    * Creates a partition specified by <code>collectionName</code> and <code>tag</code>
    * Creates a partition specified by <code>collectionName</code> and <code>tag</code>

+ 13 - 107
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -228,8 +228,10 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
   private <R> R translate(Throwable e) {
   private <R> R translate(Throwable e) {
     if (e instanceof MilvusException) {
     if (e instanceof MilvusException) {
       throw (MilvusException) e;
       throw (MilvusException) e;
-    } else {
+    } else if (e.getCause() == null || e.getCause() == e) {
       throw new ClientSideMilvusException(target(), e);
       throw new ClientSideMilvusException(target(), e);
+    } else {
+      return translate(e.getCause());
     }
     }
   }
   }
 
 
@@ -268,115 +270,19 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
   }
   }
 
 
   @Override
   @Override
-  public Response createIndex(@Nonnull Index index) {
-
-    if (!maybeAvailable()) {
-      logWarning("You are not connected to Milvus server");
-      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
-    }
-
-    List<KeyValuePair> extraParams = new ArrayList<>();
-
-    try {
-      JSONObject jsonInfo = new JSONObject(index.getParamsInJson());
-      Iterator<String> keys = jsonInfo.keys();
-      while (keys.hasNext()) {
-        String key = keys.next();
-        KeyValuePair extraParam = KeyValuePair.newBuilder()
-            .setKey(key)
-            .setValue(jsonInfo.get(key).toString())
-            .build();
-        extraParams.add(extraParam);
-      }
-    } catch (JSONException err){
-      logError("Params must be in json format.\n`{}`", err.toString());
-      return new Response(Response.Status.ILLEGAL_ARGUMENT);
-    }
-
-    IndexParam request =
-        IndexParam.newBuilder()
-            .setCollectionName(index.getCollectionName())
-            .setFieldName(index.getFieldName())
-            .addAllExtraParams(extraParams)
-            .build();
-
-    Status response;
-
-    try {
-      response = blockingStub().createIndex(request);
-
-      if (response.getErrorCode() == ErrorCode.SUCCESS) {
-        logInfo("Created index successfully!\n{}", index.toString());
-        return new Response(Response.Status.SUCCESS);
-      } else {
-        logError("Create index failed:\n{}\n{}", index.toString(), response.toString());
-        return new Response(
-            Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
-      }
-    } catch (StatusRuntimeException e) {
-      logError("createIndex RPC failed:\n{}", e.getStatus().toString());
-      return new Response(Response.Status.RPC_ERROR, e.toString());
-    }
+  public void createIndex(@Nonnull Index index) {
+    translateExceptions(() -> {
+      Futures.getUnchecked(createIndexAsync(index));
+    });
   }
   }
 
 
   @Override
   @Override
-  public ListenableFuture<Response> createIndexAsync(@Nonnull Index index) {
-
-    if (!maybeAvailable()) {
-      logWarning("You are not connected to Milvus server");
-      return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
-    }
-
-    List<KeyValuePair> extraParams = new ArrayList<>();
-
-    try {
-      JSONObject jsonInfo = new JSONObject(index.getParamsInJson());
-      Iterator<String> keys = jsonInfo.keys();
-      while (keys.hasNext()) {
-        String key = keys.next();
-        KeyValuePair extraParam = KeyValuePair.newBuilder()
-            .setKey(key)
-            .setValue(jsonInfo.get(key).toString())
-            .build();
-        extraParams.add(extraParam);
-      }
-    } catch (JSONException err){
-      logError("Params must be in json format.\n`{}`", err.toString());
-      return Futures.immediateFuture(new Response(Response.Status.ILLEGAL_ARGUMENT));
-    }
-
-    IndexParam request =
-        IndexParam.newBuilder()
-            .setCollectionName(index.getCollectionName())
-            .setFieldName(index.getFieldName())
-            .addAllExtraParams(extraParams)
-            .build();
-
-    ListenableFuture<Status> response;
-
-    response = futureStub().createIndex(request);
-
-    Futures.addCallback(
-        response,
-        new FutureCallback<Status>() {
-          @Override
-          public void onSuccess(Status result) {
-            if (result.getErrorCode() == ErrorCode.SUCCESS) {
-              logInfo("Created index successfully!\n{}", index.toString());
-            } else {
-              logError("CreateIndexAsync failed:\n{}\n{}", index.toString(), result.toString());
-            }
-          }
-
-          @Override
-          public void onFailure(Throwable t) {
-            logError("CreateIndexAsync failed:\n{}", t.getMessage());
-          }
-        },
-        MoreExecutors.directExecutor());
-
-    return Futures.transform(
-        response, transformStatusToResponseFunc::apply, MoreExecutors.directExecutor());
+  public ListenableFuture<Void> createIndexAsync(@Nonnull Index index) {
+    return translateExceptions(() -> {
+      IndexParam request = index.grpc();
+      ListenableFuture<Status> responseFuture = futureStub().createIndex(request);
+      return Futures.transform(responseFuture, this::checkResponseStatus, MoreExecutors.directExecutor());
+    });
   }
   }
 
 
   @Override
   @Override

+ 24 - 58
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -19,12 +19,15 @@
 
 
 package io.milvus.client;
 package io.milvus.client;
 
 
+import com.google.common.collect.ImmutableMap;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.common.util.concurrent.MoreExecutors;
 import io.grpc.NameResolverProvider;
 import io.grpc.NameResolverProvider;
 import io.grpc.NameResolverRegistry;
 import io.grpc.NameResolverRegistry;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
 import io.milvus.client.InsertParam.Builder;
 import io.milvus.client.InsertParam.Builder;
 import io.milvus.client.exception.ClientSideMilvusException;
 import io.milvus.client.exception.ClientSideMilvusException;
 import io.milvus.client.exception.InitializationException;
 import io.milvus.client.exception.InitializationException;
@@ -129,6 +132,12 @@ class MilvusClientTest {
     assertEquals(errorCode, assertThrows(ServerSideMilvusException.class, runnable::run).getErrorCode());
     assertEquals(errorCode, assertThrows(ServerSideMilvusException.class, runnable::run).getErrorCode());
   }
   }
 
 
+  protected void assertGrpcStatusCode(Status.Code statusCode, Runnable runnable) {
+    ClientSideMilvusException error = assertThrows(ClientSideMilvusException.class, runnable::run);
+    assertTrue(error.getCause() instanceof StatusRuntimeException);
+    assertEquals(statusCode, ((StatusRuntimeException) error.getCause()).getStatus().getCode());
+  }
+
   // Helper function that generates random float vectors
   // Helper function that generates random float vectors
   static List<List<Float>> generateFloatVectors(int vectorCount, int dimension) {
   static List<List<Float>> generateFloatVectors(int vectorCount, int dimension) {
     SplittableRandom splittableRandom = new SplittableRandom();
     SplittableRandom splittableRandom = new SplittableRandom();
@@ -301,15 +310,11 @@ class MilvusClientTest {
   void grpcTimeout() {
   void grpcTimeout() {
     insert();
     insert();
     MilvusClient timeoutClient = client.withTimeout(1, TimeUnit.MILLISECONDS);
     MilvusClient timeoutClient = client.withTimeout(1, TimeUnit.MILLISECONDS);
-    Response response = timeoutClient.createIndex(
-        new Index.Builder(randomCollectionName, "float_vec")
-            .withParamsInJson(new JsonBuilder()
-                .param("index_type", "IVF_FLAT")
-                .param("metric_type", "L2")
-                .indexParam("nlist", 2048)
-                .build())
-            .build());
-    assertEquals(Response.Status.RPC_ERROR, response.getStatus());
+    Index index = Index.create(randomCollectionName, "float_vec")
+        .setIndexType(IndexType.IVF_FLAT)
+        .setMetricType(MetricType.L2)
+        .setParamsInJson(new JsonBuilder().param("nlist", 2048).build());
+    assertGrpcStatusCode(Status.Code.DEADLINE_EXCEEDED, () -> timeoutClient.createIndex(index));
   }
   }
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
@@ -460,53 +465,18 @@ class MilvusClientTest {
     insert();
     insert();
     assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.flush(randomCollectionName).ok());
 
 
-    Index index =
-        new Index.Builder(randomCollectionName, "float_vec")
-            .withParamsInJson(new JsonBuilder().param("index_type", "IVF_SQ8")
-                                               .param("metric_type", "L2")
-                                               .indexParam("nlist", 2048)
-                                               .build())
-            .build();
+    Index index = Index.create(randomCollectionName, "float_vec")
+        .setIndexType(IndexType.IVF_SQ8)
+        .setMetricType(MetricType.L2)
+        .setParamsInJson(new JsonBuilder().param("nlist", 2048).build());
 
 
-    Response createIndexResponse = client.createIndex(index);
-    assertTrue(createIndexResponse.ok());
+    client.createIndex(index);
 
 
     // also test drop index here
     // also test drop index here
     Response dropIndexResponse = client.dropIndex(randomCollectionName, "float_vec");
     Response dropIndexResponse = client.dropIndex(randomCollectionName, "float_vec");
     assertTrue(dropIndexResponse.ok());
     assertTrue(dropIndexResponse.ok());
   }
   }
 
 
-  @org.junit.jupiter.api.Test
-  void createIndexAsync() throws ExecutionException, InterruptedException {
-    insert();
-    assertTrue(client.flush(randomCollectionName).ok());
-
-    Index index =
-        new Index.Builder(randomCollectionName, "float_vec")
-            .withParamsInJson(new JsonBuilder().param("index_type", "IVF_SQ8")
-                                               .param("metric_type", "L2")
-                                               .indexParam("nlist", 2048)
-                                               .build())
-            .build();
-
-    ListenableFuture<Response> createIndexResponseFuture = client.createIndexAsync(index);
-    Futures.addCallback(
-        createIndexResponseFuture,
-        new FutureCallback<Response>() {
-          @Override
-          public void onSuccess(@NullableDecl Response createIndexResponse) {
-            assert createIndexResponse != null;
-            assertTrue(createIndexResponse.ok());
-          }
-
-          @Override
-          public void onFailure(Throwable t) {
-            System.out.println(t.getMessage());
-          }
-        }, MoreExecutors.directExecutor()
-    );
-  }
-
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
   void insert() {
   void insert() {
     List<Long> intValues = new ArrayList<>(size);
     List<Long> intValues = new ArrayList<>(size);
@@ -600,16 +570,12 @@ class MilvusClientTest {
     assertTrue(insertResponse.ok());
     assertTrue(insertResponse.ok());
     assertEquals(size, insertResponse.getEntityIds().size());
     assertEquals(size, insertResponse.getEntityIds().size());
 
 
-    Index index =
-        new Index.Builder(binaryCollectionName, "binary_vec")
-            .withParamsInJson(new JsonBuilder().param("index_type", "BIN_IVF_FLAT")
-                .param("metric_type", "JACCARD")
-                .indexParam("nlist", 100)
-                .build())
-            .build();
+    Index index = Index.create(binaryCollectionName, "binary_vec")
+        .setIndexType(IndexType.BIN_IVF_FLAT)
+        .setMetricType(MetricType.JACCARD)
+        .setParamsInJson(new JsonBuilder().param("nlist", 100).build());
 
 
-    Response createIndexResponse = client.createIndex(index);
-    assertTrue(createIndexResponse.ok());
+    client.createIndex(index);
 
 
     // also test drop index here
     // also test drop index here
     Response dropIndexResponse = client.dropIndex(binaryCollectionName, "binary_vec");
     Response dropIndexResponse = client.dropIndex(binaryCollectionName, "binary_vec");