|
@@ -19,6 +19,9 @@
|
|
|
|
|
|
package io.milvus.client;
|
|
|
|
|
|
+import com.google.protobuf.ByteString;
|
|
|
+import io.milvus.Response.*;
|
|
|
+import io.milvus.exception.IllegalResponseException;
|
|
|
import io.milvus.exception.ParamException;
|
|
|
import io.milvus.grpc.*;
|
|
|
import io.milvus.param.*;
|
|
@@ -26,9 +29,7 @@ import io.milvus.param.alias.AlterAliasParam;
|
|
|
import io.milvus.param.alias.CreateAliasParam;
|
|
|
import io.milvus.param.alias.DropAliasParam;
|
|
|
import io.milvus.param.collection.*;
|
|
|
-import io.milvus.param.control.GetMetricsParam;
|
|
|
-import io.milvus.param.control.GetPersistentSegmentInfoParam;
|
|
|
-import io.milvus.param.control.GetQuerySegmentInfoParam;
|
|
|
+import io.milvus.param.control.*;
|
|
|
import io.milvus.param.dml.*;
|
|
|
import io.milvus.param.index.*;
|
|
|
import io.milvus.param.partition.*;
|
|
@@ -39,10 +40,7 @@ import org.junit.jupiter.api.Test;
|
|
|
import java.lang.reflect.InvocationTargetException;
|
|
|
import java.lang.reflect.Method;
|
|
|
import java.nio.ByteBuffer;
|
|
|
-import java.util.ArrayList;
|
|
|
-import java.util.Arrays;
|
|
|
-import java.util.Collections;
|
|
|
-import java.util.List;
|
|
|
+import java.util.*;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
|
@@ -97,6 +95,19 @@ class MilvusServiceClientTest {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ void r() {
|
|
|
+ String msg = "error";
|
|
|
+ R<RpcStatus> r = R.failed(ErrorCode.UnexpectedError, msg);
|
|
|
+ Exception e = r.getException();
|
|
|
+ assertEquals(msg.compareTo(e.getMessage()), 0);
|
|
|
+ System.out.println(r.toString());
|
|
|
+
|
|
|
+ r = R.success();
|
|
|
+ assertEquals(r.getStatus(), R.Status.Success.getCode());
|
|
|
+ System.out.println(r.toString());
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
void connectParam() {
|
|
|
System.out.println(System.getProperty("os.name"));
|
|
@@ -117,6 +128,8 @@ class MilvusServiceClientTest {
|
|
|
.keepAliveWithoutCalls(true)
|
|
|
.withIdleTimeout(idleTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
.build();
|
|
|
+ System.out.println(connectParam.toString());
|
|
|
+
|
|
|
assertEquals(host.compareTo(connectParam.getHost()), 0);
|
|
|
assertEquals(connectParam.getPort(), port);
|
|
|
assertEquals(connectParam.getConnectTimeoutMs(), connectTimeoutMs);
|
|
@@ -124,6 +137,66 @@ class MilvusServiceClientTest {
|
|
|
assertEquals(connectParam.getKeepAliveTimeoutMs(), keepAliveTimeoutMs);
|
|
|
assertTrue(connectParam.isKeepAliveWithoutCalls());
|
|
|
assertEquals(connectParam.getIdleTimeoutMs(), idleTimeoutMs);
|
|
|
+
|
|
|
+ assertThrows(ParamException.class, () ->
|
|
|
+ ConnectParam.newBuilder()
|
|
|
+ .withHost(host)
|
|
|
+ .withPort(0xFFFF + 1)
|
|
|
+ .withConnectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTime(keepAliveTimeMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTimeout(keepAliveTimeoutMs, TimeUnit.NANOSECONDS)
|
|
|
+ .keepAliveWithoutCalls(true)
|
|
|
+ .withIdleTimeout(idleTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThrows(ParamException.class, () ->
|
|
|
+ ConnectParam.newBuilder()
|
|
|
+ .withHost(host)
|
|
|
+ .withPort(port)
|
|
|
+ .withConnectTimeout(-1, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTime(keepAliveTimeMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTimeout(keepAliveTimeoutMs, TimeUnit.NANOSECONDS)
|
|
|
+ .keepAliveWithoutCalls(true)
|
|
|
+ .withIdleTimeout(idleTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThrows(ParamException.class, () ->
|
|
|
+ ConnectParam.newBuilder()
|
|
|
+ .withHost(host)
|
|
|
+ .withPort(port)
|
|
|
+ .withConnectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTime(-1, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTimeout(keepAliveTimeoutMs, TimeUnit.NANOSECONDS)
|
|
|
+ .keepAliveWithoutCalls(true)
|
|
|
+ .withIdleTimeout(idleTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThrows(ParamException.class, () ->
|
|
|
+ ConnectParam.newBuilder()
|
|
|
+ .withHost(host)
|
|
|
+ .withPort(port)
|
|
|
+ .withConnectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTime(keepAliveTimeMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTimeout(-1, TimeUnit.NANOSECONDS)
|
|
|
+ .keepAliveWithoutCalls(true)
|
|
|
+ .withIdleTimeout(idleTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThrows(ParamException.class, () ->
|
|
|
+ ConnectParam.newBuilder()
|
|
|
+ .withHost(host)
|
|
|
+ .withPort(port)
|
|
|
+ .withConnectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTime(keepAliveTimeMs, TimeUnit.MILLISECONDS)
|
|
|
+ .withKeepAliveTimeout(keepAliveTimeoutMs, TimeUnit.NANOSECONDS)
|
|
|
+ .keepAliveWithoutCalls(true)
|
|
|
+ .withIdleTimeout(-1, TimeUnit.MILLISECONDS)
|
|
|
+ .build()
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -188,9 +261,22 @@ class MilvusServiceClientTest {
|
|
|
CreateCollectionParam
|
|
|
.newBuilder()
|
|
|
.withCollectionName("collection1")
|
|
|
- .withShardsNum(0)
|
|
|
+ .withShardsNum(2)
|
|
|
.withFieldTypes(fields)
|
|
|
- .build()
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
+ Map<String, String> params = new HashMap<>();
|
|
|
+ params.put("1", "1");
|
|
|
+ assertThrows(ParamException.class, () ->
|
|
|
+ FieldType.newBuilder()
|
|
|
+ .withName("vec")
|
|
|
+ .withDescription("desc")
|
|
|
+ .withDataType(DataType.FloatVector)
|
|
|
+ .withTypeParams(params)
|
|
|
+ .addTypeParam("2", "2")
|
|
|
+ .withDimension(-1)
|
|
|
+ .build()
|
|
|
);
|
|
|
}
|
|
|
|
|
@@ -208,6 +294,7 @@ class MilvusServiceClientTest {
|
|
|
CreateCollectionParam param = CreateCollectionParam
|
|
|
.newBuilder()
|
|
|
.withCollectionName("collection1")
|
|
|
+ .withDescription("desc")
|
|
|
.withShardsNum(2)
|
|
|
.addFieldType(fieldType1)
|
|
|
.build();
|
|
@@ -478,10 +565,10 @@ class MilvusServiceClientTest {
|
|
|
// test throw exception with illegal input
|
|
|
List<String> names = new ArrayList<>();
|
|
|
names.add(null);
|
|
|
- assertThrows(ParamException.class, () ->
|
|
|
- ShowCollectionsParam.newBuilder()
|
|
|
- .withCollectionNames(names)
|
|
|
- .build()
|
|
|
+ assertThrows(NullPointerException.class, () ->
|
|
|
+ ShowCollectionsParam.newBuilder()
|
|
|
+ .withCollectionNames(names)
|
|
|
+ .build()
|
|
|
);
|
|
|
|
|
|
assertThrows(ParamException.class, () ->
|
|
@@ -514,6 +601,10 @@ class MilvusServiceClientTest {
|
|
|
@Test
|
|
|
void flushParam() {
|
|
|
// test throw exception with illegal input
|
|
|
+ assertThrows(ParamException.class, () -> FlushParam.newBuilder()
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
assertThrows(ParamException.class, () -> FlushParam.newBuilder()
|
|
|
.addCollectionName("")
|
|
|
.build()
|
|
@@ -657,7 +748,7 @@ class MilvusServiceClientTest {
|
|
|
|
|
|
List<String> names = new ArrayList<>();
|
|
|
names.add(null);
|
|
|
- assertThrows(ParamException.class, () -> LoadPartitionsParam.newBuilder()
|
|
|
+ assertThrows(NullPointerException.class, () -> LoadPartitionsParam.newBuilder()
|
|
|
.withCollectionName("collection1")
|
|
|
.withPartitionNames(names)
|
|
|
.build()
|
|
@@ -792,7 +883,7 @@ class MilvusServiceClientTest {
|
|
|
|
|
|
List<String> names = new ArrayList<>();
|
|
|
names.add(null);
|
|
|
- assertThrows(ParamException.class, () -> ReleasePartitionsParam.newBuilder()
|
|
|
+ assertThrows(NullPointerException.class, () -> ReleasePartitionsParam.newBuilder()
|
|
|
.withCollectionName("collection1")
|
|
|
.withPartitionNames(names)
|
|
|
.build()
|
|
@@ -852,7 +943,7 @@ class MilvusServiceClientTest {
|
|
|
|
|
|
List<String> names = new ArrayList<>();
|
|
|
names.add(null);
|
|
|
- assertThrows(ParamException.class, () -> ShowPartitionsParam.newBuilder()
|
|
|
+ assertThrows(NullPointerException.class, () -> ShowPartitionsParam.newBuilder()
|
|
|
.withCollectionName("collection1`")
|
|
|
.withPartitionNames(names)
|
|
|
.build()
|
|
@@ -1692,4 +1783,313 @@ class MilvusServiceClientTest {
|
|
|
|
|
|
testFuncByName("getQuerySegmentInfo", param);
|
|
|
}
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void loadBalanceParam() {
|
|
|
+ // test throw exception with illegal input
|
|
|
+ assertThrows(ParamException.class, () -> LoadBalanceParam
|
|
|
+ .newBuilder()
|
|
|
+ .withSourceNodeID(1L)
|
|
|
+ .withDestinationNodeID(Arrays.asList(2L, 3L))
|
|
|
+ .addDestinationNodeID(4L)
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThrows(ParamException.class, () -> LoadBalanceParam
|
|
|
+ .newBuilder()
|
|
|
+ .withSourceNodeID(1L)
|
|
|
+ .withSegmentIDs(Arrays.asList(2L, 3L))
|
|
|
+ .addSegmentID(4L)
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void loadBalance() {
|
|
|
+ LoadBalanceParam param = LoadBalanceParam.newBuilder()
|
|
|
+ .withSourceNodeID(1L)
|
|
|
+ .addDestinationNodeID(2L)
|
|
|
+ .addSegmentID(3L)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ testFuncByName("loadBalance", param);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void getCompactionState() {
|
|
|
+ GetCompactionStateParam param = GetCompactionStateParam.newBuilder()
|
|
|
+ .withCompactionID(1L)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ testFuncByName("getCompactionState", param);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void manualCompaction() {
|
|
|
+ ManualCompactionParam param = ManualCompactionParam.newBuilder()
|
|
|
+ .withCollectionID(1L)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ testFuncByName("manualCompaction", param);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void getCompactionStateWithPlans() {
|
|
|
+ GetCompactionPlansParam param = GetCompactionPlansParam.newBuilder()
|
|
|
+ .withCompactionID(1L)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ testFuncByName("getCompactionStateWithPlans", param);
|
|
|
+ }
|
|
|
+
|
|
|
+ ////////////////////////////////////////////////////////////////////////////////////
|
|
|
+ // Response wrapper test
|
|
|
+ private void testScalarField(ScalarField field, DataType type, long rowCount) {
|
|
|
+ FieldData fieldData = FieldData.newBuilder()
|
|
|
+ .setFieldName("scalar")
|
|
|
+ .setFieldId(1L)
|
|
|
+ .setType(type)
|
|
|
+ .setScalars(field)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ FieldDataWrapper wrapper = new FieldDataWrapper(fieldData);
|
|
|
+ assertEquals(rowCount, wrapper.getRowCount());
|
|
|
+
|
|
|
+ List<?> data = wrapper.getFieldData();
|
|
|
+ assertEquals(rowCount, data.size());
|
|
|
+
|
|
|
+ assertThrows(IllegalResponseException.class, wrapper::getDim);
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ @Test
|
|
|
+ void fieldDataWrapper() {
|
|
|
+ // for float vector
|
|
|
+ long dim = 3;
|
|
|
+ List<Float> floatVectors = Arrays.asList(1F, 2F, 3F, 4F, 5F, 6F);
|
|
|
+ FieldData fieldData = FieldData.newBuilder()
|
|
|
+ .setFieldName("vec")
|
|
|
+ .setFieldId(1L)
|
|
|
+ .setType(DataType.FloatVector)
|
|
|
+ .setVectors(VectorField.newBuilder()
|
|
|
+ .setDim(dim)
|
|
|
+ .setFloatVector(FloatArray.newBuilder()
|
|
|
+ .addAllData(floatVectors)
|
|
|
+ .build())
|
|
|
+ .build())
|
|
|
+ .build();
|
|
|
+
|
|
|
+ FieldDataWrapper wrapper = new FieldDataWrapper(fieldData);
|
|
|
+ assertEquals(dim, wrapper.getDim());
|
|
|
+ assertEquals(floatVectors.size() / dim, wrapper.getRowCount());
|
|
|
+
|
|
|
+ List<?> floatData = wrapper.getFieldData();
|
|
|
+ assertEquals(floatVectors.size() / dim, floatData.size());
|
|
|
+ for (Object obj : floatData) {
|
|
|
+ List<Float> vec = (List<Float>) obj;
|
|
|
+ assertEquals(dim, vec.size());
|
|
|
+ }
|
|
|
+
|
|
|
+ // for binary vector
|
|
|
+ dim = 16;
|
|
|
+ byte[] binary = new byte[(int) dim * 2];
|
|
|
+ for (int i = 0; i < binary.length; ++i) {
|
|
|
+ binary[i] = (byte) i;
|
|
|
+ }
|
|
|
+ fieldData = FieldData.newBuilder()
|
|
|
+ .setFieldName("vec")
|
|
|
+ .setFieldId(1L)
|
|
|
+ .setType(DataType.BinaryVector)
|
|
|
+ .setVectors(VectorField.newBuilder()
|
|
|
+ .setDim(dim)
|
|
|
+ .setBinaryVector(ByteString.copyFrom(binary))
|
|
|
+ .build())
|
|
|
+ .build();
|
|
|
+
|
|
|
+ wrapper = new FieldDataWrapper(fieldData);
|
|
|
+ assertEquals(dim, wrapper.getDim());
|
|
|
+ assertEquals(binary.length / dim, wrapper.getRowCount());
|
|
|
+
|
|
|
+ List<?> binaryData = wrapper.getFieldData();
|
|
|
+ assertEquals(binary.length / dim, binaryData.size());
|
|
|
+ for (Object obj : binaryData) {
|
|
|
+ ByteBuffer vec = (ByteBuffer) obj;
|
|
|
+ assertEquals(dim, vec.position());
|
|
|
+ }
|
|
|
+
|
|
|
+ // for scalar field
|
|
|
+ LongArray.Builder int64Builder = LongArray.newBuilder();
|
|
|
+ for (long i = 0; i < dim; ++i) {
|
|
|
+ int64Builder.addData(i);
|
|
|
+ }
|
|
|
+ testScalarField(ScalarField.newBuilder().setLongData(int64Builder).build(),
|
|
|
+ DataType.Int64, dim);
|
|
|
+
|
|
|
+ IntArray.Builder intBuilder = IntArray.newBuilder();
|
|
|
+ for (int i = 0; i < dim; ++i) {
|
|
|
+ intBuilder.addData(i);
|
|
|
+ }
|
|
|
+ testScalarField(ScalarField.newBuilder().setIntData(intBuilder).build(),
|
|
|
+ DataType.Int32, dim);
|
|
|
+ testScalarField(ScalarField.newBuilder().setIntData(intBuilder).build(),
|
|
|
+ DataType.Int16, dim);
|
|
|
+ testScalarField(ScalarField.newBuilder().setIntData(intBuilder).build(),
|
|
|
+ DataType.Int8, dim);
|
|
|
+
|
|
|
+ BoolArray.Builder boolBuilder = BoolArray.newBuilder();
|
|
|
+ for (long i = 0; i < dim; ++i) {
|
|
|
+ boolBuilder.addData(i % 2 == 0);
|
|
|
+ }
|
|
|
+ testScalarField(ScalarField.newBuilder().setBoolData(boolBuilder).build(),
|
|
|
+ DataType.Bool, dim);
|
|
|
+
|
|
|
+ FloatArray.Builder floatBuilder = FloatArray.newBuilder();
|
|
|
+ for (long i = 0; i < dim; ++i) {
|
|
|
+ floatBuilder.addData((float) i);
|
|
|
+ }
|
|
|
+ testScalarField(ScalarField.newBuilder().setFloatData(floatBuilder).build(),
|
|
|
+ DataType.Float, dim);
|
|
|
+
|
|
|
+ DoubleArray.Builder doubleBuilder = DoubleArray.newBuilder();
|
|
|
+ for (long i = 0; i < dim; ++i) {
|
|
|
+ doubleBuilder.addData((double) i);
|
|
|
+ }
|
|
|
+ testScalarField(ScalarField.newBuilder().setDoubleData(doubleBuilder).build(),
|
|
|
+ DataType.Double, dim);
|
|
|
+
|
|
|
+ StringArray.Builder strBuilder = StringArray.newBuilder();
|
|
|
+ for (long i = 0; i < dim; ++i) {
|
|
|
+ strBuilder.addData(String.valueOf(i));
|
|
|
+ }
|
|
|
+ testScalarField(ScalarField.newBuilder().setStringData(strBuilder).build(),
|
|
|
+ DataType.String, dim);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void getCollStatResponseWrapper() {
|
|
|
+ GetCollectionStatisticsResponse response = GetCollectionStatisticsResponse.newBuilder()
|
|
|
+ .addStats(KeyValuePair.newBuilder().setKey("row_count").setValue("invalid").build())
|
|
|
+ .build();
|
|
|
+ GetCollStatResponseWrapper invalidWrapper = new GetCollStatResponseWrapper(response);
|
|
|
+ assertThrows(NumberFormatException.class, invalidWrapper::GetRowCount);
|
|
|
+
|
|
|
+ response = GetCollectionStatisticsResponse.newBuilder()
|
|
|
+ .addStats(KeyValuePair.newBuilder().setKey("row_count").setValue("10").build())
|
|
|
+ .build();
|
|
|
+ GetCollStatResponseWrapper wrapper = new GetCollStatResponseWrapper(response);
|
|
|
+ assertEquals(10, wrapper.GetRowCount());
|
|
|
+
|
|
|
+ response = GetCollectionStatisticsResponse.newBuilder().build();
|
|
|
+ wrapper = new GetCollStatResponseWrapper(response);
|
|
|
+ assertEquals(0, wrapper.GetRowCount());
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void InsertResultWrapper() {
|
|
|
+ List<Long> nID = Arrays.asList(1L, 2L, 3L);
|
|
|
+ MutationResult results = MutationResult.newBuilder()
|
|
|
+ .setInsertCnt(nID.size())
|
|
|
+ .setIDs(IDs.newBuilder()
|
|
|
+ .setIntId(LongArray.newBuilder()
|
|
|
+ .addAllData(nID)
|
|
|
+ .build()))
|
|
|
+ .build();
|
|
|
+ InsertResultWrapper longWrapper = new InsertResultWrapper(results);
|
|
|
+ assertEquals(nID.size(), longWrapper.getInsertCount());
|
|
|
+ assertThrows(ParamException.class, longWrapper::getStringIDs);
|
|
|
+
|
|
|
+ List<Long> longIDs = longWrapper.getLongIDs();
|
|
|
+ assertEquals(nID.size(), longIDs.size());
|
|
|
+ for (int i = 0; i < longIDs.size(); ++i) {
|
|
|
+ assertEquals(nID.get(i), longIDs.get(i));
|
|
|
+ }
|
|
|
+
|
|
|
+ List<String> sID = Arrays.asList("1", "2", "3");
|
|
|
+ results = MutationResult.newBuilder()
|
|
|
+ .setInsertCnt(sID.size())
|
|
|
+ .setIDs(IDs.newBuilder()
|
|
|
+ .setStrId(StringArray.newBuilder()
|
|
|
+ .addAllData(sID)
|
|
|
+ .build()))
|
|
|
+ .build();
|
|
|
+ InsertResultWrapper strWrapper = new InsertResultWrapper(results);
|
|
|
+ assertEquals(sID.size(), strWrapper.getInsertCount());
|
|
|
+ assertThrows(ParamException.class, strWrapper::getLongIDs);
|
|
|
+
|
|
|
+ List<String> strIDs = strWrapper.getStringIDs();
|
|
|
+ assertEquals(sID.size(), strIDs.size());
|
|
|
+ for (int i = 0; i < strIDs.size(); ++i) {
|
|
|
+ assertEquals(sID.get(i), strIDs.get(i));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void QueryResultsWrapper() {
|
|
|
+ String fieldName = "test";
|
|
|
+ QueryResults results = QueryResults.newBuilder()
|
|
|
+ .addFieldsData(FieldData.newBuilder()
|
|
|
+ .setFieldName(fieldName)
|
|
|
+ .build())
|
|
|
+ .build();
|
|
|
+
|
|
|
+ QueryResultsWrapper wrapper = new QueryResultsWrapper(results);
|
|
|
+ assertThrows(ParamException.class, () -> wrapper.getFieldWrapper("invalid"));
|
|
|
+ assertNotNull(wrapper.getFieldWrapper(fieldName));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void SearchResultsWrapper() {
|
|
|
+ long topK = 5;
|
|
|
+ long numQueries = 2;
|
|
|
+ List<Long> longIDs = new ArrayList<>();
|
|
|
+ List<String> strIDs = new ArrayList<>();
|
|
|
+ List<Float> scores = new ArrayList<>();
|
|
|
+ for (long i = 0; i < topK * numQueries; ++i) {
|
|
|
+ longIDs.add(i);
|
|
|
+ strIDs.add(String.valueOf(i));
|
|
|
+ scores.add((float) i);
|
|
|
+ }
|
|
|
+
|
|
|
+ // for long id
|
|
|
+ String fieldName = "test";
|
|
|
+ SearchResultData results = SearchResultData.newBuilder()
|
|
|
+ .setTopK(topK)
|
|
|
+ .setNumQueries(numQueries)
|
|
|
+ .setIds(IDs.newBuilder()
|
|
|
+ .setIntId(LongArray.newBuilder()
|
|
|
+ .addAllData(longIDs)
|
|
|
+ .build()))
|
|
|
+ .addAllScores(scores)
|
|
|
+ .addFieldsData(FieldData.newBuilder()
|
|
|
+ .setFieldName(fieldName)
|
|
|
+ .build())
|
|
|
+ .build();
|
|
|
+
|
|
|
+ SearchResultsWrapper intWrapper = new SearchResultsWrapper(results);
|
|
|
+ assertNotNull(intWrapper.GetFieldData(fieldName));
|
|
|
+ assertNull(intWrapper.GetFieldData("invalid"));
|
|
|
+
|
|
|
+ List<SearchResultsWrapper.IDScore> idScores = intWrapper.GetIDScore(1);
|
|
|
+ assertEquals(idScores.size(), topK);
|
|
|
+ assertThrows(ParamException.class, () -> intWrapper.GetIDScore((int) numQueries));
|
|
|
+
|
|
|
+ // for string id
|
|
|
+ results = SearchResultData.newBuilder()
|
|
|
+ .setTopK(topK)
|
|
|
+ .setNumQueries(numQueries)
|
|
|
+ .setIds(IDs.newBuilder()
|
|
|
+ .setStrId(StringArray.newBuilder()
|
|
|
+ .addAllData(strIDs)
|
|
|
+ .build()))
|
|
|
+ .addAllScores(scores)
|
|
|
+ .addFieldsData(FieldData.newBuilder()
|
|
|
+ .setFieldName(fieldName)
|
|
|
+ .build())
|
|
|
+ .build();
|
|
|
+
|
|
|
+ SearchResultsWrapper strWrapper = new SearchResultsWrapper(results);
|
|
|
+ idScores = strWrapper.GetIDScore(0);
|
|
|
+ assertEquals(idScores.size(), topK);
|
|
|
+ }
|
|
|
}
|