1
0

models.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import time
  2. import logging
  3. import sys
  4. from aiocache import cached
  5. from fastapi import Request
  6. from open_webui.routers import openai, ollama
  7. from open_webui.functions import get_function_models
  8. from open_webui.models.functions import Functions
  9. from open_webui.models.models import Models
  10. from open_webui.utils.plugin import load_function_module_by_id
  11. from open_webui.utils.access_control import has_access
  12. from open_webui.config import (
  13. DEFAULT_ARENA_MODEL,
  14. )
  15. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  16. from open_webui.models.users import UserModel
  17. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  18. log = logging.getLogger(__name__)
  19. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  20. async def get_all_base_models(request: Request, user: UserModel = None):
  21. function_models = []
  22. openai_models = []
  23. ollama_models = []
  24. if request.app.state.config.ENABLE_OPENAI_API:
  25. openai_models = await openai.get_all_models(request, user=user)
  26. openai_models = openai_models["data"]
  27. if request.app.state.config.ENABLE_OLLAMA_API:
  28. ollama_models = await ollama.get_all_models(request, user=user)
  29. ollama_models = [
  30. {
  31. "id": model["model"],
  32. "name": model["name"],
  33. "object": "model",
  34. "created": int(time.time()),
  35. "owned_by": "ollama",
  36. "ollama": model,
  37. "tags": model.get("tags", []),
  38. }
  39. for model in ollama_models["models"]
  40. ]
  41. function_models = await get_function_models(request)
  42. models = function_models + openai_models + ollama_models
  43. return models
  44. async def get_all_models(request, user: UserModel = None):
  45. models = await get_all_base_models(request, user=user)
  46. # If there are no models, return an empty list
  47. if len(models) == 0:
  48. return []
  49. # Add arena models
  50. if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  51. arena_models = []
  52. if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
  53. arena_models = [
  54. {
  55. "id": model["id"],
  56. "name": model["name"],
  57. "info": {
  58. "meta": model["meta"],
  59. },
  60. "object": "model",
  61. "created": int(time.time()),
  62. "owned_by": "arena",
  63. "arena": True,
  64. }
  65. for model in request.app.state.config.EVALUATION_ARENA_MODELS
  66. ]
  67. else:
  68. # Add default arena model
  69. arena_models = [
  70. {
  71. "id": DEFAULT_ARENA_MODEL["id"],
  72. "name": DEFAULT_ARENA_MODEL["name"],
  73. "info": {
  74. "meta": DEFAULT_ARENA_MODEL["meta"],
  75. },
  76. "object": "model",
  77. "created": int(time.time()),
  78. "owned_by": "arena",
  79. "arena": True,
  80. }
  81. ]
  82. models = models + arena_models
  83. global_action_ids = [
  84. function.id for function in Functions.get_global_action_functions()
  85. ]
  86. enabled_action_ids = [
  87. function.id
  88. for function in Functions.get_functions_by_type("action", active_only=True)
  89. ]
  90. global_filter_ids = [
  91. function.id for function in Functions.get_global_filter_functions()
  92. ]
  93. enabled_filter_ids = [
  94. function.id
  95. for function in Functions.get_functions_by_type("filter", active_only=True)
  96. ]
  97. custom_models = Models.get_all_models()
  98. for custom_model in custom_models:
  99. if custom_model.base_model_id is None:
  100. for model in models:
  101. if custom_model.id == model["id"] or (
  102. model.get("owned_by") == "ollama"
  103. and custom_model.id
  104. == model["id"].split(":")[
  105. 0
  106. ] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
  107. ):
  108. if custom_model.is_active:
  109. model["name"] = custom_model.name
  110. model["info"] = custom_model.model_dump()
  111. # Set action_ids and filter_ids
  112. action_ids = []
  113. filter_ids = []
  114. if "info" in model and "meta" in model["info"]:
  115. action_ids.extend(
  116. model["info"]["meta"].get("actionIds", [])
  117. )
  118. filter_ids.extend(
  119. model["info"]["meta"].get("filterIds", [])
  120. )
  121. model["action_ids"] = action_ids
  122. model["filter_ids"] = filter_ids
  123. else:
  124. models.remove(model)
  125. elif custom_model.is_active and (
  126. custom_model.id not in [model["id"] for model in models]
  127. ):
  128. owned_by = "openai"
  129. pipe = None
  130. action_ids = []
  131. filter_ids = []
  132. for model in models:
  133. if (
  134. custom_model.base_model_id == model["id"]
  135. or custom_model.base_model_id == model["id"].split(":")[0]
  136. ):
  137. owned_by = model.get("owned_by", "unknown owner")
  138. if "pipe" in model:
  139. pipe = model["pipe"]
  140. break
  141. if custom_model.meta:
  142. meta = custom_model.meta.model_dump()
  143. if "actionIds" in meta:
  144. action_ids.extend(meta["actionIds"])
  145. if "filterIds" in meta:
  146. filter_ids.extend(meta["filterIds"])
  147. models.append(
  148. {
  149. "id": f"{custom_model.id}",
  150. "name": custom_model.name,
  151. "object": "model",
  152. "created": custom_model.created_at,
  153. "owned_by": owned_by,
  154. "info": custom_model.model_dump(),
  155. "preset": True,
  156. **({"pipe": pipe} if pipe is not None else {}),
  157. "action_ids": action_ids,
  158. "filter_ids": filter_ids,
  159. }
  160. )
  161. # Process action_ids to get the actions
  162. def get_action_items_from_module(function, module):
  163. actions = []
  164. if hasattr(module, "actions"):
  165. actions = module.actions
  166. return [
  167. {
  168. "id": f"{function.id}.{action['id']}",
  169. "name": action.get("name", f"{function.name} ({action['id']})"),
  170. "description": function.meta.description,
  171. "icon_url": action.get(
  172. "icon_url", function.meta.manifest.get("icon_url", None)
  173. ),
  174. }
  175. for action in actions
  176. ]
  177. else:
  178. return [
  179. {
  180. "id": function.id,
  181. "name": function.name,
  182. "description": function.meta.description,
  183. "icon_url": function.meta.manifest.get("icon_url", None),
  184. }
  185. ]
  186. # Process filter_ids to get the filters
  187. def get_filter_items_from_module(function, module):
  188. return [
  189. {
  190. "id": function.id,
  191. "name": function.name,
  192. "description": function.meta.description,
  193. "icon_url": function.meta.manifest.get("icon_url", None),
  194. }
  195. ]
  196. def get_function_module_by_id(function_id):
  197. if function_id in request.app.state.FUNCTIONS:
  198. function_module = request.app.state.FUNCTIONS[function_id]
  199. else:
  200. function_module, _, _ = load_function_module_by_id(function_id)
  201. request.app.state.FUNCTIONS[function_id] = function_module
  202. return function_module
  203. for model in models:
  204. action_ids = [
  205. action_id
  206. for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
  207. if action_id in enabled_action_ids
  208. ]
  209. filter_ids = [
  210. filter_id
  211. for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
  212. if filter_id in enabled_filter_ids
  213. ]
  214. model["actions"] = []
  215. for action_id in action_ids:
  216. action_function = Functions.get_function_by_id(action_id)
  217. if action_function is None:
  218. raise Exception(f"Action not found: {action_id}")
  219. function_module = get_function_module_by_id(action_id)
  220. model["actions"].extend(
  221. get_action_items_from_module(action_function, function_module)
  222. )
  223. model["filters"] = []
  224. for filter_id in filter_ids:
  225. filter_function = Functions.get_function_by_id(filter_id)
  226. if filter_function is None:
  227. raise Exception(f"Filter not found: {filter_id}")
  228. function_module = get_function_module_by_id(filter_id)
  229. if getattr(function_module, "toggle", None):
  230. model["filters"].extend(
  231. get_filter_items_from_module(filter_function, function_module)
  232. )
  233. log.debug(f"get_all_models() returned {len(models)} models")
  234. request.app.state.MODELS = {model["id"]: model for model in models}
  235. return models
  236. def check_model_access(user, model):
  237. if model.get("arena"):
  238. if not has_access(
  239. user.id,
  240. type="read",
  241. access_control=model.get("info", {})
  242. .get("meta", {})
  243. .get("access_control", {}),
  244. ):
  245. raise Exception("Model not found")
  246. else:
  247. model_info = Models.get_model_by_id(model.get("id"))
  248. if not model_info:
  249. raise Exception("Model not found")
  250. elif not (
  251. user.id == model_info.user_id
  252. or has_access(
  253. user.id, type="read", access_control=model_info.access_control
  254. )
  255. ):
  256. raise Exception("Model not found")