|
@@ -12,10 +12,11 @@ import org.apache.lucene.index.IndexOptions;
|
|
import org.apache.lucene.index.IndexableField;
|
|
import org.apache.lucene.index.IndexableField;
|
|
import org.apache.lucene.search.DocValuesFieldExistsQuery;
|
|
import org.apache.lucene.search.DocValuesFieldExistsQuery;
|
|
import org.apache.lucene.search.Query;
|
|
import org.apache.lucene.search.Query;
|
|
-import org.apache.lucene.util.ArrayUtil;
|
|
|
|
import org.apache.lucene.util.BytesRef;
|
|
import org.apache.lucene.util.BytesRef;
|
|
import org.elasticsearch.common.settings.Settings;
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
|
+import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
import org.elasticsearch.common.xcontent.XContentParser.Token;
|
|
import org.elasticsearch.common.xcontent.XContentParser.Token;
|
|
|
|
+import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
|
import org.elasticsearch.index.fielddata.IndexFieldData;
|
|
import org.elasticsearch.index.fielddata.IndexFieldData;
|
|
import org.elasticsearch.index.mapper.ArrayValueMapperParser;
|
|
import org.elasticsearch.index.mapper.ArrayValueMapperParser;
|
|
import org.elasticsearch.index.mapper.FieldMapper;
|
|
import org.elasticsearch.index.mapper.FieldMapper;
|
|
@@ -56,12 +57,28 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap
|
|
}
|
|
}
|
|
|
|
|
|
public static class Builder extends FieldMapper.Builder<Builder, DenseVectorFieldMapper> {
|
|
public static class Builder extends FieldMapper.Builder<Builder, DenseVectorFieldMapper> {
|
|
|
|
+ private int dims = 0;
|
|
|
|
|
|
public Builder(String name) {
|
|
public Builder(String name) {
|
|
super(name, Defaults.FIELD_TYPE, Defaults.FIELD_TYPE);
|
|
super(name, Defaults.FIELD_TYPE, Defaults.FIELD_TYPE);
|
|
builder = this;
|
|
builder = this;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ public Builder dims(int dims) {
|
|
|
|
+ if ((dims > MAX_DIMS_COUNT) || (dims < 1)) {
|
|
|
|
+ throw new MapperParsingException("The number of dimensions for field [" + name +
|
|
|
|
+ "] should be in the range [1, " + MAX_DIMS_COUNT + "]");
|
|
|
|
+ }
|
|
|
|
+ this.dims = dims;
|
|
|
|
+ return this;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ protected void setupFieldType(BuilderContext context) {
|
|
|
|
+ super.setupFieldType(context);
|
|
|
|
+ fieldType().setDims(dims);
|
|
|
|
+ }
|
|
|
|
+
|
|
@Override
|
|
@Override
|
|
public DenseVectorFieldType fieldType() {
|
|
public DenseVectorFieldType fieldType() {
|
|
return (DenseVectorFieldType) super.fieldType();
|
|
return (DenseVectorFieldType) super.fieldType();
|
|
@@ -80,11 +97,17 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap
|
|
@Override
|
|
@Override
|
|
public Mapper.Builder<?,?> parse(String name, Map<String, Object> node, ParserContext parserContext) throws MapperParsingException {
|
|
public Mapper.Builder<?,?> parse(String name, Map<String, Object> node, ParserContext parserContext) throws MapperParsingException {
|
|
DenseVectorFieldMapper.Builder builder = new DenseVectorFieldMapper.Builder(name);
|
|
DenseVectorFieldMapper.Builder builder = new DenseVectorFieldMapper.Builder(name);
|
|
- return builder;
|
|
|
|
|
|
+ Object dimsField = node.remove("dims");
|
|
|
|
+ if (dimsField == null) {
|
|
|
|
+ throw new MapperParsingException("The [dims] property must be specified for field [" + name + "].");
|
|
|
|
+ }
|
|
|
|
+ int dims = XContentMapValues.nodeIntegerValue(dimsField);
|
|
|
|
+ return builder.dims(dims);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
public static final class DenseVectorFieldType extends MappedFieldType {
|
|
public static final class DenseVectorFieldType extends MappedFieldType {
|
|
|
|
+ private int dims;
|
|
|
|
|
|
public DenseVectorFieldType() {}
|
|
public DenseVectorFieldType() {}
|
|
|
|
|
|
@@ -96,6 +119,14 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap
|
|
return new DenseVectorFieldType(this);
|
|
return new DenseVectorFieldType(this);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ int dims() {
|
|
|
|
+ return dims;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ void setDims(int dims) {
|
|
|
|
+ this.dims = dims;
|
|
|
|
+ }
|
|
|
|
+
|
|
@Override
|
|
@Override
|
|
public String typeName() {
|
|
public String typeName() {
|
|
return CONTENT_TYPE;
|
|
return CONTENT_TYPE;
|
|
@@ -145,28 +176,30 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap
|
|
if (context.externalValueSet()) {
|
|
if (context.externalValueSet()) {
|
|
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] can't be used in multi-fields");
|
|
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] can't be used in multi-fields");
|
|
}
|
|
}
|
|
|
|
+ int dims = fieldType().dims(); //number of vector dimensions
|
|
|
|
|
|
// encode array of floats as array of integers and store into buf
|
|
// encode array of floats as array of integers and store into buf
|
|
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
|
|
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
|
|
- byte[] buf = new byte[0];
|
|
|
|
|
|
+ byte[] buf = new byte[dims * INT_BYTES];
|
|
int offset = 0;
|
|
int offset = 0;
|
|
int dim = 0;
|
|
int dim = 0;
|
|
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
|
|
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
|
|
|
|
+ if (dim++ >= dims) {
|
|
|
|
+ throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
|
|
|
|
+ context.sourceToParse().id() + "] has exceeded the number of dimensions [" + dims + "] defined in mapping");
|
|
|
|
+ }
|
|
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
|
|
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
|
|
float value = context.parser().floatValue(true);
|
|
float value = context.parser().floatValue(true);
|
|
- if (buf.length < (offset + INT_BYTES)) {
|
|
|
|
- buf = ArrayUtil.grow(buf, (offset + INT_BYTES));
|
|
|
|
- }
|
|
|
|
int intValue = Float.floatToIntBits(value);
|
|
int intValue = Float.floatToIntBits(value);
|
|
- buf[offset] = (byte) (intValue >> 24);
|
|
|
|
- buf[offset+1] = (byte) (intValue >> 16);
|
|
|
|
- buf[offset+2] = (byte) (intValue >> 8);
|
|
|
|
- buf[offset+3] = (byte) intValue;
|
|
|
|
- offset += INT_BYTES;
|
|
|
|
- if (dim++ >= MAX_DIMS_COUNT) {
|
|
|
|
- throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
|
|
|
|
- "] has exceeded the maximum allowed number of dimensions of [" + MAX_DIMS_COUNT + "]");
|
|
|
|
- }
|
|
|
|
|
|
+ buf[offset++] = (byte) (intValue >> 24);
|
|
|
|
+ buf[offset++] = (byte) (intValue >> 16);
|
|
|
|
+ buf[offset++] = (byte) (intValue >> 8);
|
|
|
|
+ buf[offset++] = (byte) intValue;
|
|
|
|
+ }
|
|
|
|
+ if (dim != dims) {
|
|
|
|
+ throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
|
|
|
|
+ context.sourceToParse().id() + "] has number of dimensions [" + dim +
|
|
|
|
+ "] less than defined in the mapping [" + dims +"]");
|
|
}
|
|
}
|
|
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset));
|
|
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset));
|
|
if (context.doc().getByKey(fieldType().name()) != null) {
|
|
if (context.doc().getByKey(fieldType().name()) != null) {
|
|
@@ -176,6 +209,12 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap
|
|
context.doc().addWithKey(fieldType().name(), field);
|
|
context.doc().addWithKey(fieldType().name(), field);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Override
|
|
|
|
+ protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException {
|
|
|
|
+ super.doXContentBody(builder, includeDefaults, params);
|
|
|
|
+ builder.field("dims", fieldType().dims());
|
|
|
|
+ }
|
|
|
|
+
|
|
@Override
|
|
@Override
|
|
protected void parseCreateField(ParseContext context, List<IndexableField> fields) {
|
|
protected void parseCreateField(ParseContext context, List<IndexableField> fields) {
|
|
throw new AssertionError("parse is implemented directly");
|
|
throw new AssertionError("parse is implemented directly");
|