Quellcode durchsuchen

Fix a bug of highlevel interface (#658)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot vor 1 Jahr
Ursprung
Commit
75629f0ceb

+ 1 - 1
examples/main/java/io/milvus/HighLevelExample.java

@@ -214,7 +214,7 @@ public class HighLevelExample {
 
         R<SearchResponse> response = milvusClient.search(searchSimpleParam);
         handleResponseStatus(response);
-        for (QueryResultsWrapper.RowRecord rowRecord : response.getData().getRowRecords()) {
+        for (QueryResultsWrapper.RowRecord rowRecord : response.getData().getRowRecords(0)) {
             System.out.println(rowRecord);
         }
         return response;

+ 5 - 1
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -3034,7 +3034,11 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
             }
 
             SearchResultsWrapper searchResultsWrapper = new SearchResultsWrapper(response.getData().getResults());
-            return R.success(SearchResponse.builder().rowRecords(searchResultsWrapper.getRowRecords()).build());
+            List<List<QueryResultsWrapper.RowRecord>> records = new ArrayList<>();
+            for (int i = 0; i < vectors.size(); ++i) {
+                records.add(searchResultsWrapper.getRowRecords(i));
+            }
+            return R.success(SearchResponse.builder().rowRecords(records).build());
         } catch (StatusRuntimeException e) {
             logError("Search RPC failed! Collection name:{}",
                     requestParam.getCollectionName(), e);

+ 20 - 2
src/main/java/io/milvus/param/highlevel/dml/response/SearchResponse.java

@@ -19,6 +19,7 @@
 
 package io.milvus.param.highlevel.dml.response;
 
+import io.milvus.exception.ParamException;
 import io.milvus.response.QueryResultsWrapper;
 import lombok.Builder;
 import lombok.Getter;
@@ -29,7 +30,24 @@ import java.util.List;
  * Parameters for <code>search</code> interface.
  */
 @Builder
-@Getter
 public class SearchResponse {
-    public List<QueryResultsWrapper.RowRecord> rowRecords;
+    public List<List<QueryResultsWrapper.RowRecord>> rowRecords;
+
+    /**
+     * In old versions(<=2.3.2), this method only returns results of the first target vector
+     * Mark is as deprecated, keep it to compatible with the legacy code
+     */
+    @Deprecated
+    public List<QueryResultsWrapper.RowRecord> getRowRecords() {
+        return getRowRecords(0);
+    }
+
+    public List<QueryResultsWrapper.RowRecord> getRowRecords(int indexOfTarget) {
+        if (indexOfTarget >= rowRecords.size()) {
+            throw new ParamException("The indexOfTarget value " + indexOfTarget
+                    + " exceeds results count " + rowRecords.size());
+        }
+
+        return rowRecords.get(indexOfTarget);
+    }
 }

+ 22 - 20
src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -41,34 +41,33 @@ public class SearchResultsWrapper extends RowRecordWrapper {
         throw new ParamException("The field name doesn't exist");
     }
 
+    /**
+     * Note: this method only can return the first target vector's topk result
+     *       and its function is duplicated with getIDScore(), so we mark it as deprecated.
+     */
+    @Deprecated
     @Override
     public List<QueryResultsWrapper.RowRecord> getRowRecords() {
+        return getRowRecords(0);
+    }
+
+    /**
+     * Note: this method's function is duplicated with getIDScore(), it is for high-level search.
+     */
+    public List<QueryResultsWrapper.RowRecord> getRowRecords(int indexOfTarget) {
         List<QueryResultsWrapper.RowRecord> records = new ArrayList<>();
         long topK = results.getTopK();
+        List<IDScore> idScore = getIDScore(indexOfTarget);
         for (int i = 0; i < topK; ++i) {
-            QueryResultsWrapper.RowRecord rowRecord = buildRowRecord(i);
-            records.add(rowRecord);
+            QueryResultsWrapper.RowRecord record = new QueryResultsWrapper.RowRecord();
+            record.put("id", idScore.get(i).getLongID());
+            record.put("distance", idScore.get(i).getScore());
+            buildRowRecord(record, i);
+            records.add(record);
         }
         return records;
     }
 
-    /**
-     * Gets a row record from result.
-     *  Throws {@link ParamException} if the index is illegal.
-     *
-     * @return <code>RowRecord</code> a row record of the result
-     */
-    protected QueryResultsWrapper.RowRecord buildRowRecord(long index) {
-        QueryResultsWrapper.RowRecord record = new QueryResultsWrapper.RowRecord();
-
-        List<IDScore> idScore = getIDScore(0);
-        record.put("id", idScore.get((int) index).getLongID());
-        record.put("distance", idScore.get((int)index).getScore());
-
-        buildRowRecord(record, index);
-        return record;
-    }
-
     @Override
     protected List<FieldData> getFieldDataList() {
         return results.getFieldsDataList();
@@ -152,7 +151,10 @@ public class SearchResultsWrapper extends RowRecordWrapper {
                 idScores.add(new IDScore(strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
             }
         } else {
-            throw new IllegalResponseException("Result ids is illegal");
+            // in v2.3.3, return an empty list instead of throwing exception
+            // because search in an empty collection will run into this exception
+//            throw new IllegalResponseException("Result ids is illegal");
+            return idScores;
         }
 
         // set output fields