SearchGlobalParams.tsx 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import { useCallback } from 'react';
  2. import { useTranslation } from 'react-i18next';
  3. import { Slider } from '@mui/material';
  4. import CustomInput from '@/components/customInput/CustomInput';
  5. import CustomSelector from '@/components/customSelector/CustomSelector';
  6. import CustomMultiSelector from '@/components/customSelector/CustomMultiSelector';
  7. import {
  8. DYNAMIC_FIELD,
  9. CONSISTENCY_LEVEL_OPTIONS,
  10. TOP_K_OPTIONS,
  11. RERANKER_OPTIONS,
  12. DataTypeStringEnum,
  13. } from '@/consts';
  14. import type { SearchParams, GlobalParams } from '../../types';
  15. export interface SearchGlobalProps {
  16. searchParams: SearchParams;
  17. handleFormChange: (form: GlobalParams) => void;
  18. onSlideChange: (field: string) => void;
  19. onSlideChangeCommitted: () => void;
  20. }
  21. const UNSPORTED_GROUPBY_TYPES = [
  22. DataTypeStringEnum.Double,
  23. DataTypeStringEnum.Float,
  24. DataTypeStringEnum.JSON,
  25. ];
  26. const SearchGlobalParams = (props: SearchGlobalProps) => {
  27. // props
  28. const {
  29. searchParams,
  30. handleFormChange,
  31. onSlideChange,
  32. onSlideChangeCommitted,
  33. } = props;
  34. // values
  35. const searchGlobalParams = searchParams.globalParams;
  36. const fields = searchParams.collection.schema?.scalarFields || [];
  37. const outputFields = [
  38. ...(searchParams.collection.schema?.scalarFields || []),
  39. ...(searchParams.collection.schema?.dynamicFields || []),
  40. ];
  41. const selectedCount = searchParams.searchParams.filter(
  42. sp => sp.selected
  43. ).length;
  44. const showReranker = selectedCount > 1;
  45. // translations
  46. const { t: warningTrans } = useTranslation('warning');
  47. const { t: commonTrans } = useTranslation();
  48. const { t: searchTrans } = useTranslation('search');
  49. const gridTrans = commonTrans('grid');
  50. // UI functions
  51. const handleInputChange = useCallback(
  52. <K extends keyof GlobalParams>(key: K, value: GlobalParams[K]) => {
  53. let form = { ...searchGlobalParams };
  54. if (value === '') {
  55. delete form[key];
  56. } else {
  57. form = { ...searchGlobalParams, [key]: value };
  58. }
  59. handleFormChange(form);
  60. },
  61. [handleFormChange, searchGlobalParams]
  62. );
  63. const groupByOptions = fields
  64. .filter(f => !UNSPORTED_GROUPBY_TYPES.includes(f.data_type as any))
  65. .map(f => {
  66. return {
  67. value: f.name,
  68. label: f.name,
  69. };
  70. });
  71. return (
  72. <>
  73. <CustomSelector
  74. options={TOP_K_OPTIONS}
  75. value={searchGlobalParams.topK}
  76. label={searchTrans('topK')}
  77. wrapperClass="selector"
  78. variant="filled"
  79. onChange={(e: { target: { value: unknown } }) => {
  80. const topK = e.target.value as number;
  81. handleInputChange('topK', topK);
  82. }}
  83. />
  84. <CustomSelector
  85. options={CONSISTENCY_LEVEL_OPTIONS}
  86. value={searchGlobalParams.consistency_level}
  87. label={searchTrans('consistency')}
  88. wrapperClass="selector"
  89. variant="filled"
  90. onChange={(e: { target: { value: unknown } }) => {
  91. const consistency = e.target.value as string;
  92. handleInputChange('consistency_level', consistency);
  93. }}
  94. />
  95. <CustomMultiSelector
  96. options={outputFields.map(f => {
  97. return {
  98. label:
  99. f.name === DYNAMIC_FIELD ? searchTrans('dynamicFields') : f.name,
  100. value: f.name,
  101. };
  102. })}
  103. values={searchGlobalParams.output_fields}
  104. renderValue={selected => (
  105. <span>{`${(selected as string[]).length} ${
  106. gridTrans[(selected as string[]).length > 1 ? 'fields' : 'field']
  107. }`}</span>
  108. )}
  109. label={searchTrans('outputFields')}
  110. wrapperClass="selector"
  111. variant="filled"
  112. onChange={(e: { target: { value: unknown } }) => {
  113. // add value to output fields if not exist, remove if exist
  114. const newOutputFields = [...searchGlobalParams.output_fields];
  115. const values = e.target.value as string[];
  116. const newFields = values.filter(
  117. v => !newOutputFields.includes(v as string)
  118. );
  119. const removeFields = newOutputFields.filter(
  120. v => !values.includes(v as string)
  121. );
  122. newOutputFields.push(...newFields);
  123. removeFields.forEach(f => {
  124. const index = newOutputFields.indexOf(f);
  125. newOutputFields.splice(index, 1);
  126. });
  127. // sort output fields by schema order
  128. newOutputFields.sort((a, b) => {
  129. const aIndex = outputFields.findIndex(f => f.name === a);
  130. const bIndex = outputFields.findIndex(f => f.name === b);
  131. return aIndex - bIndex;
  132. });
  133. handleInputChange('output_fields', newOutputFields);
  134. }}
  135. />
  136. {!showReranker && (
  137. <CustomSelector
  138. options={[{ label: '--', value: '' }, ...groupByOptions]}
  139. value={searchGlobalParams.group_by_field || ''}
  140. label={searchTrans('groupBy')}
  141. wrapperClass="selector"
  142. variant="filled"
  143. onChange={(e: { target: { value: unknown } }) => {
  144. const groupBy = e.target.value as string;
  145. handleInputChange('group_by_field', groupBy);
  146. }}
  147. />
  148. )}
  149. {showReranker && (
  150. <>
  151. <CustomSelector
  152. options={RERANKER_OPTIONS}
  153. value={
  154. searchGlobalParams.rerank
  155. ? searchGlobalParams.rerank
  156. : RERANKER_OPTIONS[0].value
  157. }
  158. label={searchTrans('rerank')}
  159. wrapperClass="selector"
  160. variant="filled"
  161. onChange={(e: { target: { value: unknown } }) => {
  162. const rerankerStr = e.target.value as 'rrf' | 'weighted';
  163. handleInputChange('rerank', rerankerStr);
  164. }}
  165. />
  166. {searchGlobalParams.rerank == 'rrf' && (
  167. <CustomInput
  168. type="text"
  169. textConfig={{
  170. type: 'number',
  171. label: 'K',
  172. key: 'k',
  173. onChange: value => {
  174. handleInputChange('rrfParams', { k: Number(value) });
  175. },
  176. variant: 'filled',
  177. placeholder: 'k',
  178. fullWidth: true,
  179. validations: [
  180. {
  181. rule: 'require',
  182. errorText: warningTrans('required', {
  183. name: 'k',
  184. }),
  185. },
  186. ],
  187. defaultValue: 60,
  188. value: searchGlobalParams.rrfParams!.k,
  189. }}
  190. checkValid={() => true}
  191. />
  192. )}
  193. {searchGlobalParams.rerank == 'weighted' &&
  194. searchParams.searchParams.map((s, index) => {
  195. if (s.selected) {
  196. return (
  197. <Slider
  198. key={s.anns_field}
  199. color="secondary"
  200. defaultValue={0.5}
  201. value={searchGlobalParams.weightedParams!.weights[index]}
  202. getAriaValueText={value => {
  203. return `${s.anns_field}'s weight: ${value}`;
  204. }}
  205. onChange={(event: Event, value: number | number[]) => {
  206. // update the selected field
  207. const weights = [
  208. ...searchGlobalParams.weightedParams!.weights,
  209. ];
  210. weights[index] = Number(value);
  211. handleInputChange('weightedParams', { weights: weights });
  212. // fire on change event
  213. onSlideChange(s.anns_field);
  214. }}
  215. onChangeCommitted={() => {
  216. onSlideChangeCommitted();
  217. }}
  218. aria-labelledby="weight-slider"
  219. valueLabelDisplay="auto"
  220. size="small"
  221. step={0.1}
  222. min={0}
  223. max={1}
  224. />
  225. );
  226. }
  227. })}
  228. </>
  229. )}
  230. </>
  231. );
  232. };
  233. export default SearchGlobalParams;