VectorSearch.tsx 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. import { useCallback, useEffect, useMemo, useState, useContext } from 'react';
  2. import { Typography, Button, Card, CardContent } from '@material-ui/core';
  3. import { useTranslation } from 'react-i18next';
  4. import { useLocation } from 'react-router-dom';
  5. import { ALL_ROUTER_TYPES } from '@/router/Types';
  6. import { useNavigationHook, useSearchResult, usePaginationHook } from '@/hooks';
  7. import { dataContext } from '@/context';
  8. import { saveCsvAs } from '@/utils';
  9. import CustomSelector from '@/components/customSelector/CustomSelector';
  10. import { ColDefinitionsType } from '@/components/grid/Types';
  11. import AttuGrid from '@/components/grid/Grid';
  12. import EmptyCard from '@/components/cards/EmptyCard';
  13. import icons from '@/components/icons/Icons';
  14. import CustomButton from '@/components/customButton/CustomButton';
  15. import SimpleMenu from '@/components/menu/SimpleMenu';
  16. import { Option } from '@/components/customSelector/Types';
  17. import Filter from '@/components/advancedSearch';
  18. import { Field } from '@/components/advancedSearch/Types';
  19. import { Collection } from '@/http';
  20. import {
  21. parseValue,
  22. parseLocationSearch,
  23. classifyFields,
  24. getDefaultIndexType,
  25. getEmbeddingType,
  26. getNonVectorFieldsForFilter,
  27. getVectorFieldOptions,
  28. cloneObj,
  29. generateVector,
  30. } from '@/utils';
  31. import {
  32. LOADING_STATE,
  33. DEFAULT_METRIC_VALUE_MAP,
  34. DYNAMIC_FIELD,
  35. DataTypeEnum,
  36. } from '@/consts';
  37. import { getLabelDisplayedRows } from './Utils';
  38. import SearchParams from './SearchParams';
  39. import { getVectorSearchStyles } from './Styles';
  40. import { TOP_K_OPTIONS } from './Constants';
  41. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  42. const VectorSearch = () => {
  43. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  44. const location = useLocation();
  45. const { database } = useContext(dataContext);
  46. // i18n
  47. const { t: searchTrans } = useTranslation('search');
  48. const { t: btnTrans } = useTranslation('btn');
  49. const classes = getVectorSearchStyles();
  50. // data stored inside the component
  51. const [tableLoading, setTableLoading] = useState<boolean>(false);
  52. const [collections, setCollections] = useState<Collection[]>([]);
  53. const [selectedCollection, setSelectedCollection] = useState<string>('');
  54. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  55. // fields for advanced filter
  56. const [filterFields, setFilterFields] = useState<Field[]>([]);
  57. const [selectedField, setSelectedField] = useState<string>('');
  58. // search params form
  59. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  60. {}
  61. );
  62. // search params disable state
  63. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  64. // use null as init value before search, empty array means no results
  65. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  66. null
  67. );
  68. // default topK is 100
  69. const [topK, setTopK] = useState<number>(100);
  70. const [expression, setExpression] = useState<string>('');
  71. const [vectors, setVectors] = useState<string>('');
  72. // latency
  73. const [latency, setLatency] = useState<number>(0);
  74. const searchResultMemo = useSearchResult(searchResult as any);
  75. const {
  76. pageSize,
  77. handlePageSize,
  78. currentPage,
  79. handleCurrentPage,
  80. total,
  81. data: result,
  82. order,
  83. orderBy,
  84. handleGridSort,
  85. } = usePaginationHook(searchResultMemo || []);
  86. const collectionOptions: Option[] = useMemo(
  87. () =>
  88. collections.map(c => ({
  89. label: c.collectionName,
  90. value: c.collectionName,
  91. })),
  92. [collections]
  93. );
  94. const outputFields: string[] = useMemo(() => {
  95. const s = collections.find(c => c.collectionName === selectedCollection);
  96. if (!s) {
  97. return [];
  98. }
  99. const fields = s.fields || [];
  100. // vector field can't be output fields
  101. const invalidTypes = ['BinaryVector', 'FloatVector'];
  102. const nonVectorFields = fields.filter(
  103. field => !invalidTypes.includes(field.fieldType)
  104. );
  105. const _outputFields = nonVectorFields.map(f => f.name);
  106. if (s.enableDynamicField) {
  107. _outputFields.push(DYNAMIC_FIELD);
  108. }
  109. return _outputFields;
  110. }, [selectedCollection, collections]);
  111. const primaryKeyField = useMemo(() => {
  112. const selectedCollectionInfo = collections.find(
  113. c => c.collectionName === selectedCollection
  114. );
  115. const fields = selectedCollectionInfo?.fields || [];
  116. return fields.find(f => f.isPrimaryKey)?.name;
  117. }, [selectedCollection, collections]);
  118. const orderArray = [primaryKeyField, 'id', 'score', ...outputFields];
  119. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  120. /**
  121. * id represents primary key, score represents distance
  122. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  123. * we filter 'id' and 'score' to avoid redundant data
  124. */
  125. return searchResult && searchResult.length > 0
  126. ? Object.keys(searchResult[0])
  127. .sort((a, b) => {
  128. const indexA = orderArray.indexOf(a);
  129. const indexB = orderArray.indexOf(b);
  130. return indexA - indexB;
  131. })
  132. .filter(item => {
  133. // if primary key field name is id, don't filter it
  134. const invalidItems = primaryKeyField === 'id' ? [] : ['id'];
  135. return !invalidItems.includes(item);
  136. })
  137. .map(key => ({
  138. id: key,
  139. align: 'left',
  140. disablePadding: false,
  141. label: key === DYNAMIC_FIELD ? searchTrans('dynamicFields') : key,
  142. needCopy: key !== 'score',
  143. }))
  144. : [];
  145. }, [searchResult, primaryKeyField, orderArray]);
  146. const [selectedMetricType, setSelectedMetricType] = useState<string>('');
  147. const [selectedConsistencyLevel, setSelectedConsistencyLevel] =
  148. useState<string>('');
  149. const {
  150. indexType,
  151. indexParams,
  152. fieldType,
  153. embeddingType,
  154. selectedFieldDimension,
  155. } = useMemo(() => {
  156. if (selectedField !== '') {
  157. // field options must contain selected field, so selectedFieldInfo will never undefined
  158. const selectedFieldInfo = fieldOptions.find(
  159. f => f.value === selectedField
  160. );
  161. const index = selectedFieldInfo?.indexInfo;
  162. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  163. const metric =
  164. index?.metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  165. const indexParams = index?.indexParameterPairs || [];
  166. const dim = selectedFieldInfo?.dimension || 0;
  167. setSelectedMetricType(metric);
  168. return {
  169. metricType: metric,
  170. indexType: index?.indexType || getDefaultIndexType(embeddingType),
  171. indexParams,
  172. fieldType:
  173. DataTypeEnum[
  174. selectedFieldInfo?.fieldType! as keyof typeof DataTypeEnum
  175. ],
  176. embeddingType,
  177. selectedFieldDimension: dim,
  178. };
  179. }
  180. return {
  181. indexType: '',
  182. indexParams: [],
  183. fieldType: 0,
  184. embeddingType: DataTypeEnum.FloatVector,
  185. selectedFieldDimension: 0,
  186. };
  187. }, [selectedField, fieldOptions]);
  188. /**
  189. * vector value validation
  190. * @return whether is valid
  191. */
  192. const vectorValueValid = useMemo(() => {
  193. // if user hasn't input value or not select field, don't trigger validation check
  194. if (vectors === '' || selectedFieldDimension === 0) {
  195. return true;
  196. }
  197. const dim =
  198. fieldType === DataTypeEnum.BinaryVector
  199. ? selectedFieldDimension / 8
  200. : selectedFieldDimension;
  201. const value = parseValue(vectors);
  202. const isArray = Array.isArray(value);
  203. return isArray && value.length === dim;
  204. }, [vectors, selectedFieldDimension, fieldType]);
  205. const searchDisabled = useMemo(() => {
  206. /**
  207. * before search, user must:
  208. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  209. * 2. choose collection and field
  210. * 3. set extra search params
  211. */
  212. const isInvalid =
  213. vectors === '' ||
  214. selectedCollection === '' ||
  215. selectedField === '' ||
  216. paramDisabled ||
  217. !vectorValueValid;
  218. return isInvalid;
  219. }, [
  220. paramDisabled,
  221. selectedField,
  222. selectedCollection,
  223. vectors,
  224. vectorValueValid,
  225. ]);
  226. // fetch data
  227. const fetchCollections = useCallback(async () => {
  228. const collections = await Collection.getCollections();
  229. setCollections(collections.filter(c => c.status === LOADING_STATE.LOADED));
  230. }, [database]);
  231. const fetchFieldsWithIndex = useCallback(
  232. async (collectionName: string, collections: Collection[]) => {
  233. const col = collections.find(c => c.collectionName === collectionName);
  234. const fields = col?.fields ?? [];
  235. const { vectorFields, nonVectorFields } = classifyFields(fields);
  236. // only vector type fields can be select
  237. const fieldOptions = getVectorFieldOptions(
  238. vectorFields,
  239. col?.indexes ?? []
  240. );
  241. setFieldOptions(fieldOptions);
  242. if (fieldOptions.length > 0) {
  243. // set first option value as default field value
  244. const [{ value: defaultFieldValue }] = fieldOptions;
  245. setSelectedField(defaultFieldValue as string);
  246. }
  247. // only non vector type fields can be advanced filter
  248. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  249. setFilterFields(filterFields);
  250. },
  251. [collections]
  252. );
  253. useEffect(() => {
  254. fetchCollections();
  255. }, [fetchCollections]);
  256. // clear selection if database is changed
  257. useEffect(() => {
  258. setSelectedCollection('');
  259. }, [database]);
  260. // get field options with index when selected collection changed
  261. useEffect(() => {
  262. if (selectedCollection !== '') {
  263. fetchFieldsWithIndex(selectedCollection, collections);
  264. }
  265. const level = collections.find(c => c.collectionName == selectedCollection)
  266. ?.consistency_level!;
  267. setSelectedConsistencyLevel(level);
  268. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  269. // set default collection value if is from overview page
  270. useEffect(() => {
  271. if (location.search && collections.length > 0) {
  272. const { collectionName } = parseLocationSearch(location.search);
  273. // collection name validation
  274. const isNameValid = collections
  275. .map(c => c.collectionName)
  276. .includes(collectionName);
  277. isNameValid && setSelectedCollection(collectionName);
  278. }
  279. }, [location, collections]);
  280. // icons
  281. const VectorSearchIcon = icons.vectorSearch;
  282. const ResetIcon = icons.refresh;
  283. const ArrowIcon = icons.dropdown;
  284. const ExportIcon = icons.download;
  285. // methods
  286. const handlePageChange = (e: any, page: number) => {
  287. handleCurrentPage(page);
  288. };
  289. const handleReset = () => {
  290. /**
  291. * reset search includes:
  292. * 1. reset vectors
  293. * 2. reset selected collection and field
  294. * 3. reset search params
  295. * 4. reset advanced filter expression
  296. * 5. clear search result
  297. */
  298. setVectors('');
  299. setSelectedField('');
  300. setSelectedCollection('');
  301. setSearchResult(null);
  302. setFilterFields([]);
  303. setExpression('');
  304. };
  305. const handleSearch = async (topK: number, expr = expression) => {
  306. const clonedSearchParams = cloneObj(searchParam);
  307. delete clonedSearchParams.round_decimal;
  308. const searchParamPairs = {
  309. params: JSON.stringify(clonedSearchParams),
  310. anns_field: selectedField,
  311. topk: topK,
  312. metric_type: selectedMetricType,
  313. round_decimal: searchParam.round_decimal,
  314. };
  315. const params: VectorSearchParam = {
  316. output_fields: outputFields,
  317. expr,
  318. search_params: searchParamPairs,
  319. vectors: [parseValue(vectors)],
  320. vector_type: fieldType,
  321. consistency_level:
  322. selectedConsistencyLevel ||
  323. collections.find(c => c.collectionName == selectedCollection)
  324. ?.consistency_level!,
  325. };
  326. setTableLoading(true);
  327. try {
  328. const res = await Collection.vectorSearchData(selectedCollection, params);
  329. setTableLoading(false);
  330. setSearchResult(res.results);
  331. setLatency(res.latency);
  332. } catch (err) {
  333. setTableLoading(false);
  334. }
  335. };
  336. const handleAdvancedFilterChange = (expression: string) => {
  337. setExpression(expression);
  338. if (!searchDisabled) {
  339. handleSearch(topK, expression);
  340. }
  341. };
  342. const handleVectorChange = (value: string) => {
  343. setVectors(value);
  344. };
  345. const fillWithExampleVector = (selectedFieldDimension: number) => {
  346. const v = generateVector(selectedFieldDimension);
  347. setVectors(`[${v}]`);
  348. };
  349. return (
  350. <section className={`page-wrapper ${classes.pageContainer}`}>
  351. <Card className={classes.form}>
  352. <CardContent className={classes.s1}>
  353. <div className="wrapper">
  354. <CustomSelector
  355. options={collectionOptions}
  356. wrapperClass={classes.selector}
  357. variant="filled"
  358. label={searchTrans('collection')}
  359. disabled={collectionOptions.length === 0}
  360. value={selectedCollection}
  361. onChange={(e: { target: { value: unknown } }) => {
  362. const collection = e.target.value as string;
  363. setSelectedCollection(collection);
  364. // every time selected collection changed, reset field
  365. setSelectedField('');
  366. setSearchResult([]);
  367. }}
  368. />
  369. <CustomSelector
  370. options={fieldOptions}
  371. // readOnly can't avoid all events, so we use disabled instead
  372. disabled={selectedCollection === ''}
  373. wrapperClass={classes.selector}
  374. variant="filled"
  375. label={searchTrans('field')}
  376. value={selectedField}
  377. onChange={(e: { target: { value: unknown } }) => {
  378. const field = e.target.value as string;
  379. setSelectedField(field);
  380. }}
  381. />
  382. </div>
  383. </CardContent>
  384. <CardContent className={classes.s2}>
  385. <textarea
  386. className="textarea"
  387. placeholder={searchTrans('vectorPlaceholder')}
  388. value={vectors}
  389. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  390. handleVectorChange(e.target.value as string);
  391. }}
  392. />
  393. {/* validation */}
  394. {!vectorValueValid && (
  395. <Typography variant="caption" className={classes.error}>
  396. {searchTrans('vectorValueWarning', {
  397. dimension:
  398. fieldType === DataTypeEnum.BinaryVector
  399. ? selectedFieldDimension / 8
  400. : selectedFieldDimension,
  401. })}
  402. </Typography>
  403. )}
  404. {selectedFieldDimension !== 0 ? (
  405. <Button
  406. className={classes.exampleBtn}
  407. onClick={() => {
  408. const dim =
  409. fieldType === DataTypeEnum.BinaryVector
  410. ? selectedFieldDimension / 8
  411. : selectedFieldDimension;
  412. fillWithExampleVector(dim);
  413. }}
  414. >
  415. {btnTrans('example')}
  416. </Button>
  417. ) : null}
  418. <CustomButton
  419. variant="contained"
  420. disabled={searchDisabled}
  421. onClick={() => handleSearch(topK)}
  422. >
  423. {btnTrans('search')}
  424. </CustomButton>
  425. </CardContent>
  426. <CardContent className={classes.s3}>
  427. <Typography className="text">
  428. {searchTrans('thirdTip')} ({selectedMetricType})
  429. </Typography>
  430. <SearchParams
  431. wrapperClass={classes.paramsWrapper}
  432. metricType={selectedMetricType}
  433. embeddingType={
  434. embeddingType as
  435. | DataTypeEnum.BinaryVector
  436. | DataTypeEnum.FloatVector
  437. }
  438. consistency_level={selectedConsistencyLevel}
  439. handleConsistencyChange={(level: string) => {
  440. setSelectedConsistencyLevel(level);
  441. }}
  442. indexType={indexType}
  443. indexParams={indexParams!}
  444. searchParamsForm={searchParam}
  445. handleFormChange={setSearchParam}
  446. handleMetricTypeChange={setSelectedMetricType}
  447. topK={topK}
  448. setParamsDisabled={setParamDisabled}
  449. />
  450. </CardContent>
  451. </Card>
  452. <section className={classes.resultsWrapper}>
  453. <section className={classes.toolbar}>
  454. <div className="left">
  455. <Typography variant="h5" className="text">
  456. {`${searchTrans('result')}: `}
  457. </Typography>
  458. {/* topK selector */}
  459. <SimpleMenu
  460. label={searchTrans('topK', { number: topK })}
  461. menuItems={TOP_K_OPTIONS.map(item => ({
  462. label: item.toString(),
  463. callback: () => {
  464. setTopK(item);
  465. if (!searchDisabled) {
  466. handleSearch(item);
  467. }
  468. },
  469. wrapperClass: classes.menuItem,
  470. }))}
  471. buttonProps={{
  472. className: classes.menuLabel,
  473. endIcon: <ArrowIcon />,
  474. }}
  475. menuItemWidth="108px"
  476. />
  477. <Filter
  478. title="Advanced Filter"
  479. fields={filterFields}
  480. filterDisabled={selectedField === '' || selectedCollection === ''}
  481. onSubmit={handleAdvancedFilterChange}
  482. />
  483. <CustomButton
  484. className="btn"
  485. disabled={result.length === 0}
  486. onClick={() => {
  487. saveCsvAs(searchResult, `search_result_${selectedCollection}`);
  488. }}
  489. >
  490. <ExportIcon classes={{ root: 'icon' }} />
  491. {btnTrans('export')}
  492. </CustomButton>
  493. </div>
  494. <div className="right">
  495. <CustomButton className="btn" onClick={handleReset}>
  496. <ResetIcon classes={{ root: 'icon' }} />
  497. {btnTrans('reset')}
  498. </CustomButton>
  499. </div>
  500. </section>
  501. {/* search result table section */}
  502. {(searchResult && searchResult.length > 0) || tableLoading ? (
  503. <AttuGrid
  504. toolbarConfigs={[]}
  505. colDefinitions={colDefinitions}
  506. rows={result}
  507. rowCount={total}
  508. primaryKey="rank"
  509. page={currentPage}
  510. onPageChange={handlePageChange}
  511. rowsPerPage={pageSize}
  512. setRowsPerPage={handlePageSize}
  513. openCheckBox={false}
  514. isLoading={tableLoading}
  515. orderBy={orderBy}
  516. order={order}
  517. labelDisplayedRows={getLabelDisplayedRows(`(${latency} ms)`)}
  518. handleSort={handleGridSort}
  519. />
  520. ) : (
  521. <EmptyCard
  522. wrapperClass={`page-empty-card`}
  523. icon={<VectorSearchIcon />}
  524. text={
  525. searchResult !== null
  526. ? searchTrans('empty')
  527. : searchTrans('startTip')
  528. }
  529. />
  530. )}
  531. </section>
  532. </section>
  533. );
  534. };
  535. export default VectorSearch;