VectorSearch.tsx 19 KB


  1. import { TextField, Typography, Button } from '@material-ui/core';
  2. import { useTranslation } from 'react-i18next';
  3. import { useNavigationHook } from '@/hooks/Navigation';
  4. import { ALL_ROUTER_TYPES } from '@/router/Types';
  5. import CustomSelector from '@/components/customSelector/CustomSelector';
  6. import { useCallback, useEffect, useMemo, useState } from 'react';
  7. import SearchParams from './SearchParams';
  8. import { DEFAULT_METRIC_VALUE_MAP } from '../../consts/Milvus';
  9. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  10. import AttuGrid from '@/components/grid/Grid';
  11. import EmptyCard from '@/components/cards/EmptyCard';
  12. import icons from '@/components/icons/Icons';
  13. import { usePaginationHook } from '@/hooks/Pagination';
  14. import CustomButton from '@/components/customButton/CustomButton';
  15. import SimpleMenu from '@/components/menu/SimpleMenu';
  16. import { TOP_K_OPTIONS } from './Constants';
  17. import { Option } from '@/components/customSelector/Types';
  18. import { CollectionHttp } from '@/http/Collection';
  19. import { CollectionData, DataTypeEnum } from '../collections/Types';
  20. import { IndexHttp } from '@/http/Index';
  21. import { getVectorSearchStyles } from './Styles';
  22. import { parseValue } from '@/utils/Insert';
  23. import {
  24. classifyFields,
  25. getDefaultIndexType,
  26. getEmbeddingType,
  27. getNonVectorFieldsForFilter,
  28. getVectorFieldOptions,
  29. } from '@/utils/search';
  30. import { ColDefinitionsType } from '@/components/grid/Types';
  31. import Filter from '@/components/advancedSearch';
  32. import { Field } from '@/components/advancedSearch/Types';
  33. import { useLocation } from 'react-router-dom';
  34. import { parseLocationSearch } from '@/utils/Format';
  35. import { cloneObj, generateVector } from '@/utils/Common';
  36. import { CustomDatePicker } from '@/components/customDatePicker/CustomDatePicker';
  37. import { useTimeTravelHook } from '@/hooks/TimeTravel';
  38. import { LOADING_STATE } from '@/consts/Milvus';
  39. import { getLabelDisplayedRows } from './Utils';
  40. import { useSearchResult } from '@/hooks/Result';
  41. const VectorSearch = () => {
  42. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  43. const location = useLocation();
  44. // i18n
  45. const { t: searchTrans } = useTranslation('search');
  46. const { t: btnTrans } = useTranslation('btn');
  47. const classes = getVectorSearchStyles();
  48. // data stored inside the component
  49. const [tableLoading, setTableLoading] = useState<boolean>(false);
  50. const [collections, setCollections] = useState<CollectionData[]>([]);
  51. const [selectedCollection, setSelectedCollection] = useState<string>('');
  52. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  53. // fields for advanced filter
  54. const [filterFields, setFilterFields] = useState<Field[]>([]);
  55. const [selectedField, setSelectedField] = useState<string>('');
  56. // search params form
  57. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  58. {}
  59. );
  60. // search params disable state
  61. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  62. // use null as init value before search, empty array means no results
  63. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  64. null
  65. );
  66. // default topK is 100
  67. const [topK, setTopK] = useState<number>(100);
  68. const [expression, setExpression] = useState<string>('');
  69. const [vectors, setVectors] = useState<string>('');
  70. // latency
  71. const [latency, setLatency] = useState<number>(0);
  72. const searchResultMemo = useSearchResult(searchResult as any, classes);
  73. const {
  74. pageSize,
  75. handlePageSize,
  76. currentPage,
  77. handleCurrentPage,
  78. total,
  79. data: result,
  80. order,
  81. orderBy,
  82. handleGridSort,
  83. } = usePaginationHook(searchResultMemo || []);
  84. const { timeTravel, setTimeTravel, timeTravelInfo, handleDateTimeChange } =
  85. useTimeTravelHook();
  86. const collectionOptions: Option[] = useMemo(
  87. () =>
  88. collections.map(c => ({
  89. label: c._name,
  90. value: c._name,
  91. })),
  92. [collections]
  93. );
  94. const outputFields: string[] = useMemo(() => {
  95. const fields =
  96. collections.find(c => c._name === selectedCollection)?._fields || [];
  97. // vector field can't be output fields
  98. const invalidTypes = ['BinaryVector', 'FloatVector'];
  99. const nonVectorFields = fields.filter(
  100. field => !invalidTypes.includes(field._fieldType)
  101. );
  102. return nonVectorFields.map(f => f._fieldName);
  103. }, [selectedCollection, collections]);
  104. const primaryKeyField = useMemo(() => {
  105. const selectedCollectionInfo = collections.find(
  106. c => c._name === selectedCollection
  107. );
  108. const fields = selectedCollectionInfo?._fields || [];
  109. return fields.find(f => f._isPrimaryKey)?._fieldName;
  110. }, [selectedCollection, collections]);
  111. const orderArray = [primaryKeyField, 'id', 'score', ...outputFields];
  112. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  113. /**
  114. * id represents primary key, score represents distance
  115. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  116. * we filter 'id' and 'score' to avoid redundant data
  117. */
  118. return searchResult && searchResult.length > 0
  119. ? Object.keys(searchResult[0])
  120. .sort((a, b) => {
  121. const indexA = orderArray.indexOf(a);
  122. const indexB = orderArray.indexOf(b);
  123. return indexA - indexB;
  124. })
  125. .filter(item => {
  126. // if primary key field name is id, don't filter it
  127. const invalidItems = primaryKeyField === 'id' ? [] : ['id'];
  128. return !invalidItems.includes(item);
  129. })
  130. .map(key => ({
  131. id: key,
  132. align: 'left',
  133. disablePadding: false,
  134. label: key,
  135. needCopy: primaryKeyField === key,
  136. }))
  137. : [];
  138. }, [searchResult, primaryKeyField, orderArray]);
  139. const [selectedMetricType, setSelectedMetricType] = useState<string>('');
  140. const {
  141. indexType,
  142. indexParams,
  143. fieldType,
  144. embeddingType,
  145. selectedFieldDimension,
  146. } = useMemo(() => {
  147. if (selectedField !== '') {
  148. // field options must contain selected field, so selectedFieldInfo will never undefined
  149. const selectedFieldInfo = fieldOptions.find(
  150. f => f.value === selectedField
  151. );
  152. const index = selectedFieldInfo?.indexInfo;
  153. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  154. const metric =
  155. index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  156. const indexParams = index?._indexParameterPairs || [];
  157. const dim = selectedFieldInfo?.dimension || 0;
  158. setSelectedMetricType(metric);
  159. return {
  160. metricType: metric,
  161. indexType: index?._indexType || getDefaultIndexType(embeddingType),
  162. indexParams,
  163. fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
  164. embeddingType,
  165. selectedFieldDimension: dim,
  166. };
  167. }
  168. return {
  169. indexType: '',
  170. indexParams: [],
  171. fieldType: 0,
  172. embeddingType: DataTypeEnum.FloatVector,
  173. selectedFieldDimension: 0,
  174. };
  175. }, [selectedField, fieldOptions]);
  176. /**
  177. * vector value validation
  178. * @return whether is valid
  179. */
  180. const vectorValueValid = useMemo(() => {
  181. // if user hasn't input value or not select field, don't trigger validation check
  182. if (vectors === '' || selectedFieldDimension === 0) {
  183. return true;
  184. }
  185. const dim =
  186. fieldType === DataTypeEnum.BinaryVector
  187. ? selectedFieldDimension / 8
  188. : selectedFieldDimension;
  189. const value = parseValue(vectors);
  190. const isArray = Array.isArray(value);
  191. return isArray && value.length === dim;
  192. }, [vectors, selectedFieldDimension, fieldType]);
  193. const searchDisabled = useMemo(() => {
  194. /**
  195. * before search, user must:
  196. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  197. * 2. choose collection and field
  198. * 3. set extra search params
  199. */
  200. const isInvalid =
  201. vectors === '' ||
  202. selectedCollection === '' ||
  203. selectedField === '' ||
  204. paramDisabled ||
  205. !vectorValueValid;
  206. return isInvalid;
  207. }, [
  208. paramDisabled,
  209. selectedField,
  210. selectedCollection,
  211. vectors,
  212. vectorValueValid,
  213. ]);
  214. // fetch data
  215. const fetchCollections = useCallback(async () => {
  216. const collections = await CollectionHttp.getCollections();
  217. setCollections(collections.filter(c => c._status === LOADING_STATE.LOADED));
  218. }, []);
  219. const fetchFieldsWithIndex = useCallback(
  220. async (collectionName: string, collections: CollectionData[]) => {
  221. const fields =
  222. collections.find(c => c._name === collectionName)?._fields || [];
  223. const indexes = await IndexHttp.getIndexInfo(collectionName);
  224. const { vectorFields, nonVectorFields } = classifyFields(fields);
  225. // only vector type fields can be select
  226. const fieldOptions = getVectorFieldOptions(vectorFields, indexes);
  227. setFieldOptions(fieldOptions);
  228. if (fieldOptions.length > 0) {
  229. // set first option value as default field value
  230. const [{ value: defaultFieldValue }] = fieldOptions;
  231. setSelectedField(defaultFieldValue as string);
  232. }
  233. // only non vector type fields can be advanced filter
  234. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  235. setFilterFields(filterFields);
  236. },
  237. []
  238. );
  239. useEffect(() => {
  240. fetchCollections();
  241. }, [fetchCollections]);
  242. // get field options with index when selected collection changed
  243. useEffect(() => {
  244. if (selectedCollection !== '') {
  245. fetchFieldsWithIndex(selectedCollection, collections);
  246. }
  247. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  248. // set default collection value if is from overview page
  249. useEffect(() => {
  250. if (location.search && collections.length > 0) {
  251. const { collectionName } = parseLocationSearch(location.search);
  252. // collection name validation
  253. const isNameValid = collections
  254. .map(c => c._name)
  255. .includes(collectionName);
  256. isNameValid && setSelectedCollection(collectionName);
  257. }
  258. }, [location, collections]);
  259. // icons
  260. const VectorSearchIcon = icons.vectorSearch;
  261. const ResetIcon = icons.refresh;
  262. const ArrowIcon = icons.dropdown;
  263. // methods
  264. const handlePageChange = (e: any, page: number) => {
  265. handleCurrentPage(page);
  266. };
  267. const handleReset = () => {
  268. /**
  269. * reset search includes:
  270. * 1. reset vectors
  271. * 2. reset selected collection and field
  272. * 3. reset search params
  273. * 4. reset advanced filter expression
  274. * 5. clear search result
  275. */
  276. setVectors('');
  277. setSelectedField('');
  278. setSelectedCollection('');
  279. setSearchResult(null);
  280. setFilterFields([]);
  281. setExpression('');
  282. setTimeTravel(null);
  283. };
  284. const handleSearch = async (topK: number, expr = expression) => {
  285. const clonedSearchParams = cloneObj(searchParam);
  286. delete clonedSearchParams.round_decimal;
  287. const searchParamPairs = {
  288. params: JSON.stringify(clonedSearchParams),
  289. anns_field: selectedField,
  290. topk: topK,
  291. metric_type: selectedMetricType,
  292. round_decimal: searchParam.round_decimal,
  293. };
  294. const params: VectorSearchParam = {
  295. output_fields: outputFields,
  296. expr,
  297. search_params: searchParamPairs,
  298. vectors: [parseValue(vectors)],
  299. vector_type: fieldType,
  300. travel_timestamp: timeTravelInfo.timestamp,
  301. };
  302. setTableLoading(true);
  303. try {
  304. const res = await CollectionHttp.vectorSearchData(
  305. selectedCollection,
  306. params
  307. );
  308. setTableLoading(false);
  309. setSearchResult(res.results);
  310. setLatency(res.latency);
  311. } catch (err) {
  312. setTableLoading(false);
  313. }
  314. };
  315. const handleAdvancedFilterChange = (expression: string) => {
  316. setExpression(expression);
  317. if (!searchDisabled) {
  318. handleSearch(topK, expression);
  319. }
  320. };
  321. const handleVectorChange = (value: string) => {
  322. setVectors(value);
  323. };
  324. const fillWithExampleVector = (selectedFieldDimension: number) => {
  325. const v = generateVector(selectedFieldDimension);
  326. setVectors(`[${v}]`);
  327. };
  328. return (
  329. <section className="page-wrapper">
  330. {/* form section */}
  331. <form className={classes.form}>
  332. {/* collection and field selectors */}
  333. <fieldset className="field">
  334. <Typography className="text">{searchTrans('secondTip')}</Typography>
  335. <CustomSelector
  336. options={collectionOptions}
  337. wrapperClass={classes.selector}
  338. variant="filled"
  339. label={searchTrans(
  340. collectionOptions.length === 0 ? 'noCollection' : 'collection'
  341. )}
  342. disabled={collectionOptions.length === 0}
  343. value={selectedCollection}
  344. onChange={(e: { target: { value: unknown } }) => {
  345. const collection = e.target.value;
  346. setSelectedCollection(collection as string);
  347. // every time selected collection changed, reset field
  348. setSelectedField('');
  349. setSearchResult([]);
  350. }}
  351. />
  352. <CustomSelector
  353. options={fieldOptions}
  354. // readOnly can't avoid all events, so we use disabled instead
  355. disabled={selectedCollection === ''}
  356. wrapperClass={classes.selector}
  357. variant="filled"
  358. label={searchTrans('field')}
  359. value={selectedField}
  360. onChange={(e: { target: { value: unknown } }) => {
  361. const field = e.target.value;
  362. setSelectedField(field as string);
  363. }}
  364. />
  365. </fieldset>
  366. {/**
  367. * vector value textarea
  368. * use field-params class because it also has error msg if invalid
  369. */}
  370. <fieldset className="field field-params field-second">
  371. <Typography className="text">
  372. {searchTrans('firstTip', {
  373. dimensionTip:
  374. selectedFieldDimension !== 0
  375. ? `(dimension: ${selectedFieldDimension})`
  376. : '',
  377. })}
  378. {selectedFieldDimension !== 0 ? (
  379. <Button
  380. variant="outlined"
  381. size="small"
  382. onClick={() => {
  383. const dim =
  384. fieldType === DataTypeEnum.BinaryVector
  385. ? selectedFieldDimension / 8
  386. : selectedFieldDimension;
  387. fillWithExampleVector(dim);
  388. }}
  389. >
  390. {btnTrans('example')}
  391. </Button>
  392. ) : null}
  393. </Typography>
  394. <TextField
  395. className="textarea"
  396. InputProps={{
  397. classes: {
  398. root: 'textfield',
  399. multiline: 'multiline',
  400. },
  401. }}
  402. multiline
  403. rows={5}
  404. placeholder={searchTrans('vectorPlaceholder')}
  405. value={vectors}
  406. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  407. handleVectorChange(e.target.value as string);
  408. }}
  409. />
  410. {/* validation */}
  411. {!vectorValueValid && (
  412. <Typography variant="caption" className={classes.error}>
  413. {searchTrans('vectorValueWarning', {
  414. dimension:
  415. fieldType === DataTypeEnum.BinaryVector
  416. ? selectedFieldDimension / 8
  417. : selectedFieldDimension,
  418. })}
  419. </Typography>
  420. )}
  421. </fieldset>
  422. {/* search params selectors */}
  423. <fieldset className="field field-params">
  424. <Typography className="text">{searchTrans('thirdTip')}</Typography>
  425. <SearchParams
  426. wrapperClass={classes.paramsWrapper}
  427. metricType={selectedMetricType}
  428. embeddingType={
  429. embeddingType as
  430. | DataTypeEnum.BinaryVector
  431. | DataTypeEnum.FloatVector
  432. }
  433. indexType={indexType}
  434. indexParams={indexParams!}
  435. searchParamsForm={searchParam}
  436. handleFormChange={setSearchParam}
  437. handleMetricTypeChange={setSelectedMetricType}
  438. topK={topK}
  439. setParamsDisabled={setParamDisabled}
  440. />
  441. </fieldset>
  442. </form>
  443. {/**
  444. * search toolbar section
  445. * including topK selector, advanced filter, search and reset btn
  446. */}
  447. <section className={classes.toolbar}>
  448. <div className="left">
  449. <Typography variant="h5" className="text">
  450. {`${searchTrans('result')}: `}
  451. </Typography>
  452. {/* topK selector */}
  453. <SimpleMenu
  454. label={searchTrans('topK', { number: topK })}
  455. menuItems={TOP_K_OPTIONS.map(item => ({
  456. label: item.toString(),
  457. callback: () => {
  458. setTopK(item);
  459. if (!searchDisabled) {
  460. handleSearch(item);
  461. }
  462. },
  463. wrapperClass: classes.menuItem,
  464. }))}
  465. buttonProps={{
  466. className: classes.menuLabel,
  467. endIcon: <ArrowIcon />,
  468. }}
  469. menuItemWidth="108px"
  470. />
  471. <Filter
  472. title="Advanced Filter"
  473. fields={filterFields}
  474. filterDisabled={selectedField === '' || selectedCollection === ''}
  475. onSubmit={handleAdvancedFilterChange}
  476. />
  477. <CustomDatePicker
  478. label={timeTravelInfo.label}
  479. onChange={handleDateTimeChange}
  480. date={timeTravel}
  481. setDate={setTimeTravel}
  482. />
  483. </div>
  484. <div className="right">
  485. <CustomButton className="btn" onClick={handleReset}>
  486. <ResetIcon classes={{ root: 'icon' }} />
  487. {btnTrans('reset')}
  488. </CustomButton>
  489. <CustomButton
  490. variant="contained"
  491. disabled={searchDisabled}
  492. onClick={() => handleSearch(topK)}
  493. >
  494. {btnTrans('search')}
  495. </CustomButton>
  496. </div>
  497. </section>
  498. {/* search result table section */}
  499. {(searchResult && searchResult.length > 0) || tableLoading ? (
  500. <AttuGrid
  501. toolbarConfigs={[]}
  502. colDefinitions={colDefinitions}
  503. rows={result}
  504. rowCount={total}
  505. primaryKey="rank"
  506. page={currentPage}
  507. onChangePage={handlePageChange}
  508. rowsPerPage={pageSize}
  509. setRowsPerPage={handlePageSize}
  510. openCheckBox={false}
  511. isLoading={tableLoading}
  512. orderBy={orderBy}
  513. order={order}
  514. labelDisplayedRows={getLabelDisplayedRows(`(${latency} ms)`)}
  515. handleSort={handleGridSort}
  516. tableCellMaxWidth="100%"
  517. />
  518. ) : (
  519. <EmptyCard
  520. wrapperClass={`page-empty-card`}
  521. icon={<VectorSearchIcon />}
  522. text={
  523. searchResult !== null
  524. ? searchTrans('empty')
  525. : searchTrans('startTip')
  526. }
  527. />
  528. )}
  529. </section>
  530. );
  531. };
  532. export default VectorSearch;