2
0
Zhiru Zhu 5 жил өмнө
parent
commit
4b3ba63adf

+ 1 - 1
pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.3.0</version>
+    <version>0.4.0-SNAPSHOT</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>

+ 18 - 0
src/main/java/io/milvus/client/InsertParam.java

@@ -28,11 +28,13 @@ public class InsertParam {
   private final String tableName;
   private final List<List<Float>> vectors;
   private final List<Long> vectorIds;
+  private final String partitionTag;
 
   private InsertParam(@Nonnull Builder builder) {
     this.tableName = builder.tableName;
     this.vectors = builder.vectors;
     this.vectorIds = builder.vectorIds;
+    this.partitionTag = builder.partitionTag;
   }
 
   public String getTableName() {
@@ -47,6 +49,10 @@ public class InsertParam {
     return vectorIds;
   }
 
+  public String getPartitionTag() {
+    return partitionTag;
+  }
+
   /** Builder for <code>InsertParam</code> */
   public static class Builder {
     // Required parameters
@@ -55,6 +61,7 @@ public class InsertParam {
 
     // Optional parameters - initialized to default values
     private List<Long> vectorIds = new ArrayList<>();
+    private String partitionTag = "";
 
     /**
      * @param tableName table to insert vectors to
@@ -77,6 +84,17 @@ public class InsertParam {
       return this;
     }
 
+    /**
+     * Optional. Default to an empty <code>String</code>
+     *
+     * @param partitionTag partition tag
+     * @return <code>Builder</code>
+     */
+    public Builder withPartitionTag(@Nonnull String partitionTag) {
+      this.partitionTag = partitionTag;
+      return this;
+    }
+
     public InsertParam build() {
       return new InsertParam(this);
     }

+ 45 - 0
src/main/java/io/milvus/client/MilvusClient.java

@@ -134,6 +134,50 @@ public interface MilvusClient {
    */
   Response createIndex(CreateIndexParam createIndexParam);
 
+  /**
+   * Creates a partition specified by <code>PartitionParam</code>
+   *
+   * @param partition the <code>PartitionParam</code> object
+   * <pre>
+   * example usage:
+   * <code>
+   * Partition partition = new Partition.Builder(tableName, partitionName, tag).build();
+   * </code>
+   * </pre>
+   *
+   * @return <code>Response</code>
+   * @see Partition
+   * @see Response
+   */
+  Response createPartition(Partition partition);
+
+  /**
+   * Shows current partitions of a table
+   *
+   * @param tableName table name
+   * @return <code>ShowPartitionsResponse</code>
+   * @see ShowPartitionsResponse
+   * @see Response
+   */
+  ShowPartitionsResponse showPartitions(String tableName);
+
+  /**
+   * Drops partition specified by <code>partitionName</code>
+   * @param partitionName partition name
+   * @see Response
+   */
+  Response dropPartition(String partitionName);
+
+
+  /**
+   * Drops partition specified by <code>tableName</code> and <code>tag</code>
+   *
+   * @param tableName table name
+   * @param tag partition tag
+   * @see Response
+   */
+  Response dropPartition(String tableName, String tag);
+
   /**
    * Inserts data specified by <code>insertParam</code>
    *
@@ -143,6 +187,7 @@ public interface MilvusClient {
    * <code>
    * InsertParam insertParam = new InsertParam.Builder(tableName, vectors)
    *                                          .withVectorIds(vectorIds)
+   *                                          .withPartitionTag(tag)
    *                                          .build();
    * </code>
    * </pre>

+ 139 - 0
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -23,6 +23,7 @@ import io.grpc.ConnectivityState;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
 import io.grpc.StatusRuntimeException;
+import io.milvus.grpc.PartitionParam;
 import org.apache.commons.collections4.ListUtils;
 
 import javax.annotation.Nonnull;
@@ -269,6 +270,144 @@ public class MilvusGrpcClient implements MilvusClient {
     }
   }
 
+  @Override
+  public Response createPartition(Partition partition) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
+    }
+
+    io.milvus.grpc.PartitionParam request =
+        io.milvus.grpc.PartitionParam.newBuilder()
+            .setTableName(partition.getTableName())
+            .setPartitionName(partition.getPartitionName())
+            .setTag(partition.getTag())
+            .build();
+
+    io.milvus.grpc.Status response;
+
+    try {
+      response = blockingStub.createPartition(request);
+
+      if (response.getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
+        logInfo("Created partition successfully!\n{0}", partition.toString());
+        return new Response(Response.Status.SUCCESS);
+      } else {
+        logSevere("Create partition failed\n{0}\n{1}", partition.toString(), response.toString());
+        return new Response(
+            Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("createPartition RPC failed:\n{0}", e.getStatus().toString());
+      return new Response(Response.Status.RPC_ERROR, e.toString());
+    }
+  }
+
+  @Override
+  public ShowPartitionsResponse showPartitions(String tableName) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      return new ShowPartitionsResponse(
+          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
+    }
+
+    io.milvus.grpc.TableName request =
+        io.milvus.grpc.TableName.newBuilder().setTableName(tableName).build();
+    io.milvus.grpc.PartitionList response;
+
+    try {
+      response = blockingStub.showPartitions(request);
+
+      if (response.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
+        List<PartitionParam> partitionList = response.getPartitionArrayList();
+        List<Partition> partitions = new ArrayList<>();
+        for (PartitionParam partitionParam : partitionList) {
+          partitions.add(
+              new Partition.Builder(
+                      partitionParam.getTableName(),
+                      partitionParam.getPartitionName(),
+                      partitionParam.getTag())
+                  .build());
+        }
+        logInfo("Current partitions of table {0}: {1}", tableName, partitions.toString());
+        return new ShowPartitionsResponse(new Response(Response.Status.SUCCESS), partitions);
+      } else {
+        logSevere("Show partitions failed:\n{0}", response.toString());
+        return new ShowPartitionsResponse(
+            new Response(
+                Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
+                response.getStatus().getReason()),
+            new ArrayList<>());
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("showPartitions RPC failed:\n{0}", e.getStatus().toString());
+      return new ShowPartitionsResponse(
+          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
+    }
+  }
+
+  @Override
+  public Response dropPartition(String partitionName) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
+    }
+
+    io.milvus.grpc.PartitionParam request =
+        io.milvus.grpc.PartitionParam.newBuilder().setPartitionName(partitionName).build();
+    io.milvus.grpc.Status response;
+
+    try {
+      response = blockingStub.dropPartition(request);
+
+      if (response.getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
+        logInfo("Dropped partition `{0}` successfully!", partitionName);
+        return new Response(Response.Status.SUCCESS);
+      } else {
+        logSevere("Drop partition `{0}` failed:\n{1}", partitionName, response.toString());
+        return new Response(
+            Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("dropPartition RPC failed:\n{0}", e.getStatus().toString());
+      return new Response(Response.Status.RPC_ERROR, e.toString());
+    }
+  }
+
+  @Override
+  public Response dropPartition(String tableName, String tag) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
+    }
+
+    io.milvus.grpc.PartitionParam request =
+        io.milvus.grpc.PartitionParam.newBuilder().setTableName(tableName).setTag(tag).build();
+    io.milvus.grpc.Status response;
+
+    try {
+      response = blockingStub.dropPartition(request);
+
+      if (response.getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
+        logInfo("Dropped partition of table `{0}` and tag `{1}` successfully!", tableName, tag);
+        return new Response(Response.Status.SUCCESS);
+      } else {
+        logSevere(
+            "Drop partition of table `{0}` and tag `{1}` failed:\n{1}",
+            tableName, tag, response.toString());
+        return new Response(
+            Response.Status.valueOf(response.getErrorCodeValue()), response.getReason());
+      }
+    } catch (StatusRuntimeException e) {
+      logSevere("dropPartition RPC failed:\n{0}", e.getStatus().toString());
+      return new Response(Response.Status.RPC_ERROR, e.toString());
+    }
+  }
+
   @Override
   public InsertResponse insert(@Nonnull InsertParam insertParam) {
 

+ 73 - 0
src/main/java/io/milvus/client/Partition.java

@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.client;
+
+import javax.annotation.Nonnull;
+
+public class Partition {
+    private final String tableName;
+    private final String partitionName;
+    private final String tag;
+
+    private Partition(@Nonnull Builder builder) {
+        this.tableName = builder.tableName;
+        this.partitionName = builder.partitionName;
+        this.tag = builder.tag;
+    }
+
+    public String getTableName() {
+        return tableName;
+    }
+
+    public String getPartitionName() {
+        return partitionName;
+    }
+
+    public String getTag() {
+        return tag;
+    }
+
+    @Override
+    public String toString() {
+        return "PartitionParam {" +
+                "tableName='" + tableName + '\'' +
+                ", partitionName='" + partitionName + '\'' +
+                ", tag='" + tag + '\'' +
+                '}';
+    }
+
+    /** Builder for <code>Partition</code> */
+    public static class Builder {
+        // Required parameters
+        private final String tableName;
+        private final String partitionName;
+        private final String tag;
+
+        public Builder(@Nonnull String tableName, @Nonnull String partitionName, @Nonnull String tag) {
+            this.tableName = tableName;
+            this.partitionName = partitionName;
+            this.tag = tag;
+        }
+
+        public Partition build() {
+            return new Partition(this);
+        }
+    }
+}

+ 18 - 0
src/main/java/io/milvus/client/SearchParam.java

@@ -31,6 +31,7 @@ public class SearchParam {
   private final List<DateRange> dateRanges;
   private final long topK;
   private final long nProbe;
+  private final List<String> partitionTags;
 
   private SearchParam(@Nonnull Builder builder) {
     this.tableName = builder.tableName;
@@ -38,6 +39,7 @@ public class SearchParam {
     this.dateRanges = builder.dateRanges;
     this.nProbe = builder.nProbe;
     this.topK = builder.topK;
+    this.partitionTags = builder.partitionTags;
   }
 
   public String getTableName() {
@@ -60,6 +62,10 @@ public class SearchParam {
     return nProbe;
   }
 
+  public List<String> getPartitionTags() {
+    return partitionTags;
+  }
+
   /** Builder for <code>SearchParam</code> */
   public static class Builder {
     // Required parameters
@@ -70,6 +76,7 @@ public class SearchParam {
     private List<DateRange> dateRanges = new ArrayList<>();
     private long topK = 1024;
     private long nProbe = 20;
+    private List<String> partitionTags = new ArrayList<>();
 
     /**
      * @param tableName table to search from
@@ -116,6 +123,17 @@ public class SearchParam {
       return this;
     }
 
+    /**
+     * Optional. Search vectors with corresponding <code>partitionTags</code>. Default to an empty <code>List</code>
+     *
+     * @param partitionTags a <code>List</code> of partition tags
+     * @return <code>Builder</code>
+     */
+    public Builder withPartitionTags(List<String> partitionTags) {
+      this.partitionTags = partitionTags;
+      return this;
+    }
+
     public SearchParam build() {
       return new SearchParam(this);
     }

+ 45 - 0
src/main/java/io/milvus/client/ShowPartitionsResponse.java

@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.client;
+
+import java.util.List;
+
+public class ShowPartitionsResponse {
+    private final Response response;
+    private final List<Partition> partitionList;
+
+    ShowPartitionsResponse(Response response, List<Partition> partitionList) {
+        this.response = response;
+        this.partitionList = partitionList;
+    }
+
+    public List<Partition> getPartitionList() {
+        return partitionList;
+    }
+
+    public Response getResponse() {
+        return response;
+    }
+
+    public boolean ok() {
+        return response.ok();
+    }
+
+}

+ 46 - 0
src/main/proto/milvus.proto

@@ -20,6 +20,16 @@ message TableName {
 
 }
 
+/**
+
+ * @brief Partition name
+
+ */
+
+message PartitionName {
+    string partition_name = 1;
+}
+
 /**
 
  * @brief Table Name List
@@ -80,6 +90,29 @@ message RowRecord {
 
 }
 
+/**
+
+ * @brief Partition param
+
+ */
+
+message PartitionParam {
+    string table_name = 1;
+    string partition_name = 2;
+    string tag = 3;
+}
+
+/**
+
+ * @brief Partition list
+
+ */
+
+message PartitionList {
+    Status status = 1;
+    repeated PartitionParam partition_array = 2;
+}
+
 /**
 
  * @brief params to be inserted
@@ -94,6 +127,8 @@ message InsertParam {
 
     repeated int64 row_id_array = 3; //optional
 
+    string partition_tag = 4; // default empty
+
 }
 
 /**
@@ -128,6 +163,8 @@ message SearchParam {
 
     int64 nprobe = 5;
 
+    repeated string partition_tag_array = 6; // default empty
+
 }
 
 /**
@@ -342,6 +379,15 @@ service MilvusService {
     rpc CreateIndex (IndexParam) returns (Status) {
     }
 
+    rpc CreatePartition (PartitionParam) returns (Status) {
+    }
+
+    rpc ShowPartitions (TableName) returns (PartitionList) {
+    }
+
+    rpc DropPartition (PartitionParam) returns (Status) {
+    }
+
     /**
 
      * @brief Add vector array to table

+ 56 - 0
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -172,6 +172,62 @@ class MilvusClientTest {
     assertEquals(Response.Status.TABLE_NOT_EXISTS, dropTableResponse.getStatus());
   }
 
+  @org.junit.jupiter.api.Test
+  void partitionTest() throws InterruptedException {
+    final String partitionName = "partition";
+    final String tag = "tag";
+
+    Partition partition = new Partition.Builder(randomTableName, partitionName, tag).build();
+    Response createPartitionResponse = client.createPartition(partition);
+    assertTrue(createPartitionResponse.ok());
+
+    List<List<Float>> vectors = generateVectors(size, dimension);
+    InsertParam insertParam = new InsertParam.Builder(randomTableName, vectors).withPartitionTag(tag).build();
+    InsertResponse insertResponse = client.insert(insertParam);
+    assertTrue(insertResponse.ok());
+
+    TimeUnit.SECONDS.sleep(1);
+
+    final int searchSize = 5;
+    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
+
+    List<String> partitionTags = new ArrayList<>();
+    partitionTags.add(tag);
+    final long topK = 10;
+    SearchParam searchParam =
+            new SearchParam.Builder(randomTableName, vectorsToSearch)
+                    .withTopK(topK)
+                    .withNProbe(20)
+                    .withPartitionTags(partitionTags)
+                    .build();
+    SearchResponse searchResponse = client.search(searchParam);
+    assertTrue(searchResponse.ok());
+    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+    assertEquals(searchSize, resultIdsList.size());
+    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+    assertEquals(searchSize, resultDistancesList.size());
+    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+    assertEquals(searchSize, queryResultsList.size());
+
+    final String partitionName2 = "partition2";
+    final String tag2 = "tag2";
+
+    Partition partition2 = new Partition.Builder(randomTableName, partitionName2, tag2).build();
+    createPartitionResponse = client.createPartition(partition2);
+    assertTrue(createPartitionResponse.ok());
+
+    ShowPartitionsResponse showPartitionsResponse = client.showPartitions(randomTableName);
+    assertTrue(showPartitionsResponse.ok());
+    assertEquals(2, showPartitionsResponse.getPartitionList().size());
+
+    Response dropPartitionResponse = client.dropPartition(partitionName);
+    assertTrue(dropPartitionResponse.ok());
+
+    dropPartitionResponse = client.dropPartition(randomTableName, tag2);
+    assertTrue(dropPartitionResponse.ok());
+
+  }
+
   @org.junit.jupiter.api.Test
   void createIndex() {
     Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8).withNList(16384).build();