Bläddra i källkod

Merge pull request #14703 from rragundez/code-interpreter-blacklist

feat: Blacklist modules from arbitrary code execution in code interpreter
Tim Jaeryang Baek 1 månad sedan
förälder
incheckning
47560d4d72
2 ändrade filer med 31 tillägg och 0 borttagningar
  1. 10 0
      backend/open_webui/config.py
  2. 21 0
      backend/open_webui/utils/middleware.py

+ 10 - 0
backend/open_webui/config.py

@@ -1857,6 +1857,16 @@ CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig(
     ),
 )
 
+CODE_INTERPRETER_BLACKLISTED_MODULES = PersistentConfig(
+    "CODE_INTERPRETER_BLACKLISTED_MODULES",
+    "code_interpreter.blacklisted_modules",
+    [
+        library.strip()
+        for library in os.environ.get("CODE_INTERPRETER_BLACKLISTED_MODULES", "").split(",")
+        if library.strip()
+    ],
+)
+
 
 DEFAULT_CODE_INTERPRETER_PROMPT = """
 #### Tools Available

+ 21 - 0
backend/open_webui/utils/middleware.py

@@ -3,6 +3,7 @@ import logging
 import sys
 import os
 import base64
+import textwrap
 
 import asyncio
 from aiocache import cached
@@ -91,6 +92,7 @@ from open_webui.config import (
     CACHE_DIR,
     DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     DEFAULT_CODE_INTERPRETER_PROMPT,
+    CODE_INTERPRETER_BLACKLISTED_MODULES,
 )
 from open_webui.env import (
     SRC_LOG_LEVELS,
@@ -2369,6 +2371,25 @@ async def process_chat_response(
                         try:
                             if content_blocks[-1]["attributes"].get("type") == "code":
                                 code = content_blocks[-1]["content"]
+                                if CODE_INTERPRETER_BLACKLISTED_MODULES:
+                                    blocking_code = textwrap.dedent(f"""
+                                        import builtins
+
+                                        BLACKLISTED_MODULES = {CODE_INTERPRETER_BLACKLISTED_MODULES}
+
+                                        _real_import = builtins.__import__
+                                        def restricted_import(name, globals=None, locals=None, fromlist=(), level=0):
+                                            if name.split('.')[0] in BLACKLISTED_MODULES:
+                                                importer_name = globals.get('__name__') if globals else None
+                                                if importer_name == '__main__':
+                                                    raise ImportError(
+                                                        f"Direct import of module {{name}} is restricted."
+                                                    )
+                                            return _real_import(name, globals, locals, fromlist, level)
+
+                                        builtins.__import__ = restricted_import
+                                    """)
+                                    code = blocking_code + "\n" + code
 
                                 if (
                                     request.app.state.config.CODE_INTERPRETER_ENGINE