1
0

auth.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import logging
  2. import uuid
  3. import jwt
  4. import base64
  5. import hmac
  6. import hashlib
  7. import requests
  8. import os
  9. import bcrypt
  10. from cryptography.hazmat.primitives.ciphers.aead import AESGCM
  11. from cryptography.hazmat.primitives.asymmetric import ed25519
  12. from cryptography.hazmat.primitives import serialization
  13. import json
  14. from datetime import datetime, timedelta
  15. import pytz
  16. from pytz import UTC
  17. from typing import Optional, Union, List, Dict
  18. from opentelemetry import trace
  19. from open_webui.models.users import Users
  20. from open_webui.constants import ERROR_MESSAGES
  21. from open_webui.env import (
  22. OFFLINE_MODE,
  23. LICENSE_BLOB,
  24. pk,
  25. WEBUI_SECRET_KEY,
  26. TRUSTED_SIGNATURE_KEY,
  27. STATIC_DIR,
  28. SRC_LOG_LEVELS,
  29. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  30. )
  31. from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
  32. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  33. log = logging.getLogger(__name__)
  34. log.setLevel(SRC_LOG_LEVELS["OAUTH"])
  35. SESSION_SECRET = WEBUI_SECRET_KEY
  36. ALGORITHM = "HS256"
  37. ##############
  38. # Auth Utils
  39. ##############
  40. def verify_signature(payload: str, signature: str) -> bool:
  41. """
  42. Verifies the HMAC signature of the received payload.
  43. """
  44. try:
  45. expected_signature = base64.b64encode(
  46. hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
  47. ).decode()
  48. # Compare securely to prevent timing attacks
  49. return hmac.compare_digest(expected_signature, signature)
  50. except Exception:
  51. return False
  52. def override_static(path: str, content: str):
  53. # Ensure path is safe
  54. if "/" in path or ".." in path:
  55. log.error(f"Invalid path: {path}")
  56. return
  57. file_path = os.path.join(STATIC_DIR, path)
  58. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  59. with open(file_path, "wb") as f:
  60. f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
  61. def get_license_data(app, key):
  62. def data_handler(data):
  63. for k, v in data.items():
  64. if k == "resources":
  65. for p, c in v.items():
  66. globals().get("override_static", lambda a, b: None)(p, c)
  67. elif k == "count":
  68. setattr(app.state, "USER_COUNT", v)
  69. elif k == "name":
  70. setattr(app.state, "WEBUI_NAME", v)
  71. elif k == "metadata":
  72. setattr(app.state, "LICENSE_METADATA", v)
  73. def handler(u):
  74. res = requests.post(
  75. f"{u}/api/v1/license/",
  76. json={"key": key, "version": "1"},
  77. timeout=5,
  78. )
  79. if getattr(res, "ok", False):
  80. payload = getattr(res, "json", lambda: {})()
  81. data_handler(payload)
  82. return True
  83. else:
  84. log.error(
  85. f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
  86. )
  87. if key:
  88. us = [
  89. "https://api.openwebui.com",
  90. "https://licenses.api.openwebui.com",
  91. ]
  92. try:
  93. for u in us:
  94. if handler(u):
  95. return True
  96. except Exception as ex:
  97. log.exception(f"License: Uncaught Exception: {ex}")
  98. try:
  99. if LICENSE_BLOB:
  100. nl = 12
  101. kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()
  102. def nt(b):
  103. return b[:nl], b[nl:]
  104. lb = base64.b64decode(LICENSE_BLOB)
  105. ln, lt = nt(lb)
  106. aesgcm = AESGCM(kb)
  107. p = json.loads(aesgcm.decrypt(ln, lt, None))
  108. pk.verify(base64.b64decode(p["s"]), p["p"].encode())
  109. pb = base64.b64decode(p["p"])
  110. pn, pt = nt(pb)
  111. data = json.loads(aesgcm.decrypt(pn, pt, None).decode())
  112. if not data.get("exp") and data.get("exp") < datetime.now().date():
  113. return False
  114. data_handler(data)
  115. return True
  116. except Exception as e:
  117. log.error(f"License: {e}")
  118. return False
  119. bearer_security = HTTPBearer(auto_error=False)
  120. def get_password_hash(password: str) -> str:
  121. """Hash a password using bcrypt"""
  122. return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
  123. def verify_password(plain_password: str, hashed_password: str) -> bool:
  124. """Verify a password against its hash"""
  125. return (
  126. bcrypt.checkpw(
  127. plain_password.encode("utf-8"),
  128. hashed_password.encode("utf-8"),
  129. )
  130. if hashed_password
  131. else None
  132. )
  133. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  134. payload = data.copy()
  135. if expires_delta:
  136. expire = datetime.now(UTC) + expires_delta
  137. payload.update({"exp": expire})
  138. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  139. return encoded_jwt
  140. def decode_token(token: str) -> Optional[dict]:
  141. try:
  142. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  143. return decoded
  144. except Exception:
  145. return None
  146. def extract_token_from_auth_header(auth_header: str):
  147. return auth_header[len("Bearer ") :]
  148. def create_api_key():
  149. key = str(uuid.uuid4()).replace("-", "")
  150. return f"sk-{key}"
  151. def get_http_authorization_cred(auth_header: Optional[str]):
  152. if not auth_header:
  153. return None
  154. try:
  155. scheme, credentials = auth_header.split(" ")
  156. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  157. except Exception:
  158. return None
  159. def get_current_user(
  160. request: Request,
  161. response: Response,
  162. background_tasks: BackgroundTasks,
  163. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  164. ):
  165. token = None
  166. if auth_token is not None:
  167. token = auth_token.credentials
  168. if token is None and "token" in request.cookies:
  169. token = request.cookies.get("token")
  170. if token is None:
  171. raise HTTPException(status_code=401, detail="Not authenticated")
  172. # auth by api key
  173. if token.startswith("sk-"):
  174. if not request.state.enable_api_key:
  175. raise HTTPException(
  176. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  177. )
  178. if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
  179. allowed_paths = [
  180. path.strip()
  181. for path in str(
  182. request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
  183. ).split(",")
  184. ]
  185. # Check if the request path matches any allowed endpoint.
  186. if not any(
  187. request.url.path == allowed
  188. or request.url.path.startswith(allowed + "/")
  189. for allowed in allowed_paths
  190. ):
  191. raise HTTPException(
  192. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  193. )
  194. user = get_current_user_by_api_key(token)
  195. # Add user info to current span
  196. current_span = trace.get_current_span()
  197. if current_span:
  198. current_span.set_attribute("client.user.id", user.id)
  199. current_span.set_attribute("client.user.email", user.email)
  200. current_span.set_attribute("client.user.role", user.role)
  201. current_span.set_attribute("client.auth.type", "api_key")
  202. return user
  203. # auth by jwt token
  204. try:
  205. try:
  206. data = decode_token(token)
  207. except Exception as e:
  208. raise HTTPException(
  209. status_code=status.HTTP_401_UNAUTHORIZED,
  210. detail="Invalid token",
  211. )
  212. if data is not None and "id" in data:
  213. user = Users.get_user_by_id(data["id"])
  214. if user is None:
  215. raise HTTPException(
  216. status_code=status.HTTP_401_UNAUTHORIZED,
  217. detail=ERROR_MESSAGES.INVALID_TOKEN,
  218. )
  219. else:
  220. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  221. trusted_email = request.headers.get(
  222. WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
  223. ).lower()
  224. if trusted_email and user.email != trusted_email:
  225. raise HTTPException(
  226. status_code=status.HTTP_401_UNAUTHORIZED,
  227. detail="User mismatch. Please sign in again.",
  228. )
  229. # Add user info to current span
  230. current_span = trace.get_current_span()
  231. if current_span:
  232. current_span.set_attribute("client.user.id", user.id)
  233. current_span.set_attribute("client.user.email", user.email)
  234. current_span.set_attribute("client.user.role", user.role)
  235. current_span.set_attribute("client.auth.type", "jwt")
  236. # Refresh the user's last active timestamp asynchronously
  237. # to prevent blocking the request
  238. if background_tasks:
  239. background_tasks.add_task(
  240. Users.update_user_last_active_by_id, user.id
  241. )
  242. return user
  243. else:
  244. raise HTTPException(
  245. status_code=status.HTTP_401_UNAUTHORIZED,
  246. detail=ERROR_MESSAGES.UNAUTHORIZED,
  247. )
  248. except Exception as e:
  249. # Delete the token cookie
  250. if request.cookies.get("token"):
  251. response.delete_cookie("token")
  252. if request.cookies.get("oauth_id_token"):
  253. response.delete_cookie("oauth_id_token")
  254. # Delete OAuth session if present
  255. if request.cookies.get("oauth_session_id"):
  256. response.delete_cookie("oauth_session_id")
  257. raise e
  258. def get_current_user_by_api_key(api_key: str):
  259. user = Users.get_user_by_api_key(api_key)
  260. if user is None:
  261. raise HTTPException(
  262. status_code=status.HTTP_401_UNAUTHORIZED,
  263. detail=ERROR_MESSAGES.INVALID_TOKEN,
  264. )
  265. else:
  266. # Add user info to current span
  267. current_span = trace.get_current_span()
  268. if current_span:
  269. current_span.set_attribute("client.user.id", user.id)
  270. current_span.set_attribute("client.user.email", user.email)
  271. current_span.set_attribute("client.user.role", user.role)
  272. current_span.set_attribute("client.auth.type", "api_key")
  273. Users.update_user_last_active_by_id(user.id)
  274. return user
  275. def get_verified_user(user=Depends(get_current_user)):
  276. if user.role not in {"user", "admin"}:
  277. raise HTTPException(
  278. status_code=status.HTTP_401_UNAUTHORIZED,
  279. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  280. )
  281. return user
  282. def get_admin_user(user=Depends(get_current_user)):
  283. if user.role != "admin":
  284. raise HTTPException(
  285. status_code=status.HTTP_401_UNAUTHORIZED,
  286. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  287. )
  288. return user