auth.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import logging
  2. import uuid
  3. import jwt
  4. import base64
  5. import hmac
  6. import hashlib
  7. import requests
  8. import os
  9. from datetime import datetime, timedelta
  10. import pytz
  11. from pytz import UTC
  12. from typing import Optional, Union, List, Dict
  13. from opentelemetry import trace
  14. from open_webui.models.users import Users
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import (
  17. WEBUI_SECRET_KEY,
  18. TRUSTED_SIGNATURE_KEY,
  19. STATIC_DIR,
  20. SRC_LOG_LEVELS,
  21. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  22. )
  23. from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
  24. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  25. from passlib.context import CryptContext
  26. logging.getLogger("passlib").setLevel(logging.ERROR)
  27. log = logging.getLogger(__name__)
  28. log.setLevel(SRC_LOG_LEVELS["OAUTH"])
  29. SESSION_SECRET = WEBUI_SECRET_KEY
  30. ALGORITHM = "HS256"
  31. ##############
  32. # Auth Utils
  33. ##############
  34. def verify_signature(payload: str, signature: str) -> bool:
  35. """
  36. Verifies the HMAC signature of the received payload.
  37. """
  38. try:
  39. expected_signature = base64.b64encode(
  40. hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
  41. ).decode()
  42. # Compare securely to prevent timing attacks
  43. return hmac.compare_digest(expected_signature, signature)
  44. except Exception:
  45. return False
  46. def override_static(path: str, content: str):
  47. # Ensure path is safe
  48. if "/" in path or ".." in path:
  49. log.error(f"Invalid path: {path}")
  50. return
  51. file_path = os.path.join(STATIC_DIR, path)
  52. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  53. with open(file_path, "wb") as f:
  54. f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
  55. def get_license_data(app, key):
  56. if key:
  57. try:
  58. res = requests.post(
  59. "https://api.openwebui.com/api/v1/license/",
  60. json={"key": key, "version": "1"},
  61. timeout=5,
  62. )
  63. if getattr(res, "ok", False):
  64. payload = getattr(res, "json", lambda: {})()
  65. for k, v in payload.items():
  66. if k == "resources":
  67. for p, c in v.items():
  68. globals().get("override_static", lambda a, b: None)(p, c)
  69. elif k == "count":
  70. setattr(app.state, "USER_COUNT", v)
  71. elif k == "name":
  72. setattr(app.state, "WEBUI_NAME", v)
  73. elif k == "metadata":
  74. setattr(app.state, "LICENSE_METADATA", v)
  75. return True
  76. else:
  77. log.error(
  78. f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
  79. )
  80. except Exception as ex:
  81. log.exception(f"License: Uncaught Exception: {ex}")
  82. return False
  83. bearer_security = HTTPBearer(auto_error=False)
  84. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  85. def verify_password(plain_password, hashed_password):
  86. return (
  87. pwd_context.verify(plain_password, hashed_password) if hashed_password else None
  88. )
  89. def get_password_hash(password):
  90. return pwd_context.hash(password)
  91. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  92. payload = data.copy()
  93. if expires_delta:
  94. expire = datetime.now(UTC) + expires_delta
  95. payload.update({"exp": expire})
  96. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  97. return encoded_jwt
  98. def decode_token(token: str) -> Optional[dict]:
  99. try:
  100. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  101. return decoded
  102. except Exception:
  103. return None
  104. def extract_token_from_auth_header(auth_header: str):
  105. return auth_header[len("Bearer ") :]
  106. def create_api_key():
  107. key = str(uuid.uuid4()).replace("-", "")
  108. return f"sk-{key}"
  109. def get_http_authorization_cred(auth_header: Optional[str]):
  110. if not auth_header:
  111. return None
  112. try:
  113. scheme, credentials = auth_header.split(" ")
  114. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  115. except Exception:
  116. return None
  117. def get_current_user(
  118. request: Request,
  119. response: Response,
  120. background_tasks: BackgroundTasks,
  121. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  122. ):
  123. token = None
  124. if auth_token is not None:
  125. token = auth_token.credentials
  126. if token is None and "token" in request.cookies:
  127. token = request.cookies.get("token")
  128. if token is None:
  129. raise HTTPException(status_code=403, detail="Not authenticated")
  130. # auth by api key
  131. if token.startswith("sk-"):
  132. if not request.state.enable_api_key:
  133. raise HTTPException(
  134. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  135. )
  136. if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
  137. allowed_paths = [
  138. path.strip()
  139. for path in str(
  140. request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
  141. ).split(",")
  142. ]
  143. # Check if the request path matches any allowed endpoint.
  144. if not any(
  145. request.url.path == allowed
  146. or request.url.path.startswith(allowed + "/")
  147. for allowed in allowed_paths
  148. ):
  149. raise HTTPException(
  150. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  151. )
  152. user = get_current_user_by_api_key(token)
  153. # Add user info to current span
  154. current_span = trace.get_current_span()
  155. if current_span:
  156. current_span.set_attribute("client.user.id", user.id)
  157. current_span.set_attribute("client.user.email", user.email)
  158. current_span.set_attribute("client.user.role", user.role)
  159. current_span.set_attribute("client.auth.type", "api_key")
  160. return user
  161. # auth by jwt token
  162. try:
  163. data = decode_token(token)
  164. except Exception as e:
  165. raise HTTPException(
  166. status_code=status.HTTP_401_UNAUTHORIZED,
  167. detail="Invalid token",
  168. )
  169. if data is not None and "id" in data:
  170. user = Users.get_user_by_id(data["id"])
  171. if user is None:
  172. raise HTTPException(
  173. status_code=status.HTTP_401_UNAUTHORIZED,
  174. detail=ERROR_MESSAGES.INVALID_TOKEN,
  175. )
  176. else:
  177. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  178. trusted_email = request.headers.get(
  179. WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
  180. ).lower()
  181. if trusted_email and user.email != trusted_email:
  182. # Delete the token cookie
  183. response.delete_cookie("token")
  184. # Delete OAuth token if present
  185. if request.cookies.get("oauth_id_token"):
  186. response.delete_cookie("oauth_id_token")
  187. raise HTTPException(
  188. status_code=status.HTTP_401_UNAUTHORIZED,
  189. detail="User mismatch. Please sign in again.",
  190. )
  191. # Add user info to current span
  192. current_span = trace.get_current_span()
  193. if current_span:
  194. current_span.set_attribute("client.user.id", user.id)
  195. current_span.set_attribute("client.user.email", user.email)
  196. current_span.set_attribute("client.user.role", user.role)
  197. current_span.set_attribute("client.auth.type", "jwt")
  198. # Refresh the user's last active timestamp asynchronously
  199. # to prevent blocking the request
  200. if background_tasks:
  201. background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
  202. return user
  203. else:
  204. raise HTTPException(
  205. status_code=status.HTTP_401_UNAUTHORIZED,
  206. detail=ERROR_MESSAGES.UNAUTHORIZED,
  207. )
  208. def get_current_user_by_api_key(api_key: str):
  209. user = Users.get_user_by_api_key(api_key)
  210. if user is None:
  211. raise HTTPException(
  212. status_code=status.HTTP_401_UNAUTHORIZED,
  213. detail=ERROR_MESSAGES.INVALID_TOKEN,
  214. )
  215. else:
  216. # Add user info to current span
  217. current_span = trace.get_current_span()
  218. if current_span:
  219. current_span.set_attribute("client.user.id", user.id)
  220. current_span.set_attribute("client.user.email", user.email)
  221. current_span.set_attribute("client.user.role", user.role)
  222. current_span.set_attribute("client.auth.type", "api_key")
  223. Users.update_user_last_active_by_id(user.id)
  224. return user
  225. def get_verified_user(user=Depends(get_current_user)):
  226. if user.role not in {"user", "admin"}:
  227. raise HTTPException(
  228. status_code=status.HTTP_401_UNAUTHORIZED,
  229. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  230. )
  231. return user
  232. def get_admin_user(user=Depends(get_current_user)):
  233. if user.role != "admin":
  234. raise HTTPException(
  235. status_code=status.HTTP_401_UNAUTHORIZED,
  236. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  237. )
  238. return user