code_interpreter.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import asyncio
  2. import json
  3. import logging
  4. import uuid
  5. from typing import Optional
  6. import aiohttp
  7. import websockets
  8. from pydantic import BaseModel
  9. from websockets import ClientConnection
  10. logger = logging.getLogger(__name__)
  11. class ResultModel(BaseModel):
  12. """
  13. Execute Code Result Model
  14. """
  15. stdout: Optional[str] = ""
  16. stderr: Optional[str] = ""
  17. result: Optional[str] = ""
  18. class JupyterCodeExecuter:
  19. """
  20. Execute code in jupyter notebook
  21. """
  22. def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
  23. """
  24. :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
  25. :param code: Code to execute
  26. :param token: Jupyter authentication token (optional)
  27. :param password: Jupyter password (optional)
  28. :param timeout: WebSocket timeout in seconds (default: 60s)
  29. """
  30. self.base_url = base_url.rstrip("/")
  31. self.code = code
  32. self.token = token
  33. self.password = password
  34. self.timeout = timeout
  35. self.kernel_id = ""
  36. self.session = aiohttp.ClientSession(base_url=self.base_url)
  37. self.params = {}
  38. self.result = ResultModel()
  39. async def __aenter__(self):
  40. return self
  41. async def __aexit__(self, exc_type, exc_val, exc_tb):
  42. if self.kernel_id:
  43. try:
  44. await self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params)
  45. except Exception as err:
  46. logger.exception("close kernel failed, %s", err)
  47. await self.session.close()
  48. async def run(self) -> ResultModel:
  49. try:
  50. await self.sign_in()
  51. await self.init_kernel()
  52. await self.execute_code()
  53. except Exception as err:
  54. logger.error(err)
  55. self.result.stderr = f"Error: {err}"
  56. return self.result
  57. async def sign_in(self) -> None:
  58. # password authentication
  59. if self.password and not self.token:
  60. async with self.session.get("/login") as response:
  61. response.raise_for_status()
  62. xsrf_token = response.cookies["_xsrf"].value
  63. if not xsrf_token:
  64. raise ValueError("_xsrf token not found")
  65. self.session.cookie_jar.update_cookies(response.cookies)
  66. self.session.headers.update({"X-XSRFToken": xsrf_token})
  67. async with self.session.post(
  68. "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
  69. ) as response:
  70. response.raise_for_status()
  71. self.session.cookie_jar.update_cookies(response.cookies)
  72. # token authentication
  73. if self.token:
  74. self.params.update({"token": self.token})
  75. async def init_kernel(self) -> None:
  76. async with self.session.post(url="/api/kernels", params=self.params) as response:
  77. response.raise_for_status()
  78. kernel_data = await response.json()
  79. self.kernel_id = kernel_data["id"]
  80. def init_ws(self) -> (str, dict):
  81. ws_base = self.base_url.replace("http", "ws")
  82. ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
  83. websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
  84. ws_headers = {}
  85. if self.password and not self.token:
  86. ws_headers = {
  87. "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
  88. **self.session.headers,
  89. }
  90. return websocket_url, ws_headers
  91. async def execute_code(self) -> None:
  92. # initialize ws
  93. websocket_url, ws_headers = self.init_ws()
  94. # execute
  95. async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
  96. await self.execute_in_jupyter(ws)
  97. async def execute_in_jupyter(self, ws: ClientConnection) -> None:
  98. # send message
  99. msg_id = uuid.uuid4().hex
  100. await ws.send(
  101. json.dumps(
  102. {
  103. "header": {
  104. "msg_id": msg_id,
  105. "msg_type": "execute_request",
  106. "username": "user",
  107. "session": uuid.uuid4().hex,
  108. "date": "",
  109. "version": "5.3",
  110. },
  111. "parent_header": {},
  112. "metadata": {},
  113. "content": {
  114. "code": self.code,
  115. "silent": False,
  116. "store_history": True,
  117. "user_expressions": {},
  118. "allow_stdin": False,
  119. "stop_on_error": True,
  120. },
  121. "channel": "shell",
  122. }
  123. )
  124. )
  125. # parse message
  126. stdout, stderr, result = "", "", []
  127. while True:
  128. try:
  129. # wait for message
  130. message = await asyncio.wait_for(ws.recv(), self.timeout)
  131. message_data = json.loads(message)
  132. # msg id not match, skip
  133. if message_data.get("parent_header", {}).get("msg_id") != msg_id:
  134. continue
  135. # check message type
  136. msg_type = message_data.get("msg_type")
  137. match msg_type:
  138. case "stream":
  139. if message_data["content"]["name"] == "stdout":
  140. stdout += message_data["content"]["text"]
  141. elif message_data["content"]["name"] == "stderr":
  142. stderr += message_data["content"]["text"]
  143. case "execute_result" | "display_data":
  144. data = message_data["content"]["data"]
  145. if "image/png" in data:
  146. result.append(f"data:image/png;base64,{data['image/png']}")
  147. elif "text/plain" in data:
  148. result.append(data["text/plain"])
  149. case "error":
  150. stderr += "\n".join(message_data["content"]["traceback"])
  151. case "status":
  152. if message_data["content"]["execution_state"] == "idle":
  153. break
  154. except asyncio.TimeoutError:
  155. stderr += "\nExecution timed out."
  156. break
  157. self.result.stdout = stdout.strip()
  158. self.result.stderr = stderr.strip()
  159. self.result.result = "\n".join(result).strip() if result else ""
  160. async def execute_code_jupyter(
  161. base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
  162. ) -> dict:
  163. async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
  164. result = await executor.run()
  165. return result.model_dump()