VectorSearch.tsx 19 KB


  1. import { useCallback, useEffect, useMemo, useState, useContext } from 'react';
  2. import { TextField, Typography, Button } from '@material-ui/core';
  3. import { useTranslation } from 'react-i18next';
  4. import { useLocation } from 'react-router-dom';
  5. import { ALL_ROUTER_TYPES } from '@/router/Types';
  6. import {
  7. useNavigationHook,
  8. useSearchResult,
  9. usePaginationHook,
  10. useTimeTravelHook,
  11. } from '@/hooks';
  12. import { databaseContext } from '@/context';
  13. import CustomSelector from '@/components/customSelector/CustomSelector';
  14. import { ColDefinitionsType } from '@/components/grid/Types';
  15. import AttuGrid from '@/components/grid/Grid';
  16. import EmptyCard from '@/components/cards/EmptyCard';
  17. import icons from '@/components/icons/Icons';
  18. import CustomButton from '@/components/customButton/CustomButton';
  19. import SimpleMenu from '@/components/menu/SimpleMenu';
  20. import { Option } from '@/components/customSelector/Types';
  21. import Filter from '@/components/advancedSearch';
  22. import { Field } from '@/components/advancedSearch/Types';
  23. import { CustomDatePicker } from '@/components/customDatePicker/CustomDatePicker';
  24. import { CollectionHttp, IndexHttp } from '@/http';
  25. import {
  26. parseValue,
  27. parseLocationSearch,
  28. classifyFields,
  29. getDefaultIndexType,
  30. getEmbeddingType,
  31. getNonVectorFieldsForFilter,
  32. getVectorFieldOptions,
  33. cloneObj,
  34. generateVector,
  35. } from '@/utils';
  36. import { LOADING_STATE, DEFAULT_METRIC_VALUE_MAP } from '@/consts';
  37. import { getLabelDisplayedRows } from './Utils';
  38. import SearchParams from './SearchParams';
  39. import { getVectorSearchStyles } from './Styles';
  40. import { CollectionData, DataTypeEnum } from '../collections/Types';
  41. import { TOP_K_OPTIONS } from './Constants';
  42. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  43. const VectorSearch = () => {
  44. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  45. const location = useLocation();
  46. const { database } = useContext(databaseContext);
  47. // i18n
  48. const { t: searchTrans } = useTranslation('search');
  49. const { t: btnTrans } = useTranslation('btn');
  50. const classes = getVectorSearchStyles();
  51. // data stored inside the component
  52. const [tableLoading, setTableLoading] = useState<boolean>(false);
  53. const [collections, setCollections] = useState<CollectionData[]>([]);
  54. const [selectedCollection, setSelectedCollection] = useState<string>('');
  55. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  56. // fields for advanced filter
  57. const [filterFields, setFilterFields] = useState<Field[]>([]);
  58. const [selectedField, setSelectedField] = useState<string>('');
  59. // search params form
  60. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  61. {}
  62. );
  63. // search params disable state
  64. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  65. // use null as init value before search, empty array means no results
  66. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  67. null
  68. );
  69. // default topK is 100
  70. const [topK, setTopK] = useState<number>(100);
  71. const [expression, setExpression] = useState<string>('');
  72. const [vectors, setVectors] = useState<string>('');
  73. // latency
  74. const [latency, setLatency] = useState<number>(0);
  75. const searchResultMemo = useSearchResult(searchResult as any, classes);
  76. const {
  77. pageSize,
  78. handlePageSize,
  79. currentPage,
  80. handleCurrentPage,
  81. total,
  82. data: result,
  83. order,
  84. orderBy,
  85. handleGridSort,
  86. } = usePaginationHook(searchResultMemo || []);
  87. const { timeTravel, setTimeTravel, timeTravelInfo, handleDateTimeChange } =
  88. useTimeTravelHook();
  89. const collectionOptions: Option[] = useMemo(
  90. () =>
  91. collections.map(c => ({
  92. label: c._name,
  93. value: c._name,
  94. })),
  95. [collections]
  96. );
  97. const outputFields: string[] = useMemo(() => {
  98. const fields =
  99. collections.find(c => c._name === selectedCollection)?._fields || [];
  100. // vector field can't be output fields
  101. const invalidTypes = ['BinaryVector', 'FloatVector'];
  102. const nonVectorFields = fields.filter(
  103. field => !invalidTypes.includes(field._fieldType)
  104. );
  105. return nonVectorFields.map(f => f._fieldName);
  106. }, [selectedCollection, collections]);
  107. const primaryKeyField = useMemo(() => {
  108. const selectedCollectionInfo = collections.find(
  109. c => c._name === selectedCollection
  110. );
  111. const fields = selectedCollectionInfo?._fields || [];
  112. return fields.find(f => f._isPrimaryKey)?._fieldName;
  113. }, [selectedCollection, collections]);
  114. const orderArray = [primaryKeyField, 'id', 'score', ...outputFields];
  115. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  116. /**
  117. * id represents primary key, score represents distance
  118. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  119. * we filter 'id' and 'score' to avoid redundant data
  120. */
  121. return searchResult && searchResult.length > 0
  122. ? Object.keys(searchResult[0])
  123. .sort((a, b) => {
  124. const indexA = orderArray.indexOf(a);
  125. const indexB = orderArray.indexOf(b);
  126. return indexA - indexB;
  127. })
  128. .filter(item => {
  129. // if primary key field name is id, don't filter it
  130. const invalidItems = primaryKeyField === 'id' ? [] : ['id'];
  131. return !invalidItems.includes(item);
  132. })
  133. .map(key => ({
  134. id: key,
  135. align: 'left',
  136. disablePadding: false,
  137. label: key,
  138. needCopy: primaryKeyField === key,
  139. }))
  140. : [];
  141. }, [searchResult, primaryKeyField, orderArray]);
  142. const [selectedMetricType, setSelectedMetricType] = useState<string>('');
  143. const {
  144. indexType,
  145. indexParams,
  146. fieldType,
  147. embeddingType,
  148. selectedFieldDimension,
  149. } = useMemo(() => {
  150. if (selectedField !== '') {
  151. // field options must contain selected field, so selectedFieldInfo will never undefined
  152. const selectedFieldInfo = fieldOptions.find(
  153. f => f.value === selectedField
  154. );
  155. const index = selectedFieldInfo?.indexInfo;
  156. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  157. const metric =
  158. index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  159. const indexParams = index?._indexParameterPairs || [];
  160. const dim = selectedFieldInfo?.dimension || 0;
  161. setSelectedMetricType(metric);
  162. return {
  163. metricType: metric,
  164. indexType: index?._indexType || getDefaultIndexType(embeddingType),
  165. indexParams,
  166. fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
  167. embeddingType,
  168. selectedFieldDimension: dim,
  169. };
  170. }
  171. return {
  172. indexType: '',
  173. indexParams: [],
  174. fieldType: 0,
  175. embeddingType: DataTypeEnum.FloatVector,
  176. selectedFieldDimension: 0,
  177. };
  178. }, [selectedField, fieldOptions]);
  179. /**
  180. * vector value validation
  181. * @return whether is valid
  182. */
  183. const vectorValueValid = useMemo(() => {
  184. // if user hasn't input value or not select field, don't trigger validation check
  185. if (vectors === '' || selectedFieldDimension === 0) {
  186. return true;
  187. }
  188. const dim =
  189. fieldType === DataTypeEnum.BinaryVector
  190. ? selectedFieldDimension / 8
  191. : selectedFieldDimension;
  192. const value = parseValue(vectors);
  193. const isArray = Array.isArray(value);
  194. return isArray && value.length === dim;
  195. }, [vectors, selectedFieldDimension, fieldType]);
  196. const searchDisabled = useMemo(() => {
  197. /**
  198. * before search, user must:
  199. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  200. * 2. choose collection and field
  201. * 3. set extra search params
  202. */
  203. const isInvalid =
  204. vectors === '' ||
  205. selectedCollection === '' ||
  206. selectedField === '' ||
  207. paramDisabled ||
  208. !vectorValueValid;
  209. return isInvalid;
  210. }, [
  211. paramDisabled,
  212. selectedField,
  213. selectedCollection,
  214. vectors,
  215. vectorValueValid,
  216. ]);
  217. // fetch data
  218. const fetchCollections = useCallback(async () => {
  219. const collections = await CollectionHttp.getCollections();
  220. setCollections(collections.filter(c => c._status === LOADING_STATE.LOADED));
  221. }, [database]);
  222. const fetchFieldsWithIndex = useCallback(
  223. async (collectionName: string, collections: CollectionData[]) => {
  224. const fields =
  225. collections.find(c => c._name === collectionName)?._fields || [];
  226. const indexes = await IndexHttp.getIndexInfo(collectionName);
  227. const { vectorFields, nonVectorFields } = classifyFields(fields);
  228. // only vector type fields can be select
  229. const fieldOptions = getVectorFieldOptions(vectorFields, indexes);
  230. setFieldOptions(fieldOptions);
  231. if (fieldOptions.length > 0) {
  232. // set first option value as default field value
  233. const [{ value: defaultFieldValue }] = fieldOptions;
  234. setSelectedField(defaultFieldValue as string);
  235. }
  236. // only non vector type fields can be advanced filter
  237. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  238. setFilterFields(filterFields);
  239. },
  240. [collections]
  241. );
  242. useEffect(() => {
  243. fetchCollections();
  244. }, [fetchCollections]);
  245. // clear selection if database is changed
  246. useEffect(() => {
  247. setSelectedCollection('');
  248. }, [database]);
  249. // get field options with index when selected collection changed
  250. useEffect(() => {
  251. if (selectedCollection !== '') {
  252. fetchFieldsWithIndex(selectedCollection, collections);
  253. }
  254. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  255. // set default collection value if is from overview page
  256. useEffect(() => {
  257. if (location.search && collections.length > 0) {
  258. const { collectionName } = parseLocationSearch(location.search);
  259. // collection name validation
  260. const isNameValid = collections
  261. .map(c => c._name)
  262. .includes(collectionName);
  263. isNameValid && setSelectedCollection(collectionName);
  264. }
  265. }, [location, collections]);
  266. // icons
  267. const VectorSearchIcon = icons.vectorSearch;
  268. const ResetIcon = icons.refresh;
  269. const ArrowIcon = icons.dropdown;
  270. // methods
  271. const handlePageChange = (e: any, page: number) => {
  272. handleCurrentPage(page);
  273. };
  274. const handleReset = () => {
  275. /**
  276. * reset search includes:
  277. * 1. reset vectors
  278. * 2. reset selected collection and field
  279. * 3. reset search params
  280. * 4. reset advanced filter expression
  281. * 5. clear search result
  282. */
  283. setVectors('');
  284. setSelectedField('');
  285. setSelectedCollection('');
  286. setSearchResult(null);
  287. setFilterFields([]);
  288. setExpression('');
  289. setTimeTravel(null);
  290. };
  291. const handleSearch = async (topK: number, expr = expression) => {
  292. const clonedSearchParams = cloneObj(searchParam);
  293. delete clonedSearchParams.round_decimal;
  294. const searchParamPairs = {
  295. params: JSON.stringify(clonedSearchParams),
  296. anns_field: selectedField,
  297. topk: topK,
  298. metric_type: selectedMetricType,
  299. round_decimal: searchParam.round_decimal,
  300. };
  301. const params: VectorSearchParam = {
  302. output_fields: outputFields,
  303. expr,
  304. search_params: searchParamPairs,
  305. vectors: [parseValue(vectors)],
  306. vector_type: fieldType,
  307. travel_timestamp: timeTravelInfo.timestamp,
  308. };
  309. setTableLoading(true);
  310. try {
  311. const res = await CollectionHttp.vectorSearchData(
  312. selectedCollection,
  313. params
  314. );
  315. setTableLoading(false);
  316. setSearchResult(res.results);
  317. setLatency(res.latency);
  318. } catch (err) {
  319. setTableLoading(false);
  320. }
  321. };
  322. const handleAdvancedFilterChange = (expression: string) => {
  323. setExpression(expression);
  324. if (!searchDisabled) {
  325. handleSearch(topK, expression);
  326. }
  327. };
  328. const handleVectorChange = (value: string) => {
  329. setVectors(value);
  330. };
  331. const fillWithExampleVector = (selectedFieldDimension: number) => {
  332. const v = generateVector(selectedFieldDimension);
  333. setVectors(`[${v}]`);
  334. };
  335. return (
  336. <section className="page-wrapper">
  337. {/* form section */}
  338. <form className={classes.form}>
  339. {/* collection and field selectors */}
  340. <fieldset className="field">
  341. <Typography className="text">{searchTrans('secondTip')}</Typography>
  342. <CustomSelector
  343. options={collectionOptions}
  344. wrapperClass={classes.selector}
  345. variant="filled"
  346. label={searchTrans(
  347. collectionOptions.length === 0 ? 'noCollection' : 'collection'
  348. )}
  349. disabled={collectionOptions.length === 0}
  350. value={selectedCollection}
  351. onChange={(e: { target: { value: unknown } }) => {
  352. const collection = e.target.value;
  353. setSelectedCollection(collection as string);
  354. // every time selected collection changed, reset field
  355. setSelectedField('');
  356. setSearchResult([]);
  357. }}
  358. />
  359. <CustomSelector
  360. options={fieldOptions}
  361. // readOnly can't avoid all events, so we use disabled instead
  362. disabled={selectedCollection === ''}
  363. wrapperClass={classes.selector}
  364. variant="filled"
  365. label={searchTrans('field')}
  366. value={selectedField}
  367. onChange={(e: { target: { value: unknown } }) => {
  368. const field = e.target.value;
  369. setSelectedField(field as string);
  370. }}
  371. />
  372. </fieldset>
  373. {/**
  374. * vector value textarea
  375. * use field-params class because it also has error msg if invalid
  376. */}
  377. <fieldset className="field field-params field-second">
  378. <Typography className="text">
  379. {searchTrans('firstTip', {
  380. dimensionTip:
  381. selectedFieldDimension !== 0
  382. ? `(dimension: ${selectedFieldDimension})`
  383. : '',
  384. })}
  385. {selectedFieldDimension !== 0 ? (
  386. <Button
  387. variant="outlined"
  388. size="small"
  389. onClick={() => {
  390. const dim =
  391. fieldType === DataTypeEnum.BinaryVector
  392. ? selectedFieldDimension / 8
  393. : selectedFieldDimension;
  394. fillWithExampleVector(dim);
  395. }}
  396. >
  397. {btnTrans('example')}
  398. </Button>
  399. ) : null}
  400. </Typography>
  401. <TextField
  402. className="textarea"
  403. InputProps={{
  404. classes: {
  405. root: 'textfield',
  406. multiline: 'multiline',
  407. },
  408. }}
  409. multiline
  410. rows={5}
  411. placeholder={searchTrans('vectorPlaceholder')}
  412. value={vectors}
  413. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  414. handleVectorChange(e.target.value as string);
  415. }}
  416. />
  417. {/* validation */}
  418. {!vectorValueValid && (
  419. <Typography variant="caption" className={classes.error}>
  420. {searchTrans('vectorValueWarning', {
  421. dimension:
  422. fieldType === DataTypeEnum.BinaryVector
  423. ? selectedFieldDimension / 8
  424. : selectedFieldDimension,
  425. })}
  426. </Typography>
  427. )}
  428. </fieldset>
  429. {/* search params selectors */}
  430. <fieldset className="field field-params">
  431. <Typography className="text">{searchTrans('thirdTip')}</Typography>
  432. <SearchParams
  433. wrapperClass={classes.paramsWrapper}
  434. metricType={selectedMetricType}
  435. embeddingType={
  436. embeddingType as
  437. | DataTypeEnum.BinaryVector
  438. | DataTypeEnum.FloatVector
  439. }
  440. indexType={indexType}
  441. indexParams={indexParams!}
  442. searchParamsForm={searchParam}
  443. handleFormChange={setSearchParam}
  444. handleMetricTypeChange={setSelectedMetricType}
  445. topK={topK}
  446. setParamsDisabled={setParamDisabled}
  447. />
  448. </fieldset>
  449. </form>
  450. {/**
  451. * search toolbar section
  452. * including topK selector, advanced filter, search and reset btn
  453. */}
  454. <section className={classes.toolbar}>
  455. <div className="left">
  456. <Typography variant="h5" className="text">
  457. {`${searchTrans('result')}: `}
  458. </Typography>
  459. {/* topK selector */}
  460. <SimpleMenu
  461. label={searchTrans('topK', { number: topK })}
  462. menuItems={TOP_K_OPTIONS.map(item => ({
  463. label: item.toString(),
  464. callback: () => {
  465. setTopK(item);
  466. if (!searchDisabled) {
  467. handleSearch(item);
  468. }
  469. },
  470. wrapperClass: classes.menuItem,
  471. }))}
  472. buttonProps={{
  473. className: classes.menuLabel,
  474. endIcon: <ArrowIcon />,
  475. }}
  476. menuItemWidth="108px"
  477. />
  478. <Filter
  479. title="Advanced Filter"
  480. fields={filterFields}
  481. filterDisabled={selectedField === '' || selectedCollection === ''}
  482. onSubmit={handleAdvancedFilterChange}
  483. />
  484. <CustomDatePicker
  485. label={timeTravelInfo.label}
  486. onChange={handleDateTimeChange}
  487. date={timeTravel}
  488. setDate={setTimeTravel}
  489. />
  490. </div>
  491. <div className="right">
  492. <CustomButton className="btn" onClick={handleReset}>
  493. <ResetIcon classes={{ root: 'icon' }} />
  494. {btnTrans('reset')}
  495. </CustomButton>
  496. <CustomButton
  497. variant="contained"
  498. disabled={searchDisabled}
  499. onClick={() => handleSearch(topK)}
  500. >
  501. {btnTrans('search')}
  502. </CustomButton>
  503. </div>
  504. </section>
  505. {/* search result table section */}
  506. {(searchResult && searchResult.length > 0) || tableLoading ? (
  507. <AttuGrid
  508. toolbarConfigs={[]}
  509. colDefinitions={colDefinitions}
  510. rows={result}
  511. rowCount={total}
  512. primaryKey="rank"
  513. page={currentPage}
  514. onChangePage={handlePageChange}
  515. rowsPerPage={pageSize}
  516. setRowsPerPage={handlePageSize}
  517. openCheckBox={false}
  518. isLoading={tableLoading}
  519. orderBy={orderBy}
  520. order={order}
  521. labelDisplayedRows={getLabelDisplayedRows(`(${latency} ms)`)}
  522. handleSort={handleGridSort}
  523. tableCellMaxWidth="100%"
  524. />
  525. ) : (
  526. <EmptyCard
  527. wrapperClass={`page-empty-card`}
  528. icon={<VectorSearchIcon />}
  529. text={
  530. searchResult !== null
  531. ? searchTrans('empty')
  532. : searchTrans('startTip')
  533. }
  534. />
  535. )}
  536. </section>
  537. );
  538. };
  539. export default VectorSearch;