Преглед на файлове

refac/enh: model default filter/feature

Timothy Jaeryang Baek преди 3 седмици
родител
ревизия
9a55547827

+ 24 - 0
backend/open_webui/models/functions.py

@@ -37,6 +37,7 @@ class Function(Base):
 class FunctionMeta(BaseModel):
     description: Optional[str] = None
     manifest: Optional[dict] = {}
+    model_config = ConfigDict(extra="allow")
 
 
 class FunctionModel(BaseModel):
@@ -260,6 +261,29 @@ class FunctionsTable:
             except Exception:
                 return None
 
+    def update_function_metadata_by_id(
+        self, id: str, metadata: dict
+    ) -> Optional[FunctionModel]:
+        with get_db() as db:
+            try:
+                function = db.get(Function, id)
+
+                if function:
+                    if function.meta:
+                        function.meta = {**function.meta, **metadata}
+                    else:
+                        function.meta = metadata
+
+                    function.updated_at = int(time.time())
+                    db.commit()
+                    db.refresh(function)
+                    return self.get_function_by_id(id)
+                else:
+                    return None
+            except Exception as e:
+                log.exception(f"Error updating function metadata by id {id}: {e}")
+                return None
+
     def get_user_valves_by_id_and_user_id(
         self, id: str, user_id: str
     ) -> Optional[dict]:

+ 6 - 0
backend/open_webui/routers/functions.py

@@ -192,6 +192,9 @@ async def create_new_function(
             function_cache_dir = CACHE_DIR / "functions" / form_data.id
             function_cache_dir.mkdir(parents=True, exist_ok=True)
 
+            if function_type == "filter" and getattr(function_module, "toggle", None):
+                Functions.update_function_metadata_by_id(id, {"toggle": True})
+
             if function:
                 return function
             else:
@@ -308,6 +311,9 @@ async def update_function_by_id(
 
         function = Functions.update_function_by_id(id, updated)
 
+        if function_type == "filter" and getattr(function_module, "toggle", None):
+            Functions.update_function_metadata_by_id(id, {"toggle": True})
+
         if function:
             return function
         else:

+ 44 - 13
src/lib/components/chat/Chat.svelte

@@ -37,6 +37,7 @@
 		showArtifacts,
 		tools,
 		toolServers,
+		functions,
 		selectedFolder,
 		pinnedChats
 	} from '$lib/stores';
@@ -88,6 +89,7 @@
 	import Spinner from '../common/Spinner.svelte';
 	import Tooltip from '../common/Tooltip.svelte';
 	import Sidebar from '../icons/Sidebar.svelte';
+	import { getFunctions } from '$lib/apis/functions';
 
 	export let chatIdProp = '';
 
@@ -236,33 +238,62 @@
 	};
 
 	const resetInput = () => {
-		console.debug('resetInput');
-		setToolIds();
-
+		selectedToolIds = [];
 		selectedFilterIds = [];
 		webSearchEnabled = false;
 		imageGenerationEnabled = false;
 		codeInterpreterEnabled = false;
+
+		setDefaults();
 	};
 
-	const setToolIds = async () => {
+	const setDefaults = async () => {
 		if (!$tools) {
 			tools.set(await getTools(localStorage.token));
 		}
-
+		if (!$functions) {
+			functions.set(await getFunctions(localStorage.token));
+		}
 		if (selectedModels.length !== 1 && !atSelectedModel) {
 			return;
 		}
 
 		const model = atSelectedModel ?? $models.find((m) => m.id === selectedModels[0]);
-		if (model && model?.info?.meta?.toolIds) {
-			selectedToolIds = [
-				...new Set(
-					[...(model?.info?.meta?.toolIds ?? [])].filter((id) => $tools.find((t) => t.id === id))
-				)
-			];
-		} else {
-			selectedToolIds = [];
+		if (model) {
+			if (model?.info?.meta?.toolIds) {
+				selectedToolIds = [
+					...new Set(
+						[...(model?.info?.meta?.toolIds ?? [])].filter((id) => $tools.find((t) => t.id === id))
+					)
+				];
+			} else {
+				selectedToolIds = [];
+			}
+
+			if (model?.info?.meta?.defaultFilterIds) {
+				console.log('model.info.meta.defaultFilterIds', model.info.meta.defaultFilterIds);
+				selectedFilterIds = model.info.meta.defaultFilterIds;
+				console.log('selectedFilterIds', selectedFilterIds);
+			} else {
+				selectedFilterIds = [];
+			}
+
+			if (model?.info?.meta?.defaultFeatureIds) {
+				console.log('model.info.meta.defaultFeatureIds', model.info.meta.defaultFeatureIds);
+				imageGenerationEnabled = model.info.meta.defaultFeatureIds.includes('image_generation');
+				webSearchEnabled = model.info.meta.defaultFeatureIds.includes('web_search');
+				codeInterpreterEnabled = model.info.meta.defaultFeatureIds.includes('code_interpreter');
+
+				console.log({
+					imageGenerationEnabled,
+					webSearchEnabled,
+					codeInterpreterEnabled
+				});
+			} else {
+				imageGenerationEnabled = false;
+				webSearchEnabled = false;
+				codeInterpreterEnabled = false;
+			}
 		}
 	};
 

+ 10 - 14
src/lib/components/workspace/Models/ActionsSelector.svelte

@@ -22,18 +22,14 @@
 	});
 </script>
 
-<div>
-	<div class="flex w-full justify-between mb-1">
-		<div class=" self-center text-sm font-semibold">{$i18n.t('Actions')}</div>
-	</div>
-
-	<div class=" text-xs dark:text-gray-500">
-		{$i18n.t('To select actions here, add them to the "Functions" workspace first.')}
-	</div>
-
-	<div class="flex flex-col">
-		{#if actions.length > 0}
-			<div class=" flex items-center mt-2 flex-wrap">
+{#if actions.length > 0}
+	<div>
+		<div class="flex w-full justify-between mb-1">
+			<div class=" self-center text-sm font-semibold">{$i18n.t('Actions')}</div>
+		</div>
+
+		<div class="flex flex-col">
+			<div class=" flex items-center flex-wrap">
 				{#each Object.keys(_actions) as action, actionIdx}
 					<div class=" flex items-center gap-2 mr-3">
 						<div class="self-center flex items-center">
@@ -54,6 +50,6 @@
 					</div>
 				{/each}
 			</div>
-		{/if}
+		</div>
 	</div>
-</div>
+{/if}

+ 54 - 0
src/lib/components/workspace/Models/DefaultFeatures.svelte

@@ -0,0 +1,54 @@
+<script lang="ts">
+	import { getContext } from 'svelte';
+	import Checkbox from '$lib/components/common/Checkbox.svelte';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
+	import { marked } from 'marked';
+
+	const i18n = getContext('i18n');
+
+	const featureLabels = {
+		web_search: {
+			label: $i18n.t('Web Search'),
+			description: $i18n.t('Model can search the web for information')
+		},
+		image_generation: {
+			label: $i18n.t('Image Generation'),
+			description: $i18n.t('Model can generate images based on text prompts')
+		},
+		code_interpreter: {
+			label: $i18n.t('Code Interpreter'),
+			description: $i18n.t('Model can execute code and perform calculations')
+		}
+	};
+
+	export let availableFeatures = ['web_search', 'image_generation', 'code_interpreter'];
+	export let featureIds = [];
+</script>
+
+<div>
+	<div class="flex w-full justify-between mb-1">
+		<div class=" self-center text-sm font-semibold">{$i18n.t('Default Features')}</div>
+	</div>
+	<div class="flex items-center mt-2 flex-wrap">
+		{#each availableFeatures as feature}
+			<div class=" flex items-center gap-2 mr-3">
+				<Checkbox
+					state={featureIds.includes(feature) ? 'checked' : 'unchecked'}
+					on:change={(e) => {
+						if (e.detail === 'checked') {
+							featureIds = [...featureIds, feature];
+						} else {
+							featureIds = featureIds.filter((id) => id !== feature);
+						}
+					}}
+				/>
+
+				<div class=" py-0.5 text-sm capitalize">
+					<Tooltip content={marked.parse(featureLabels[feature].description)}>
+						{$i18n.t(featureLabels[feature].label)}
+					</Tooltip>
+				</div>
+			</div>
+		{/each}
+	</div>
+</div>

+ 62 - 0
src/lib/components/workspace/Models/DefaultFiltersSelector.svelte

@@ -0,0 +1,62 @@
+<script lang="ts">
+	import { getContext, onMount } from 'svelte';
+	import Checkbox from '$lib/components/common/Checkbox.svelte';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
+
+	const i18n = getContext('i18n');
+
+	export let filters = [];
+	export let selectedFilterIds = [];
+
+	let _filters = {};
+
+	onMount(() => {
+		_filters = filters.reduce((acc, filter) => {
+			acc[filter.id] = {
+				...filter,
+				selected: selectedFilterIds.includes(filter.id)
+			};
+
+			return acc;
+		}, {});
+	});
+</script>
+
+<div>
+	<div class="flex w-full justify-between mb-1">
+		<div class=" self-center text-sm font-semibold">{$i18n.t('Default Filters')}</div>
+	</div>
+
+	<div class="flex flex-col">
+		{#if filters.length > 0}
+			<div class=" flex items-center flex-wrap">
+				{#each Object.keys(_filters) as filter, filterIdx}
+					<div class=" flex items-center gap-2 mr-3">
+						<div class="self-center flex items-center">
+							<Checkbox
+								state={_filters[filter].is_global
+									? 'checked'
+									: _filters[filter].selected
+										? 'checked'
+										: 'unchecked'}
+								disabled={_filters[filter].is_global}
+								on:change={(e) => {
+									if (!_filters[filter].is_global) {
+										_filters[filter].selected = e.detail === 'checked';
+										selectedFilterIds = Object.keys(_filters).filter((t) => _filters[t].selected);
+									}
+								}}
+							/>
+						</div>
+
+						<div class=" py-0.5 text-sm w-full capitalize font-medium">
+							<Tooltip content={_filters[filter].meta.description}>
+								{_filters[filter].name}
+							</Tooltip>
+						</div>
+					</div>
+				{/each}
+			</div>
+		{/if}
+	</div>
+</div>

+ 11 - 15
src/lib/components/workspace/Models/FiltersSelector.svelte

@@ -22,19 +22,15 @@
 	});
 </script>
 
-<div>
-	<div class="flex w-full justify-between mb-1">
-		<div class=" self-center text-sm font-semibold">{$i18n.t('Filters')}</div>
-	</div>
-
-	<div class=" text-xs dark:text-gray-500">
-		{$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
-	</div>
-
-	<!-- TODO: Filer order matters -->
-	<div class="flex flex-col">
-		{#if filters.length > 0}
-			<div class=" flex items-center mt-2 flex-wrap">
+{#if filters.length > 0}
+	<div>
+		<div class="flex w-full justify-between mb-1">
+			<div class=" self-center text-sm font-semibold">{$i18n.t('Filters')}</div>
+		</div>
+
+		<!-- TODO: Filer order matters -->
+		<div class="flex flex-col">
+			<div class=" flex items-center flex-wrap">
 				{#each Object.keys(_filters) as filter, filterIdx}
 					<div class=" flex items-center gap-2 mr-3">
 						<div class="self-center flex items-center">
@@ -62,6 +58,6 @@
 					</div>
 				{/each}
 			</div>
-		{/if}
+		</div>
 	</div>
-</div>
+{/if}

+ 69 - 13
src/lib/components/workspace/Models/ModelEditor.svelte

@@ -1,8 +1,14 @@
 <script lang="ts">
+	import { toast } from 'svelte-sonner';
+
 	import { onMount, getContext, tick } from 'svelte';
 	import { models, tools, functions, knowledge as knowledgeCollections, user } from '$lib/stores';
 	import { WEBUI_BASE_URL } from '$lib/constants';
 
+	import { getTools } from '$lib/apis/tools';
+	import { getFunctions } from '$lib/apis/functions';
+	import { getKnowledgeBases } from '$lib/apis/knowledge';
+
 	import AdvancedParams from '$lib/components/chat/Settings/Advanced/AdvancedParams.svelte';
 	import Tags from '$lib/components/common/Tags.svelte';
 	import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte';
@@ -11,15 +17,11 @@
 	import ActionsSelector from '$lib/components/workspace/Models/ActionsSelector.svelte';
 	import Capabilities from '$lib/components/workspace/Models/Capabilities.svelte';
 	import Textarea from '$lib/components/common/Textarea.svelte';
-	import { getTools } from '$lib/apis/tools';
-	import { getFunctions } from '$lib/apis/functions';
-	import { getKnowledgeBases } from '$lib/apis/knowledge';
 	import AccessControl from '../common/AccessControl.svelte';
-	import { stringify } from 'postcss';
-	import { toast } from 'svelte-sonner';
 	import Spinner from '$lib/components/common/Spinner.svelte';
 	import XMark from '$lib/components/icons/XMark.svelte';
-	import { getNoteList } from '$lib/apis/notes';
+	import DefaultFiltersSelector from './DefaultFiltersSelector.svelte';
+	import DefaultFeatures from './DefaultFeatures.svelte';
 
 	const i18n = getContext('i18n');
 
@@ -79,6 +81,13 @@
 	let params = {
 		system: ''
 	};
+
+	let knowledge = [];
+	let toolIds = [];
+
+	let filterIds = [];
+	let defaultFilterIds = [];
+
 	let capabilities = {
 		vision: true,
 		file_upload: true,
@@ -89,12 +98,9 @@
 		status_updates: true,
 		usage: undefined
 	};
+	let defaultFeatureIds = [];
 
-	let knowledge = [];
-	let toolIds = [];
-	let filterIds = [];
 	let actionIds = [];
-
 	let accessControl = {};
 
 	const addUsage = (base_model_id) => {
@@ -172,6 +178,14 @@
 			}
 		}
 
+		if (defaultFilterIds.length > 0) {
+			info.meta.defaultFilterIds = defaultFilterIds;
+		} else {
+			if (info.meta.defaultFilterIds) {
+				delete info.meta.defaultFilterIds;
+			}
+		}
+
 		if (actionIds.length > 0) {
 			info.meta.actionIds = actionIds;
 		} else {
@@ -180,6 +194,14 @@
 			}
 		}
 
+		if (defaultFeatureIds.length > 0) {
+			info.meta.defaultFeatureIds = defaultFeatureIds;
+		} else {
+			if (info.meta.defaultFeatureIds) {
+				delete info.meta.defaultFeatureIds;
+			}
+		}
+
 		info.params.system = system.trim() === '' ? null : system;
 		info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null;
 		Object.keys(info.params).forEach((key) => {
@@ -236,9 +258,6 @@
 					)
 				: null;
 
-			toolIds = model?.meta?.toolIds ?? [];
-			filterIds = model?.meta?.filterIds ?? [];
-			actionIds = model?.meta?.actionIds ?? [];
 			knowledge = (model?.meta?.knowledge ?? []).map((item) => {
 				if (item?.collection_name && item?.type !== 'file') {
 					return {
@@ -257,7 +276,14 @@
 					return item;
 				}
 			});
+
+			toolIds = model?.meta?.toolIds ?? [];
+			filterIds = model?.meta?.filterIds ?? [];
+			defaultFilterIds = model?.meta?.defaultFilterIds ?? [];
+			actionIds = model?.meta?.actionIds ?? [];
+
 			capabilities = { ...capabilities, ...(model?.meta?.capabilities ?? {}) };
+			defaultFeatureIds = model?.meta?.defaultFeatureIds ?? [];
 
 			if ('access_control' in model) {
 				accessControl = model.access_control;
@@ -725,6 +751,21 @@
 						/>
 					</div>
 
+					{#if filterIds.length > 0}
+						{@const toggleableFilters = $functions.filter(
+							(func) => func.type === 'filter' && filterIds.includes(func.id) && func?.meta?.toggle
+						)}
+
+						{#if toggleableFilters.length > 0}
+							<div class="my-2">
+								<DefaultFiltersSelector
+									bind:selectedFilterIds={defaultFilterIds}
+									filters={toggleableFilters}
+								/>
+							</div>
+						{/if}
+					{/if}
+
 					<div class="my-2">
 						<ActionsSelector
 							bind:selectedActionIds={actionIds}
@@ -736,6 +777,21 @@
 						<Capabilities bind:capabilities />
 					</div>
 
+					{#if Object.keys(capabilities).filter((key) => capabilities[key]).length > 0}
+						{@const availableFeatures = Object.entries(capabilities)
+							.filter(
+								([key, value]) =>
+									value && ['web_search', 'code_interpreter', 'image_generation'].includes(key)
+							)
+							.map(([key, value]) => key)}
+
+						{#if availableFeatures.length > 0}
+							<div class="my-2">
+								<DefaultFeatures {availableFeatures} bind:featureIds={defaultFeatureIds} />
+							</div>
+						{/if}
+					{/if}
+
 					<div class="my-2 text-gray-300 dark:text-gray-700">
 						<div class="flex w-full justify-between mb-2">
 							<div class=" self-center text-sm font-semibold">{$i18n.t('JSON Preview')}</div>