models.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import time
  2. import logging
  3. import sys
  4. from aiocache import cached
  5. from fastapi import Request
  6. from open_webui.routers import openai, ollama
  7. from open_webui.functions import get_function_models
  8. from open_webui.models.functions import Functions
  9. from open_webui.models.models import Models
  10. from open_webui.utils.plugin import (
  11. load_function_module_by_id,
  12. get_function_module_from_cache,
  13. )
  14. from open_webui.utils.access_control import has_access
  15. from open_webui.config import (
  16. DEFAULT_ARENA_MODEL,
  17. )
  18. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  19. from open_webui.models.users import UserModel
  20. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  23. async def get_all_base_models(request: Request, user: UserModel = None):
  24. function_models = []
  25. openai_models = []
  26. ollama_models = []
  27. if request.app.state.config.ENABLE_OPENAI_API:
  28. openai_models = await openai.get_all_models(request, user=user)
  29. openai_models = openai_models["data"]
  30. if request.app.state.config.ENABLE_OLLAMA_API:
  31. ollama_models = await ollama.get_all_models(request, user=user)
  32. ollama_models = [
  33. {
  34. "id": model["model"],
  35. "name": model["name"],
  36. "object": "model",
  37. "created": int(time.time()),
  38. "owned_by": "ollama",
  39. "ollama": model,
  40. "connection_type": model.get("connection_type", "local"),
  41. "tags": model.get("tags", []),
  42. }
  43. for model in ollama_models["models"]
  44. ]
  45. function_models = await get_function_models(request)
  46. models = function_models + openai_models + ollama_models
  47. return models
  48. async def get_all_models(request, user: UserModel = None):
  49. models = await get_all_base_models(request, user=user)
  50. # If there are no models, return an empty list
  51. if len(models) == 0:
  52. return []
  53. # Add arena models
  54. if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  55. arena_models = []
  56. if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
  57. arena_models = [
  58. {
  59. "id": model["id"],
  60. "name": model["name"],
  61. "info": {
  62. "meta": model["meta"],
  63. },
  64. "object": "model",
  65. "created": int(time.time()),
  66. "owned_by": "arena",
  67. "arena": True,
  68. }
  69. for model in request.app.state.config.EVALUATION_ARENA_MODELS
  70. ]
  71. else:
  72. # Add default arena model
  73. arena_models = [
  74. {
  75. "id": DEFAULT_ARENA_MODEL["id"],
  76. "name": DEFAULT_ARENA_MODEL["name"],
  77. "info": {
  78. "meta": DEFAULT_ARENA_MODEL["meta"],
  79. },
  80. "object": "model",
  81. "created": int(time.time()),
  82. "owned_by": "arena",
  83. "arena": True,
  84. }
  85. ]
  86. models = models + arena_models
  87. global_action_ids = [
  88. function.id for function in Functions.get_global_action_functions()
  89. ]
  90. enabled_action_ids = [
  91. function.id
  92. for function in Functions.get_functions_by_type("action", active_only=True)
  93. ]
  94. global_filter_ids = [
  95. function.id for function in Functions.get_global_filter_functions()
  96. ]
  97. enabled_filter_ids = [
  98. function.id
  99. for function in Functions.get_functions_by_type("filter", active_only=True)
  100. ]
  101. custom_models = Models.get_all_models()
  102. for custom_model in custom_models:
  103. if custom_model.base_model_id is None:
  104. for model in models:
  105. if custom_model.id == model["id"] or (
  106. model.get("owned_by") == "ollama"
  107. and custom_model.id
  108. == model["id"].split(":")[
  109. 0
  110. ] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
  111. ):
  112. if custom_model.is_active:
  113. model["name"] = custom_model.name
  114. model["info"] = custom_model.model_dump()
  115. # Set action_ids and filter_ids
  116. action_ids = []
  117. filter_ids = []
  118. if "info" in model and "meta" in model["info"]:
  119. action_ids.extend(
  120. model["info"]["meta"].get("actionIds", [])
  121. )
  122. filter_ids.extend(
  123. model["info"]["meta"].get("filterIds", [])
  124. )
  125. model["action_ids"] = action_ids
  126. model["filter_ids"] = filter_ids
  127. else:
  128. models.remove(model)
  129. elif custom_model.is_active and (
  130. custom_model.id not in [model["id"] for model in models]
  131. ):
  132. owned_by = "openai"
  133. pipe = None
  134. action_ids = []
  135. filter_ids = []
  136. for model in models:
  137. if (
  138. custom_model.base_model_id == model["id"]
  139. or custom_model.base_model_id == model["id"].split(":")[0]
  140. ):
  141. owned_by = model.get("owned_by", "unknown owner")
  142. if "pipe" in model:
  143. pipe = model["pipe"]
  144. break
  145. if custom_model.meta:
  146. meta = custom_model.meta.model_dump()
  147. if "actionIds" in meta:
  148. action_ids.extend(meta["actionIds"])
  149. if "filterIds" in meta:
  150. filter_ids.extend(meta["filterIds"])
  151. models.append(
  152. {
  153. "id": f"{custom_model.id}",
  154. "name": custom_model.name,
  155. "object": "model",
  156. "created": custom_model.created_at,
  157. "owned_by": owned_by,
  158. "info": custom_model.model_dump(),
  159. "preset": True,
  160. **({"pipe": pipe} if pipe is not None else {}),
  161. "action_ids": action_ids,
  162. "filter_ids": filter_ids,
  163. }
  164. )
  165. # Process action_ids to get the actions
  166. def get_action_items_from_module(function, module):
  167. actions = []
  168. if hasattr(module, "actions"):
  169. actions = module.actions
  170. return [
  171. {
  172. "id": f"{function.id}.{action['id']}",
  173. "name": action.get("name", f"{function.name} ({action['id']})"),
  174. "description": function.meta.description,
  175. "icon": action.get(
  176. "icon_url",
  177. function.meta.manifest.get("icon_url", None)
  178. or getattr(module, "icon_url", None)
  179. or getattr(module, "icon", None),
  180. ),
  181. }
  182. for action in actions
  183. ]
  184. else:
  185. return [
  186. {
  187. "id": function.id,
  188. "name": function.name,
  189. "description": function.meta.description,
  190. "icon": function.meta.manifest.get("icon_url", None)
  191. or getattr(module, "icon_url", None)
  192. or getattr(module, "icon", None),
  193. }
  194. ]
  195. # Process filter_ids to get the filters
  196. def get_filter_items_from_module(function, module):
  197. return [
  198. {
  199. "id": function.id,
  200. "name": function.name,
  201. "description": function.meta.description,
  202. "icon": function.meta.manifest.get("icon_url", None)
  203. or getattr(module, "icon_url", None)
  204. or getattr(module, "icon", None),
  205. }
  206. ]
  207. def get_function_module_by_id(function_id):
  208. function_module, _, _ = get_function_module_from_cache(request, function_id)
  209. return function_module
  210. for model in models:
  211. action_ids = [
  212. action_id
  213. for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
  214. if action_id in enabled_action_ids
  215. ]
  216. filter_ids = [
  217. filter_id
  218. for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
  219. if filter_id in enabled_filter_ids
  220. ]
  221. model["actions"] = []
  222. for action_id in action_ids:
  223. action_function = Functions.get_function_by_id(action_id)
  224. if action_function is None:
  225. raise Exception(f"Action not found: {action_id}")
  226. function_module = get_function_module_by_id(action_id)
  227. model["actions"].extend(
  228. get_action_items_from_module(action_function, function_module)
  229. )
  230. model["filters"] = []
  231. for filter_id in filter_ids:
  232. filter_function = Functions.get_function_by_id(filter_id)
  233. if filter_function is None:
  234. raise Exception(f"Filter not found: {filter_id}")
  235. function_module = get_function_module_by_id(filter_id)
  236. if getattr(function_module, "toggle", None):
  237. model["filters"].extend(
  238. get_filter_items_from_module(filter_function, function_module)
  239. )
  240. log.debug(f"get_all_models() returned {len(models)} models")
  241. request.app.state.MODELS = {model["id"]: model for model in models}
  242. return models
  243. def check_model_access(user, model):
  244. if model.get("arena"):
  245. if not has_access(
  246. user.id,
  247. type="read",
  248. access_control=model.get("info", {})
  249. .get("meta", {})
  250. .get("access_control", {}),
  251. ):
  252. raise Exception("Model not found")
  253. else:
  254. model_info = Models.get_model_by_id(model.get("id"))
  255. if not model_info:
  256. raise Exception("Model not found")
  257. elif not (
  258. user.id == model_info.user_id
  259. or has_access(
  260. user.id, type="read", access_control=model_info.access_control
  261. )
  262. ):
  263. raise Exception("Model not found")