|
@@ -18,28 +18,30 @@
|
|
|
*/
|
|
|
|
|
|
import com.google.common.util.concurrent.ListenableFuture;
|
|
|
-import com.google.gson.JsonObject;
|
|
|
import io.milvus.client.*;
|
|
|
import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
import java.util.SplittableRandom;
|
|
|
import java.util.concurrent.ExecutionException;
|
|
|
import java.util.stream.Collectors;
|
|
|
import java.util.stream.DoubleStream;
|
|
|
+import org.json.JSONObject;
|
|
|
|
|
|
-// This is a simple example demonstrating how to use Milvus Java SDK.
|
|
|
-// For detailed API document, please refer to
|
|
|
+// This is a simple example demonstrating how to use Milvus Java SDK v0.9.0.
|
|
|
+// For detailed API documentation, please refer to
|
|
|
// https://milvus-io.github.io/milvus-sdk-java/javadoc/io/milvus/client/package-summary.html
|
|
|
-// You can also find more information on https://milvus.io/
|
|
|
+// You can also find more information on https://milvus.io/docs/overview.md
|
|
|
public class MilvusClientExample {
|
|
|
|
|
|
// Helper function that generates random vectors
|
|
|
static List<List<Float>> generateVectors(int vectorCount, int dimension) {
|
|
|
- SplittableRandom splitcollectionRandom = new SplittableRandom();
|
|
|
+ SplittableRandom splitCollectionRandom = new SplittableRandom();
|
|
|
List<List<Float>> vectors = new ArrayList<>(vectorCount);
|
|
|
for (int i = 0; i < vectorCount; ++i) {
|
|
|
- splitcollectionRandom = splitcollectionRandom.split();
|
|
|
- DoubleStream doubleStream = splitcollectionRandom.doubles(dimension);
|
|
|
+ splitCollectionRandom = splitCollectionRandom.split();
|
|
|
+ DoubleStream doubleStream = splitCollectionRandom.doubles(dimension);
|
|
|
List<Float> vector =
|
|
|
doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
|
|
|
vectors.add(vector);
|
|
@@ -56,7 +58,65 @@ public class MilvusClientExample {
|
|
|
return vector;
|
|
|
}
|
|
|
|
|
|
- public static void main(String[] args) throws InterruptedException {
|
|
|
+ // Helper function that generates default fields list for a collection
|
|
|
+ // In this example, we have 3 fields with names "int64", "float" and "float_vec".
|
|
|
+ // Their DataType must also be defined.
|
|
|
+ static List<Map<String, Object>> generateDefaultFields(int dimension) {
|
|
|
+ List<Map<String, Object>> fieldList = new ArrayList<>();
|
|
|
+ Map<String, Object> intField = new HashMap<>();
|
|
|
+ intField.put("field", "int64");
|
|
|
+ intField.put("type", DataType.INT64);
|
|
|
+
|
|
|
+ Map<String, Object> floatField = new HashMap<>();
|
|
|
+ floatField.put("field", "float");
|
|
|
+ floatField.put("type", DataType.FLOAT);
|
|
|
+
|
|
|
+ Map<String, Object> vecField = new HashMap<>();
|
|
|
+ vecField.put("field", "float_vec");
|
|
|
+ vecField.put("type", DataType.VECTOR_FLOAT);
|
|
|
+ JSONObject jsonObject = new JSONObject();
|
|
|
+ jsonObject.put("dim", dimension);
|
|
|
+ vecField.put("params", jsonObject.toString());
|
|
|
+
|
|
|
+ fieldList.add(intField);
|
|
|
+ fieldList.add(floatField);
|
|
|
+ fieldList.add(vecField);
|
|
|
+ return fieldList;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Helper function that generates entity field values for inserting into a collection
|
|
|
+ // This corresponds to the above function that initializes fields.
|
|
|
+ static List<Map<String, Object>> generateDefaultFieldValues(int vectorCount, List<List<Float>> vectors) {
|
|
|
+ List<Map<String, Object>> fieldList = new ArrayList<>();
|
|
|
+ Map<String, Object> intField = new HashMap<>();
|
|
|
+ intField.put("field", "int64");
|
|
|
+ intField.put("type", DataType.INT64);
|
|
|
+
|
|
|
+ Map<String, Object> floatField = new HashMap<>();
|
|
|
+ floatField.put("field", "float");
|
|
|
+ floatField.put("type", DataType.FLOAT);
|
|
|
+
|
|
|
+ Map<String, Object> vecField = new HashMap<>();
|
|
|
+ vecField.put("field", "float_vec");
|
|
|
+ vecField.put("type", DataType.VECTOR_FLOAT);
|
|
|
+
|
|
|
+ List<Long> intValues = new ArrayList<>(vectorCount);
|
|
|
+ List<Float> floatValues = new ArrayList<>(vectorCount);
|
|
|
+ for (int i = 0; i < vectorCount; i++) {
|
|
|
+ intValues.add((long) i);
|
|
|
+ floatValues.add((float) i);
|
|
|
+ }
|
|
|
+ intField.put("values", intValues);
|
|
|
+ floatField.put("values", floatValues);
|
|
|
+ vecField.put("values", vectors);
|
|
|
+
|
|
|
+ fieldList.add(intField);
|
|
|
+ fieldList.add(floatField);
|
|
|
+ fieldList.add(vecField);
|
|
|
+ return fieldList;
|
|
|
+ }
|
|
|
+
|
|
|
+ public static void main(String[] args) throws InterruptedException, ConnectFailedException {
|
|
|
|
|
|
// You may need to change the following to the host and port of your Milvus server
|
|
|
String host = "localhost";
|
|
@@ -66,18 +126,26 @@ public class MilvusClientExample {
|
|
|
port = Integer.parseInt(args[1]);
|
|
|
}
|
|
|
|
|
|
+ // Create Milvus client
|
|
|
+ MilvusClient client = new MilvusGrpcClient();
|
|
|
+
|
|
|
+ // Connect to Milvus server
|
|
|
ConnectParam connectParam = new ConnectParam.Builder().withHost(host).withPort(port).build();
|
|
|
- MilvusClient client = new MilvusGrpcClient(connectParam);
|
|
|
+ try {
|
|
|
+ Response connectResponse = client.connect(connectParam);
|
|
|
+ } catch (ConnectFailedException e) {
|
|
|
+ System.out.println("Failed to connect to Milvus server: " + e.toString());
|
|
|
+ throw e;
|
|
|
+ }
|
|
|
|
|
|
// Create a collection with the following collection mapping
|
|
|
final String collectionName = "example"; // collection name
|
|
|
final int dimension = 128; // dimension of each vector
|
|
|
- final int indexFileSize = 1024; // maximum size (in MB) of each index file
|
|
|
- final MetricType metricType = MetricType.IP; // we choose IP (Inner Product) as our metric type
|
|
|
+ // we choose IP (Inner Product) as our metric type
|
|
|
CollectionMapping collectionMapping =
|
|
|
- new CollectionMapping.Builder(collectionName, dimension)
|
|
|
- .withIndexFileSize(indexFileSize)
|
|
|
- .withMetricType(metricType)
|
|
|
+ new CollectionMapping.Builder(collectionName)
|
|
|
+ .withFields(generateDefaultFields(dimension))
|
|
|
+ .withParamsInJson("{\"segment_row_count\": 50000, \"auto_id\": true}")
|
|
|
.build();
|
|
|
Response createCollectionResponse = client.createCollection(collectionMapping);
|
|
|
|
|
@@ -87,42 +155,37 @@ public class MilvusClientExample {
|
|
|
// Get collection info
|
|
|
GetCollectionInfoResponse getCollectionInfoResponse = client.getCollectionInfo(collectionName);
|
|
|
|
|
|
- // Insert randomly generated vectors to collection
|
|
|
+ // Insert randomly generated field values to collection
|
|
|
final int vectorCount = 100000;
|
|
|
List<List<Float>> vectors = generateVectors(vectorCount, dimension);
|
|
|
vectors =
|
|
|
vectors.stream().map(MilvusClientExample::normalizeVector).collect(Collectors.toList());
|
|
|
+ List<Map<String, Object>> defaultFieldValues = generateDefaultFieldValues(vectorCount, vectors);
|
|
|
InsertParam insertParam =
|
|
|
- new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
|
|
|
+ new InsertParam.Builder(collectionName)
|
|
|
+ .withFields(defaultFieldValues)
|
|
|
+ .build();
|
|
|
InsertResponse insertResponse = client.insert(insertParam);
|
|
|
- // Insert returns a list of vector ids that you will be using (if you did not supply them
|
|
|
- // yourself) to reference the vectors you just inserted
|
|
|
- List<Long> vectorIds = insertResponse.getVectorIds();
|
|
|
+ // Insert returns a list of entity ids that you will be using (if you did not supply them
|
|
|
+ // yourself) to reference the entities you just inserted
|
|
|
+ List<Long> vectorIds = insertResponse.getEntityIds();
|
|
|
|
|
|
// Flush data in collection
|
|
|
Response flushResponse = client.flush(collectionName);
|
|
|
|
|
|
// Get current entity count of collection
|
|
|
- CountEntitiesResponse ountEntitiesResponse = client.countEntities(collectionName);
|
|
|
+ CountEntitiesResponse countEntitiesResponse = client.countEntities(collectionName);
|
|
|
|
|
|
// Create index for the collection
|
|
|
- // We choose IVF_SQ8 as our index type here. Refer to IndexType javadoc for a
|
|
|
- // complete explanation of different index types
|
|
|
- final IndexType indexType = IndexType.IVF_SQ8;
|
|
|
- // Each index type has its optional parameters you can set. Refer to the Milvus documentation
|
|
|
- // for how to set the optimal parameters based on your needs.
|
|
|
- JsonObject indexParamsJson = new JsonObject();
|
|
|
- indexParamsJson.addProperty("nlist", 16384);
|
|
|
+ // We choose IVF_SQ8 as our index type here. Refer to Milvus documentation for a
|
|
|
+ // complete explanation of different index types and their relative parameters.
|
|
|
Index index =
|
|
|
- new Index.Builder(collectionName, indexType)
|
|
|
- .withParamsInJson(indexParamsJson.toString())
|
|
|
+ new Index.Builder(collectionName, "float_vec")
|
|
|
+ .withParamsInJson("{\"index_type\": \"IVF_SQ8\", \"metric_type\": \"L2\", "
|
|
|
+ + "\"params\": {\"nlist\": 2048}}")
|
|
|
.build();
|
|
|
Response createIndexResponse = client.createIndex(index);
|
|
|
|
|
|
- // Get index info for your collection
|
|
|
- GetIndexInfoResponse getIndexInfoResponse = client.getIndexInfo(collectionName);
|
|
|
- System.out.format("Index Info: %s\n", getIndexInfoResponse.getIndex().get().toString());
|
|
|
-
|
|
|
// Get collection info
|
|
|
Response getCollectionStatsResponse = client.getCollectionStats(collectionName);
|
|
|
if (getCollectionStatsResponse.ok()) {
|
|
@@ -138,32 +201,41 @@ public class MilvusClientExample {
|
|
|
throw new AssertionError("Wrong results!");
|
|
|
}
|
|
|
|
|
|
- // Search vectors
|
|
|
- // Searching the first 5 vectors of the vectors we just inserted
|
|
|
+ // Search entities using DSL statement.
|
|
|
+ // Searching the first 5 entities we just inserted by including them in DSL.
|
|
|
final int searchBatchSize = 5;
|
|
|
List<List<Float>> vectorsToSearch = vectors.subList(0, searchBatchSize);
|
|
|
final long topK = 10;
|
|
|
// Based on the index you created, the available search parameters will be different. Refer to
|
|
|
// the Milvus documentation for how to set the optimal parameters based on your needs.
|
|
|
- JsonObject searchParamsJson = new JsonObject();
|
|
|
- searchParamsJson.addProperty("nprobe", 20);
|
|
|
+ String dsl = String.format(
|
|
|
+ "{\"bool\": {"
|
|
|
+ + "\"must\": [{"
|
|
|
+ + " \"range\": {"
|
|
|
+ + " \"float\": {\"GT\": -10, \"LT\": 100}"
|
|
|
+ + " }},{"
|
|
|
+ + " \"vector\": {"
|
|
|
+ + " \"float_vec\": {"
|
|
|
+ + " \"topk\": %d, \"metric_type\": \"IP\", \"type\": \"float\", \"query\": "
|
|
|
+ + "%s, \"params\": {\"nprobe\": 50}"
|
|
|
+ + " }}}]}}",
|
|
|
+ topK, vectorsToSearch.toString());
|
|
|
SearchParam searchParam =
|
|
|
new SearchParam.Builder(collectionName)
|
|
|
- .withFloatVectors(vectorsToSearch)
|
|
|
- .withTopK(topK)
|
|
|
- .withParamsInJson(searchParamsJson.toString())
|
|
|
+ .withDSL(dsl)
|
|
|
+ .withParamsInJson("{\"fields\": [\"int64\", \"float\"]}")
|
|
|
.build();
|
|
|
SearchResponse searchResponse = client.search(searchParam);
|
|
|
if (searchResponse.ok()) {
|
|
|
List<List<SearchResponse.QueryResult>> queryResultsList =
|
|
|
searchResponse.getQueryResultsList();
|
|
|
- final double epsilon = 0.001;
|
|
|
+ final double epsilon = 0.01;
|
|
|
for (int i = 0; i < searchBatchSize; i++) {
|
|
|
// Since we are searching for vector that is already present in the collection,
|
|
|
// the first result vector should be itself and the distance (inner product) should be
|
|
|
// very close to 1 (some precision is lost during the process)
|
|
|
SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
|
|
|
- if (firstQueryResult.getVectorId() != vectorIds.get(i)
|
|
|
+ if (firstQueryResult.getEntityId() != vectorIds.get(i)
|
|
|
|| Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
|
|
|
throw new AssertionError("Wrong results!");
|
|
|
}
|
|
@@ -183,31 +255,36 @@ public class MilvusClientExample {
|
|
|
e.printStackTrace();
|
|
|
}
|
|
|
|
|
|
- // Delete the first 5 vectors you just searched
|
|
|
+ // Delete the first 5 entities you just searched
|
|
|
Response deleteByIdsResponse =
|
|
|
client.deleteEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
|
|
|
flushResponse = client.flush(collectionName);
|
|
|
|
|
|
- // Try to get the corresponding vector of the first id you just deleted.
|
|
|
+ // After deleting them, we call getEntityByID and obviously all 5 entities should not be returned.
|
|
|
GetEntityByIDResponse getEntityByIDResponse =
|
|
|
client.getEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
|
|
|
- // Obviously you won't get anything
|
|
|
- if (!getEntityByIDResponse.getFloatVectors().get(0).isEmpty()) {
|
|
|
+ if (getEntityByIDResponse.getValidIds().size() > 0) {
|
|
|
throw new AssertionError("This can never happen!");
|
|
|
}
|
|
|
|
|
|
// Compact the collection, erase deleted data from disk and rebuild index in background (if
|
|
|
// the data size after compaction is still larger than indexFileSize). Data was only
|
|
|
// soft-deleted until you call compact.
|
|
|
- Response compactResponse = client.compact(collectionName);
|
|
|
+ Response compactResponse = client.compact(
|
|
|
+ new CompactParam.Builder(collectionName).withThreshold(0.2).build());
|
|
|
|
|
|
// Drop index for the collection
|
|
|
- Response dropIndexResponse = client.dropIndex(collectionName);
|
|
|
+ Response dropIndexResponse = client.dropIndex(collectionName, "float_vec");
|
|
|
|
|
|
// Drop collection
|
|
|
Response dropCollectionResponse = client.dropCollection(collectionName);
|
|
|
|
|
|
// Disconnect from Milvus server
|
|
|
- client.close();
|
|
|
+ try {
|
|
|
+ Response disconnectResponse = client.disconnect();
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ System.out.println("Failed to disconnect: " + e.toString());
|
|
|
+ throw e;
|
|
|
+ }
|
|
|
}
|
|
|
}
|