1
0

auth.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import logging
  2. import uuid
  3. import jwt
  4. import base64
  5. import hmac
  6. import hashlib
  7. import requests
  8. import os
  9. from cryptography.hazmat.primitives.ciphers.aead import AESGCM
  10. from cryptography.hazmat.primitives.asymmetric import ed25519
  11. from cryptography.hazmat.primitives import serialization
  12. import json
  13. from datetime import datetime, timedelta
  14. import pytz
  15. from pytz import UTC
  16. from typing import Optional, Union, List, Dict
  17. from opentelemetry import trace
  18. from open_webui.models.users import Users
  19. from open_webui.constants import ERROR_MESSAGES
  20. from open_webui.env import (
  21. OFFLINE_MODE,
  22. LICENSE_BLOB,
  23. pk,
  24. WEBUI_SECRET_KEY,
  25. TRUSTED_SIGNATURE_KEY,
  26. STATIC_DIR,
  27. SRC_LOG_LEVELS,
  28. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  29. )
  30. from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
  31. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  32. from passlib.context import CryptContext
  33. logging.getLogger("passlib").setLevel(logging.ERROR)
  34. log = logging.getLogger(__name__)
  35. log.setLevel(SRC_LOG_LEVELS["OAUTH"])
  36. SESSION_SECRET = WEBUI_SECRET_KEY
  37. ALGORITHM = "HS256"
  38. ##############
  39. # Auth Utils
  40. ##############
  41. def verify_signature(payload: str, signature: str) -> bool:
  42. """
  43. Verifies the HMAC signature of the received payload.
  44. """
  45. try:
  46. expected_signature = base64.b64encode(
  47. hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
  48. ).decode()
  49. # Compare securely to prevent timing attacks
  50. return hmac.compare_digest(expected_signature, signature)
  51. except Exception:
  52. return False
  53. def override_static(path: str, content: str):
  54. # Ensure path is safe
  55. if "/" in path or ".." in path:
  56. log.error(f"Invalid path: {path}")
  57. return
  58. file_path = os.path.join(STATIC_DIR, path)
  59. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  60. with open(file_path, "wb") as f:
  61. f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
  62. def get_license_data(app, key):
  63. def data_handler(data):
  64. for k, v in data.items():
  65. if k == "resources":
  66. for p, c in v.items():
  67. globals().get("override_static", lambda a, b: None)(p, c)
  68. elif k == "count":
  69. setattr(app.state, "USER_COUNT", v)
  70. elif k == "name":
  71. setattr(app.state, "WEBUI_NAME", v)
  72. elif k == "metadata":
  73. setattr(app.state, "LICENSE_METADATA", v)
  74. def handler(u):
  75. res = requests.post(
  76. f"{u}/api/v1/license/",
  77. json={"key": key, "version": "1"},
  78. timeout=5,
  79. )
  80. if getattr(res, "ok", False):
  81. payload = getattr(res, "json", lambda: {})()
  82. data_handler(payload)
  83. return True
  84. else:
  85. log.error(
  86. f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
  87. )
  88. if key:
  89. us = [
  90. "https://api.openwebui.com",
  91. "https://licenses.api.openwebui.com",
  92. ]
  93. try:
  94. for u in us:
  95. if handler(u):
  96. return True
  97. except Exception as ex:
  98. log.exception(f"License: Uncaught Exception: {ex}")
  99. try:
  100. if LICENSE_BLOB:
  101. nl = 12
  102. kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()
  103. def nt(b):
  104. return b[:nl], b[nl:]
  105. lb = base64.b64decode(LICENSE_BLOB)
  106. ln, lt = nt(lb)
  107. aesgcm = AESGCM(kb)
  108. p = json.loads(aesgcm.decrypt(ln, lt, None))
  109. pk.verify(base64.b64decode(p["s"]), p["p"].encode())
  110. pb = base64.b64decode(p["p"])
  111. pn, pt = nt(pb)
  112. data = json.loads(aesgcm.decrypt(pn, pt, None).decode())
  113. if not data.get("exp") and data.get("exp") < datetime.now().date():
  114. return False
  115. data_handler(data)
  116. return True
  117. except Exception as e:
  118. log.error(f"License: {e}")
  119. return False
  120. bearer_security = HTTPBearer(auto_error=False)
  121. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  122. def verify_password(plain_password, hashed_password):
  123. return (
  124. pwd_context.verify(plain_password, hashed_password) if hashed_password else None
  125. )
  126. def get_password_hash(password):
  127. return pwd_context.hash(password)
  128. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  129. payload = data.copy()
  130. if expires_delta:
  131. expire = datetime.now(UTC) + expires_delta
  132. payload.update({"exp": expire})
  133. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  134. return encoded_jwt
  135. def decode_token(token: str) -> Optional[dict]:
  136. try:
  137. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  138. return decoded
  139. except Exception:
  140. return None
  141. def extract_token_from_auth_header(auth_header: str):
  142. return auth_header[len("Bearer ") :]
  143. def create_api_key():
  144. key = str(uuid.uuid4()).replace("-", "")
  145. return f"sk-{key}"
  146. def get_http_authorization_cred(auth_header: Optional[str]):
  147. if not auth_header:
  148. return None
  149. try:
  150. scheme, credentials = auth_header.split(" ")
  151. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  152. except Exception:
  153. return None
  154. def get_current_user(
  155. request: Request,
  156. response: Response,
  157. background_tasks: BackgroundTasks,
  158. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  159. ):
  160. token = None
  161. if auth_token is not None:
  162. token = auth_token.credentials
  163. if token is None and "token" in request.cookies:
  164. token = request.cookies.get("token")
  165. if token is None:
  166. raise HTTPException(status_code=401, detail="Not authenticated")
  167. # auth by api key
  168. if token.startswith("sk-"):
  169. if not request.state.enable_api_key:
  170. raise HTTPException(
  171. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  172. )
  173. if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
  174. allowed_paths = [
  175. path.strip()
  176. for path in str(
  177. request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
  178. ).split(",")
  179. ]
  180. # Check if the request path matches any allowed endpoint.
  181. if not any(
  182. request.url.path == allowed
  183. or request.url.path.startswith(allowed + "/")
  184. for allowed in allowed_paths
  185. ):
  186. raise HTTPException(
  187. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  188. )
  189. user = get_current_user_by_api_key(token)
  190. # Add user info to current span
  191. current_span = trace.get_current_span()
  192. if current_span:
  193. current_span.set_attribute("client.user.id", user.id)
  194. current_span.set_attribute("client.user.email", user.email)
  195. current_span.set_attribute("client.user.role", user.role)
  196. current_span.set_attribute("client.auth.type", "api_key")
  197. return user
  198. # auth by jwt token
  199. try:
  200. data = decode_token(token)
  201. except Exception as e:
  202. raise HTTPException(
  203. status_code=status.HTTP_401_UNAUTHORIZED,
  204. detail="Invalid token",
  205. )
  206. if data is not None and "id" in data:
  207. user = Users.get_user_by_id(data["id"])
  208. if user is None:
  209. raise HTTPException(
  210. status_code=status.HTTP_401_UNAUTHORIZED,
  211. detail=ERROR_MESSAGES.INVALID_TOKEN,
  212. )
  213. else:
  214. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  215. trusted_email = request.headers.get(
  216. WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
  217. ).lower()
  218. if trusted_email and user.email != trusted_email:
  219. # Delete the token cookie
  220. response.delete_cookie("token")
  221. # Delete OAuth token if present
  222. if request.cookies.get("oauth_id_token"):
  223. response.delete_cookie("oauth_id_token")
  224. raise HTTPException(
  225. status_code=status.HTTP_401_UNAUTHORIZED,
  226. detail="User mismatch. Please sign in again.",
  227. )
  228. # Add user info to current span
  229. current_span = trace.get_current_span()
  230. if current_span:
  231. current_span.set_attribute("client.user.id", user.id)
  232. current_span.set_attribute("client.user.email", user.email)
  233. current_span.set_attribute("client.user.role", user.role)
  234. current_span.set_attribute("client.auth.type", "jwt")
  235. # Refresh the user's last active timestamp asynchronously
  236. # to prevent blocking the request
  237. if background_tasks:
  238. background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
  239. return user
  240. else:
  241. raise HTTPException(
  242. status_code=status.HTTP_401_UNAUTHORIZED,
  243. detail=ERROR_MESSAGES.UNAUTHORIZED,
  244. )
  245. def get_current_user_by_api_key(api_key: str):
  246. user = Users.get_user_by_api_key(api_key)
  247. if user is None:
  248. raise HTTPException(
  249. status_code=status.HTTP_401_UNAUTHORIZED,
  250. detail=ERROR_MESSAGES.INVALID_TOKEN,
  251. )
  252. else:
  253. # Add user info to current span
  254. current_span = trace.get_current_span()
  255. if current_span:
  256. current_span.set_attribute("client.user.id", user.id)
  257. current_span.set_attribute("client.user.email", user.email)
  258. current_span.set_attribute("client.user.role", user.role)
  259. current_span.set_attribute("client.auth.type", "api_key")
  260. Users.update_user_last_active_by_id(user.id)
  261. return user
  262. def get_verified_user(user=Depends(get_current_user)):
  263. if user.role not in {"user", "admin"}:
  264. raise HTTPException(
  265. status_code=status.HTTP_401_UNAUTHORIZED,
  266. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  267. )
  268. return user
  269. def get_admin_user(user=Depends(get_current_user)):
  270. if user.role != "admin":
  271. raise HTTPException(
  272. status_code=status.HTTP_401_UNAUTHORIZED,
  273. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  274. )
  275. return user