|
@@ -28,47 +28,51 @@ import io.milvus.param.collection.*;
|
|
import io.milvus.param.dml.*;
|
|
import io.milvus.param.dml.*;
|
|
import io.milvus.param.index.*;
|
|
import io.milvus.param.index.*;
|
|
import io.milvus.response.*;
|
|
import io.milvus.response.*;
|
|
|
|
+import org.tensorflow.types.TBfloat16;
|
|
|
|
+import org.tensorflow.types.TFloat16;
|
|
|
|
|
|
import java.nio.ByteBuffer;
|
|
import java.nio.ByteBuffer;
|
|
import java.util.*;
|
|
import java.util.*;
|
|
|
|
|
|
-import org.tensorflow.Tensor;
|
|
|
|
-import org.tensorflow.ndarray.Shape;
|
|
|
|
-import org.tensorflow.ndarray.buffer.ByteDataBuffer;
|
|
|
|
-import org.tensorflow.ndarray.buffer.DataBuffers;
|
|
|
|
-import org.tensorflow.types.TBfloat16;
|
|
|
|
-import org.tensorflow.types.TFloat16;
|
|
|
|
-
|
|
|
|
|
|
|
|
public class Float16VectorExample {
|
|
public class Float16VectorExample {
|
|
private static final String COLLECTION_NAME = "java_sdk_example_float16";
|
|
private static final String COLLECTION_NAME = "java_sdk_example_float16";
|
|
private static final String ID_FIELD = "id";
|
|
private static final String ID_FIELD = "id";
|
|
private static final String VECTOR_FIELD = "vector";
|
|
private static final String VECTOR_FIELD = "vector";
|
|
private static final Integer VECTOR_DIM = 128;
|
|
private static final Integer VECTOR_DIM = 128;
|
|
-
|
|
|
|
-
|
|
|
|
- private static void testFloat16(boolean bfloat16) {
|
|
|
|
- DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
|
|
|
|
- System.out.printf("=================== %s ===================\n", dataType.name());
|
|
|
|
|
|
|
|
|
|
+ private static final MilvusServiceClient milvusClient;
|
|
|
|
+ static {
|
|
// Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
|
|
// Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
|
|
- MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
|
|
|
|
|
|
+ milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
|
|
.withHost("localhost")
|
|
.withHost("localhost")
|
|
.withPort(19530)
|
|
.withPort(19530)
|
|
.build());
|
|
.build());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // For float16 values between 0.0~1.0, the precision can be controlled under 0.001f
|
|
|
|
+ // For bfloat16 values between 0.0~1.0, the precision can be controlled under 0.01f
|
|
|
|
+ private static boolean isFloat16Eauql(Float a, Float b, boolean bfloat16) {
|
|
|
|
+ if (bfloat16) {
|
|
|
|
+ return Math.abs(a - b) <= 0.01f;
|
|
|
|
+ } else {
|
|
|
|
+ return Math.abs(a - b) <= 0.001f;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private static void createCollection(boolean bfloat16) {
|
|
|
|
|
|
// drop the collection if you don't need the collection anymore
|
|
// drop the collection if you don't need the collection anymore
|
|
R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
|
|
R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
|
|
- .withCollectionName(COLLECTION_NAME)
|
|
|
|
- .build());
|
|
|
|
|
|
+ .withCollectionName(COLLECTION_NAME)
|
|
|
|
+ .build());
|
|
CommonUtils.handleResponseStatus(hasR);
|
|
CommonUtils.handleResponseStatus(hasR);
|
|
if (hasR.getData()) {
|
|
if (hasR.getData()) {
|
|
- milvusClient.dropCollection(DropCollectionParam.newBuilder()
|
|
|
|
- .withCollectionName(COLLECTION_NAME)
|
|
|
|
- .build());
|
|
|
|
|
|
+ dropCollection();
|
|
}
|
|
}
|
|
|
|
|
|
// Define fields
|
|
// Define fields
|
|
|
|
+ DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
|
|
List<FieldType> fieldsSchema = Arrays.asList(
|
|
List<FieldType> fieldsSchema = Arrays.asList(
|
|
FieldType.newBuilder()
|
|
FieldType.newBuilder()
|
|
.withName(ID_FIELD)
|
|
.withName(ID_FIELD)
|
|
@@ -84,6 +88,8 @@ public class Float16VectorExample {
|
|
);
|
|
);
|
|
|
|
|
|
// Create the collection
|
|
// Create the collection
|
|
|
|
+ // Note that we set default consistency level to "STRONG",
|
|
|
|
+ // to ensure data is visible to search, for validation the result
|
|
R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
|
|
R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
|
|
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
|
|
@@ -92,31 +98,73 @@ public class Float16VectorExample {
|
|
CommonUtils.handleResponseStatus(ret);
|
|
CommonUtils.handleResponseStatus(ret);
|
|
System.out.println("Collection created");
|
|
System.out.println("Collection created");
|
|
|
|
|
|
- // Insert entities by columns
|
|
|
|
- int rowCount = 10000;
|
|
|
|
|
|
+ // Specify an index type on the vector field.
|
|
|
|
+ ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
|
|
|
|
+ .withCollectionName(COLLECTION_NAME)
|
|
|
|
+ .withFieldName(VECTOR_FIELD)
|
|
|
|
+ .withIndexType(IndexType.IVF_FLAT)
|
|
|
|
+ .withMetricType(MetricType.L2)
|
|
|
|
+ .withExtraParam("{\"nlist\":128}")
|
|
|
|
+ .build());
|
|
|
|
+ CommonUtils.handleResponseStatus(ret);
|
|
|
|
+ System.out.println("Index created");
|
|
|
|
+
|
|
|
|
+ // Call loadCollection() to enable automatically loading data into memory for searching
|
|
|
|
+ ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
|
|
|
|
+ .withCollectionName(COLLECTION_NAME)
|
|
|
|
+ .build());
|
|
|
|
+ CommonUtils.handleResponseStatus(ret);
|
|
|
|
+ System.out.println("Collection loaded");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private static void dropCollection() {
|
|
|
|
+ milvusClient.dropCollection(DropCollectionParam.newBuilder()
|
|
|
|
+ .withCollectionName(COLLECTION_NAME)
|
|
|
|
+ .build());
|
|
|
|
+ System.out.println("Collection dropped");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private static void testFloat16(boolean bfloat16) {
|
|
|
|
+ DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
|
|
|
|
+ System.out.printf("============ testFloat16 %s ===================\n", dataType.name());
|
|
|
|
+
|
|
|
|
+ createCollection(bfloat16);
|
|
|
|
+
|
|
|
|
+ // Insert 5000 entities by columns
|
|
|
|
+ // Prepare original vectors, then encode into ByteBuffer
|
|
|
|
+ int batchRowCount = 5000;
|
|
|
|
+ List<List<Float>> originVectors = CommonUtils.generateFloatVectors(VECTOR_DIM, batchRowCount);
|
|
|
|
+ List<ByteBuffer> encodedVectors = CommonUtils.encodeFloat16Vectors(originVectors, bfloat16);
|
|
|
|
+
|
|
List<Long> ids = new ArrayList<>();
|
|
List<Long> ids = new ArrayList<>();
|
|
- for (long i = 0L; i < rowCount; ++i) {
|
|
|
|
|
|
+ for (long i = 0L; i < batchRowCount; ++i) {
|
|
ids.add(i);
|
|
ids.add(i);
|
|
}
|
|
}
|
|
- List<ByteBuffer> vectors = CommonUtils.generateFloat16Vectors(VECTOR_DIM, rowCount, bfloat16);
|
|
|
|
-
|
|
|
|
List<InsertParam.Field> fieldsInsert = new ArrayList<>();
|
|
List<InsertParam.Field> fieldsInsert = new ArrayList<>();
|
|
fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
|
|
fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
|
|
- fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
|
|
|
|
|
|
+ fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, encodedVectors));
|
|
|
|
|
|
R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
|
|
R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withFields(fieldsInsert)
|
|
.withFields(fieldsInsert)
|
|
.build());
|
|
.build());
|
|
CommonUtils.handleResponseStatus(insertR);
|
|
CommonUtils.handleResponseStatus(insertR);
|
|
|
|
+ System.out.println(ids.size() + " rows inserted");
|
|
|
|
|
|
- // Insert entities by rows
|
|
|
|
|
|
+ // Insert 5000 entities by rows
|
|
List<JsonObject> rows = new ArrayList<>();
|
|
List<JsonObject> rows = new ArrayList<>();
|
|
Gson gson = new Gson();
|
|
Gson gson = new Gson();
|
|
- for (long i = 1L; i <= rowCount; ++i) {
|
|
|
|
|
|
+ for (int i = 0; i < batchRowCount; ++i) {
|
|
JsonObject row = new JsonObject();
|
|
JsonObject row = new JsonObject();
|
|
- row.addProperty(ID_FIELD, rowCount + i);
|
|
|
|
- row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloat16Vector(VECTOR_DIM, bfloat16).array()));
|
|
|
|
|
|
+ row.addProperty(ID_FIELD, batchRowCount + i);
|
|
|
|
+
|
|
|
|
+ List<Float> originVector = CommonUtils.generateFloatVector(VECTOR_DIM);
|
|
|
|
+ originVectors.add(originVector);
|
|
|
|
+
|
|
|
|
+ ByteBuffer buf = CommonUtils.encodeFloat16Vector(originVector, bfloat16);
|
|
|
|
+ encodedVectors.add(buf);
|
|
|
|
+
|
|
|
|
+ row.add(VECTOR_FIELD, gson.toJsonTree(buf.array()));
|
|
rows.add(row);
|
|
rows.add(row);
|
|
}
|
|
}
|
|
|
|
|
|
@@ -125,39 +173,14 @@ public class Float16VectorExample {
|
|
.withRows(rows)
|
|
.withRows(rows)
|
|
.build());
|
|
.build());
|
|
CommonUtils.handleResponseStatus(insertR);
|
|
CommonUtils.handleResponseStatus(insertR);
|
|
|
|
+ System.out.println(ids.size() + " rows inserted");
|
|
|
|
|
|
- // Flush the data to storage for testing purpose
|
|
|
|
- // Note that no need to manually call flush interface in practice
|
|
|
|
- R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
|
|
|
|
- addCollectionName(COLLECTION_NAME).
|
|
|
|
- build());
|
|
|
|
- CommonUtils.handleResponseStatus(flushR);
|
|
|
|
- System.out.println("Entities inserted");
|
|
|
|
-
|
|
|
|
- // Specify an index type on the vector field.
|
|
|
|
- ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
|
|
|
|
- .withCollectionName(COLLECTION_NAME)
|
|
|
|
- .withFieldName(VECTOR_FIELD)
|
|
|
|
- .withIndexType(IndexType.IVF_FLAT)
|
|
|
|
- .withMetricType(MetricType.L2)
|
|
|
|
- .withExtraParam("{\"nlist\":128}")
|
|
|
|
- .build());
|
|
|
|
- CommonUtils.handleResponseStatus(ret);
|
|
|
|
- System.out.println("Index created");
|
|
|
|
-
|
|
|
|
- // Call loadCollection() to enable automatically loading data into memory for searching
|
|
|
|
- ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
|
|
|
|
- .withCollectionName(COLLECTION_NAME)
|
|
|
|
- .build());
|
|
|
|
- CommonUtils.handleResponseStatus(ret);
|
|
|
|
- System.out.println("Collection loaded");
|
|
|
|
-
|
|
|
|
- // Pick some vectors from the inserted vectors to search
|
|
|
|
|
|
+ // Pick some random vectors from the original vectors to search
|
|
// Ensure the returned top1 item's ID should be equal to target vector's ID
|
|
// Ensure the returned top1 item's ID should be equal to target vector's ID
|
|
for (int i = 0; i < 10; i++) {
|
|
for (int i = 0; i < 10; i++) {
|
|
Random ran = new Random();
|
|
Random ran = new Random();
|
|
- int k = ran.nextInt(rowCount);
|
|
|
|
- ByteBuffer targetVector = vectors.get(k);
|
|
|
|
|
|
+ int k = ran.nextInt(batchRowCount*2);
|
|
|
|
+ ByteBuffer targetVector = encodedVectors.get(k);
|
|
SearchParam.Builder builder = SearchParam.newBuilder()
|
|
SearchParam.Builder builder = SearchParam.newBuilder()
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withMetricType(MetricType.L2)
|
|
.withMetricType(MetricType.L2)
|
|
@@ -181,128 +204,142 @@ public class Float16VectorExample {
|
|
for (SearchResultsWrapper.IDScore score : scores) {
|
|
for (SearchResultsWrapper.IDScore score : scores) {
|
|
System.out.println(score);
|
|
System.out.println(score);
|
|
}
|
|
}
|
|
- if (scores.get(0).getLongID() != k) {
|
|
|
|
|
|
+
|
|
|
|
+ SearchResultsWrapper.IDScore firstScore = scores.get(0);
|
|
|
|
+ if (firstScore.getLongID() != k) {
|
|
throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
|
|
throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
|
|
- scores.get(0).getLongID(), k));
|
|
|
|
|
|
+ firstScore.getLongID(), k));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ByteBuffer outputBuf = (ByteBuffer)firstScore.get(VECTOR_FIELD);
|
|
|
|
+ if (!outputBuf.equals(targetVector)) {
|
|
|
|
+ throw new RuntimeException(String.format("The output vector is not equal to target vector: ID %d", k));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ List<Float> outputVector = CommonUtils.decodeFloat16Vector(outputBuf, bfloat16);
|
|
|
|
+ List<Float> originVector = originVectors.get(k);
|
|
|
|
+ for (int j = 0; j < outputVector.size(); j++) {
|
|
|
|
+ if (!isFloat16Eauql(outputVector.get(j), originVector.get(j), bfloat16)) {
|
|
|
|
+ throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
System.out.println("Search result is correct");
|
|
System.out.println("Search result is correct");
|
|
|
|
|
|
- // Retrieve some data
|
|
|
|
- int n = 99;
|
|
|
|
- R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
|
|
|
|
- .withCollectionName(COLLECTION_NAME)
|
|
|
|
- .withExpr(String.format("id == %d", n))
|
|
|
|
- .addOutField(VECTOR_FIELD)
|
|
|
|
- .build());
|
|
|
|
- CommonUtils.handleResponseStatus(queryR);
|
|
|
|
- QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
|
|
|
|
- FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
|
|
|
|
- List<?> r = field.getFieldData();
|
|
|
|
- if (r.isEmpty()) {
|
|
|
|
- throw new RuntimeException("The query result is empty");
|
|
|
|
- } else {
|
|
|
|
- ByteBuffer bf = (ByteBuffer) r.get(0);
|
|
|
|
- if (!bf.equals(vectors.get(n))) {
|
|
|
|
- throw new RuntimeException("The query result is incorrect");
|
|
|
|
|
|
+ // Retrieve some data and verify the output
|
|
|
|
+ for (int i = 0; i < 10; i++) {
|
|
|
|
+ Random ran = new Random();
|
|
|
|
+ int k = ran.nextInt(batchRowCount*2);
|
|
|
|
+ R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
|
|
|
|
+ .withCollectionName(COLLECTION_NAME)
|
|
|
|
+ .withExpr(String.format("id == %d", k))
|
|
|
|
+ .addOutField(VECTOR_FIELD)
|
|
|
|
+ .build());
|
|
|
|
+ CommonUtils.handleResponseStatus(queryR);
|
|
|
|
+ QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
|
|
|
|
+ FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
|
|
|
|
+ List<?> r = field.getFieldData();
|
|
|
|
+ if (r.isEmpty()) {
|
|
|
|
+ throw new RuntimeException("The query result is empty");
|
|
|
|
+ } else {
|
|
|
|
+ ByteBuffer outputBuf = (ByteBuffer) r.get(0);
|
|
|
|
+ ByteBuffer targetVector = encodedVectors.get(k);
|
|
|
|
+ if (!outputBuf.equals(targetVector)) {
|
|
|
|
+ throw new RuntimeException("The query result is incorrect");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ List<Float> outputVector = CommonUtils.decodeFloat16Vector(outputBuf, bfloat16);
|
|
|
|
+ List<Float> originVector = originVectors.get(k);
|
|
|
|
+ for (int j = 0; j < outputVector.size(); j++) {
|
|
|
|
+ if (!isFloat16Eauql(outputVector.get(j), originVector.get(j), bfloat16)) {
|
|
|
|
+ throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
System.out.println("Query result is correct");
|
|
System.out.println("Query result is correct");
|
|
|
|
|
|
- // insert a single row
|
|
|
|
- JsonObject row = new JsonObject();
|
|
|
|
- row.addProperty(ID_FIELD, 9999999);
|
|
|
|
- List<Float> newVector = CommonUtils.generateFloatVector(VECTOR_DIM);
|
|
|
|
- ByteBuffer vector16Buf = encodeTF(newVector, bfloat16);
|
|
|
|
- row.add(VECTOR_FIELD, gson.toJsonTree(vector16Buf.array()));
|
|
|
|
- insertR = milvusClient.insert(InsertParam.newBuilder()
|
|
|
|
|
|
+ // drop the collection if you don't need the collection anymore
|
|
|
|
+ dropCollection();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private static void testTensorflowFloat16(boolean bfloat16) {
|
|
|
|
+ DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
|
|
|
|
+ System.out.printf("============ testTensorflowFloat16 %s ===================\n", dataType.name());
|
|
|
|
+ createCollection(bfloat16);
|
|
|
|
+
|
|
|
|
+ // Prepare tensorflow vectors, convert to ByteBuffer and insert
|
|
|
|
+ int rowCount = 10000;
|
|
|
|
+ List<Long> ids = new ArrayList<>();
|
|
|
|
+ for (long i = 0L; i < rowCount; ++i) {
|
|
|
|
+ ids.add(i);
|
|
|
|
+ }
|
|
|
|
+ List<InsertParam.Field> fieldsInsert = new ArrayList<>();
|
|
|
|
+ fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
|
|
|
|
+
|
|
|
|
+ List<ByteBuffer> encodedVectors;
|
|
|
|
+ if (bfloat16) {
|
|
|
|
+ List<TBfloat16> tfVectors = CommonUtils.genTensorflowBF16Vectors(VECTOR_DIM, rowCount);
|
|
|
|
+ encodedVectors = CommonUtils.encodeTensorBF16Vectors(tfVectors);
|
|
|
|
+ } else {
|
|
|
|
+ List<TFloat16> tfVectors = CommonUtils.genTensorflowFP16Vectors(VECTOR_DIM, rowCount);
|
|
|
|
+ encodedVectors = CommonUtils.encodeTensorFP16Vectors(tfVectors);
|
|
|
|
+ }
|
|
|
|
+ fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, encodedVectors));
|
|
|
|
+
|
|
|
|
+ R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withCollectionName(COLLECTION_NAME)
|
|
- .withRows(Collections.singletonList(row))
|
|
|
|
|
|
+ .withFields(fieldsInsert)
|
|
.build());
|
|
.build());
|
|
CommonUtils.handleResponseStatus(insertR);
|
|
CommonUtils.handleResponseStatus(insertR);
|
|
|
|
+ System.out.println(ids.size() + " rows inserted");
|
|
|
|
|
|
- // retrieve the single row
|
|
|
|
- queryR = milvusClient.query(QueryParam.newBuilder()
|
|
|
|
|
|
+ // Retrieve some data and verify the output
|
|
|
|
+ Random ran = new Random();
|
|
|
|
+ int k = ran.nextInt(rowCount);
|
|
|
|
+ R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
|
|
.withCollectionName(COLLECTION_NAME)
|
|
.withCollectionName(COLLECTION_NAME)
|
|
- .withExpr("id == 9999999")
|
|
|
|
|
|
+ .withExpr(String.format("id == %d", k))
|
|
.addOutField(VECTOR_FIELD)
|
|
.addOutField(VECTOR_FIELD)
|
|
.build());
|
|
.build());
|
|
CommonUtils.handleResponseStatus(queryR);
|
|
CommonUtils.handleResponseStatus(queryR);
|
|
- queryWrapper = new QueryResultsWrapper(queryR.getData());
|
|
|
|
- field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
|
|
|
|
- r = field.getFieldData();
|
|
|
|
|
|
+ QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
|
|
|
|
+ FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
|
|
|
|
+ List<?> r = field.getFieldData();
|
|
if (r.isEmpty()) {
|
|
if (r.isEmpty()) {
|
|
- throw new RuntimeException("The retrieve result is empty");
|
|
|
|
- } else {
|
|
|
|
- ByteBuffer outBuf = (ByteBuffer) r.get(0);
|
|
|
|
- List<Float> outVector = decodeTF(outBuf, bfloat16);
|
|
|
|
- if (outVector.size() != newVector.size()) {
|
|
|
|
- throw new RuntimeException("The retrieve result is incorrect");
|
|
|
|
- }
|
|
|
|
- for (int i = 0; i < outVector.size(); i++) {
|
|
|
|
- if (!isFloat16Eauql(outVector.get(i), newVector.get(i), bfloat16)) {
|
|
|
|
- throw new RuntimeException("The retrieve result is incorrect");
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
|
|
+ throw new RuntimeException("The query result is empty");
|
|
}
|
|
}
|
|
- System.out.println("Retrieve result is correct");
|
|
|
|
|
|
|
|
- // drop the collection if you don't need the collection anymore
|
|
|
|
- milvusClient.dropCollection(DropCollectionParam.newBuilder()
|
|
|
|
- .withCollectionName(COLLECTION_NAME)
|
|
|
|
- .build());
|
|
|
|
- System.out.println("Collection dropped");
|
|
|
|
-
|
|
|
|
- milvusClient.close();
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private static ByteBuffer encodeTF(List<Float> vector, boolean bfloat16) {
|
|
|
|
- ByteBuffer buf = ByteBuffer.allocate(vector.size() * 2);
|
|
|
|
- for (Float value : vector) {
|
|
|
|
- ByteDataBuffer bf;
|
|
|
|
- if (bfloat16) {
|
|
|
|
- TBfloat16 tt = TBfloat16.scalarOf(value);
|
|
|
|
- bf = tt.asRawTensor().data();
|
|
|
|
- } else {
|
|
|
|
- TFloat16 tt = TFloat16.scalarOf(value);
|
|
|
|
- bf = tt.asRawTensor().data();
|
|
|
|
- }
|
|
|
|
- buf.put(bf.getByte(0));
|
|
|
|
- buf.put(bf.getByte(1));
|
|
|
|
|
|
+ ByteBuffer outputBuf = (ByteBuffer) r.get(0);
|
|
|
|
+ ByteBuffer originVector = encodedVectors.get(k);
|
|
|
|
+ if (!outputBuf.equals(originVector)) {
|
|
|
|
+ throw new RuntimeException("The query result is incorrect");
|
|
}
|
|
}
|
|
- return buf;
|
|
|
|
- }
|
|
|
|
|
|
|
|
- private static List<Float> decodeTF(ByteBuffer buf, boolean bfloat16) {
|
|
|
|
- int dim = buf.limit()/2;
|
|
|
|
- ByteDataBuffer bf = DataBuffers.of(buf.array());
|
|
|
|
- List<Float> vec = new ArrayList<>();
|
|
|
|
|
|
+ List<Float> vector = new ArrayList<>();
|
|
if (bfloat16) {
|
|
if (bfloat16) {
|
|
- TBfloat16 tf = Tensor.of(TBfloat16.class, Shape.of(dim), bf);
|
|
|
|
- for (long k = 0; k < tf.size(); k++) {
|
|
|
|
- vec.add(tf.getFloat(k));
|
|
|
|
|
|
+ TBfloat16 tf = CommonUtils.decodeTensorBF16Vector(outputBuf);
|
|
|
|
+ for (long i = 0; i < tf.size(); i++) {
|
|
|
|
+ vector.add(tf.getFloat(i));
|
|
}
|
|
}
|
|
} else {
|
|
} else {
|
|
- TFloat16 tf = Tensor.of(TFloat16.class, Shape.of(dim), bf);
|
|
|
|
- for (long k = 0; k < tf.size(); k++) {
|
|
|
|
- vec.add(tf.getFloat(k));
|
|
|
|
|
|
+ TFloat16 tf = CommonUtils.decodeTensorFP16Vector(outputBuf);
|
|
|
|
+ for (long i = 0; i < tf.size(); i++) {
|
|
|
|
+ vector.add(tf.getFloat(i));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ System.out.println(vector);
|
|
|
|
+ System.out.println("Query result is correct");
|
|
|
|
|
|
- return vec;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private static boolean isFloat16Eauql(Float a, Float b, boolean bfloat16) {
|
|
|
|
- if (bfloat16) {
|
|
|
|
- return Math.abs(a - b) <= 0.01f;
|
|
|
|
- } else {
|
|
|
|
- return Math.abs(a - b) <= 0.001f;
|
|
|
|
- }
|
|
|
|
|
|
+ // drop the collection if you don't need the collection anymore
|
|
|
|
+ dropCollection();
|
|
}
|
|
}
|
|
|
|
|
|
-
|
|
|
|
public static void main(String[] args) {
|
|
public static void main(String[] args) {
|
|
testFloat16(true);
|
|
testFloat16(true);
|
|
testFloat16(false);
|
|
testFloat16(false);
|
|
|
|
+
|
|
|
|
+ testTensorflowFloat16(true);
|
|
|
|
+ testTensorflowFloat16(false);
|
|
}
|
|
}
|
|
}
|
|
}
|