Browse Source

support setup output fields for vector search (#568)

* support setup output fields for vector search

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* remove console

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

---------

Signed-off-by: ryjiang <jiangruiyi@gmail.com>
ryjiang 1 year ago
parent
commit
376f76925b

+ 68 - 0
client/src/components/customSelector/CustomMultiSelector.tsx

@@ -0,0 +1,68 @@
+import { FC } from 'react';
+import {
+  FormControl,
+  InputLabel,
+  MenuItem,
+  Select,
+  Checkbox,
+} from '@material-ui/core';
+import { withStyles } from '@material-ui/core/styles';
+import { CustomMultiSelectorType } from './Types';
+import { generateId } from '../../utils/Common';
+import { render } from '@testing-library/react';
+
+const CustomMenuItem = withStyles({
+  root: {
+    minHeight: 'auto',
+    padding: '0 8px',
+    fontSize: '0.875rem',
+  },
+})(MenuItem);
+
+const CustomSelector: FC<CustomMultiSelectorType> = props => {
+  const {
+    label,
+    values,
+    onChange,
+    options,
+    classes,
+    variant,
+    wrapperClass = '',
+    labelClass = '',
+    size = 'small',
+    renderValue = selected => <>selected</>,
+    ...others
+  } = props;
+
+  const id = generateId('selector');
+
+  return (
+    <FormControl variant={variant} className={wrapperClass} size={size}>
+      {label && (
+        <InputLabel classes={{ root: labelClass }} htmlFor={id}>
+          {label}
+        </InputLabel>
+      )}
+      <Select
+        classes={{ ...classes }}
+        {...others}
+        multiple={true}
+        value={values}
+        onChange={onChange}
+        inputProps={{
+          id,
+        }}
+        renderValue={renderValue}
+      >
+        {options.map(v => (
+          <CustomMenuItem key={v.value} value={v.value}>
+            <Checkbox checked={values.indexOf(v.value as string) !== -1} />
+            {v.label}
+          </CustomMenuItem>
+        ))}
+      </Select>
+    </FormControl>
+  );
+};
+
+export default CustomSelector;

+ 5 - 0
client/src/components/customSelector/Types.ts

@@ -24,6 +24,11 @@ export type CustomSelectorType = SelectProps & {
   size?: 'small' | 'medium' | undefined;
 };
 
+export type CustomMultiSelectorType = Omit<CustomSelectorType, 'value'> & {
+  values: string[];
+  renderValue?: (selected: string[]) => React.ReactNode;
+};
+
 export interface ICustomGroupSelect {
   className?: string;
   options: GroupOption[];

+ 2 - 0
client/src/i18n/cn/search.ts

@@ -25,6 +25,8 @@ const searchTrans = {
   noSelectedVectorField: '至少选择一个向量字段进行搜索.',
   rerank: '排序器',
   groupBy: '分组',
+  outputFields: '输出字段',
+  consistency: '一致性',
 };
 
 export default searchTrans;

+ 2 - 0
client/src/i18n/en/search.ts

@@ -25,6 +25,8 @@ const searchTrans = {
   noSelectedVectorField: 'At least select one vector field to search.',
   rerank: 'Reranker',
   groupBy: 'Group By',
+  outputFields: 'Outputs',
+  consistency: 'Consistency',
 };
 
 export default searchTrans;

+ 1 - 0
client/src/pages/databases/Databases.tsx

@@ -97,6 +97,7 @@ const Databases = () => {
                 weightedParams: {
                   weights: Array(c.schema.vectorFields.length).fill(0.5),
                 },
+                output_fields: c.schema.scalarFields.map(s => s.name),
               },
               searchResult: null,
               searchLatency: 0,

+ 3 - 2
client/src/pages/databases/collections/search/Search.tsx

@@ -197,7 +197,7 @@ const Search = (props: CollectionDataProps) => {
     }
 
     const s = searchParams.collection.schema!;
-    const _outputFields = s.scalarFields.map(f => f.name);
+    const _outputFields = [...searchParams.globalParams.output_fields];
 
     if (s.enable_dynamic_field) {
       _outputFields.push(DYNAMIC_FIELD);
@@ -243,7 +243,7 @@ const Search = (props: CollectionDataProps) => {
           .filter(item => {
             // if primary key field name is id, don't filter it
             const invalidItems = primaryKeyField === 'id' ? [] : ['id'];
-            return !invalidItems.includes(item);
+            return !invalidItems.includes(item) && orderArray.includes(item);
           })
           .map(key => {
             // find the field
@@ -411,6 +411,7 @@ const Search = (props: CollectionDataProps) => {
                 setHighlightField('');
               }}
               fields={searchParams.collection.schema.scalarFields}
+              outputFields={searchParams.collection.schema.scalarFields}
               searchParams={searchParams}
               searchGlobalParams={searchParams.globalParams}
               handleFormChange={(params: any) => {

+ 41 - 3
client/src/pages/databases/collections/search/SearchGlobalParams.tsx

@@ -3,6 +3,8 @@ import { useTranslation } from 'react-i18next';
 import { Slider } from '@material-ui/core';
 import CustomInput from '@/components/customInput/CustomInput';
 import CustomSelector from '@/components/customSelector/CustomSelector';
+import CustomMultiSelector from '@/components/customSelector/CustomMultiSelector';
+
 import {
   CONSISTENCY_LEVEL_OPTIONS,
   TOP_K_OPTIONS,
@@ -19,6 +21,7 @@ export interface SearchGlobalProps {
   onSlideChange: (field: string) => void;
   onSlideChangeCommitted: () => void;
   fields: FieldObject[];
+  outputFields: FieldObject[];
 }
 
 const UNSPORTED_GROUPBY_TYPES = [
@@ -36,6 +39,7 @@ const SearchGlobalParams = (props: SearchGlobalProps) => {
     onSlideChange,
     onSlideChangeCommitted,
     fields,
+    outputFields,
   } = props;
   const selectedCount = searchParams.searchParams.filter(
     sp => sp.selected
@@ -44,8 +48,9 @@ const SearchGlobalParams = (props: SearchGlobalProps) => {
 
   // translations
   const { t: warningTrans } = useTranslation('warning');
-  const { t: collectionTrans } = useTranslation('collection');
+  const { t: commonTrans } = useTranslation();
   const { t: searchTrans } = useTranslation('search');
+  const gridTrans = commonTrans('grid');
 
   // UI functions
   const handleInputChange = useCallback(
@@ -76,7 +81,7 @@ const SearchGlobalParams = (props: SearchGlobalProps) => {
       <CustomSelector
         options={TOP_K_OPTIONS}
         value={searchGlobalParams.topK}
-        label={collectionTrans('topK')}
+        label={searchTrans('topK')}
         wrapperClass="selector"
         variant="filled"
         onChange={(e: { target: { value: unknown } }) => {
@@ -87,7 +92,7 @@ const SearchGlobalParams = (props: SearchGlobalProps) => {
       <CustomSelector
         options={CONSISTENCY_LEVEL_OPTIONS}
         value={searchGlobalParams.consistency_level}
-        label={collectionTrans('consistency')}
+        label={searchTrans('consistency')}
         wrapperClass="selector"
         variant="filled"
         onChange={(e: { target: { value: unknown } }) => {
@@ -96,6 +101,39 @@ const SearchGlobalParams = (props: SearchGlobalProps) => {
         }}
       />
 
+      <CustomMultiSelector
+        options={outputFields.map(f => {
+          return { label: f.name, value: f.name };
+        })}
+        values={searchGlobalParams.output_fields}
+        renderValue={selected => (
+          <span>{`${(selected as string[]).length} ${
+            gridTrans[(selected as string[]).length > 1 ? 'fields' : 'field']
+          }`}</span>
+        )}
+        label={searchTrans('outputFields')}
+        wrapperClass="selector"
+        variant="filled"
+        onChange={(e: { target: { value: unknown } }) => {
+          // add value to output fields if not exist, remove if exist
+          const outputFields = [...searchGlobalParams.output_fields];
+          const values = e.target.value as string[];
+          const newFields = values.filter(
+            v => !outputFields.includes(v as string)
+          );
+          const removeFields = outputFields.filter(
+            v => !values.includes(v as string)
+          );
+          outputFields.push(...newFields);
+          removeFields.forEach(f => {
+            const index = outputFields.indexOf(f);
+            outputFields.splice(index, 1);
+          });
+
+          handleInputChange('output_fields', outputFields);
+        }}
+      />
+
       {!showReranker && (
         <CustomSelector
           options={[{ label: '--', value: '' }, ...groupByOptions]}

+ 1 - 0
client/src/pages/databases/types.ts

@@ -18,6 +18,7 @@ export type GlobalParams = {
   weightedParams: { weights: number[] };
   round_decimal?: number;
   group_by_field?: string;
+  output_fields: string[];
 };
 
 export type SearchResultView = {