Timothy Jaeryang Baek 2 months ago
parent
commit
b8da4a8cd8

+ 0 - 6
backend/open_webui/config.py

@@ -779,12 +779,6 @@ if CUSTOM_NAME:
         pass
 
 
-####################################
-# LICENSE_KEY
-####################################
-
-LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
-
 ####################################
 # STORAGE PROVIDER
 ####################################

+ 29 - 0
backend/open_webui/env.py

@@ -7,6 +7,7 @@ import sys
 import shutil
 from uuid import uuid4
 from pathlib import Path
+from cryptography.hazmat.primitives import serialization
 
 import markdown
 from bs4 import BeautifulSoup
@@ -430,6 +431,34 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
     os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
 )
 
+
+####################################
+# LICENSE_KEY
+####################################
+
+LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
+
+LICENSE_BLOB = None
+LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data")
+if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH):
+    with open(LICENSE_BLOB_PATH, "rb") as f:
+        LICENSE_BLOB = f.read()
+
+LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "")
+
+pk = None
+if LICENSE_PUBLIC_KEY:
+    pk = serialization.load_pem_public_key(
+        f"""
+-----BEGIN PUBLIC KEY-----
+{LICENSE_PUBLIC_KEY}
+-----END PUBLIC KEY-----
+""".encode(
+            "utf-8"
+        )
+    )
+
+
 ####################################
 # MODELS
 ####################################

+ 1 - 1
backend/open_webui/main.py

@@ -102,7 +102,6 @@ from open_webui.models.users import UserModel, Users
 from open_webui.models.chats import Chats
 
 from open_webui.config import (
-    LICENSE_KEY,
     # Ollama
     ENABLE_OLLAMA_API,
     OLLAMA_BASE_URLS,
@@ -395,6 +394,7 @@ from open_webui.config import (
     reset_config,
 )
 from open_webui.env import (
+    LICENSE_KEY,
     AUDIT_EXCLUDED_PATHS,
     AUDIT_LOG_LEVEL,
     CHANGELOG,

+ 55 - 11
backend/open_webui/utils/auth.py

@@ -8,6 +8,12 @@ import requests
 import os
 
 
+from cryptography.hazmat.primitives.ciphers.aead import AESGCM
+from cryptography.hazmat.primitives.asymmetric import ed25519
+from cryptography.hazmat.primitives import serialization
+import json
+
+
 from datetime import datetime, timedelta
 import pytz
 from pytz import UTC
@@ -18,7 +24,11 @@ from opentelemetry import trace
 from open_webui.models.users import Users
 
 from open_webui.constants import ERROR_MESSAGES
+
 from open_webui.env import (
+    OFFLINE_MODE,
+    LICENSE_BLOB,
+    pk,
     WEBUI_SECRET_KEY,
     TRUSTED_SIGNATURE_KEY,
     STATIC_DIR,
@@ -74,6 +84,18 @@ def override_static(path: str, content: str):
 
 
 def get_license_data(app, key):
+    def data_handler(data):
+        for k, v in data.items():
+            if k == "resources":
+                for p, c in v.items():
+                    globals().get("override_static", lambda a, b: None)(p, c)
+            elif k == "count":
+                setattr(app.state, "USER_COUNT", v)
+            elif k == "name":
+                setattr(app.state, "WEBUI_NAME", v)
+            elif k == "metadata":
+                setattr(app.state, "LICENSE_METADATA", v)
+
     def handler(u):
         res = requests.post(
             f"{u}/api/v1/license/",
@@ -83,16 +105,7 @@ def get_license_data(app, key):
 
         if getattr(res, "ok", False):
             payload = getattr(res, "json", lambda: {})()
-            for k, v in payload.items():
-                if k == "resources":
-                    for p, c in v.items():
-                        globals().get("override_static", lambda a, b: None)(p, c)
-                elif k == "count":
-                    setattr(app.state, "USER_COUNT", v)
-                elif k == "name":
-                    setattr(app.state, "WEBUI_NAME", v)
-                elif k == "metadata":
-                    setattr(app.state, "LICENSE_METADATA", v)
+            data_handler(payload)
             return True
         else:
             log.error(
@@ -100,13 +113,44 @@ def get_license_data(app, key):
             )
 
     if key:
-        us = ["https://api.openwebui.com", "https://licenses.api.openwebui.com"]
+        us = [
+            "https://api.openwebui.com",
+            "https://licenses.api.openwebui.com",
+        ]
         try:
             for u in us:
                 if handler(u):
                     return True
         except Exception as ex:
             log.exception(f"License: Uncaught Exception: {ex}")
+
+    try:
+        if LICENSE_BLOB:
+            nl = 12
+            kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()
+
+            def nt(b):
+                return b[:nl], b[nl:]
+
+            lb = base64.b64decode(LICENSE_BLOB)
+            ln, lt = nt(lb)
+
+            aesgcm = AESGCM(kb)
+            p = json.loads(aesgcm.decrypt(ln, lt, None))
+            pk.verify(base64.b64decode(p["s"]), p["p"].encode())
+
+            pb = base64.b64decode(p["p"])
+            pn, pt = nt(pb)
+
+            data = json.loads(aesgcm.decrypt(pn, pt, None).decode())
+            if not data.get("exp") and data.get("exp") < datetime.now().date():
+                return False
+
+            data_handler(data)
+            return True
+    except Exception as e:
+        log.error(f"License: {e}")
+
     return False