Selaa lähdekoodia

Merge pull request #2742 from open-webui/dev

0.2.2
Timothy Jaeryang Baek 1 vuosi sitten
vanhempi
commit
0744806523

+ 11 - 0
CHANGELOG.md

@@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
+## [0.2.2] - 2024-06-02
+
+### Added
+
+- **🌊 Mermaid Rendering Support**: We've included support for Mermaid rendering. This allows you to create beautiful diagrams and flowcharts directly within Open WebUI.
+- **🔄 New Environment Variable 'RESET_CONFIG_ON_START'**: Introducing a new environment variable: `RESET_CONFIG_ON_START`. Set this variable to reset your configuration settings upon starting the application, making it easier to revert to default settings.
+
+### Fixed
+
+- **🔧 Pipelines Filter Issue**: We've addressed an issue with the pipelines where filters were not functioning as expected.
+
 ## [0.2.1] - 2024-06-02
 
 ### Added

+ 54 - 347
backend/apps/ollama/main.py

@@ -29,6 +29,8 @@ import time
 from urllib.parse import urlparse
 from typing import Optional, List, Union
 
+from starlette.background import BackgroundTask
+
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
@@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 
 
-REQUEST_POOL = []
-
-
 # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
 # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
 # least connections, or least response time for better resource utilization and performance optimization.
@@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
     return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
 
 
-@app.get("/cancel/{request_id}")
-async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
-    if user:
-        if request_id in REQUEST_POOL:
-            REQUEST_POOL.remove(request_id)
-        return True
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
-
-
 async def fetch_url(url):
     timeout = aiohttp.ClientTimeout(total=5)
     try:
@@ -154,6 +143,45 @@ async def fetch_url(url):
         return None
 
 
+async def cleanup_response(
+    response: Optional[aiohttp.ClientResponse],
+    session: Optional[aiohttp.ClientSession],
+):
+    if response:
+        response.close()
+    if session:
+        await session.close()
+
+
+async def post_streaming_url(url: str, payload: str):
+    r = None
+    try:
+        session = aiohttp.ClientSession()
+        r = await session.post(url, data=payload)
+        r.raise_for_status()
+
+        return StreamingResponse(
+            r.content,
+            status_code=r.status,
+            headers=dict(r.headers),
+            background=BackgroundTask(cleanup_response, response=r, session=session),
+        )
+    except Exception as e:
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = await r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status if r else 500,
+            detail=error_detail,
+        )
+
+
 def merge_models_lists(model_lists):
     merged_models = {}
 
@@ -313,65 +341,7 @@ async def pull_model(
     # Admin should be able to pull models from any source
     payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
 
-    def get_request():
-        nonlocal url
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    yield json.dumps({"id": request_id, "done": False}) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/pull",
-                data=json.dumps(payload),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
 
 
 class PushModelForm(BaseModel):
@@ -399,50 +369,9 @@ async def push_model(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.debug(f"url: {url}")
 
-    r = None
-
-    def get_request():
-        nonlocal url
-        nonlocal r
-        try:
-
-            def stream_content():
-                for chunk in r.iter_content(chunk_size=8192):
-                    yield chunk
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/push",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(
+        f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
+    )
 
 
 class CreateModelForm(BaseModel):
@@ -461,53 +390,9 @@ async def create_model(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
-    r = None
-
-    def get_request():
-        nonlocal url
-        nonlocal r
-        try:
-
-            def stream_content():
-                for chunk in r.iter_content(chunk_size=8192):
-                    yield chunk
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/create",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            log.debug(f"r: {r}")
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(
+        f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
+    )
 
 
 class CopyModelForm(BaseModel):
@@ -797,66 +682,9 @@ async def generate_completion(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
-    r = None
-
-    def get_request():
-        nonlocal form_data
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    if form_data.stream:
-                        yield json.dumps({"id": request_id, "done": False}) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/generate",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(
+        f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
+    )
 
 
 class ChatMessage(BaseModel):
@@ -1014,67 +842,7 @@ async def generate_chat_completion(
 
     print(payload)
 
-    r = None
-
-    def get_request():
-        nonlocal payload
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    if payload.get("stream", None):
-                        yield json.dumps({"id": request_id, "done": False}) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/chat",
-                data=json.dumps(payload),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            log.exception(e)
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
 
 
 # TODO: we should update this part once Ollama supports other types
@@ -1165,68 +933,7 @@ async def generate_openai_chat_completion(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
-    r = None
-
-    def get_request():
-        nonlocal payload
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    if payload.get("stream"):
-                        yield json.dumps(
-                            {"request_id": request_id, "done": False}
-                        ) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/v1/chat/completions",
-                data=json.dumps(payload),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
 
 
 @app.get("/v1/models")
@@ -1555,7 +1262,7 @@ async def deprecated_proxy(
                     if path == "generate":
                         data = json.loads(body.decode("utf-8"))
 
-                        if not ("stream" in data and data["stream"] == False):
+                        if data.get("stream", True):
                             yield json.dumps({"id": request_id, "done": False}) + "\n"
 
                     elif path == "chat":

+ 29 - 10
backend/apps/openai/main.py

@@ -9,6 +9,7 @@ import json
 import logging
 
 from pydantic import BaseModel
+from starlette.background import BackgroundTask
 
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
         return None
 
 
+async def cleanup_response(
+    response: Optional[aiohttp.ClientResponse],
+    session: Optional[aiohttp.ClientSession],
+):
+    if response:
+        response.close()
+    if session:
+        await session.close()
+
+
 def merge_models_lists(model_lists):
     log.debug(f"merge_models_lists {model_lists}")
     merged_list = []
@@ -447,40 +458,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     headers["Content-Type"] = "application/json"
 
     r = None
+    session = None
+    streaming = False
 
     try:
-        r = requests.request(
+        session = aiohttp.ClientSession()
+        r = await session.request(
             method=request.method,
             url=target_url,
             data=payload if payload else body,
             headers=headers,
-            stream=True,
         )
 
         r.raise_for_status()
 
         # Check if response is SSE
         if "text/event-stream" in r.headers.get("Content-Type", ""):
+            streaming = True
             return StreamingResponse(
-                r.iter_content(chunk_size=8192),
-                status_code=r.status_code,
+                r.content,
+                status_code=r.status,
                 headers=dict(r.headers),
+                background=BackgroundTask(
+                    cleanup_response, response=r, session=session
+                ),
             )
         else:
-            response_data = r.json()
+            response_data = await r.json()
             return response_data
     except Exception as e:
         log.exception(e)
         error_detail = "Open WebUI: Server Connection Error"
         if r is not None:
             try:
-                res = r.json()
+                res = await r.json()
                 print(res)
                 if "error" in res:
                     error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
             except:
                 error_detail = f"External: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500, detail=error_detail
-        )
+        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
+    finally:
+        if not streaming and session:
+            if r:
+                r.close()
+            await session.close()

+ 11 - 0
backend/config.py

@@ -180,6 +180,17 @@ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
 DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
 FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
 
+RESET_CONFIG_ON_START = (
+    os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
+)
+if RESET_CONFIG_ON_START:
+    try:
+        os.remove(f"{DATA_DIR}/config.json")
+        with open(f"{DATA_DIR}/config.json", "w") as f:
+            f.write("{}")
+    except:
+        pass
+
 try:
     CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
 except:

+ 10 - 4
backend/main.py

@@ -498,10 +498,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
     ]
     sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
 
-    model = app.state.MODELS[model_id]
+    print(model_id)
 
-    if "pipeline" in model:
-        sorted_filters = [model] + sorted_filters
+    if model_id in app.state.MODELS:
+        model = app.state.MODELS[model_id]
+        if "pipeline" in model:
+            sorted_filters = [model] + sorted_filters
 
     for filter in sorted_filters:
         r = None
@@ -550,7 +552,11 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
     responses = await get_openai_models(raw=True)
 
     print(responses)
-    urlIdxs = [idx for idx, response in enumerate(responses) if "pipelines" in response]
+    urlIdxs = [
+        idx
+        for idx, response in enumerate(responses)
+        if response != None and "pipelines" in response
+    ]
 
     return {
         "data": [

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 1054 - 3
package-lock.json


+ 2 - 1
package.json

@@ -1,6 +1,6 @@
 {
 	"name": "open-webui",
-	"version": "0.2.1",
+	"version": "0.2.2",
 	"private": true,
 	"scripts": {
 		"dev": "npm run pyodide:fetch && vite dev --host",
@@ -63,6 +63,7 @@
 		"js-sha256": "^0.10.1",
 		"katex": "^0.16.9",
 		"marked": "^9.1.0",
+		"mermaid": "^10.9.1",
 		"pyodide": "^0.26.0-alpha.4",
 		"sortablejs": "^1.15.2",
 		"svelte-sonner": "^0.3.19",

+ 3 - 22
src/lib/apis/ollama/index.ts

@@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
 	return [res, controller];
 };
 
-export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
-	let error = null;
-
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
-		method: 'GET',
-		headers: {
-			'Content-Type': 'text/event-stream',
-			Authorization: `Bearer ${token}`
-		}
-	}).catch((err) => {
-		error = err;
-		return null;
-	});
-
-	if (error) {
-		throw error;
-	}
-
-	return res;
-};
-
 export const createModel = async (token: string, tagName: string, content: string) => {
 	let error = null;
 
@@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
 
 export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => {
 	let error = null;
+	const controller = new AbortController();
 
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, {
+		signal: controller.signal,
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
@@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
 	if (error) {
 		throw error;
 	}
-	return res;
+	return [res, controller];
 };
 
 export const downloadModel = async (

+ 86 - 115
src/lib/components/chat/Chat.svelte

@@ -1,6 +1,7 @@
 <script lang="ts">
 	import { v4 as uuidv4 } from 'uuid';
 	import { toast } from 'svelte-sonner';
+	import mermaid from 'mermaid';
 
 	import { getContext, onMount, tick } from 'svelte';
 	import { goto } from '$app/navigation';
@@ -26,7 +27,7 @@
 		splitStream
 	} from '$lib/utils';
 
-	import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
+	import { generateChatCompletion } from '$lib/apis/ollama';
 	import {
 		addTagById,
 		createNewChat,
@@ -65,7 +66,6 @@
 	let autoScroll = true;
 	let processing = '';
 	let messagesContainerElement: HTMLDivElement;
-	let currentRequestId = null;
 
 	let showModelSelector = true;
 
@@ -130,10 +130,6 @@
 	//////////////////////////
 
 	const initNewChat = async () => {
-		if (currentRequestId !== null) {
-			await cancelOllamaRequest(localStorage.token, currentRequestId);
-			currentRequestId = null;
-		}
 		window.history.replaceState(history.state, '', `/`);
 		await chatId.set('');
 
@@ -251,6 +247,39 @@
 		}
 	};
 
+	const chatCompletedHandler = async (model, messages) => {
+		await mermaid.run({
+			querySelector: '.mermaid'
+		});
+
+		const res = await chatCompleted(localStorage.token, {
+			model: model.id,
+			messages: messages.map((m) => ({
+				id: m.id,
+				role: m.role,
+				content: m.content,
+				timestamp: m.timestamp
+			})),
+			chat_id: $chatId
+		}).catch((error) => {
+			console.error(error);
+			return null;
+		});
+
+		if (res !== null) {
+			// Update chat history with the new messages
+			for (const message of res.messages) {
+				history.messages[message.id] = {
+					...history.messages[message.id],
+					...(history.messages[message.id].content !== message.content
+						? { originalContent: history.messages[message.id].content }
+						: {}),
+					...message
+				};
+			}
+		}
+	};
+
 	//////////////////////////
 	// Ollama functions
 	//////////////////////////
@@ -616,39 +645,11 @@
 
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
-						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					} else {
 						const messages = createMessagesList(responseMessageId);
-						const res = await chatCompleted(localStorage.token, {
-							model: model,
-							messages: messages.map((m) => ({
-								id: m.id,
-								role: m.role,
-								content: m.content,
-								timestamp: m.timestamp
-							})),
-							chat_id: $chatId
-						}).catch((error) => {
-							console.error(error);
-							return null;
-						});
-
-						if (res !== null) {
-							// Update chat history with the new messages
-							for (const message of res.messages) {
-								history.messages[message.id] = {
-									...history.messages[message.id],
-									...(history.messages[message.id].content !== message.content
-										? { originalContent: history.messages[message.id].content }
-										: {}),
-									...message
-								};
-							}
-						}
+						await chatCompletedHandler(model, messages);
 					}
 
-					currentRequestId = null;
-
 					break;
 				}
 
@@ -669,63 +670,58 @@
 								throw data;
 							}
 
-							if ('id' in data) {
-								console.log(data);
-								currentRequestId = data.id;
-							} else {
-								if (data.done == false) {
-									if (responseMessage.content == '' && data.message.content == '\n') {
-										continue;
-									} else {
-										responseMessage.content += data.message.content;
-										messages = messages;
-									}
+							if (data.done == false) {
+								if (responseMessage.content == '' && data.message.content == '\n') {
+									continue;
 								} else {
-									responseMessage.done = true;
-
-									if (responseMessage.content == '') {
-										responseMessage.error = {
-											code: 400,
-											content: `Oops! No text generated from Ollama, Please try again.`
-										};
-									}
-
-									responseMessage.context = data.context ?? null;
-									responseMessage.info = {
-										total_duration: data.total_duration,
-										load_duration: data.load_duration,
-										sample_count: data.sample_count,
-										sample_duration: data.sample_duration,
-										prompt_eval_count: data.prompt_eval_count,
-										prompt_eval_duration: data.prompt_eval_duration,
-										eval_count: data.eval_count,
-										eval_duration: data.eval_duration
-									};
+									responseMessage.content += data.message.content;
 									messages = messages;
+								}
+							} else {
+								responseMessage.done = true;
 
-									if ($settings.notificationEnabled && !document.hasFocus()) {
-										const notification = new Notification(
-											selectedModelfile
-												? `${
-														selectedModelfile.title.charAt(0).toUpperCase() +
-														selectedModelfile.title.slice(1)
-												  }`
-												: `${model}`,
-											{
-												body: responseMessage.content,
-												icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
-											}
-										);
-									}
-
-									if ($settings.responseAutoCopy) {
-										copyToClipboard(responseMessage.content);
-									}
-
-									if ($settings.responseAutoPlayback) {
-										await tick();
-										document.getElementById(`speak-button-${responseMessage.id}`)?.click();
-									}
+								if (responseMessage.content == '') {
+									responseMessage.error = {
+										code: 400,
+										content: `Oops! No text generated from Ollama, Please try again.`
+									};
+								}
+
+								responseMessage.context = data.context ?? null;
+								responseMessage.info = {
+									total_duration: data.total_duration,
+									load_duration: data.load_duration,
+									sample_count: data.sample_count,
+									sample_duration: data.sample_duration,
+									prompt_eval_count: data.prompt_eval_count,
+									prompt_eval_duration: data.prompt_eval_duration,
+									eval_count: data.eval_count,
+									eval_duration: data.eval_duration
+								};
+								messages = messages;
+
+								if ($settings.notificationEnabled && !document.hasFocus()) {
+									const notification = new Notification(
+										selectedModelfile
+											? `${
+													selectedModelfile.title.charAt(0).toUpperCase() +
+													selectedModelfile.title.slice(1)
+											  }`
+											: `${model}`,
+										{
+											body: responseMessage.content,
+											icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
+										}
+									);
+								}
+
+								if ($settings.responseAutoCopy) {
+									copyToClipboard(responseMessage.content);
+								}
+
+								if ($settings.responseAutoPlayback) {
+									await tick();
+									document.getElementById(`speak-button-${responseMessage.id}`)?.click();
 								}
 							}
 						}
@@ -906,32 +902,7 @@
 						} else {
 							const messages = createMessagesList(responseMessageId);
 
-							const res = await chatCompleted(localStorage.token, {
-								model: model.id,
-								messages: messages.map((m) => ({
-									id: m.id,
-									role: m.role,
-									content: m.content,
-									timestamp: m.timestamp
-								})),
-								chat_id: $chatId
-							}).catch((error) => {
-								console.error(error);
-								return null;
-							});
-
-							if (res !== null) {
-								// Update chat history with the new messages
-								for (const message of res.messages) {
-									history.messages[message.id] = {
-										...history.messages[message.id],
-										...(history.messages[message.id].content !== message.content
-											? { originalContent: history.messages[message.id].content }
-											: {}),
-										...message
-									};
-								}
-							}
+							await chatCompletedHandler(model, messages);
 						}
 
 						break;

+ 1 - 2
src/lib/components/chat/Messages.svelte

@@ -1,8 +1,7 @@
 <script lang="ts">
 	import { v4 as uuidv4 } from 'uuid';
-
 	import { chats, config, settings, user as _user, mobile } from '$lib/stores';
-	import { tick, getContext } from 'svelte';
+	import { tick, getContext, onMount } from 'svelte';
 
 	import { toast } from 'svelte-sonner';
 	import { getChatList, updateChatById } from '$lib/apis/chats';

+ 14 - 5
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -5,6 +5,7 @@
 	import tippy from 'tippy.js';
 	import auto_render from 'katex/dist/contrib/auto-render.mjs';
 	import 'katex/dist/katex.min.css';
+	import mermaid from 'mermaid';
 
 	import { fade } from 'svelte/transition';
 	import { createEventDispatcher } from 'svelte';
@@ -343,6 +344,10 @@
 	onMount(async () => {
 		await tick();
 		renderStyling();
+
+		await mermaid.run({
+			querySelector: '.mermaid'
+		});
 	});
 </script>
 
@@ -458,11 +463,15 @@
 								<!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content -->
 								{#each tokens as token, tokenIdx}
 									{#if token.type === 'code'}
-										<CodeBlock
-											id={`${message.id}-${tokenIdx}`}
-											lang={token?.lang ?? ''}
-											code={revertSanitizedResponseContent(token?.text ?? '')}
-										/>
+										{#if token.lang === 'mermaid'}
+											<pre class="mermaid">{revertSanitizedResponseContent(token.text)}</pre>
+										{:else}
+											<CodeBlock
+												id={`${message.id}-${tokenIdx}`}
+												lang={token?.lang ?? ''}
+												code={revertSanitizedResponseContent(token?.text ?? '')}
+											/>
+										{/if}
 									{:else}
 										{@html marked.parse(token.raw, {
 											...defaults,

+ 21 - 21
src/lib/components/chat/ModelSelector/Selector.svelte

@@ -8,7 +8,7 @@
 	import Check from '$lib/components/icons/Check.svelte';
 	import Search from '$lib/components/icons/Search.svelte';
 
-	import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
+	import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
 
 	import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
 	import { toast } from 'svelte-sonner';
@@ -72,10 +72,12 @@
 			return;
 		}
 
-		const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
-			toast.error(error);
-			return null;
-		});
+		const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
+			(error) => {
+				toast.error(error);
+				return null;
+			}
+		);
 
 		if (res) {
 			const reader = res.body
@@ -83,6 +85,16 @@
 				.pipeThrough(splitStream('\n'))
 				.getReader();
 
+			MODEL_DOWNLOAD_POOL.set({
+				...$MODEL_DOWNLOAD_POOL,
+				[sanitizedModelTag]: {
+					...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
+					abortController: controller,
+					reader,
+					done: false
+				}
+			});
+
 			while (true) {
 				try {
 					const { value, done } = await reader.read();
@@ -101,19 +113,6 @@
 								throw data.detail;
 							}
 
-							if (data.id) {
-								MODEL_DOWNLOAD_POOL.set({
-									...$MODEL_DOWNLOAD_POOL,
-									[sanitizedModelTag]: {
-										...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
-										requestId: data.id,
-										reader,
-										done: false
-									}
-								});
-								console.log(data);
-							}
-
 							if (data.status) {
 								if (data.digest) {
 									let downloadProgress = 0;
@@ -181,11 +180,12 @@
 	});
 
 	const cancelModelPullHandler = async (model: string) => {
-		const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
+		const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
+		if (abortController) {
+			abortController.abort();
+		}
 		if (reader) {
 			await reader.cancel();
-
-			await cancelOllamaRequest(localStorage.token, requestId);
 			delete $MODEL_DOWNLOAD_POOL[model];
 			MODEL_DOWNLOAD_POOL.set({
 				...$MODEL_DOWNLOAD_POOL

+ 28 - 27
src/lib/components/chat/Settings/Models.svelte

@@ -8,7 +8,6 @@
 		getOllamaUrls,
 		getOllamaVersion,
 		pullModel,
-		cancelOllamaRequest,
 		uploadModel,
 		getOllamaConfig
 	} from '$lib/apis/ollama';
@@ -70,12 +69,14 @@
 			console.log(model);
 
 			updateModelId = model.id;
-			const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch(
-				(error) => {
-					toast.error(error);
-					return null;
-				}
-			);
+			const [res, controller] = await pullModel(
+				localStorage.token,
+				model.id,
+				selectedOllamaUrlIdx
+			).catch((error) => {
+				toast.error(error);
+				return null;
+			});
 
 			if (res) {
 				const reader = res.body
@@ -144,10 +145,12 @@
 			return;
 		}
 
-		const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
-			toast.error(error);
-			return null;
-		});
+		const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
+			(error) => {
+				toast.error(error);
+				return null;
+			}
+		);
 
 		if (res) {
 			const reader = res.body
@@ -155,6 +158,16 @@
 				.pipeThrough(splitStream('\n'))
 				.getReader();
 
+			MODEL_DOWNLOAD_POOL.set({
+				...$MODEL_DOWNLOAD_POOL,
+				[sanitizedModelTag]: {
+					...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
+					abortController: controller,
+					reader,
+					done: false
+				}
+			});
+
 			while (true) {
 				try {
 					const { value, done } = await reader.read();
@@ -173,19 +186,6 @@
 								throw data.detail;
 							}
 
-							if (data.id) {
-								MODEL_DOWNLOAD_POOL.set({
-									...$MODEL_DOWNLOAD_POOL,
-									[sanitizedModelTag]: {
-										...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
-										requestId: data.id,
-										reader,
-										done: false
-									}
-								});
-								console.log(data);
-							}
-
 							if (data.status) {
 								if (data.digest) {
 									let downloadProgress = 0;
@@ -419,11 +419,12 @@
 	};
 
 	const cancelModelPullHandler = async (model: string) => {
-		const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
+		const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
+		if (abortController) {
+			abortController.abort();
+		}
 		if (reader) {
 			await reader.cancel();
-
-			await cancelOllamaRequest(localStorage.token, requestId);
 			delete $MODEL_DOWNLOAD_POOL[model];
 			MODEL_DOWNLOAD_POOL.set({
 				...$MODEL_DOWNLOAD_POOL

+ 1 - 13
src/lib/components/workspace/Playground.svelte

@@ -8,7 +8,7 @@
 	import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
 	import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
 
-	import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
+	import { generateChatCompletion } from '$lib/apis/ollama';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 	import { splitStream } from '$lib/utils';
@@ -24,7 +24,6 @@
 	let selectedModelId = '';
 
 	let loading = false;
-	let currentRequestId = null;
 	let stopResponseFlag = false;
 
 	let messagesContainerElement: HTMLDivElement;
@@ -46,14 +45,6 @@
 		}
 	};
 
-	// const cancelHandler = async () => {
-	// 	if (currentRequestId) {
-	// 		const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
-	// 		currentRequestId = null;
-	// 		loading = false;
-	// 	}
-	// };
-
 	const stopResponse = () => {
 		stopResponseFlag = true;
 		console.log('stopResponse');
@@ -171,8 +162,6 @@
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
 					}
-
-					currentRequestId = null;
 					break;
 				}
 
@@ -229,7 +218,6 @@
 
 			loading = false;
 			stopResponseFlag = false;
-			currentRequestId = null;
 		}
 	};
 

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä