| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 | import loggingimport uuidimport jwtimport base64import hmacimport hashlibimport requestsimport osfrom datetime import datetime, timedeltaimport pytzfrom pytz import UTCfrom typing import Optional, Union, List, Dictfrom opentelemetry import tracefrom open_webui.models.users import Usersfrom open_webui.constants import ERROR_MESSAGESfrom open_webui.env import (    WEBUI_SECRET_KEY,    TRUSTED_SIGNATURE_KEY,    STATIC_DIR,    SRC_LOG_LEVELS,)from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, statusfrom fastapi.security import HTTPAuthorizationCredentials, HTTPBearerfrom passlib.context import CryptContextlogging.getLogger("passlib").setLevel(logging.ERROR)log = logging.getLogger(__name__)log.setLevel(SRC_LOG_LEVELS["OAUTH"])SESSION_SECRET = WEBUI_SECRET_KEYALGORITHM = "HS256"############### Auth Utils##############def verify_signature(payload: str, signature: str) -> bool:    """    Verifies the HMAC signature of the received payload.    """    try:        expected_signature = base64.b64encode(            hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()        ).decode()        # Compare securely to prevent timing attacks        return hmac.compare_digest(expected_signature, signature)    except Exception:        return Falsedef override_static(path: str, content: str):    # Ensure path is safe    if "/" in path or ".." in path:        log.error(f"Invalid path: {path}")        return    file_path = os.path.join(STATIC_DIR, path)    os.makedirs(os.path.dirname(file_path), exist_ok=True)    with open(file_path, "wb") as f:        f.write(base64.b64decode(content))  # Convert Base64 back to raw binarydef get_license_data(app, key):    if key:        try:            res = requests.post(                "https://api.openwebui.com/api/v1/license/",                json={"key": key, "version": "1"},                timeout=5,            )            if getattr(res, "ok", False):                payload = getattr(res, "json", lambda: {})()                for k, v in payload.items():                    if k == "resources":                        for p, c in v.items():                            globals().get("override_static", lambda a, b: None)(p, c)                    elif k == "count":                        setattr(app.state, "USER_COUNT", v)                    elif k == "name":                        setattr(app.state, "WEBUI_NAME", v)                    elif k == "metadata":                        setattr(app.state, "LICENSE_METADATA", v)                return True            else:                log.error(                    f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"                )        except Exception as ex:            log.exception(f"License: Uncaught Exception: {ex}")    return Falsebearer_security = HTTPBearer(auto_error=False)pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")def verify_password(plain_password, hashed_password):    return (        pwd_context.verify(plain_password, hashed_password) if hashed_password else None    )def get_password_hash(password):    return pwd_context.hash(password)def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:    payload = data.copy()    if expires_delta:        expire = datetime.now(UTC) + expires_delta        payload.update({"exp": expire})    encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)    return encoded_jwtdef decode_token(token: str) -> Optional[dict]:    try:        decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])        return decoded    except Exception:        return Nonedef extract_token_from_auth_header(auth_header: str):    return auth_header[len("Bearer ") :]def create_api_key():    key = str(uuid.uuid4()).replace("-", "")    return f"sk-{key}"def get_http_authorization_cred(auth_header: Optional[str]):    if not auth_header:        return None    try:        scheme, credentials = auth_header.split(" ")        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)    except Exception:        return Nonedef get_current_user(    request: Request,    background_tasks: BackgroundTasks,    auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),):    token = None    if auth_token is not None:        token = auth_token.credentials    if token is None and "token" in request.cookies:        token = request.cookies.get("token")    if token is None:        raise HTTPException(status_code=403, detail="Not authenticated")    # auth by api key    if token.startswith("sk-"):        if not request.state.enable_api_key:            raise HTTPException(                status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED            )        if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:            allowed_paths = [                path.strip()                for path in str(                    request.app.state.config.API_KEY_ALLOWED_ENDPOINTS                ).split(",")            ]            # Check if the request path matches any allowed endpoint.            if not any(                request.url.path == allowed                or request.url.path.startswith(allowed + "/")                for allowed in allowed_paths            ):                raise HTTPException(                    status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED                )        user = get_current_user_by_api_key(token)        # Add user info to current span        current_span = trace.get_current_span()        if current_span:            current_span.set_attribute("client.user.id", user.id)            current_span.set_attribute("client.user.email", user.email)            current_span.set_attribute("client.user.role", user.role)            current_span.set_attribute("client.auth.type", "api_key")        return user    # auth by jwt token    try:        data = decode_token(token)    except Exception as e:        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail="Invalid token",        )    if data is not None and "id" in data:        user = Users.get_user_by_id(data["id"])        if user is None:            raise HTTPException(                status_code=status.HTTP_401_UNAUTHORIZED,                detail=ERROR_MESSAGES.INVALID_TOKEN,            )        else:            # Add user info to current span            current_span = trace.get_current_span()            if current_span:                current_span.set_attribute("client.user.id", user.id)                current_span.set_attribute("client.user.email", user.email)                current_span.set_attribute("client.user.role", user.role)                current_span.set_attribute("client.auth.type", "jwt")            # Refresh the user's last active timestamp asynchronously            # to prevent blocking the request            if background_tasks:                background_tasks.add_task(Users.update_user_last_active_by_id, user.id)        return user    else:        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail=ERROR_MESSAGES.UNAUTHORIZED,        )def get_current_user_by_api_key(api_key: str):    user = Users.get_user_by_api_key(api_key)    if user is None:        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail=ERROR_MESSAGES.INVALID_TOKEN,        )    else:        # Add user info to current span        current_span = trace.get_current_span()        if current_span:            current_span.set_attribute("client.user.id", user.id)            current_span.set_attribute("client.user.email", user.email)            current_span.set_attribute("client.user.role", user.role)            current_span.set_attribute("client.auth.type", "api_key")        Users.update_user_last_active_by_id(user.id)    return userdef get_verified_user(user=Depends(get_current_user)):    if user.role not in {"user", "admin"}:        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,        )    return userdef get_admin_user(user=Depends(get_current_user)):    if user.role != "admin":        raise HTTPException(            status_code=status.HTTP_401_UNAUTHORIZED,            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,        )    return user
 |