| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 | import loggingimport uuidimport jwtimport base64import hmacimport hashlibimport requestsimport osimport bcryptfrom cryptography.hazmat.primitives.ciphers.aead import AESGCMfrom cryptography.hazmat.primitives.asymmetric import ed25519from cryptography.hazmat.primitives import serializationimport jsonfrom datetime import datetime, timedeltaimport pytzfrom pytz import UTCfrom typing import Optional, Union, List, Dictfrom opentelemetry import tracefrom open_webui.models.users import Usersfrom open_webui.constants import ERROR_MESSAGESfrom open_webui.env import (    OFFLINE_MODE,    LICENSE_BLOB,    pk,    WEBUI_SECRET_KEY,    TRUSTED_SIGNATURE_KEY,    STATIC_DIR,    SRC_LOG_LEVELS,    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,)from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, statusfrom fastapi.security import HTTPAuthorizationCredentials, HTTPBearerlog = logging.getLogger(__name__)log.setLevel(SRC_LOG_LEVELS["OAUTH"])SESSION_SECRET = WEBUI_SECRET_KEYALGORITHM = "HS256"############### Auth Utils##############def verify_signature(payload: str, signature: str) -> bool:    """    Verifies the HMAC signature of the received payload.    """    try:        expected_signature = base64.b64encode(            hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()        ).decode()        # Compare securely to prevent timing attacks        return hmac.compare_digest(expected_signature, signature)    except Exception:        return Falsedef override_static(path: str, content: str):    # Ensure path is safe    if "/" in path or ".." in path:        log.error(f"Invalid path: {path}")        return    file_path = os.path.join(STATIC_DIR, path)    os.makedirs(os.path.dirname(file_path), exist_ok=True)    with open(file_path, "wb") as f:        f.write(base64.b64decode(content))  # Convert Base64 back to raw binarydef get_license_data(app, key):    def data_handler(data):        for k, v in data.items():            if k == "resources":                for p, c in v.items():                    globals().get("override_static", lambda a, b: None)(p, c)            elif k == "count":                setattr(app.state, "USER_COUNT", v)            elif k == "name":                setattr(app.state, "WEBUI_NAME", v)            elif k == "metadata":                setattr(app.state, "LICENSE_METADATA", v)    def handler(u):        res = requests.post(            f"{u}/api/v1/license/",            json={"key": key, "version": "1"},            timeout=5,        )        if getattr(res, "ok", False):            payload = getattr(res, "json", lambda: {})()            data_handler(payload)            return True        else:            log.error(                f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"            )    if key:        us = [            "https://api.openwebui.com",            "https://licenses.api.openwebui.com",        ]        try:            for u in us:                if handler(u):                    return True        except Exception as ex:            log.exception(f"License: Uncaught Exception: {ex}")    try:        if LICENSE_BLOB:            nl = 12            kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()            def nt(b):                return b[:nl], b[nl:]            lb = base64.b64decode(LICENSE_BLOB)            ln, lt = nt(lb)            aesgcm = AESGCM(kb)            p = json.loads(aesgcm.decrypt(ln, lt, None))            pk.verify(base64.b64decode(p["s"]), p["p"].encode())            pb = base64.b64decode(p["p"])            pn, pt = nt(pb)            data = json.loads(aesgcm.decrypt(pn, pt, None).decode())            if not data.get("exp") and data.get("exp") < datetime.now().date():                return False            data_handler(data)            return True    except Exception as e:        log.error(f"License: {e}")    return Falsebearer_security = HTTPBearer(auto_error=False)def get_password_hash(password: str) -> str:    """Hash a password using bcrypt"""    return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")def verify_password(plain_password: str, hashed_password: str) -> bool:    """Verify a password against its hash"""    return (        bcrypt.checkpw(            plain_password.encode("utf-8"),            hashed_password.encode("utf-8"),        )        if hashed_password        else None    )def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:    payload = data.copy()    if expires_delta:        expire = datetime.now(UTC) + expires_delta        payload.update({"exp": expire})    encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)    return encoded_jwtdef decode_token(token: str) -> Optional[dict]:    try:        decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])        return decoded    except Exception:        return Nonedef extract_token_from_auth_header(auth_header: str):    return auth_header[len("Bearer ") :]def create_api_key():    key = str(uuid.uuid4()).replace("-", "")    return f"sk-{key}"def get_http_authorization_cred(auth_header: Optional[str]):    if not auth_header:        return None    try:        scheme, credentials = auth_header.split(" ")        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)    except Exception:        return Nonedef get_current_user(    request: Request,    response: Response,    background_tasks: BackgroundTasks,    auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),):    token = None    if auth_token is not None:        token = auth_token.credentials    if token is None and "token" in request.cookies:        token = request.cookies.get("token")    if token is None:        raise HTTPException(status_code=401, detail="Not authenticated")    # auth by api key    if token.startswith("sk-"):        if not request.state.enable_api_key:            raise HTTPException(                status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED            )        if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:            allowed_paths = [                path.strip()                for path in str(                    request.app.state.config.API_KEY_ALLOWED_ENDPOINTS                ).split(",")            ]            # Check if the request path matches any allowed endpoint.            if not any(                request.url.path == allowed                or request.url.path.startswith(allowed + "/")                for allowed in allowed_paths            ):                raise HTTPException(                    status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED                )        user = get_current_user_by_api_key(token)        # Add user info to current span        current_span = trace.get_current_span()        if current_span:            current_span.set_attribute("client.user.id", user.id)            current_span.set_attribute("client.user.email", user.email)            current_span.set_attribute("client.user.role", user.role)            current_span.set_attribute("client.auth.type", "api_key")        return user    # auth by jwt token    try:        try:            data = decode_token(token)        except Exception as e:            raise HTTPException(                status_code=status.HTTP_401_UNAUTHORIZED,                detail="Invalid token",            )        if data is not None and "id" in data:            user = Users.get_user_by_id(data["id"])            if user is None:                raise HTTPException(                    status_code=status.HTTP_401_UNAUTHORIZED,                    detail=ERROR_MESSAGES.INVALID_TOKEN,                )            else:                if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:                    trusted_email = request.headers.get(                        WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""                    ).lower()                    if trusted_email and user.email != trusted_email:                        raise HTTPException(                            status_code=status.HTTP_401_UNAUTHORIZED,                            detail="User mismatch. Please sign in again.",                        )                # Add user info to current span                current_span = trace.get_current_span()                if current_span:                    current_span.set_attribute("client.user.id", user.id)                    current_span.set_attribute("client.user.email", user.email)                    current_span.set_attribute("client.user.role", user.role)                    current_span.set_attribute("client.auth.type", "jwt")                # Refresh the user's last active timestamp asynchronously                # to prevent blocking the request                if background_tasks:                    background_tasks.add_task(                        Users.update_user_last_active_by_id, user.id                    )            return user        else:            raise HTTPException(                status_code=status.HTTP_401_UNAUTHORIZED,                detail=ERROR_MESSAGES.UNAUTHORIZED,            )    except Exception as e:        # Delete the token cookie        if request.cookies.get("token"):            response.delete_cookie("token")        if request.cookies.get("oauth_id_token"):            response.delete_cookie("oauth_id_token")        # Delete OAuth session if present        if request.cookies.get("oauth_session_id"):            response.delete_cookie("oauth_session_id")        raise edef get_current_user_by_api_key(api_key: str):    user = Users.get_user_by_api_key(api_key)    if user is None:        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail=ERROR_MESSAGES.INVALID_TOKEN,        )    else:        # Add user info to current span        current_span = trace.get_current_span()        if current_span:            current_span.set_attribute("client.user.id", user.id)            current_span.set_attribute("client.user.email", user.email)            current_span.set_attribute("client.user.role", user.role)            current_span.set_attribute("client.auth.type", "api_key")        Users.update_user_last_active_by_id(user.id)    return userdef get_verified_user(user=Depends(get_current_user)):    if user.role not in {"user", "admin"}:        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,        )    return userdef get_admin_user(user=Depends(get_current_user)):    if user.role != "admin":        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,        )    return user
 |