Browse Source

Fix bug of SearchResultsWrapper (#240)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 years ago
parent
commit
116c836568

+ 67 - 60
examples/main/io/milvus/GeneralExample.java

@@ -28,9 +28,15 @@ import io.milvus.param.index.*;
 import io.milvus.param.partition.*;
 import io.milvus.Response.*;
 
-import java.nio.ByteBuffer;
 import java.util.*;
 
+///////////////////////////////////////////////////////////////////////////////////////////////////////////
+// Note:
+// Due do a technical limitation, the Milvus 2.0 not allow to create multi-vector-fields within a collection.
+// So this example only create a single vector field in the collection, but we suppose the next version
+// should support this function.
+///////////////////////////////////////////////////////////////////////////////////////////////////////////
+
 public class GeneralExample {
     private static final MilvusServiceClient milvusClient;
 
@@ -47,8 +53,8 @@ public class GeneralExample {
     private static final String VECTOR_FIELD = "userFace";
     private static final Integer VECTOR_DIM = 64;
     private static final String AGE_FIELD = "userAge";
-    private static final String PROFILE_FIELD = "userProfile";
-    private static final Integer BINARY_DIM = 128;
+//    private static final String PROFILE_FIELD = "userProfile";
+//    private static final Integer BINARY_DIM = 128;
 
     private static final IndexType INDEX_TYPE = IndexType.IVF_FLAT;
     private static final String INDEX_PARAM = "{\"nlist\":128}";
@@ -79,12 +85,12 @@ public class GeneralExample {
                 .withDataType(DataType.Int8)
                 .build();
 
-        FieldType fieldType4 = FieldType.newBuilder()
-                .withName(PROFILE_FIELD)
-                .withDescription("user profile")
-                .withDataType(DataType.BinaryVector)
-                .withDimension(BINARY_DIM)
-                .build();
+//        FieldType fieldType4 = FieldType.newBuilder()
+//                .withName(PROFILE_FIELD)
+//                .withDescription("user profile")
+//                .withDataType(DataType.BinaryVector)
+//                .withDimension(BINARY_DIM)
+//                .build();
 
         CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
@@ -93,7 +99,7 @@ public class GeneralExample {
                 .addFieldType(fieldType1)
                 .addFieldType(fieldType2)
                 .addFieldType(fieldType3)
-                .addFieldType(fieldType4)
+//                .addFieldType(fieldType4)
                 .build();
         R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
 
@@ -311,43 +317,44 @@ public class GeneralExample {
             System.out.println("Search result of No." + i);
             List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
             System.out.println(scores);
+            System.out.println("Output field data for No." + i);
+            System.out.println(wrapper.getFieldData(AGE_FIELD, i));
         }
-        System.out.println(wrapper.getFieldData(AGE_FIELD).getFieldData());
 
         return response;
     }
 
-    private R<SearchResults> searchProfile(String expr) {
-        System.out.println("========== searchProfile() ==========");
-
-        List<String> outFields = Collections.singletonList(AGE_FIELD);
-        List<ByteBuffer> vectors = generateBinaryVectors(5);
-
-        SearchParam searchParam = SearchParam.newBuilder()
-                .withCollectionName(COLLECTION_NAME)
-                .withMetricType(MetricType.HAMMING)
-                .withOutFields(outFields)
-                .withTopK(SEARCH_K)
-                .withVectors(vectors)
-                .withVectorFieldName(PROFILE_FIELD)
-                .withExpr(expr)
-                .withParams(SEARCH_PARAM)
-                .build();
-
-
-        R<SearchResults> response = milvusClient.search(searchParam);
-
-        SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
-        for (int i = 0; i < vectors.size(); ++i) {
-            System.out.println("Search result of No." + i);
-            List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
-            System.out.println(scores);
-        }
-
-        System.out.println(wrapper.getFieldData(AGE_FIELD).getFieldData());
-
-        return response;
-    }
+//    private R<SearchResults> searchProfile(String expr) {
+//        System.out.println("========== searchProfile() ==========");
+//
+//        List<String> outFields = Collections.singletonList(AGE_FIELD);
+//        List<ByteBuffer> vectors = generateBinaryVectors(5);
+//
+//        SearchParam searchParam = SearchParam.newBuilder()
+//                .withCollectionName(COLLECTION_NAME)
+//                .withMetricType(MetricType.HAMMING)
+//                .withOutFields(outFields)
+//                .withTopK(SEARCH_K)
+//                .withVectors(vectors)
+//                .withVectorFieldName(PROFILE_FIELD)
+//                .withExpr(expr)
+//                .withParams(SEARCH_PARAM)
+//                .build();
+//
+//
+//        R<SearchResults> response = milvusClient.search(searchParam);
+//
+//        SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
+//        for (int i = 0; i < vectors.size(); ++i) {
+//            System.out.println("Search result of No." + i);
+//            List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
+//            System.out.println(scores);
+//            System.out.println("Output field data for No." + i);
+//            System.out.println(wrapper.getFieldData(AGE_FIELD, i));
+//        }
+//
+//        return response;
+//    }
 
     private R<CalcDistanceResults> calDistance() {
         System.out.println("========== calDistance() ==========");
@@ -388,7 +395,7 @@ public class GeneralExample {
     private R<MutationResult> insert(String partitionName, int count) {
         System.out.println("========== insert() ==========");
         List<List<Float>> vectors = generateFloatVectors(count);
-        List<ByteBuffer> profiles = generateBinaryVectors(count);
+//        List<ByteBuffer> profiles = generateBinaryVectors(count);
 
         Random ran = new Random();
         List<Integer> ages = new ArrayList<>();
@@ -398,7 +405,7 @@ public class GeneralExample {
 
         List<InsertParam.Field> fields = new ArrayList<>();
         fields.add(new InsertParam.Field(VECTOR_FIELD, DataType.FloatVector, vectors));
-        fields.add(new InsertParam.Field(PROFILE_FIELD, DataType.BinaryVector, profiles));
+//        fields.add(new InsertParam.Field(PROFILE_FIELD, DataType.BinaryVector, profiles));
         fields.add(new InsertParam.Field(AGE_FIELD, DataType.Int8, ages));
 
         InsertParam insertParam = InsertParam.newBuilder()
@@ -424,20 +431,20 @@ public class GeneralExample {
         return vectors;
     }
 
-    private List<ByteBuffer> generateBinaryVectors(int count) {
-        Random ran = new Random();
-        List<ByteBuffer> vectors = new ArrayList<>();
-        int byteCount = BINARY_DIM/8;
-        for (int n = 0; n < count; ++n) {
-            ByteBuffer vector = ByteBuffer.allocate(byteCount);
-            for (int i = 0; i < byteCount; ++i) {
-                vector.put((byte)ran.nextInt(Byte.MAX_VALUE));
-            }
-            vectors.add(vector);
-        }
-        return vectors;
-
-    }
+//    private List<ByteBuffer> generateBinaryVectors(int count) {
+//        Random ran = new Random();
+//        List<ByteBuffer> vectors = new ArrayList<>();
+//        int byteCount = BINARY_DIM/8;
+//        for (int n = 0; n < count; ++n) {
+//            ByteBuffer vector = ByteBuffer.allocate(byteCount);
+//            for (int i = 0; i < byteCount; ++i) {
+//                vector.put((byte)ran.nextInt(Byte.MAX_VALUE));
+//            }
+//            vectors.add(vector);
+//        }
+//        return vectors;
+//
+//    }
 
     public static void main(String[] args) {
         GeneralExample example = new GeneralExample();
@@ -476,8 +483,8 @@ public class GeneralExample {
         example.query(queryExpr);
         String searchExpr = AGE_FIELD + " > 50";
         example.searchFace(searchExpr);
-        searchExpr = AGE_FIELD + " <= 30";
-        example.searchProfile(searchExpr);
+//        searchExpr = AGE_FIELD + " <= 30";
+//        example.searchProfile(searchExpr);
         example.calDistance();
 
         example.releasePartition(partitionName);

+ 3 - 2
src/main/java/io/milvus/Response/DescCollResponseWrapper.java

@@ -101,13 +101,14 @@ public class DescCollResponseWrapper {
      * Get schema of a field by name.
      * Return null if the field doesn't exist
      *
+     * @param fieldName field name to get field description
      * @return <code>FieldType</code> schema of the field
      */
-    public FieldType getFieldByName(@NonNull String name) {
+    public FieldType getFieldByName(@NonNull String fieldName) {
         CollectionSchema schema = response.getSchema();
         for (int i = 0; i < schema.getFieldsCount(); ++i) {
             FieldSchema field = schema.getFields(i);
-            if (name.compareTo(field.getName()) == 0) {
+            if (fieldName.compareTo(field.getName()) == 0) {
                 return convertField(field);
             }
         }

+ 3 - 2
src/main/java/io/milvus/Response/DescIndexResponseWrapper.java

@@ -45,12 +45,13 @@ public class DescIndexResponseWrapper {
      * Get index description by field name.
      * Return null if the field doesn't exist
      *
+     * @param fieldName field name to get index description
      * @return <code>IndexDesc</code> description of the index
      */
-    public IndexDesc getIndexDescByFieldName(@NonNull String name) {
+    public IndexDesc getIndexDescByFieldName(@NonNull String fieldName) {
         for (int i = 0; i < response.getIndexDescriptionsCount(); ++i) {
             IndexDescription desc = response.getIndexDescriptions(i);
-            if (name.compareTo(desc.getFieldName()) == 0) {
+            if (fieldName.compareTo(desc.getFieldName()) == 0) {
                 IndexDesc res = new IndexDesc(desc.getFieldName(), desc.getIndexName(), desc.getIndexID());
                 desc.getParamsList().forEach((kv)-> res.addParam(kv.getKey(), kv.getValue()));
                 return res;

+ 1 - 0
src/main/java/io/milvus/Response/QueryResultsWrapper.java

@@ -21,6 +21,7 @@ public class QueryResultsWrapper {
      * Gets {@link FieldDataWrapper} for a field.
      * Throws {@link ParamException} if the field doesn't exist.
      *
+     * @param fieldName field name to get output data
      * @return <code>FieldDataWrapper</code>
      */
     public FieldDataWrapper getFieldWrapper(@NonNull String fieldName) throws ParamException {

+ 61 - 25
src/main/java/io/milvus/Response/SearchResultsWrapper.java

@@ -20,20 +20,37 @@ public class SearchResultsWrapper {
     }
 
     /**
-     * Gets {@link FieldDataWrapper} for a field.
+     * Gets data for an output field which is specified by search request.
      * Throws {@link ParamException} if the field doesn't exist.
+     * Throws {@link ParamException} if the indexOfTarget is illegal.
      *
+     * @param fieldName field name to get output data
+     * @param indexOfTarget which target vector the field data belongs to
      * @return <code>FieldDataWrapper</code>
      */
-    public FieldDataWrapper getFieldData(@NonNull String fieldName) {
+    public List<?> getFieldData(@NonNull String fieldName, int indexOfTarget) {
+        FieldDataWrapper wrapper = null;
         for (int i = 0; i < results.getFieldsDataCount(); ++i) {
             FieldData data = results.getFieldsData(i);
             if (fieldName.compareTo(data.getFieldName()) == 0) {
-                return new FieldDataWrapper(data);
+                wrapper = new FieldDataWrapper(data);
             }
         }
 
-        return null;
+        if (wrapper == null) {
+            throw new ParamException("Illegal field name: " + fieldName);
+        }
+
+        Position position = getOffsetByIndex(indexOfTarget);
+        long offset = position.getOffset();
+        long k = position.getK();
+
+        List<?> allData = wrapper.getFieldData();
+        if (offset + k > allData.size()) {
+            throw new IllegalResponseException("Field data row count is wrong");
+        }
+
+        return allData.subList((int)offset, (int)offset + (int)k);
     }
 
     /**
@@ -41,29 +58,14 @@ public class SearchResultsWrapper {
      * Throws {@link ParamException} if the indexOfTarget is illegal.
      * Throws {@link IllegalResponseException} if the returned results is illegal.
      *
+     * @param indexOfTarget which target vector the result belongs to
      * @return <code>List<IDScore></code> ID-score pairs returned by search interface
      */
     public List<IDScore> getIDScore(int indexOfTarget) throws ParamException, IllegalResponseException {
-        List<Long> kList = results.getTopksList();
+        Position position = getOffsetByIndex(indexOfTarget);
 
-        // if the server didn't return separate topK, use same topK value
-        if (kList.isEmpty()) {
-            kList = new ArrayList<>();
-            for (long i = 0; i < results.getNumQueries(); ++i) {
-                kList.add(results.getTopK());
-            }
-        }
-
-        if (indexOfTarget < 0 || indexOfTarget >= kList.size()) {
-            throw new ParamException("Illegal index of target: " + indexOfTarget);
-        }
-
-        int offset = 0;
-        for (int i = 0; i < indexOfTarget; ++i) {
-            offset += kList.get(i);
-        }
-
-        long k = kList.get(indexOfTarget);
+        long offset = position.getOffset();
+        long k = position.getK();
         if (offset + k > results.getScoresCount()) {
             throw new IllegalResponseException("Result scores count is wrong");
         }
@@ -78,7 +80,7 @@ public class SearchResultsWrapper {
             }
 
             for (int n = 0; n < k; ++n) {
-                idScore.add(new IDScore("", longIDs.getData(offset + n), results.getScores(offset + n)));
+                idScore.add(new IDScore("", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
             }
         } else if (ids.hasStrId()) {
             StringArray strIDs = ids.getStrId();
@@ -87,7 +89,7 @@ public class SearchResultsWrapper {
             }
 
             for (int n = 0; n < k; ++n) {
-                idScore.add(new IDScore(strIDs.getData(offset + n), 0, results.getScores(offset + n)));
+                idScore.add(new IDScore(strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
             }
         } else {
             throw new IllegalResponseException("Result ids is illegal");
@@ -96,6 +98,40 @@ public class SearchResultsWrapper {
         return idScore;
     }
 
+    @Getter
+    private static final class Position {
+        private final long offset;
+        private final long k;
+
+        public Position(long offset, long k) {
+            this.offset = offset;
+            this.k = k;
+        }
+    }
+    private Position getOffsetByIndex(int indexOfTarget) {
+        List<Long> kList = results.getTopksList();
+
+        // if the server didn't return separate topK, use same topK value
+        if (kList.isEmpty()) {
+            kList = new ArrayList<>();
+            for (long i = 0; i < results.getNumQueries(); ++i) {
+                kList.add(results.getTopK());
+            }
+        }
+
+        if (indexOfTarget < 0 || indexOfTarget >= kList.size()) {
+            throw new ParamException("Illegal index of target: " + indexOfTarget);
+        }
+
+        long offset = 0;
+        for (int i = 0; i < indexOfTarget; ++i) {
+            offset += kList.get(i);
+        }
+
+        long k = kList.get(indexOfTarget);
+        return new Position(offset, k);
+    }
+
     /**
      * Internal-use class to wrap response of <code>search</code> interface.
      */

+ 3 - 2
src/main/java/io/milvus/Response/ShowCollResponseWrapper.java

@@ -46,11 +46,12 @@ public class ShowCollResponseWrapper {
     /**
      * Get information of one collection by name.
      *
+     * @param collectionName collection name to get information
      * @return <code>CollectionInfo</code> information of the collection
      */
-    public CollectionInfo getCollectionInfoByName(@NonNull String name) {
+    public CollectionInfo getCollectionInfoByName(@NonNull String collectionName) {
         for (int i = 0; i < response.getCollectionNamesCount(); ++i) {
-            if ( name.compareTo(response.getCollectionNames(i)) == 0) {
+            if ( collectionName.compareTo(response.getCollectionNames(i)) == 0) {
                 CollectionInfo info = new CollectionInfo(response.getCollectionNames(i), response.getCollectionIds(i),
                         response.getCreatedUtcTimestamps(i));
                 if (response.getInMemoryPercentagesCount() > i) {

+ 3 - 2
src/main/java/io/milvus/Response/ShowPartResponseWrapper.java

@@ -46,11 +46,12 @@ public class ShowPartResponseWrapper {
     /**
      * Get information of one partition by name.
      *
+     * @param partitionName partition name to get information
      * @return <code>PartitionInfo</code> information of the partition
      */
-    public PartitionInfo getPartitionInfoByName(@NonNull String name) {
+    public PartitionInfo getPartitionInfoByName(@NonNull String partitionName) {
         for (int i = 0; i < response.getPartitionNamesCount(); ++i) {
-            if ( name.compareTo(response.getPartitionNames(i)) == 0) {
+            if ( partitionName.compareTo(response.getPartitionNames(i)) == 0) {
                 PartitionInfo info = new PartitionInfo(response.getPartitionNames(i), response.getPartitionIDs(i),
                         response.getCreatedUtcTimestamps(i));
                 if (response.getInMemoryPercentagesCount() > i) {

+ 1 - 0
src/main/java/io/milvus/param/dml/InsertParam.java

@@ -23,6 +23,7 @@ import io.milvus.exception.ParamException;
 import io.milvus.grpc.DataType;
 import io.milvus.param.ParamUtils;
 
+import io.milvus.param.partition.LoadPartitionsParam;
 import lombok.Getter;
 import lombok.NonNull;
 import java.nio.ByteBuffer;

+ 29 - 2
src/main/java/io/milvus/param/dml/QueryParam.java

@@ -23,6 +23,7 @@ import com.google.common.collect.Lists;
 import io.milvus.exception.ParamException;
 import io.milvus.param.ParamUtils;
 
+import io.milvus.param.partition.LoadPartitionsParam;
 import lombok.Getter;
 import lombok.NonNull;
 import java.util.ArrayList;
@@ -79,7 +80,20 @@ public class QueryParam {
          * @return <code>Builder</code>
          */
         public Builder withPartitionNames(@NonNull List<String> partitionNames) {
-            this.partitionNames = partitionNames;
+            partitionNames.forEach(this::addPartitionName);
+            return this;
+        }
+
+        /**
+         * Adds a partition to specify query scope (Optional).
+         *
+         * @param partitionName partition name
+         * @return <code>Builder</code>
+         */
+        public Builder addPartitionName(@NonNull String partitionName) {
+            if (!this.partitionNames.contains(partitionName)) {
+                this.partitionNames.add(partitionName);
+            }
             return this;
         }
 
@@ -90,7 +104,20 @@ public class QueryParam {
          * @return <code>Builder</code>
          */
         public Builder withOutFields(@NonNull List<String> outFields) {
-            this.outFields = outFields;
+            outFields.forEach(this::addOutField);
+            return this;
+        }
+
+        /**
+         * Specifies an output field (Optional).
+         *
+         * @param fieldName field name
+         * @return <code>Builder</code>
+         */
+        public Builder addOutField(@NonNull String fieldName) {
+            if (!this.outFields.contains(fieldName)) {
+                this.outFields.add(fieldName);
+            }
             return this;
         }
 

+ 27 - 1
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -99,7 +99,20 @@ public class SearchParam {
          * @return <code>Builder</code>
          */
         public Builder withPartitionNames(@NonNull List<String> partitionNames) {
-            this.partitionNames = partitionNames;
+            partitionNames.forEach(this::addPartitionName);
+            return this;
+        }
+
+        /**
+         * Adds a partition to specify search scope (Optional).
+         *
+         * @param partitionName partition name
+         * @return <code>Builder</code>
+         */
+        public Builder addPartitionName(@NonNull String partitionName) {
+            if (!this.partitionNames.contains(partitionName)) {
+                this.partitionNames.add(partitionName);
+            }
             return this;
         }
 
@@ -159,6 +172,19 @@ public class SearchParam {
             return this;
         }
 
+        /**
+         * Specifies an output field (Optional).
+         *
+         * @param fieldName filed name
+         * @return <code>Builder</code>
+         */
+        public Builder addOutField(@NonNull String fieldName) {
+            if (!this.outFields.contains(fieldName)) {
+                this.outFields.add(fieldName);
+            }
+            return this;
+        }
+
         /**
          * Sets the target vectors.
          *

+ 51 - 46
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -168,6 +168,7 @@ public class MilvusClientDockerTest {
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void testFloatVectors() {
         String randomCollectionName = generator.generate(10);
 
@@ -219,7 +220,7 @@ public class MilvusClientDockerTest {
                 .build();
 
         R<RpcStatus> createR = client.createCollection(createParam);
-        assertEquals(createR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
 
         R<DescribeCollectionResponse> response = client.describeCollection(DescribeCollectionParam.newBuilder()
                 .withCollectionName(randomCollectionName)
@@ -229,7 +230,7 @@ public class MilvusClientDockerTest {
         System.out.println(desc.toString());
 
         // insert data
-        int rowCount = 10000;
+        int rowCount = 10;
         List<Long> ids = new ArrayList<>();
         List<Boolean> genders = new ArrayList<>();
         List<Double> weights = new ArrayList<>();
@@ -237,8 +238,8 @@ public class MilvusClientDockerTest {
         for (long i = 0L; i < rowCount; ++i) {
             ids.add(i);
             genders.add(i%3 == 0 ? Boolean.TRUE : Boolean.FALSE);
-            weights.add((double) (i / 100));
-            ages.add((short)(i%99));
+            weights.add( ((double)(i + 1) / 100));
+            ages.add((short)((i + 1)%99));
         }
         List<List<Float>> vectors = generateFloatVectors(rowCount);
 
@@ -255,7 +256,7 @@ public class MilvusClientDockerTest {
                 .build();
 
         R<MutationResult> insertR = client.insert(insertParam);
-        assertEquals(insertR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
 //        System.out.println(insertR.getData());
 
         // get collection statistics
@@ -264,7 +265,7 @@ public class MilvusClientDockerTest {
                 .withCollectionName(randomCollectionName)
                 .withFlush(true)
                 .build());
-        assertEquals(statR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), statR.getStatus().intValue());
 
         GetCollStatResponseWrapper stat = new GetCollStatResponseWrapper(statR.getData());
         System.out.println("Collection row count: " + stat.getRowCount());
@@ -276,7 +277,7 @@ public class MilvusClientDockerTest {
                 .withPartitionName("_default") // each collection has '_default' partition
                 .withFlush(true)
                 .build());
-        assertEquals(statPartR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), statPartR.getStatus().intValue());
 
         GetPartStatResponseWrapper statPart = new GetPartStatResponseWrapper(statPartR.getData());
         System.out.println("Partition row count: " + statPart.getRowCount());
@@ -294,7 +295,7 @@ public class MilvusClientDockerTest {
                 .build();
 
         R<RpcStatus> createIndexR = client.createIndex(indexParam);
-        assertEquals(createIndexR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue());
 
         // get index description
         DescribeIndexParam descIndexParam = DescribeIndexParam.newBuilder()
@@ -302,7 +303,7 @@ public class MilvusClientDockerTest {
                 .withFieldName(field2Name)
                 .build();
         R<DescribeIndexResponse> descIndexR = client.describeIndex(descIndexParam);
-        assertEquals(descIndexR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), descIndexR.getStatus().intValue());
 
         DescIndexResponseWrapper indexDesc = new DescIndexResponseWrapper(descIndexR.getData());
         System.out.println("Index description: " + indexDesc.toString());
@@ -311,13 +312,13 @@ public class MilvusClientDockerTest {
         R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
                 .withCollectionName(randomCollectionName)
                 .build());
-        assertEquals(loadR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
 
         // show collections
         R<ShowCollectionsResponse> showR = client.showCollections(ShowCollectionsParam.newBuilder()
                 .addCollectionName(randomCollectionName)
                 .build());
-        assertEquals(showR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), showR.getStatus().intValue());
         ShowCollResponseWrapper info = new ShowCollResponseWrapper(showR.getData());
         System.out.println("Collection info: " + info.toString());
 
@@ -326,21 +327,19 @@ public class MilvusClientDockerTest {
                 .withCollectionName(randomCollectionName)
                 .addPartitionName("_default") // each collection has a '_default' partition
                 .build());
-        assertEquals(showPartR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), showPartR.getStatus().intValue());
         ShowPartResponseWrapper infoPart = new ShowPartResponseWrapper(showPartR.getData());
         System.out.println("Partition info: " + infoPart.toString());
 
         // query vectors to verify
         List<Long> queryIDs = new ArrayList<>();
-        List<Boolean> compareGenders = new ArrayList<>();
         List<Double> compareWeights = new ArrayList<>();
         int nq = 5;
         Random ran = new Random();
-        for (int i = 0; i < nq; ++i) {
-            int randomIndex = ran.nextInt(rowCount);
-            queryIDs.add(ids.get(randomIndex));
-            compareGenders.add(genders.get(randomIndex));
-            compareWeights.add(weights.get(randomIndex));
+        int randomIndex = ran.nextInt(rowCount - nq);
+        for (int i = randomIndex; i < randomIndex + nq; ++i) {
+            queryIDs.add(ids.get(i));
+            compareWeights.add(weights.get(i));
         }
         String expr = field1Name + " in " + queryIDs.toString();
         List<String> outputFields = Arrays.asList(field1Name, field2Name, field3Name, field4Name, field4Name);
@@ -351,7 +350,7 @@ public class MilvusClientDockerTest {
                 .build();
 
         R<QueryResults> queryR= client.query(queryParam);
-        assertEquals(queryR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), queryR.getStatus().intValue());
 
         // verify query result
         QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(queryR.getData());
@@ -359,14 +358,15 @@ public class MilvusClientDockerTest {
             FieldDataWrapper wrapper = queryResultsWrapper.getFieldWrapper(fieldName);
             System.out.println("Query data of " + fieldName + ", row count: " + wrapper.getRowCount());
             System.out.println(wrapper.getFieldData());
-        }
-
-        if (outputFields.contains(field1Name)) {
-            List<?> out = queryResultsWrapper.getFieldWrapper(field1Name).getFieldData();
-            assertEquals(out.size(), nq);
-            for (Object o : out) {
-                long id = (Long) o;
-                assertTrue(queryIDs.contains(id));
+            assertEquals(nq, wrapper.getFieldData().size());
+
+            if (fieldName.compareTo(field1Name) == 0) {
+                List<?> out = queryResultsWrapper.getFieldWrapper(field1Name).getFieldData();
+                assertEquals(nq, out.size());
+                for (Object o : out) {
+                    long id = (Long) o;
+                    assertTrue(queryIDs.contains(id));
+                }
             }
         }
 
@@ -376,17 +376,17 @@ public class MilvusClientDockerTest {
         if (outputFields.contains(field2Name)) {
             assertTrue(queryResultsWrapper.getFieldWrapper(field2Name).isVectorField());
             List<?> out = queryResultsWrapper.getFieldWrapper(field2Name).getFieldData();
-            assertEquals(out.size(), nq);
+            assertEquals(nq, out.size());
         }
 
         if (outputFields.contains(field3Name)) {
             List<?> out = queryResultsWrapper.getFieldWrapper(field3Name).getFieldData();
-            assertEquals(out.size(), nq);
+            assertEquals(nq, out.size());
         }
 
         if (outputFields.contains(field4Name)) {
             List<?> out = queryResultsWrapper.getFieldWrapper(field4Name).getFieldData();
-            assertEquals(out.size(), nq);
+            assertEquals(nq, out.size());
             for (Object o : out) {
                 double d = (Double)o;
                 assertTrue(compareWeights.contains(d));
@@ -397,10 +397,9 @@ public class MilvusClientDockerTest {
         // pick some vectors to search
         List<Long> targetVectorIDs = new ArrayList<>();
         List<List<Float>> targetVectors = new ArrayList<>();
-        for (int i = 0; i < nq; ++i) {
-            int randomIndex = ran.nextInt(rowCount);
-            targetVectorIDs.add(ids.get(randomIndex));
-            targetVectors.add(vectors.get(randomIndex));
+        for (int i = randomIndex; i < randomIndex + nq; ++i) {
+            targetVectorIDs.add(ids.get(i));
+            targetVectors.add(vectors.get(i));
         }
 
         int topK = 5;
@@ -411,11 +410,12 @@ public class MilvusClientDockerTest {
                 .withVectors(targetVectors)
                 .withVectorFieldName(field2Name)
                 .withParams("{\"nprobe\":8}")
+                .addOutField(field4Name)
                 .build();
 
         R<SearchResults> searchR = client.search(searchParam);
 //        System.out.println(searchR);
-        assertEquals(searchR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
 
         // verify the search result
         SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
@@ -426,13 +426,18 @@ public class MilvusClientDockerTest {
             assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID());
         }
 
+        List<?> fieldData = results.getFieldData(field4Name, 0);
+        assertEquals(topK, fieldData.size());
+        fieldData = results.getFieldData(field4Name, nq - 1);
+        assertEquals(topK, fieldData.size());
+
         // drop collection
         DropCollectionParam dropParam = DropCollectionParam.newBuilder()
                 .withCollectionName(randomCollectionName)
                 .build();
 
         R<RpcStatus> dropR = client.dropCollection(dropParam);
-        assertEquals(dropR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
     }
 
     @Test
@@ -466,7 +471,7 @@ public class MilvusClientDockerTest {
                 .build();
 
         R<RpcStatus> createR = client.createCollection(createParam);
-        assertEquals(createR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
 
         // insert data
         int rowCount = 10000;
@@ -482,7 +487,7 @@ public class MilvusClientDockerTest {
                 .build();
 
         R<MutationResult> insertR = client.insert(insertParam);
-        assertEquals(insertR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
 //        System.out.println(insertR.getData());
         MutationResultWrapper insertResultWrapper = new MutationResultWrapper(insertR.getData());
         System.out.println(insertResultWrapper.getInsertCount() + " rows inserted");
@@ -495,7 +500,7 @@ public class MilvusClientDockerTest {
                 .withCollectionName(randomCollectionName)
                 .withFlush(true)
                 .build());
-        assertEquals(statR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), statR.getStatus().intValue());
 
         GetCollStatResponseWrapper stat = new GetCollStatResponseWrapper(statR.getData());
         System.out.println("Collection row count: " + stat.getRowCount());
@@ -504,17 +509,17 @@ public class MilvusClientDockerTest {
         R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
                 .withCollectionName(randomCollectionName)
                 .build());
-        assertEquals(loadR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
 
         // pick some vectors to search
         int nq = 5;
         List<Long> targetVectorIDs = new ArrayList<>();
         List<ByteBuffer> targetVectors = new ArrayList<>();
         Random ran = new Random();
-        for (int i = 0; i < nq; ++i) {
-            int randomIndex = ran.nextInt(rowCount);
-            targetVectorIDs.add(ids.get(randomIndex));
-            targetVectors.add(vectors.get(randomIndex));
+        int randomIndex = ran.nextInt(rowCount - nq);
+        for (int i = randomIndex; i < randomIndex + nq; ++i) {
+            targetVectorIDs.add(ids.get(i));
+            targetVectors.add(vectors.get(i));
         }
 
         int topK = 5;
@@ -528,7 +533,7 @@ public class MilvusClientDockerTest {
 
         R<SearchResults> searchR = client.search(searchParam);
 //        System.out.println(searchR);
-        assertEquals(searchR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
 
         // verify the search result
         SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
@@ -545,6 +550,6 @@ public class MilvusClientDockerTest {
                 .build();
 
         R<RpcStatus> dropR = client.dropCollection(dropParam);
-        assertEquals(dropR.getStatus().intValue(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
     }
 }

+ 28 - 19
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -101,11 +101,11 @@ class MilvusServiceClientTest {
         String msg = "error";
         R<RpcStatus> r = R.failed(ErrorCode.UnexpectedError, msg);
         Exception e = r.getException();
-        assertEquals(msg.compareTo(e.getMessage()), 0);
+        assertEquals(0, msg.compareTo(e.getMessage()));
         System.out.println(r.toString());
 
         r = R.success();
-        assertEquals(r.getStatus(), R.Status.Success.getCode());
+        assertEquals(R.Status.Success.getCode(), r.getStatus());
         System.out.println(r.toString());
     }
 
@@ -131,13 +131,13 @@ class MilvusServiceClientTest {
                 .build();
         System.out.println(connectParam.toString());
 
-        assertEquals(host.compareTo(connectParam.getHost()), 0);
-        assertEquals(connectParam.getPort(), port);
-        assertEquals(connectParam.getConnectTimeoutMs(), connectTimeoutMs);
-        assertEquals(connectParam.getKeepAliveTimeMs(), keepAliveTimeMs);
-        assertEquals(connectParam.getKeepAliveTimeoutMs(), keepAliveTimeoutMs);
+        assertEquals(0, host.compareTo(connectParam.getHost()));
+        assertEquals(port, connectParam.getPort());
+        assertEquals(connectTimeoutMs, connectParam.getConnectTimeoutMs());
+        assertEquals(keepAliveTimeMs, connectParam.getKeepAliveTimeMs());
+        assertEquals(keepAliveTimeoutMs, connectParam.getKeepAliveTimeoutMs());
         assertTrue(connectParam.isKeepAliveWithoutCalls());
-        assertEquals(connectParam.getIdleTimeoutMs(), idleTimeoutMs);
+        assertEquals(idleTimeoutMs, connectParam.getIdleTimeoutMs());
 
         assertThrows(ParamException.class, () ->
                 ConnectParam.newBuilder()
@@ -581,12 +581,12 @@ class MilvusServiceClientTest {
         // verify internal param
         ShowCollectionsParam param = ShowCollectionsParam.newBuilder()
                 .build();
-        assertEquals(param.getShowType(), ShowType.All);
+        assertEquals(ShowType.All, param.getShowType());
 
         param = ShowCollectionsParam.newBuilder()
                 .addCollectionName("collection1")
                 .build();
-        assertEquals(param.getShowType(), ShowType.InMemory);
+        assertEquals(ShowType.InMemory, param.getShowType());
     }
 
     @Test
@@ -966,13 +966,13 @@ class MilvusServiceClientTest {
         ShowPartitionsParam param = ShowPartitionsParam.newBuilder()
                 .withCollectionName("collection1`")
                 .build();
-        assertEquals(param.getShowType(), ShowType.All);
+        assertEquals(ShowType.All, param.getShowType());
 
         param = ShowPartitionsParam.newBuilder()
                 .withCollectionName("collection1`")
                 .addPartitionName("partition1")
                 .build();
-        assertEquals(param.getShowType(), ShowType.InMemory);
+        assertEquals(ShowType.InMemory, param.getShowType());
     }
 
     @Test
@@ -2193,13 +2193,18 @@ class MilvusServiceClientTest {
         List<Long> longIDs = new ArrayList<>();
         List<String> strIDs = new ArrayList<>();
         List<Float> scores = new ArrayList<>();
+        List<Double> outputField = new ArrayList<>();
         for (long i = 0; i < topK * numQueries; ++i) {
             longIDs.add(i);
             strIDs.add(String.valueOf(i));
             scores.add((float) i);
+            outputField.add((double) i);
         }
 
         // for long id
+        DoubleArray.Builder doubleArrayBuilder = DoubleArray.newBuilder();
+        outputField.forEach(doubleArrayBuilder::addData);
+
         String fieldName = "test";
         SearchResultData results = SearchResultData.newBuilder()
                 .setTopK(topK)
@@ -2211,16 +2216,20 @@ class MilvusServiceClientTest {
                 .addAllScores(scores)
                 .addFieldsData(FieldData.newBuilder()
                         .setFieldName(fieldName)
-                        .build())
+                        .setType(DataType.Double)
+                        .setScalars(ScalarField.newBuilder()
+                                .setDoubleData(doubleArrayBuilder.build())
+                                .build()))
                 .build();
 
         SearchResultsWrapper intWrapper = new SearchResultsWrapper(results);
-        assertNotNull(intWrapper.getFieldData(fieldName));
-        assertNull(intWrapper.getFieldData("invalid"));
+        assertThrows(ParamException.class, () -> intWrapper.getFieldData(fieldName, -1));
+        assertThrows(ParamException.class, () -> intWrapper.getFieldData("invalid", 0));
+        assertEquals(topK, intWrapper.getFieldData(fieldName, (int)numQueries-1).size());
 
         List<SearchResultsWrapper.IDScore> idScores = intWrapper.getIDScore(1);
         assertFalse(idScores.toString().isEmpty());
-        assertEquals(idScores.size(), topK);
+        assertEquals(topK, idScores.size());
         assertThrows(ParamException.class, () -> intWrapper.getIDScore((int) numQueries));
 
         // for string id
@@ -2240,7 +2249,7 @@ class MilvusServiceClientTest {
         SearchResultsWrapper strWrapper = new SearchResultsWrapper(results);
         idScores = strWrapper.getIDScore(0);
         assertFalse(idScores.toString().isEmpty());
-        assertEquals(idScores.size(), topK);
+        assertEquals(topK, idScores.size());
 
         idScores.forEach((score)->assertFalse(score.toString().isEmpty()));
     }
@@ -2264,7 +2273,7 @@ class MilvusServiceClientTest {
 
         for (int i = 0; i < 2; ++i) {
             ShowCollResponseWrapper.CollectionInfo info = wrapper.getCollectionInfoByName(names.get(i));
-            assertEquals(names.get(i).compareTo(info.getName()), 0);
+            assertEquals(0, names.get(i).compareTo(info.getName()));
             assertEquals(ids.get(i), info.getId());
             assertEquals(ts.get(i), info.getUtcTimestamp());
             assertEquals(inMemory.get(i), info.getInMemoryPercentage());
@@ -2292,7 +2301,7 @@ class MilvusServiceClientTest {
 
         for (int i = 0; i < 2; ++i) {
             ShowPartResponseWrapper.PartitionInfo info = wrapper.getPartitionInfoByName(names.get(i));
-            assertEquals(names.get(i).compareTo(info.getName()), 0);
+            assertEquals(0, names.get(i).compareTo(info.getName()));
             assertEquals(ids.get(i), info.getId());
             assertEquals(ts.get(i), info.getUtcTimestamp());
             assertEquals(inMemory.get(i), info.getInMemoryPercentage());