소스 검색

Merge pull request #11120 from OrenZhang/fix_jupyter

fix(jupyter): fix kernel_id not set and optimize code
Timothy Jaeryang Baek 5 달 전
부모
커밋
bb2bd7d721
1개의 변경된 파일170개의 추가작업 그리고 129개의 파일을 삭제
  1. 170 129
      backend/open_webui/utils/code_interpreter.py

+ 170 - 129
backend/open_webui/utils/code_interpreter.py

@@ -1,148 +1,189 @@
 import asyncio
 import json
+import logging
 import uuid
+from typing import Optional
+
+import aiohttp
 import websockets
-import requests
-from urllib.parse import urljoin
+from pydantic import BaseModel
+from websockets import ClientConnection
 
+from open_webui.env import SRC_LOG_LEVELS
 
-async def execute_code_jupyter(
-    jupyter_url, code, token=None, password=None, timeout=10
-):
+logger = logging.getLogger(__name__)
+logger.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+class ResultModel(BaseModel):
     """
-    Executes Python code in a Jupyter kernel.
-    Supports authentication with a token or password.
-    :param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
-    :param code: Code to execute
-    :param token: Jupyter authentication token (optional)
-    :param password: Jupyter password (optional)
-    :param timeout: WebSocket timeout in seconds (default: 10s)
-    :return: Dictionary with stdout, stderr, and result
-             - Images are prefixed with "base64:image/png," and separated by newlines if multiple.
+    Execute Code Result Model
     """
-    session = requests.Session()  # Maintain cookies
-    headers = {}  # Headers for requests
 
-    # Authenticate using password
-    if password and not token:
-        try:
-            login_url = urljoin(jupyter_url, "/login")
-            response = session.get(login_url)
-            response.raise_for_status()
-            xsrf_token = session.cookies.get("_xsrf")
-            if not xsrf_token:
-                raise ValueError("Failed to fetch _xsrf token")
+    stdout: Optional[str] = ""
+    stderr: Optional[str] = ""
+    result: Optional[str] = ""
 
-            login_data = {"_xsrf": xsrf_token, "password": password}
-            login_response = session.post(
-                login_url, data=login_data, cookies=session.cookies
-            )
-            login_response.raise_for_status()
-            headers["X-XSRFToken"] = xsrf_token
-        except Exception as e:
-            return {
-                "stdout": "",
-                "stderr": f"Authentication Error: {str(e)}",
-                "result": "",
-            }
-
-    # Construct API URLs with authentication token if provided
-    params = f"?token={token}" if token else ""
-    kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
 
-    try:
-        response = session.post(kernel_url, headers=headers, cookies=session.cookies)
-        response.raise_for_status()
-        kernel_id = response.json()["id"]
+class JupyterCodeExecuter:
+    """
+    Execute code in jupyter notebook
+    """
 
-        websocket_url = urljoin(
-            jupyter_url.replace("http", "ws"),
-            f"/api/kernels/{kernel_id}/channels{params}",
-        )
+    def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
+        """
+        :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
+        :param code: Code to execute
+        :param token: Jupyter authentication token (optional)
+        :param password: Jupyter password (optional)
+        :param timeout: WebSocket timeout in seconds (default: 60s)
+        """
+        self.base_url = base_url.rstrip("/")
+        self.code = code
+        self.token = token
+        self.password = password
+        self.timeout = timeout
+        self.kernel_id = ""
+        self.session = aiohttp.ClientSession(base_url=self.base_url)
+        self.params = {}
+        self.result = ResultModel()
+
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        if self.kernel_id:
+            try:
+                async with self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params) as response:
+                    response.raise_for_status()
+            except Exception as err:
+                logger.exception("close kernel failed, %s", err)
+        await self.session.close()
+
+    async def run(self) -> ResultModel:
+        try:
+            await self.sign_in()
+            await self.init_kernel()
+            await self.execute_code()
+        except Exception as err:
+            logger.exception("execute code failed, %s", err)
+            self.result.stderr = f"Error: {err}"
+        return self.result
+
+    async def sign_in(self) -> None:
+        # password authentication
+        if self.password and not self.token:
+            async with self.session.get("/login") as response:
+                response.raise_for_status()
+                xsrf_token = response.cookies["_xsrf"].value
+                if not xsrf_token:
+                    raise ValueError("_xsrf token not found")
+                self.session.cookie_jar.update_cookies(response.cookies)
+                self.session.headers.update({"X-XSRFToken": xsrf_token})
+            async with self.session.post(
+                "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
+            ) as response:
+                response.raise_for_status()
+                self.session.cookie_jar.update_cookies(response.cookies)
+
+        # token authentication
+        if self.token:
+            self.params.update({"token": self.token})
+
+    async def init_kernel(self) -> None:
+        async with self.session.post(url="/api/kernels", params=self.params) as response:
+            response.raise_for_status()
+            kernel_data = await response.json()
+            self.kernel_id = kernel_data["id"]
 
+    def init_ws(self) -> (str, dict):
+        ws_base = self.base_url.replace("http", "ws")
+        ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
+        websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
         ws_headers = {}
-        if password and not token:
-            ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
-            cookies = {name: value for name, value in session.cookies.items()}
-            ws_headers["Cookie"] = "; ".join(
-                [f"{name}={value}" for name, value in cookies.items()]
-            )
-
-        async with websockets.connect(
-            websocket_url, additional_headers=ws_headers
-        ) as ws:
-            msg_id = str(uuid.uuid4())
-            execute_request = {
-                "header": {
-                    "msg_id": msg_id,
-                    "msg_type": "execute_request",
-                    "username": "user",
-                    "session": str(uuid.uuid4()),
-                    "date": "",
-                    "version": "5.3",
-                },
-                "parent_header": {},
-                "metadata": {},
-                "content": {
-                    "code": code,
-                    "silent": False,
-                    "store_history": True,
-                    "user_expressions": {},
-                    "allow_stdin": False,
-                    "stop_on_error": True,
-                },
-                "channel": "shell",
+        if self.password and not self.token:
+            ws_headers = {
+                "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
+                **self.session.headers,
             }
-            await ws.send(json.dumps(execute_request))
-
-            stdout, stderr, result = "", "", []
-
-            while True:
-                try:
-                    message = await asyncio.wait_for(ws.recv(), timeout)
-                    message_data = json.loads(message)
-                    if message_data.get("parent_header", {}).get("msg_id") == msg_id:
-                        msg_type = message_data.get("msg_type")
-
-                        if msg_type == "stream":
-                            if message_data["content"]["name"] == "stdout":
-                                stdout += message_data["content"]["text"]
-                            elif message_data["content"]["name"] == "stderr":
-                                stderr += message_data["content"]["text"]
-
-                        elif msg_type in ("execute_result", "display_data"):
-                            data = message_data["content"]["data"]
-                            if "image/png" in data:
-                                result.append(
-                                    f"data:image/png;base64,{data['image/png']}"
-                                )
-                            elif "text/plain" in data:
-                                result.append(data["text/plain"])
-
-                        elif msg_type == "error":
-                            stderr += "\n".join(message_data["content"]["traceback"])
-
-                        elif (
-                            msg_type == "status"
-                            and message_data["content"]["execution_state"] == "idle"
-                        ):
+        return websocket_url, ws_headers
+
+    async def execute_code(self) -> None:
+        # initialize ws
+        websocket_url, ws_headers = self.init_ws()
+        # execute
+        async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
+            await self.execute_in_jupyter(ws)
+
+    async def execute_in_jupyter(self, ws: ClientConnection) -> None:
+        # send message
+        msg_id = uuid.uuid4().hex
+        await ws.send(
+            json.dumps(
+                {
+                    "header": {
+                        "msg_id": msg_id,
+                        "msg_type": "execute_request",
+                        "username": "user",
+                        "session": uuid.uuid4().hex,
+                        "date": "",
+                        "version": "5.3",
+                    },
+                    "parent_header": {},
+                    "metadata": {},
+                    "content": {
+                        "code": self.code,
+                        "silent": False,
+                        "store_history": True,
+                        "user_expressions": {},
+                        "allow_stdin": False,
+                        "stop_on_error": True,
+                    },
+                    "channel": "shell",
+                }
+            )
+        )
+        # parse message
+        stdout, stderr, result = "", "", []
+        while True:
+            try:
+                # wait for message
+                message = await asyncio.wait_for(ws.recv(), self.timeout)
+                message_data = json.loads(message)
+                # msg id not match, skip
+                if message_data.get("parent_header", {}).get("msg_id") != msg_id:
+                    continue
+                # check message type
+                msg_type = message_data.get("msg_type")
+                match msg_type:
+                    case "stream":
+                        if message_data["content"]["name"] == "stdout":
+                            stdout += message_data["content"]["text"]
+                        elif message_data["content"]["name"] == "stderr":
+                            stderr += message_data["content"]["text"]
+                    case "execute_result" | "display_data":
+                        data = message_data["content"]["data"]
+                        if "image/png" in data:
+                            result.append(f"data:image/png;base64,{data['image/png']}")
+                        elif "text/plain" in data:
+                            result.append(data["text/plain"])
+                    case "error":
+                        stderr += "\n".join(message_data["content"]["traceback"])
+                    case "status":
+                        if message_data["content"]["execution_state"] == "idle":
                             break
 
-                except asyncio.TimeoutError:
-                    stderr += "\nExecution timed out."
-                    break
+            except asyncio.TimeoutError:
+                stderr += "\nExecution timed out."
+                break
+        self.result.stdout = stdout.strip()
+        self.result.stderr = stderr.strip()
+        self.result.result = "\n".join(result).strip() if result else ""
 
-    except Exception as e:
-        return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
 
-    finally:
-        if kernel_id:
-            requests.delete(
-                f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
-            )
-
-    return {
-        "stdout": stdout.strip(),
-        "stderr": stderr.strip(),
-        "result": "\n".join(result).strip() if result else "",
-    }
+async def execute_code_jupyter(
+    base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
+) -> dict:
+    async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
+        result = await executor.run()
+        return result.model_dump()