Ver Fonte

Merge pull request #123 from dddddai/master

Optimize the code
Xiaohai Xu há 5 anos atrás
pai
commit
73fbf7f6e4
1 ficheiros alterados com 24 adições e 23 exclusões
  1. 24 23
      src/main/java/io/milvus/client/MilvusGrpcClient.java

+ 24 - 23
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -32,6 +32,7 @@ import io.milvus.grpc.*;
 import java.nio.Buffer;
 import java.nio.Buffer;
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.List;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import java.util.function.Function;
@@ -397,7 +398,7 @@ public class MilvusGrpcClient implements MilvusClient {
     if (!channelIsReadyOrIdle()) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
       logWarning("You are not connected to Milvus server");
       return new ListPartitionsResponse(
       return new ListPartitionsResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
+          new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
     }
     }
 
 
     CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
     CollectionName request = CollectionName.newBuilder().setCollectionName(collectionName).build();
@@ -419,12 +420,12 @@ public class MilvusGrpcClient implements MilvusClient {
             new Response(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
                 response.getStatus().getReason()),
-            new ArrayList<>());
+            Collections.emptyList());
       }
       }
     } catch (StatusRuntimeException e) {
     } catch (StatusRuntimeException e) {
       logError("listPartitions RPC failed:\n{}", e.getStatus().toString());
       logError("listPartitions RPC failed:\n{}", e.getStatus().toString());
       return new ListPartitionsResponse(
       return new ListPartitionsResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
+          new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList());
     }
     }
   }
   }
 
 
@@ -467,7 +468,7 @@ public class MilvusGrpcClient implements MilvusClient {
     if (!channelIsReadyOrIdle()) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
       logWarning("You are not connected to Milvus server");
       return new InsertResponse(
       return new InsertResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
+          new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
     }
     }
 
 
     List<RowRecord> rowRecordList =
     List<RowRecord> rowRecordList =
@@ -498,12 +499,12 @@ public class MilvusGrpcClient implements MilvusClient {
             new Response(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
                 response.getStatus().getReason()),
-            new ArrayList<>());
+            Collections.emptyList());
       }
       }
     } catch (StatusRuntimeException e) {
     } catch (StatusRuntimeException e) {
       logError("insert RPC failed:\n{}", e.getStatus().toString());
       logError("insert RPC failed:\n{}", e.getStatus().toString());
       return new InsertResponse(
       return new InsertResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
+          new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList());
     }
     }
   }
   }
 
 
@@ -514,7 +515,7 @@ public class MilvusGrpcClient implements MilvusClient {
       logWarning("You are not connected to Milvus server");
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(
       return Futures.immediateFuture(
           new InsertResponse(
           new InsertResponse(
-              new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>()));
+              new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList()));
     }
     }
 
 
     List<RowRecord> rowRecordList =
     List<RowRecord> rowRecordList =
@@ -564,7 +565,7 @@ public class MilvusGrpcClient implements MilvusClient {
                 new Response(
                 new Response(
                     Response.Status.valueOf(vectorIds.getStatus().getErrorCodeValue()),
                     Response.Status.valueOf(vectorIds.getStatus().getErrorCodeValue()),
                     vectorIds.getStatus().getReason()),
                     vectorIds.getStatus().getReason()),
-                new ArrayList<>());
+                Collections.emptyList());
           }
           }
         };
         };
 
 
@@ -748,7 +749,7 @@ public class MilvusGrpcClient implements MilvusClient {
     if (!channelIsReadyOrIdle()) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
       logWarning("You are not connected to Milvus server");
       return new ListCollectionsResponse(
       return new ListCollectionsResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
+          new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
     }
     }
 
 
     Command request = Command.newBuilder().setCmd("").build();
     Command request = Command.newBuilder().setCmd("").build();
@@ -767,12 +768,12 @@ public class MilvusGrpcClient implements MilvusClient {
             new Response(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
                 response.getStatus().getReason()),
-            new ArrayList<>());
+            Collections.emptyList());
       }
       }
     } catch (StatusRuntimeException e) {
     } catch (StatusRuntimeException e) {
       logError("listCollections RPC failed:\n{}", e.getStatus().toString());
       logError("listCollections RPC failed:\n{}", e.getStatus().toString());
       return new ListCollectionsResponse(
       return new ListCollectionsResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
+          new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList());
     }
     }
   }
   }
 
 
@@ -987,7 +988,7 @@ public class MilvusGrpcClient implements MilvusClient {
     if (!channelIsReadyOrIdle()) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
       logWarning("You are not connected to Milvus server");
       return new GetEntityByIDResponse(
       return new GetEntityByIDResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>(), null);
+          new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList(), null);
     }
     }
 
 
     VectorsIdentity request =
     VectorsIdentity request =
@@ -1001,8 +1002,8 @@ public class MilvusGrpcClient implements MilvusClient {
 
 
         logInfo("getEntityByID in collection `{}` returned successfully!", collectionName);
         logInfo("getEntityByID in collection `{}` returned successfully!", collectionName);
 
 
-        List<List<Float>> floatVectors = new ArrayList<>();
-        List<ByteBuffer> binaryVectors = new ArrayList<>();
+        List<List<Float>> floatVectors = new ArrayList<>(ids.size());
+        List<ByteBuffer> binaryVectors = new ArrayList<>(ids.size());
         for (int i = 0; i < ids.size(); i++) {
         for (int i = 0; i < ids.size(); i++) {
           floatVectors.add(response.getVectorsData(i).getFloatDataList());
           floatVectors.add(response.getVectorsData(i).getFloatDataList());
           binaryVectors.add(response.getVectorsData(i).getBinaryData().asReadOnlyByteBuffer());
           binaryVectors.add(response.getVectorsData(i).getBinaryData().asReadOnlyByteBuffer());
@@ -1019,13 +1020,13 @@ public class MilvusGrpcClient implements MilvusClient {
             new Response(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
                 response.getStatus().getReason()),
-            new ArrayList<>(),
+            Collections.emptyList(),
             null);
             null);
       }
       }
     } catch (StatusRuntimeException e) {
     } catch (StatusRuntimeException e) {
       logError("getEntityByID RPC failed:\n{}", e.getStatus().toString());
       logError("getEntityByID RPC failed:\n{}", e.getStatus().toString());
       return new GetEntityByIDResponse(
       return new GetEntityByIDResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>(), null);
+          new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList(), null);
     }
     }
   }
   }
 
 
@@ -1034,7 +1035,7 @@ public class MilvusGrpcClient implements MilvusClient {
     if (!channelIsReadyOrIdle()) {
     if (!channelIsReadyOrIdle()) {
       logWarning("You are not connected to Milvus server");
       logWarning("You are not connected to Milvus server");
       return new ListIDInSegmentResponse(
       return new ListIDInSegmentResponse(
-          new Response(Response.Status.CLIENT_NOT_CONNECTED), new ArrayList<>());
+          new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
     }
     }
 
 
     GetVectorIDsParam request =
     GetVectorIDsParam request =
@@ -1065,12 +1066,12 @@ public class MilvusGrpcClient implements MilvusClient {
             new Response(
             new Response(
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 Response.Status.valueOf(response.getStatus().getErrorCodeValue()),
                 response.getStatus().getReason()),
                 response.getStatus().getReason()),
-            new ArrayList<>());
+            Collections.emptyList());
       }
       }
     } catch (StatusRuntimeException e) {
     } catch (StatusRuntimeException e) {
       logError("listIDInSegment RPC failed:\n{}", e.getStatus().toString());
       logError("listIDInSegment RPC failed:\n{}", e.getStatus().toString());
       return new ListIDInSegmentResponse(
       return new ListIDInSegmentResponse(
-          new Response(Response.Status.RPC_ERROR, e.toString()), new ArrayList<>());
+          new Response(Response.Status.RPC_ERROR, e.toString()), Collections.emptyList());
     }
     }
   }
   }
 
 
@@ -1266,9 +1267,9 @@ public class MilvusGrpcClient implements MilvusClient {
 
 
   private List<RowRecord> buildRowRecordList(
   private List<RowRecord> buildRowRecordList(
       @Nonnull List<List<Float>> floatVectors, @Nonnull List<ByteBuffer> binaryVectors) {
       @Nonnull List<List<Float>> floatVectors, @Nonnull List<ByteBuffer> binaryVectors) {
-    List<RowRecord> rowRecordList = new ArrayList<>();
-
     int largerSize = Math.max(floatVectors.size(), binaryVectors.size());
     int largerSize = Math.max(floatVectors.size(), binaryVectors.size());
+    
+    List<RowRecord> rowRecordList = new ArrayList<>(largerSize); 
 
 
     for (int i = 0; i < largerSize; ++i) {
     for (int i = 0; i < largerSize; ++i) {
 
 
@@ -1297,8 +1298,8 @@ public class MilvusGrpcClient implements MilvusClient {
             : topKQueryResult.getIdsCount()
             : topKQueryResult.getIdsCount()
                 / numQueries; // Guaranteed to be divisible from server side
                 / numQueries; // Guaranteed to be divisible from server side
 
 
-    List<List<Long>> resultIdsList = new ArrayList<>();
-    List<List<Float>> resultDistancesList = new ArrayList<>();
+    List<List<Long>> resultIdsList = new ArrayList<>(numQueries);
+    List<List<Float>> resultDistancesList = new ArrayList<>(numQueries);
 
 
     if (topK > 0) {
     if (topK > 0) {
       for (int i = 0; i < numQueries; i++) {
       for (int i = 0; i < numQueries; i++) {