VectorSearch.tsx 18 KB


  1. import { TextField, Typography, Button } from '@material-ui/core';
  2. import { useTranslation } from 'react-i18next';
  3. import { useNavigationHook } from '../../hooks/Navigation';
  4. import { ALL_ROUTER_TYPES } from '../../router/Types';
  5. import CustomSelector from '../../components/customSelector/CustomSelector';
  6. import { useCallback, useEffect, useMemo, useState } from 'react';
  7. import SearchParams from './SearchParams';
  8. import { DEFAULT_METRIC_VALUE_MAP } from '../../consts/Milvus';
  9. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  10. import AttuGrid from '../../components/grid/Grid';
  11. import EmptyCard from '../../components/cards/EmptyCard';
  12. import icons from '../../components/icons/Icons';
  13. import { usePaginationHook } from '../../hooks/Pagination';
  14. import CustomButton from '../../components/customButton/CustomButton';
  15. import SimpleMenu from '../../components/menu/SimpleMenu';
  16. import { TOP_K_OPTIONS } from './Constants';
  17. import { Option } from '../../components/customSelector/Types';
  18. import { CollectionHttp } from '../../http/Collection';
  19. import { CollectionData, DataTypeEnum } from '../collections/Types';
  20. import { IndexHttp } from '../../http/Index';
  21. import { getVectorSearchStyles } from './Styles';
  22. import { parseValue } from '../../utils/Insert';
  23. import {
  24. classifyFields,
  25. getDefaultIndexType,
  26. getEmbeddingType,
  27. getNonVectorFieldsForFilter,
  28. getVectorFieldOptions,
  29. transferSearchResult,
  30. } from '../../utils/search';
  31. import { ColDefinitionsType } from '../../components/grid/Types';
  32. import Filter from '../../components/advancedSearch';
  33. import { Field } from '../../components/advancedSearch/Types';
  34. import { useLocation } from 'react-router-dom';
  35. import { parseLocationSearch } from '../../utils/Format';
  36. import { cloneObj, generateVector } from '../../utils/Common';
  37. import { CustomDatePicker } from '../../components/customDatePicker/CustomDatePicker';
  38. import { useTimeTravelHook } from '../../hooks/TimeTravel';
  39. const VectorSearch = () => {
  40. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  41. const location = useLocation();
  42. // i18n
  43. const { t: searchTrans } = useTranslation('search');
  44. const { t: btnTrans } = useTranslation('btn');
  45. const classes = getVectorSearchStyles();
  46. // data stored inside the component
  47. const [tableLoading, setTableLoading] = useState<boolean>(false);
  48. const [collections, setCollections] = useState<CollectionData[]>([]);
  49. const [selectedCollection, setSelectedCollection] = useState<string>('');
  50. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  51. // fields for advanced filter
  52. const [filterFields, setFilterFields] = useState<Field[]>([]);
  53. const [selectedField, setSelectedField] = useState<string>('');
  54. // search params form
  55. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  56. {}
  57. );
  58. // search params disable state
  59. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  60. // use null as init value before search, empty array means no results
  61. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  62. null
  63. );
  64. // default topK is 100
  65. const [topK, setTopK] = useState<number>(100);
  66. const [expression, setExpression] = useState<string>('');
  67. const [vectors, setVectors] = useState<string>('');
  68. const {
  69. pageSize,
  70. handlePageSize,
  71. currentPage,
  72. handleCurrentPage,
  73. total,
  74. data: result,
  75. order,
  76. orderBy,
  77. handleGridSort,
  78. } = usePaginationHook(searchResult || []);
  79. const { timeTravel, setTimeTravel, timeTravelInfo, handleDateTimeChange } =
  80. useTimeTravelHook();
  81. const collectionOptions: Option[] = useMemo(
  82. () =>
  83. collections.map(c => ({
  84. label: c._name,
  85. value: c._name,
  86. })),
  87. [collections]
  88. );
  89. const outputFields: string[] = useMemo(() => {
  90. const fields =
  91. collections.find(c => c._name === selectedCollection)?._fields || [];
  92. // vector field can't be output fields
  93. const invalidTypes = ['BinaryVector', 'FloatVector'];
  94. const nonVectorFields = fields.filter(
  95. field => !invalidTypes.includes(field._fieldType)
  96. );
  97. return nonVectorFields.map(f => f._fieldName);
  98. }, [selectedCollection, collections]);
  99. const primaryKeyField = useMemo(() => {
  100. const selectedCollectionInfo = collections.find(
  101. c => c._name === selectedCollection
  102. );
  103. const fields = selectedCollectionInfo?._fields || [];
  104. return fields.find(f => f._isPrimaryKey)?._fieldName;
  105. }, [selectedCollection, collections]);
  106. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  107. /**
  108. * id represents primary key, score represents distance
  109. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  110. * we filter 'id' and 'score' to avoid redundant data
  111. */
  112. return searchResult && searchResult.length > 0
  113. ? Object.keys(searchResult[0])
  114. .filter(item => {
  115. // if primary key field name is id, don't filter it
  116. const invalidItems =
  117. primaryKeyField === 'id' ? ['score'] : ['id', 'score'];
  118. return !invalidItems.includes(item);
  119. })
  120. .map(key => ({
  121. id: key,
  122. align: 'left',
  123. disablePadding: false,
  124. label: key,
  125. }))
  126. : [];
  127. }, [searchResult, primaryKeyField]);
  128. const [selectedMetricType, setSelectedMetricType] = useState<string>('');
  129. const {
  130. indexType,
  131. indexParams,
  132. fieldType,
  133. embeddingType,
  134. selectedFieldDimension,
  135. } = useMemo(() => {
  136. if (selectedField !== '') {
  137. // field options must contain selected field, so selectedFieldInfo will never undefined
  138. const selectedFieldInfo = fieldOptions.find(
  139. f => f.value === selectedField
  140. );
  141. const index = selectedFieldInfo?.indexInfo;
  142. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  143. const metric =
  144. index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  145. const indexParams = index?._indexParameterPairs || [];
  146. const dim = selectedFieldInfo?.dimension || 0;
  147. setSelectedMetricType(metric);
  148. return {
  149. metricType: metric,
  150. indexType: index?._indexType || getDefaultIndexType(embeddingType),
  151. indexParams,
  152. fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
  153. embeddingType,
  154. selectedFieldDimension: dim,
  155. };
  156. }
  157. return {
  158. indexType: '',
  159. indexParams: [],
  160. fieldType: 0,
  161. embeddingType: DataTypeEnum.FloatVector,
  162. selectedFieldDimension: 0,
  163. };
  164. }, [selectedField, fieldOptions]);
  165. /**
  166. * vector value validation
  167. * @return whether is valid
  168. */
  169. const vectorValueValid = useMemo(() => {
  170. // if user hasn't input value or not select field, don't trigger validation check
  171. if (vectors === '' || selectedFieldDimension === 0) {
  172. return true;
  173. }
  174. const dim =
  175. fieldType === DataTypeEnum.BinaryVector
  176. ? selectedFieldDimension / 8
  177. : selectedFieldDimension;
  178. const value = parseValue(vectors);
  179. const isArray = Array.isArray(value);
  180. return isArray && value.length === dim;
  181. }, [vectors, selectedFieldDimension, fieldType]);
  182. const searchDisabled = useMemo(() => {
  183. /**
  184. * before search, user must:
  185. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  186. * 2. choose collection and field
  187. * 3. set extra search params
  188. */
  189. const isInvalid =
  190. vectors === '' ||
  191. selectedCollection === '' ||
  192. selectedField === '' ||
  193. paramDisabled ||
  194. !vectorValueValid;
  195. return isInvalid;
  196. }, [
  197. paramDisabled,
  198. selectedField,
  199. selectedCollection,
  200. vectors,
  201. vectorValueValid,
  202. ]);
  203. // fetch data
  204. const fetchCollections = useCallback(async () => {
  205. const collections = await CollectionHttp.getCollections();
  206. setCollections(collections);
  207. }, []);
  208. const fetchFieldsWithIndex = useCallback(
  209. async (collectionName: string, collections: CollectionData[]) => {
  210. const fields =
  211. collections.find(c => c._name === collectionName)?._fields || [];
  212. const indexes = await IndexHttp.getIndexInfo(collectionName);
  213. const { vectorFields, nonVectorFields } = classifyFields(fields);
  214. // only vector type fields can be select
  215. const fieldOptions = getVectorFieldOptions(vectorFields, indexes);
  216. setFieldOptions(fieldOptions);
  217. if (fieldOptions.length > 0) {
  218. // set first option value as default field value
  219. const [{ value: defaultFieldValue }] = fieldOptions;
  220. setSelectedField(defaultFieldValue as string);
  221. }
  222. // only non vector type fields can be advanced filter
  223. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  224. setFilterFields(filterFields);
  225. },
  226. []
  227. );
  228. useEffect(() => {
  229. fetchCollections();
  230. }, [fetchCollections]);
  231. // get field options with index when selected collection changed
  232. useEffect(() => {
  233. if (selectedCollection !== '') {
  234. fetchFieldsWithIndex(selectedCollection, collections);
  235. }
  236. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  237. // set default collection value if is from overview page
  238. useEffect(() => {
  239. if (location.search && collections.length > 0) {
  240. const { collectionName } = parseLocationSearch(location.search);
  241. // collection name validation
  242. const isNameValid = collections
  243. .map(c => c._name)
  244. .includes(collectionName);
  245. isNameValid && setSelectedCollection(collectionName);
  246. }
  247. }, [location, collections]);
  248. // icons
  249. const VectorSearchIcon = icons.vectorSearch;
  250. const ResetIcon = icons.refresh;
  251. const ArrowIcon = icons.dropdown;
  252. // methods
  253. const handlePageChange = (e: any, page: number) => {
  254. handleCurrentPage(page);
  255. };
  256. const handleReset = () => {
  257. /**
  258. * reset search includes:
  259. * 1. reset vectors
  260. * 2. reset selected collection and field
  261. * 3. reset search params
  262. * 4. reset advanced filter expression
  263. * 5. clear search result
  264. */
  265. setVectors('');
  266. setSelectedField('');
  267. setSelectedCollection('');
  268. setSearchResult(null);
  269. setFilterFields([]);
  270. setExpression('');
  271. setTimeTravel(null);
  272. };
  273. const handleSearch = async (topK: number, expr = expression) => {
  274. const clonedSearchParams = cloneObj(searchParam);
  275. delete clonedSearchParams.round_decimal;
  276. const searhParamPairs = {
  277. params: JSON.stringify(clonedSearchParams),
  278. anns_field: selectedField,
  279. topk: topK,
  280. metric_type: selectedMetricType,
  281. round_decimal: searchParam.round_decimal,
  282. };
  283. const params: VectorSearchParam = {
  284. output_fields: outputFields,
  285. expr,
  286. search_params: searhParamPairs,
  287. vectors: [parseValue(vectors)],
  288. vector_type: fieldType,
  289. travel_timestamp: timeTravelInfo.timestamp,
  290. };
  291. setTableLoading(true);
  292. try {
  293. const res = await CollectionHttp.vectorSearchData(
  294. selectedCollection,
  295. params
  296. );
  297. setTableLoading(false);
  298. const result = transferSearchResult(res.results);
  299. setSearchResult(result);
  300. } catch (err) {
  301. setTableLoading(false);
  302. }
  303. };
  304. const handleAdvancedFilterChange = (expression: string) => {
  305. setExpression(expression);
  306. if (!searchDisabled) {
  307. handleSearch(topK, expression);
  308. }
  309. };
  310. const handleVectorChange = (value: string) => {
  311. setVectors(value);
  312. };
  313. const fillWithExampleVector = (selectedFieldDimension: number) => {
  314. const v = generateVector(selectedFieldDimension);
  315. setVectors(`[${v}]`);
  316. };
  317. return (
  318. <section className="page-wrapper">
  319. {/* form section */}
  320. <form className={classes.form}>
  321. {/* collection and field selectors */}
  322. <fieldset className="field">
  323. <Typography className="text">{searchTrans('secondTip')}</Typography>
  324. <CustomSelector
  325. options={collectionOptions}
  326. wrapperClass={classes.selector}
  327. variant="filled"
  328. label={searchTrans(
  329. collectionOptions.length === 0 ? 'noCollection' : 'collection'
  330. )}
  331. disabled={collectionOptions.length === 0}
  332. value={selectedCollection}
  333. onChange={(e: { target: { value: unknown } }) => {
  334. const collection = e.target.value;
  335. setSelectedCollection(collection as string);
  336. // every time selected collection changed, reset field
  337. setSelectedField('');
  338. setSearchResult([]);
  339. }}
  340. />
  341. <CustomSelector
  342. options={fieldOptions}
  343. // readOnly can't avoid all events, so we use disabled instead
  344. disabled={selectedCollection === ''}
  345. wrapperClass={classes.selector}
  346. variant="filled"
  347. label={searchTrans('field')}
  348. value={selectedField}
  349. onChange={(e: { target: { value: unknown } }) => {
  350. const field = e.target.value;
  351. setSelectedField(field as string);
  352. }}
  353. />
  354. </fieldset>
  355. {/**
  356. * vector value textarea
  357. * use field-params class because it also has error msg if invalid
  358. */}
  359. <fieldset className="field field-params field-second">
  360. <Typography className="text">
  361. {searchTrans('firstTip', {
  362. dimensionTip:
  363. selectedFieldDimension !== 0
  364. ? `(dimension: ${selectedFieldDimension})`
  365. : '',
  366. })}
  367. {selectedFieldDimension !== 0 ? (
  368. <Button
  369. variant="outlined"
  370. size="small"
  371. onClick={() => {
  372. const dim =
  373. fieldType === DataTypeEnum.BinaryVector
  374. ? selectedFieldDimension / 8
  375. : selectedFieldDimension;
  376. fillWithExampleVector(dim);
  377. }}
  378. >
  379. {btnTrans('example')}
  380. </Button>
  381. ) : null}
  382. </Typography>
  383. <TextField
  384. className="textarea"
  385. InputProps={{
  386. classes: {
  387. root: 'textfield',
  388. multiline: 'multiline',
  389. },
  390. }}
  391. multiline
  392. rows={5}
  393. placeholder={searchTrans('vectorPlaceholder')}
  394. value={vectors}
  395. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  396. handleVectorChange(e.target.value as string);
  397. }}
  398. />
  399. {/* validation */}
  400. {!vectorValueValid && (
  401. <Typography variant="caption" className={classes.error}>
  402. {searchTrans('vectorValueWarning', {
  403. dimension:
  404. fieldType === DataTypeEnum.BinaryVector
  405. ? selectedFieldDimension / 8
  406. : selectedFieldDimension,
  407. })}
  408. </Typography>
  409. )}
  410. </fieldset>
  411. {/* search params selectors */}
  412. <fieldset className="field field-params">
  413. <Typography className="text">{searchTrans('thirdTip')}</Typography>
  414. <SearchParams
  415. wrapperClass={classes.paramsWrapper}
  416. metricType={selectedMetricType}
  417. embeddingType={
  418. embeddingType as
  419. | DataTypeEnum.BinaryVector
  420. | DataTypeEnum.FloatVector
  421. }
  422. indexType={indexType}
  423. indexParams={indexParams!}
  424. searchParamsForm={searchParam}
  425. handleFormChange={setSearchParam}
  426. handleMetricTypeChange={setSelectedMetricType}
  427. topK={topK}
  428. setParamsDisabled={setParamDisabled}
  429. />
  430. </fieldset>
  431. </form>
  432. {/**
  433. * search toolbar section
  434. * including topK selector, advanced filter, search and reset btn
  435. */}
  436. <section className={classes.toolbar}>
  437. <div className="left">
  438. <Typography variant="h5" className="text">
  439. {`${searchTrans('result')}: `}
  440. </Typography>
  441. {/* topK selector */}
  442. <SimpleMenu
  443. label={searchTrans('topK', { number: topK })}
  444. menuItems={TOP_K_OPTIONS.map(item => ({
  445. label: item.toString(),
  446. callback: () => {
  447. setTopK(item);
  448. if (!searchDisabled) {
  449. handleSearch(item);
  450. }
  451. },
  452. wrapperClass: classes.menuItem,
  453. }))}
  454. buttonProps={{
  455. className: classes.menuLabel,
  456. endIcon: <ArrowIcon />,
  457. }}
  458. menuItemWidth="108px"
  459. />
  460. <Filter
  461. title="Advanced Filter"
  462. fields={filterFields}
  463. filterDisabled={selectedField === '' || selectedCollection === ''}
  464. onSubmit={handleAdvancedFilterChange}
  465. />
  466. <CustomDatePicker
  467. label={timeTravelInfo.label}
  468. onChange={handleDateTimeChange}
  469. date={timeTravel}
  470. setDate={setTimeTravel}
  471. />
  472. </div>
  473. <div className="right">
  474. <CustomButton className="btn" onClick={handleReset}>
  475. <ResetIcon classes={{ root: 'icon' }} />
  476. {btnTrans('reset')}
  477. </CustomButton>
  478. <CustomButton
  479. variant="contained"
  480. disabled={searchDisabled}
  481. onClick={() => handleSearch(topK)}
  482. >
  483. {btnTrans('search')}
  484. </CustomButton>
  485. </div>
  486. </section>
  487. {/* search result table section */}
  488. {(searchResult && searchResult.length > 0) || tableLoading ? (
  489. <AttuGrid
  490. toolbarConfigs={[]}
  491. colDefinitions={colDefinitions}
  492. rows={result}
  493. rowCount={total}
  494. primaryKey="rank"
  495. page={currentPage}
  496. onChangePage={handlePageChange}
  497. rowsPerPage={pageSize}
  498. setRowsPerPage={handlePageSize}
  499. openCheckBox={false}
  500. isLoading={tableLoading}
  501. orderBy={orderBy}
  502. order={order}
  503. handleSort={handleGridSort}
  504. />
  505. ) : (
  506. <EmptyCard
  507. wrapperClass={`page-empty-card`}
  508. icon={<VectorSearchIcon />}
  509. text={
  510. searchResult !== null
  511. ? searchTrans('empty')
  512. : searchTrans('startTip')
  513. }
  514. />
  515. )}
  516. </section>
  517. );
  518. };
  519. export default VectorSearch;