浏览代码

fix/refac: ollama api backward compatibility

Timothy Jaeryang Baek 3 月之前
父节点
当前提交
1c41e95ba6
共有 2 个文件被更改,包括 27 次插入12 次删除
  1. 26 11
      backend/open_webui/routers/ollama.py
  2. 1 1
      src/lib/apis/ollama/index.ts

+ 26 - 11
backend/open_webui/routers/ollama.py

@@ -636,7 +636,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
 
 
 
 
 class ModelNameForm(BaseModel):
 class ModelNameForm(BaseModel):
-    model: str
+    model: Optional[str] = None
     model_config = ConfigDict(
     model_config = ConfigDict(
         extra="allow",
         extra="allow",
     )
     )
@@ -648,7 +648,9 @@ async def unload_model(
     form_data: ModelNameForm,
     form_data: ModelNameForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
-    model_name = form_data.model
+    form_data = form_data.model_dump(exclude_none=True)
+    model_name = form_data.get("model", form_data.get("name"))
+
     if not model_name:
     if not model_name:
         raise HTTPException(
         raise HTTPException(
             status_code=400, detail="Missing name of the model to unload."
             status_code=400, detail="Missing name of the model to unload."
@@ -714,11 +716,14 @@ async def pull_model(
     url_idx: int = 0,
     url_idx: int = 0,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
+    form_data = form_data.model_dump(exclude_none=True)
+    form_data["model"] = form_data.get("model", form_data.get("name"))
+
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
     # Admin should be able to pull models from any source
     # Admin should be able to pull models from any source
-    payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
+    payload = {**form_data, "insecure": True}
 
 
     return await send_post_request(
     return await send_post_request(
         url=f"{url}/api/pull",
         url=f"{url}/api/pull",
@@ -870,16 +875,21 @@ async def delete_model(
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
+    form_data = form_data.model_dump(exclude_none=True)
+    form_data["model"] = form_data.get("model", form_data.get("name"))
+
+    model = form_data.get("model")
+
     if url_idx is None:
     if url_idx is None:
         await get_all_models(request, user=user)
         await get_all_models(request, user=user)
         models = request.app.state.OLLAMA_MODELS
         models = request.app.state.OLLAMA_MODELS
 
 
-        if form_data.model in models:
-            url_idx = models[form_data.model]["urls"][0]
+        if model in models:
+            url_idx = models[model]["urls"][0]
         else:
         else:
             raise HTTPException(
             raise HTTPException(
                 status_code=400,
                 status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
+                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
             )
             )
 
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -889,7 +899,7 @@ async def delete_model(
         r = requests.request(
         r = requests.request(
             method="DELETE",
             method="DELETE",
             url=f"{url}/api/delete",
             url=f"{url}/api/delete",
-            data=form_data.model_dump_json(exclude_none=True).encode(),
+            data=json.dumps(form_data).encode(),
             headers={
             headers={
                 "Content-Type": "application/json",
                 "Content-Type": "application/json",
                 **({"Authorization": f"Bearer {key}"} if key else {}),
                 **({"Authorization": f"Bearer {key}"} if key else {}),
@@ -931,16 +941,21 @@ async def delete_model(
 async def show_model_info(
 async def show_model_info(
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
 ):
 ):
+    form_data = form_data.model_dump(exclude_none=True)
+    form_data["model"] = form_data.get("model", form_data.get("name"))
+
     await get_all_models(request, user=user)
     await get_all_models(request, user=user)
     models = request.app.state.OLLAMA_MODELS
     models = request.app.state.OLLAMA_MODELS
 
 
-    if form_data.model not in models:
+    model = form_data.get("model")
+
+    if model not in models:
         raise HTTPException(
         raise HTTPException(
             status_code=400,
             status_code=400,
-            detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
+            detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
         )
         )
 
 
-    url_idx = random.choice(models[form_data.model]["urls"])
+    url_idx = random.choice(models[model]["urls"])
 
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
     key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -963,7 +978,7 @@ async def show_model_info(
                     else {}
                     else {}
                 ),
                 ),
             },
             },
-            data=form_data.model_dump_json(exclude_none=True).encode(),
+            data=json.dumps(form_data).encode(),
         )
         )
         r.raise_for_status()
         r.raise_for_status()
 
 

+ 1 - 1
src/lib/apis/ollama/index.ts

@@ -419,7 +419,7 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
 				Authorization: `Bearer ${token}`
 				Authorization: `Bearer ${token}`
 			},
 			},
 			body: JSON.stringify({
 			body: JSON.stringify({
-				name: tagName
+				model: tagName
 			})
 			})
 		}
 		}
 	)
 	)