functions.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import logging
  2. import time
  3. from typing import Optional
  4. from open_webui.internal.db import Base, JSONField, get_db
  5. from open_webui.models.users import Users
  6. from open_webui.env import SRC_LOG_LEVELS
  7. from pydantic import BaseModel, ConfigDict
  8. from sqlalchemy import BigInteger, Boolean, Column, String, Text
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  11. ####################
  12. # Functions DB Schema
  13. ####################
  14. class Function(Base):
  15. __tablename__ = "function"
  16. id = Column(String, primary_key=True)
  17. user_id = Column(String)
  18. name = Column(Text)
  19. type = Column(Text)
  20. content = Column(Text)
  21. meta = Column(JSONField)
  22. valves = Column(JSONField)
  23. is_active = Column(Boolean)
  24. is_global = Column(Boolean)
  25. updated_at = Column(BigInteger)
  26. created_at = Column(BigInteger)
  27. class FunctionMeta(BaseModel):
  28. description: Optional[str] = None
  29. manifest: Optional[dict] = {}
  30. class FunctionModel(BaseModel):
  31. id: str
  32. user_id: str
  33. name: str
  34. type: str
  35. content: str
  36. meta: FunctionMeta
  37. is_active: bool = False
  38. is_global: bool = False
  39. updated_at: int # timestamp in epoch
  40. created_at: int # timestamp in epoch
  41. model_config = ConfigDict(from_attributes=True)
  42. ####################
  43. # Forms
  44. ####################
  45. class FunctionResponse(BaseModel):
  46. id: str
  47. user_id: str
  48. type: str
  49. name: str
  50. meta: FunctionMeta
  51. is_active: bool
  52. is_global: bool
  53. updated_at: int # timestamp in epoch
  54. created_at: int # timestamp in epoch
  55. class FunctionForm(BaseModel):
  56. id: str
  57. name: str
  58. content: str
  59. meta: FunctionMeta
  60. class FunctionValves(BaseModel):
  61. valves: Optional[dict] = None
  62. class FunctionsTable:
  63. def insert_new_function(
  64. self, user_id: str, type: str, form_data: FunctionForm
  65. ) -> Optional[FunctionModel]:
  66. function = FunctionModel(
  67. **{
  68. **form_data.model_dump(),
  69. "user_id": user_id,
  70. "type": type,
  71. "updated_at": int(time.time()),
  72. "created_at": int(time.time()),
  73. }
  74. )
  75. try:
  76. with get_db() as db:
  77. result = Function(**function.model_dump())
  78. db.add(result)
  79. db.commit()
  80. db.refresh(result)
  81. if result:
  82. return FunctionModel.model_validate(result)
  83. else:
  84. return None
  85. except Exception as e:
  86. log.exception(f"Error creating a new function: {e}")
  87. return None
  88. def sync_functions(
  89. self, user_id: str, functions: list[FunctionModel]
  90. ) -> list[FunctionModel]:
  91. # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
  92. try:
  93. with get_db() as db:
  94. # Get existing functions
  95. existing_functions = db.query(Function).all()
  96. existing_ids = {func.id for func in existing_functions}
  97. # Prepare a set of new function IDs
  98. new_function_ids = {func.id for func in functions}
  99. # Update or insert functions
  100. for func in functions:
  101. if func.id in existing_ids:
  102. db.query(Function).filter_by(id=func.id).update(
  103. {
  104. **func.model_dump(),
  105. "user_id": user_id,
  106. "updated_at": int(time.time()),
  107. }
  108. )
  109. else:
  110. new_func = Function(
  111. **{
  112. **func.model_dump(),
  113. "user_id": user_id,
  114. "updated_at": int(time.time()),
  115. }
  116. )
  117. db.add(new_func)
  118. # Remove functions that are no longer present
  119. for func in existing_functions:
  120. if func.id not in new_function_ids:
  121. db.delete(func)
  122. db.commit()
  123. return [
  124. FunctionModel.model_validate(func)
  125. for func in db.query(Function).all()
  126. ]
  127. except Exception as e:
  128. log.exception(f"Error syncing functions for user {user_id}: {e}")
  129. return []
  130. def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
  131. try:
  132. with get_db() as db:
  133. function = db.get(Function, id)
  134. return FunctionModel.model_validate(function)
  135. except Exception:
  136. return None
  137. def get_functions(self, active_only=False) -> list[FunctionModel]:
  138. with get_db() as db:
  139. if active_only:
  140. return [
  141. FunctionModel.model_validate(function)
  142. for function in db.query(Function).filter_by(is_active=True).all()
  143. ]
  144. else:
  145. return [
  146. FunctionModel.model_validate(function)
  147. for function in db.query(Function).all()
  148. ]
  149. def get_functions_by_type(
  150. self, type: str, active_only=False
  151. ) -> list[FunctionModel]:
  152. with get_db() as db:
  153. if active_only:
  154. return [
  155. FunctionModel.model_validate(function)
  156. for function in db.query(Function)
  157. .filter_by(type=type, is_active=True)
  158. .all()
  159. ]
  160. else:
  161. return [
  162. FunctionModel.model_validate(function)
  163. for function in db.query(Function).filter_by(type=type).all()
  164. ]
  165. def get_global_filter_functions(self) -> list[FunctionModel]:
  166. with get_db() as db:
  167. return [
  168. FunctionModel.model_validate(function)
  169. for function in db.query(Function)
  170. .filter_by(type="filter", is_active=True, is_global=True)
  171. .all()
  172. ]
  173. def get_global_action_functions(self) -> list[FunctionModel]:
  174. with get_db() as db:
  175. return [
  176. FunctionModel.model_validate(function)
  177. for function in db.query(Function)
  178. .filter_by(type="action", is_active=True, is_global=True)
  179. .all()
  180. ]
  181. def get_function_valves_by_id(self, id: str) -> Optional[dict]:
  182. with get_db() as db:
  183. try:
  184. function = db.get(Function, id)
  185. return function.valves if function.valves else {}
  186. except Exception as e:
  187. log.exception(f"Error getting function valves by id {id}: {e}")
  188. return None
  189. def update_function_valves_by_id(
  190. self, id: str, valves: dict
  191. ) -> Optional[FunctionValves]:
  192. with get_db() as db:
  193. try:
  194. function = db.get(Function, id)
  195. function.valves = valves
  196. function.updated_at = int(time.time())
  197. db.commit()
  198. db.refresh(function)
  199. return self.get_function_by_id(id)
  200. except Exception:
  201. return None
  202. def get_user_valves_by_id_and_user_id(
  203. self, id: str, user_id: str
  204. ) -> Optional[dict]:
  205. try:
  206. user = Users.get_user_by_id(user_id)
  207. user_settings = user.settings.model_dump() if user.settings else {}
  208. # Check if user has "functions" and "valves" settings
  209. if "functions" not in user_settings:
  210. user_settings["functions"] = {}
  211. if "valves" not in user_settings["functions"]:
  212. user_settings["functions"]["valves"] = {}
  213. return user_settings["functions"]["valves"].get(id, {})
  214. except Exception as e:
  215. log.exception(
  216. f"Error getting user values by id {id} and user id {user_id}: {e}"
  217. )
  218. return None
  219. def update_user_valves_by_id_and_user_id(
  220. self, id: str, user_id: str, valves: dict
  221. ) -> Optional[dict]:
  222. try:
  223. user = Users.get_user_by_id(user_id)
  224. user_settings = user.settings.model_dump() if user.settings else {}
  225. # Check if user has "functions" and "valves" settings
  226. if "functions" not in user_settings:
  227. user_settings["functions"] = {}
  228. if "valves" not in user_settings["functions"]:
  229. user_settings["functions"]["valves"] = {}
  230. user_settings["functions"]["valves"][id] = valves
  231. # Update the user settings in the database
  232. Users.update_user_by_id(user_id, {"settings": user_settings})
  233. return user_settings["functions"]["valves"][id]
  234. except Exception as e:
  235. log.exception(
  236. f"Error updating user valves by id {id} and user_id {user_id}: {e}"
  237. )
  238. return None
  239. def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
  240. with get_db() as db:
  241. try:
  242. db.query(Function).filter_by(id=id).update(
  243. {
  244. **updated,
  245. "updated_at": int(time.time()),
  246. }
  247. )
  248. db.commit()
  249. return self.get_function_by_id(id)
  250. except Exception:
  251. return None
  252. def deactivate_all_functions(self) -> Optional[bool]:
  253. with get_db() as db:
  254. try:
  255. db.query(Function).update(
  256. {
  257. "is_active": False,
  258. "updated_at": int(time.time()),
  259. }
  260. )
  261. db.commit()
  262. return True
  263. except Exception:
  264. return None
  265. def delete_function_by_id(self, id: str) -> bool:
  266. with get_db() as db:
  267. try:
  268. db.query(Function).filter_by(id=id).delete()
  269. db.commit()
  270. return True
  271. except Exception:
  272. return False
  273. Functions = FunctionsTable()