Browse Source

enh: folder filter

Timothy Jaeryang Baek 2 months ago
parent
commit
3f7d3def02

+ 15 - 0
backend/open_webui/models/chats.py

@@ -6,6 +6,7 @@ from typing import Optional
 
 from open_webui.internal.db import Base, get_db
 from open_webui.models.tags import TagModel, Tag, Tags
+from open_webui.models.folders import Folders
 from open_webui.env import SRC_LOG_LEVELS
 
 from pydantic import BaseModel, ConfigDict
@@ -617,6 +618,17 @@ class ChatTable:
             if word.startswith("tag:")
         ]
 
+        # Extract folder names - handle spaces and case insensitivity
+        folders = Folders.search_folders_by_names(
+            user_id,
+            [
+                word.replace("folder:", "")
+                for word in search_text_words
+                if word.startswith("folder:")
+            ],
+        )
+        folder_ids = [folder.id for folder in folders]
+
         is_pinned = None
         if "pinned:true" in search_text_words:
             is_pinned = True
@@ -666,6 +678,9 @@ class ChatTable:
                 else:
                     query = query.filter(Chat.share_id.is_(None))
 
+            if folder_ids:
+                query = query.filter(Chat.folder_id.in_(folder_ids))
+
             query = query.order_by(Chat.updated_at.desc())
 
             # Check if the database dialect is either 'sqlite' or 'postgresql'

+ 58 - 6
backend/open_webui/models/folders.py

@@ -2,14 +2,14 @@ import logging
 import time
 import uuid
 from typing import Optional
+import re
 
-from open_webui.internal.db import Base, get_db
-from open_webui.models.chats import Chats
 
-from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
-from open_webui.utils.access_control import get_permissions
+from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
+
+from open_webui.internal.db import Base, get_db
+from open_webui.env import SRC_LOG_LEVELS
 
 
 log = logging.getLogger(__name__)
@@ -106,7 +106,7 @@ class FolderTable:
 
     def get_children_folders_by_id_and_user_id(
         self, id: str, user_id: str
-    ) -> Optional[FolderModel]:
+    ) -> Optional[list[FolderModel]]:
         try:
             with get_db() as db:
                 folders = []
@@ -283,5 +283,57 @@ class FolderTable:
             log.error(f"delete_folder: {e}")
             return []
 
+    def normalize_folder_name(self, name: str) -> str:
+        # Replace _ and space with a single space, lower case, collapse multiple spaces
+        name = re.sub(r"[\s_]+", " ", name)
+        return name.strip().lower()
+
+    def search_folders_by_names(
+        self, user_id: str, queries: list[str]
+    ) -> list[FolderModel]:
+        """
+        Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive.
+        """
+        normalized_queries = [self.normalize_folder_name(q) for q in queries]
+        if not normalized_queries:
+            return []
+
+        results = {}
+        with get_db() as db:
+            folders = db.query(Folder).filter_by(user_id=user_id).all()
+            for folder in folders:
+                if self.normalize_folder_name(folder.name) in normalized_queries:
+                    results[folder.id] = FolderModel.model_validate(folder)
+
+                    # get children folders
+                    children = self.get_children_folders_by_id_and_user_id(
+                        folder.id, user_id
+                    )
+                    for child in children:
+                        results[child.id] = child
+
+        # Return the results as a list
+        if not results:
+            return []
+        else:
+            results = list(results.values())
+            return results
+
+    def search_folders_by_name_contains(
+        self, user_id: str, query: str
+    ) -> list[FolderModel]:
+        """
+        Partial match: normalized name contains (as substring) the normalized query.
+        """
+        normalized_query = self.normalize_folder_name(query)
+        results = []
+        with get_db() as db:
+            folders = db.query(Folder).filter_by(user_id=user_id).all()
+            for folder in folders:
+                norm_name = self.normalize_folder_name(folder.name)
+                if normalized_query in norm_name:
+                    results.append(FolderModel.model_validate(folder))
+        return results
+
 
 Folders = FolderTable()

+ 11 - 3
src/lib/components/layout/SearchModal.svelte

@@ -29,8 +29,7 @@
 
 	let searchDebounceTimeout;
 
-	let selectedIdx = 0;
-
+	let selectedIdx = null;
 	let selectedChat = null;
 
 	let selectedModels = [''];
@@ -42,7 +41,12 @@
 	}
 
 	const loadChatPreview = async (selectedIdx) => {
-		if (!chatList || chatList.length === 0 || chatList[selectedIdx] === undefined) {
+		if (
+			!chatList ||
+			chatList.length === 0 ||
+			selectedIdx === null ||
+			chatList[selectedIdx] === undefined
+		) {
 			selectedChat = null;
 			messages = null;
 			history = null;
@@ -217,6 +221,10 @@
 				on:input={searchHandler}
 				placeholder={$i18n.t('Search')}
 				showClearButton={true}
+				onFocus={() => {
+					selectedIdx = null;
+					messages = null;
+				}}
 				onKeydown={(e) => {
 					console.log('e', e);
 

+ 2 - 0
src/lib/components/layout/Sidebar.svelte

@@ -10,6 +10,7 @@
 		showSettings,
 		chatId,
 		tags,
+		folders as _folders,
 		showSidebar,
 		showSearch,
 		mobile,
@@ -85,6 +86,7 @@
 			toast.error(`${error}`);
 			return [];
 		});
+		_folders.set(folderList);
 
 		folders = {};
 

+ 36 - 1
src/lib/components/layout/Sidebar/SearchInput.svelte

@@ -1,6 +1,6 @@
 <script lang="ts">
 	import { getAllTags } from '$lib/apis/chats';
-	import { tags } from '$lib/stores';
+	import { folders, tags } from '$lib/stores';
 	import { getContext, createEventDispatcher, onMount, onDestroy, tick } from 'svelte';
 	import { fade } from 'svelte/transition';
 	import Search from '$lib/components/icons/Search.svelte';
@@ -12,6 +12,8 @@
 	export let placeholder = '';
 	export let value = '';
 	export let showClearButton = false;
+
+	export let onFocus = () => {};
 	export let onKeydown = (e) => {};
 
 	let selectedIdx = 0;
@@ -25,6 +27,10 @@
 			name: 'tag:',
 			description: $i18n.t('search for tags')
 		},
+		{
+			name: 'folder:',
+			description: $i18n.t('search for folders')
+		},
 		{
 			name: 'pinned:',
 			description: $i18n.t('search for pinned chats')
@@ -88,6 +94,30 @@
 						type: 'tag'
 					};
 				});
+		} else if (lastWord.startsWith('folder:')) {
+			filteredItems = [...$folders]
+				.filter((folder) => {
+					const folderName = lastWord.slice(7);
+					if (folderName) {
+						const id = folder.name.replace(' ', '_').toLowerCase();
+						const folderId = folderName.replace(' ', '_').toLowerCase();
+
+						if (id !== folderId) {
+							return id.startsWith(folderId);
+						} else {
+							return false;
+						}
+					} else {
+						return true;
+					}
+				})
+				.map((folder) => {
+					return {
+						id: folder.name.replace(' ', '_').toLowerCase(),
+						name: folder.name,
+						type: 'folder'
+					};
+				});
 		} else if (lastWord.startsWith('pinned:')) {
 			filteredItems = [
 				{
@@ -163,6 +193,7 @@
 				dispatch('input');
 			}}
 			on:focus={() => {
+				onFocus();
 				hovering = false;
 				focused = true;
 				initTags();
@@ -211,6 +242,9 @@
 					selectedIdx = 0;
 				}
 
+				const item = document.querySelector(`[data-selected="true"]`);
+				item?.scrollIntoView({ block: 'center', inline: 'nearest', behavior: 'instant' });
+
 				if (!document.getElementById('search-options-container')) {
 					onKeydown(e);
 				}
@@ -257,6 +291,7 @@
 								itemIdx
 									? 'bg-gray-100 dark:bg-gray-900'
 									: ''}"
+								data-selected={selectedIdx === itemIdx}
 								id="search-item-{itemIdx}"
 								on:click|stopPropagation={async () => {
 									const words = value.split(' ');