|
|
@@ -110,19 +110,20 @@ public class DistanceFunctionBenchmark {
|
|
|
private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {
|
|
|
|
|
|
final BytesRef docVector;
|
|
|
+ final float[] docFloatVector;
|
|
|
final float[] queryVector;
|
|
|
|
|
|
private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
|
|
|
super(dims);
|
|
|
|
|
|
- float[] docVector = new float[dims];
|
|
|
+ docFloatVector = new float[dims];
|
|
|
queryVector = new float[dims];
|
|
|
|
|
|
float docMagnitude = 0f;
|
|
|
float queryMagnitude = 0f;
|
|
|
|
|
|
for (int i = 0; i < dims; ++i) {
|
|
|
- docVector[i] = (float) (dims - i);
|
|
|
+ docFloatVector[i] = (float) (dims - i);
|
|
|
queryVector[i] = (float) i;
|
|
|
|
|
|
docMagnitude += (float) (dims - i);
|
|
|
@@ -136,11 +137,11 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
for (int i = 0; i < dims; ++i) {
|
|
|
if (normalize) {
|
|
|
- docVector[i] /= docMagnitude;
|
|
|
+ docFloatVector[i] /= docMagnitude;
|
|
|
queryVector[i] /= queryMagnitude;
|
|
|
}
|
|
|
|
|
|
- byteBuffer.putFloat(docVector[i]);
|
|
|
+ byteBuffer.putFloat(docFloatVector[i]);
|
|
|
}
|
|
|
|
|
|
byteBuffer.putFloat(docMagnitude);
|
|
|
@@ -178,6 +179,7 @@ public class DistanceFunctionBenchmark {
|
|
|
private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction {
|
|
|
|
|
|
final BytesRef docVector;
|
|
|
+ final byte[] vectorValue;
|
|
|
final byte[] queryVector;
|
|
|
|
|
|
final float queryMagnitude;
|
|
|
@@ -187,12 +189,14 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
ByteBuffer docVector = ByteBuffer.allocate(dims + 4);
|
|
|
queryVector = new byte[dims];
|
|
|
+ vectorValue = new byte[dims];
|
|
|
|
|
|
float docMagnitude = 0f;
|
|
|
float queryMagnitude = 0f;
|
|
|
|
|
|
for (int i = 0; i < dims; ++i) {
|
|
|
docVector.put((byte) (dims - i));
|
|
|
+ vectorValue[i] = (byte) (dims - i);
|
|
|
queryVector[i] = (byte) i;
|
|
|
|
|
|
docMagnitude += (float) (dims - i);
|
|
|
@@ -238,7 +242,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- new BinaryDenseVector(docVector, dims, Version.CURRENT).dotProduct(queryVector);
|
|
|
+ new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).dotProduct(queryVector);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -250,7 +254,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- new ByteBinaryDenseVector(docVector, dims).dotProduct(queryVector);
|
|
|
+ new ByteBinaryDenseVector(vectorValue, docVector, dims).dotProduct(queryVector);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -286,7 +290,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- new BinaryDenseVector(docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
|
|
|
+ new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -298,7 +302,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- new ByteBinaryDenseVector(docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
|
|
|
+ new ByteBinaryDenseVector(vectorValue, docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -334,7 +338,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
|
|
|
+ new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -346,7 +350,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- new ByteBinaryDenseVector(docVector, dims).l1Norm(queryVector);
|
|
|
+ new ByteBinaryDenseVector(vectorValue, docVector, dims).l1Norm(queryVector);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -382,7 +386,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
|
|
|
+ new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -394,7 +398,7 @@ public class DistanceFunctionBenchmark {
|
|
|
|
|
|
@Override
|
|
|
public void execute(Consumer<Object> consumer) {
|
|
|
- consumer.accept(new ByteBinaryDenseVector(docVector, dims).l2Norm(queryVector));
|
|
|
+ consumer.accept(new ByteBinaryDenseVector(vectorValue, docVector, dims).l2Norm(queryVector));
|
|
|
}
|
|
|
}
|
|
|
|