Bladeren bron

Simplify `countEntities`

jianghua 4 jaren geleden
bovenliggende
commit
e4da087594

+ 0 - 55
src/main/java/io/milvus/client/CountEntitiesResponse.java

@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package io.milvus.client;
-
-/**
- * Contains the returned <code>response</code> and <code>collectionEntityCount</code> for <code>
- * countEntities</code>
- */
-public class CountEntitiesResponse {
-  private final Response response;
-  private final long collectionEntityCount;
-
-  CountEntitiesResponse(Response response, long collectionEntityCount) {
-    this.response = response;
-    this.collectionEntityCount = collectionEntityCount;
-  }
-
-  /** @return collection entity count */
-  public long getCollectionEntityCount() {
-    return collectionEntityCount;
-  }
-
-  public Response getResponse() {
-    return response;
-  }
-
-  /** @return <code>true</code> if the response status equals SUCCESS */
-  public boolean ok() {
-    return response.ok();
-  }
-
-  @Override
-  public String toString() {
-    return String.format(
-        "CountCollectionResponse {%s, collection entity count = %d}",
-        response.toString(), collectionEntityCount);
-  }
-}

+ 4 - 8
src/main/java/io/milvus/client/MilvusClient.java

@@ -275,21 +275,17 @@ public interface MilvusClient {
   /**
   /**
    * Lists current collections
    * Lists current collections
    *
    *
-   * @return <code>ListCollectionsResponse</code>
-   * @see ListCollectionsResponse
-   * @see Response
+   * @return a list of collection names
    */
    */
-  ListCollectionsResponse listCollections();
+  List<String> listCollections();
 
 
   /**
   /**
    * Gets current entity count of a collection
    * Gets current entity count of a collection
    *
    *
    * @param collectionName collection name
    * @param collectionName collection name
-   * @return <code>CountEntitiesResponse</code>
-   * @see CountEntitiesResponse
-   * @see Response
+   * @return a count of entities in the collection
    */
    */
-  CountEntitiesResponse countEntities(String collectionName);
+  long countEntities(String collectionName);
 
 
   /**
   /**
    * Gets server status
    * Gets server status

+ 14 - 63
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -367,72 +367,23 @@ abstract class AbstractMilvusGrpcClient implements MilvusClient {
   }
   }
 
 
   @Override
   @Override
-  public ListCollectionsResponse listCollections() {
-
-    if (!maybeAvailable()) {
-      logWarning("You are not connected to Milvus server");
-      return new ListCollectionsResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
-    }
-
-    Command request = Command.newBuilder().setCmd("").build();
-    CollectionNameList response;
-
-    try {
-      response = blockingStub().showCollections(request);
-
-      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-        List<String> collectionNames = response.getCollectionNamesList();
-        logInfo("Current collections: {}", collectionNames.toString());
-        return new ListCollectionsResponse(new Response(Response.Status.SUCCESS), collectionNames);
-      } else {
-        logError("List collections failed:\n{}", response.getStatus().toString());
-        return new ListCollectionsResponse(
-            new Response(
-                Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()),
-            Collections.emptyList());
-      }
-    } catch (StatusRuntimeException e) {
-      logError("listCollections RPC failed:\n{}", e.getStatus().toString());
-      return new ListCollectionsResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList());
-    }
+  public List<String> listCollections() {
+    return translateExceptions(() -> {
+      Command request = Command.newBuilder().setCmd("").build();
+      CollectionNameList response = blockingStub().showCollections(request);
+      checkResponseStatus(response.getStatus());
+      return response.getCollectionNamesList();
+    });
   }
   }
 
 
   @Override
   @Override
-  public CountEntitiesResponse countEntities(@Nonnull String collectionName) {
-
-    if (!maybeAvailable()) {
-      logWarning("You are not connected to Milvus server");
-      return new CountEntitiesResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), 0);
-    }
-
-    CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
-    CollectionRowCount response;
-
-    try {
-      response = blockingStub().countCollection(request);
-
-      if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
-        long collectionRowCount = response.getCollectionRowCount();
-        logInfo("Collection `{}` has {} entities", collectionName, collectionRowCount);
-        return new CountEntitiesResponse(new Response(Response.Status.SUCCESS), collectionRowCount);
-      } else {
-        logError(
-            "Get collection `{}` entity count failed:\n{}",
-            collectionName,
-            response.getStatus().toString());
-        return new CountEntitiesResponse(
-            new Response(
-                Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
-                response.getStatus().getReason()),
-            0);
-      }
-    } catch (StatusRuntimeException e) {
-      logError("countEntities RPC failed:\n{}", e.getStatus().toString());
-      return new CountEntitiesResponse(new Response(Response.Status.RPC_ERROR, e.toString()), 0);
-    }
+  public long countEntities(@Nonnull String collectionName) {
+    return translateExceptions(() -> {
+      CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
+      CollectionRowCount response = blockingStub().countCollection(request);
+      checkResponseStatus(response.getStatus());
+      return response.getCollectionRowCount();
+    });
   }
   }
 
 
   @Override
   @Override

+ 6 - 11
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -241,7 +241,7 @@ class MilvusClientTest {
     MilvusClient client = new MilvusGrpcClient(connectParam);
     MilvusClient client = new MilvusGrpcClient(connectParam);
     TimeUnit.SECONDS.sleep(2);
     TimeUnit.SECONDS.sleep(2);
     // A new RPC would take the channel out of idle mode
     // A new RPC would take the channel out of idle mode
-    assertTrue(client.listCollections().ok());
+    client.listCollections();
   }
   }
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
@@ -385,8 +385,7 @@ class MilvusClientTest {
 
 
     assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.flush(randomCollectionName).ok());
 
 
-    assertEquals(size * 2,
-        client.countEntities(randomCollectionName).getCollectionEntityCount());
+    assertEquals(size * 2, client.countEntities(randomCollectionName));
 
 
     final int searchSize = 1;
     final int searchSize = 1;
     final long topK = 10;
     final long topK = 10;
@@ -621,9 +620,8 @@ class MilvusClientTest {
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
   void listCollections() {
   void listCollections() {
-    ListCollectionsResponse listCollectionsResponse = client.listCollections();
-    assertTrue(listCollectionsResponse.ok());
-    assertTrue(listCollectionsResponse.getCollectionNames().contains(randomCollectionName));
+    List<String> collectionList = client.listCollections();
+    assertTrue(collectionList.contains(randomCollectionName));
   }
   }
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
@@ -642,10 +640,7 @@ class MilvusClientTest {
   void countEntities() {
   void countEntities() {
     insert();
     insert();
     assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.flush(randomCollectionName).ok());
-
-    CountEntitiesResponse countEntitiesResponse = client.countEntities(randomCollectionName);
-    assertTrue(countEntitiesResponse.ok());
-    assertEquals(size, countEntitiesResponse.getCollectionEntityCount());
+    assertEquals(size, client.countEntities(randomCollectionName));
   }
   }
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test
@@ -796,7 +791,7 @@ class MilvusClientTest {
 
 
     assertTrue(client.deleteEntityByID(randomCollectionName, entityIds.subList(0, 100)).ok());
     assertTrue(client.deleteEntityByID(randomCollectionName, entityIds.subList(0, 100)).ok());
     assertTrue(client.flush(randomCollectionName).ok());
     assertTrue(client.flush(randomCollectionName).ok());
-    assertEquals(client.countEntities(randomCollectionName).getCollectionEntityCount(), size - 100);
+    assertEquals(size - 100, client.countEntities(randomCollectionName));
   }
   }
 
 
   @org.junit.jupiter.api.Test
   @org.junit.jupiter.api.Test