oauth_sessions.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import time
  2. import logging
  3. import uuid
  4. from typing import Optional, List
  5. import base64
  6. import hashlib
  7. import json
  8. from cryptography.fernet import Fernet
  9. from open_webui.internal.db import Base, get_db
  10. from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
  11. from pydantic import BaseModel, ConfigDict
  12. from sqlalchemy import BigInteger, Column, String, Text, Index
  13. log = logging.getLogger(__name__)
  14. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  15. ####################
  16. # DB MODEL
  17. ####################
  18. class OAuthSession(Base):
  19. __tablename__ = "oauth_session"
  20. id = Column(Text, primary_key=True)
  21. user_id = Column(Text, nullable=False)
  22. provider = Column(Text, nullable=False)
  23. token = Column(
  24. Text, nullable=False
  25. ) # JSON with access_token, id_token, refresh_token
  26. expires_at = Column(BigInteger, nullable=False)
  27. created_at = Column(BigInteger, nullable=False)
  28. updated_at = Column(BigInteger, nullable=False)
  29. # Add indexes for better performance
  30. __table_args__ = (
  31. Index("idx_oauth_session_user_id", "user_id"),
  32. Index("idx_oauth_session_expires_at", "expires_at"),
  33. Index("idx_oauth_session_user_provider", "user_id", "provider"),
  34. )
  35. class OAuthSessionModel(BaseModel):
  36. id: str
  37. user_id: str
  38. provider: str
  39. token: dict
  40. expires_at: int # timestamp in epoch
  41. created_at: int # timestamp in epoch
  42. updated_at: int # timestamp in epoch
  43. model_config = ConfigDict(from_attributes=True)
  44. ####################
  45. # Forms
  46. ####################
  47. class OAuthSessionResponse(BaseModel):
  48. id: str
  49. user_id: str
  50. provider: str
  51. expires_at: int
  52. class OAuthSessionTable:
  53. def __init__(self):
  54. self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
  55. if not self.encryption_key:
  56. raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
  57. # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
  58. if len(self.encryption_key) != 44:
  59. key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
  60. self.encryption_key = base64.urlsafe_b64encode(key_bytes)
  61. else:
  62. self.encryption_key = self.encryption_key.encode()
  63. try:
  64. self.fernet = Fernet(self.encryption_key)
  65. except Exception as e:
  66. log.error(f"Error initializing Fernet with provided key: {e}")
  67. raise
  68. def _encrypt_token(self, token) -> str:
  69. """Encrypt OAuth tokens for storage"""
  70. try:
  71. token_json = json.dumps(token)
  72. encrypted = self.fernet.encrypt(token_json.encode()).decode()
  73. return encrypted
  74. except Exception as e:
  75. log.error(f"Error encrypting tokens: {e}")
  76. raise
  77. def _decrypt_token(self, token: str):
  78. """Decrypt OAuth tokens from storage"""
  79. try:
  80. decrypted = self.fernet.decrypt(token.encode()).decode()
  81. return json.loads(decrypted)
  82. except Exception as e:
  83. log.error(f"Error decrypting tokens: {e}")
  84. raise
  85. def create_session(
  86. self,
  87. user_id: str,
  88. provider: str,
  89. token: dict,
  90. ) -> Optional[OAuthSessionModel]:
  91. """Create a new OAuth session"""
  92. try:
  93. with get_db() as db:
  94. current_time = int(time.time())
  95. id = str(uuid.uuid4())
  96. result = OAuthSession(
  97. **{
  98. "id": id,
  99. "user_id": user_id,
  100. "provider": provider,
  101. "token": self._encrypt_token(token),
  102. "expires_at": token.get("expires_at"),
  103. "created_at": current_time,
  104. "updated_at": current_time,
  105. }
  106. )
  107. db.add(result)
  108. db.commit()
  109. db.refresh(result)
  110. if result:
  111. result.token = token # Return decrypted token
  112. return OAuthSessionModel.model_validate(result)
  113. else:
  114. return None
  115. except Exception as e:
  116. log.error(f"Error creating OAuth session: {e}")
  117. return None
  118. def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
  119. """Get OAuth session by ID"""
  120. try:
  121. with get_db() as db:
  122. session = db.query(OAuthSession).filter_by(id=session_id).first()
  123. if session:
  124. session.token = self._decrypt_token(session.token)
  125. return OAuthSessionModel.model_validate(session)
  126. return None
  127. except Exception as e:
  128. log.error(f"Error getting OAuth session by ID: {e}")
  129. return None
  130. def get_session_by_id_and_user_id(
  131. self, session_id: str, user_id: str
  132. ) -> Optional[OAuthSessionModel]:
  133. """Get OAuth session by ID and user ID"""
  134. try:
  135. with get_db() as db:
  136. session = (
  137. db.query(OAuthSession)
  138. .filter_by(id=session_id, user_id=user_id)
  139. .first()
  140. )
  141. if session:
  142. session.token = self._decrypt_token(session.token)
  143. return OAuthSessionModel.model_validate(session)
  144. return None
  145. except Exception as e:
  146. log.error(f"Error getting OAuth session by ID: {e}")
  147. return None
  148. def get_session_by_provider_and_user_id(
  149. self, provider: str, user_id: str
  150. ) -> Optional[OAuthSessionModel]:
  151. """Get OAuth session by provider and user ID"""
  152. try:
  153. with get_db() as db:
  154. session = (
  155. db.query(OAuthSession)
  156. .filter_by(provider=provider, user_id=user_id)
  157. .first()
  158. )
  159. if session:
  160. session.token = self._decrypt_token(session.token)
  161. return OAuthSessionModel.model_validate(session)
  162. return None
  163. except Exception as e:
  164. log.error(f"Error getting OAuth session by provider and user ID: {e}")
  165. return None
  166. def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
  167. """Get all OAuth sessions for a user"""
  168. try:
  169. with get_db() as db:
  170. sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
  171. results = []
  172. for session in sessions:
  173. session.token = self._decrypt_token(session.token)
  174. results.append(OAuthSessionModel.model_validate(session))
  175. return results
  176. except Exception as e:
  177. log.error(f"Error getting OAuth sessions by user ID: {e}")
  178. return []
  179. def update_session_by_id(
  180. self, session_id: str, token: dict
  181. ) -> Optional[OAuthSessionModel]:
  182. """Update OAuth session tokens"""
  183. try:
  184. with get_db() as db:
  185. current_time = int(time.time())
  186. db.query(OAuthSession).filter_by(id=session_id).update(
  187. {
  188. "token": self._encrypt_token(token),
  189. "expires_at": token.get("expires_at"),
  190. "updated_at": current_time,
  191. }
  192. )
  193. db.commit()
  194. session = db.query(OAuthSession).filter_by(id=session_id).first()
  195. if session:
  196. session.token = self._decrypt_token(session.token)
  197. return OAuthSessionModel.model_validate(session)
  198. return None
  199. except Exception as e:
  200. log.error(f"Error updating OAuth session tokens: {e}")
  201. return None
  202. def delete_session_by_id(self, session_id: str) -> bool:
  203. """Delete an OAuth session"""
  204. try:
  205. with get_db() as db:
  206. result = db.query(OAuthSession).filter_by(id=session_id).delete()
  207. db.commit()
  208. return result > 0
  209. except Exception as e:
  210. log.error(f"Error deleting OAuth session: {e}")
  211. return False
  212. def delete_sessions_by_user_id(self, user_id: str) -> bool:
  213. """Delete all OAuth sessions for a user"""
  214. try:
  215. with get_db() as db:
  216. result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
  217. db.commit()
  218. return True
  219. except Exception as e:
  220. log.error(f"Error deleting OAuth sessions by user ID: {e}")
  221. return False
  222. OAuthSessions = OAuthSessionTable()