misc.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import hashlib
  2. import re
  3. import time
  4. import uuid
  5. import logging
  6. from datetime import timedelta
  7. from pathlib import Path
  8. from typing import Callable, Optional
  9. import json
  10. import collections.abc
  11. from open_webui.env import SRC_LOG_LEVELS
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  14. def deep_update(d, u):
  15. for k, v in u.items():
  16. if isinstance(v, collections.abc.Mapping):
  17. d[k] = deep_update(d.get(k, {}), v)
  18. else:
  19. d[k] = v
  20. return d
  21. def get_message_list(messages, message_id):
  22. """
  23. Reconstructs a list of messages in order up to the specified message_id.
  24. :param message_id: ID of the message to reconstruct the chain
  25. :param messages: Message history dict containing all messages
  26. :return: List of ordered messages starting from the root to the given message
  27. """
  28. # Handle case where messages is None
  29. if not messages:
  30. return [] # Return empty list instead of None to prevent iteration errors
  31. # Find the message by its id
  32. current_message = messages.get(message_id)
  33. if not current_message:
  34. return [] # Return empty list instead of None to prevent iteration errors
  35. # Reconstruct the chain by following the parentId links
  36. message_list = []
  37. while current_message:
  38. message_list.insert(
  39. 0, current_message
  40. ) # Insert the message at the beginning of the list
  41. parent_id = current_message.get("parentId") # Use .get() for safety
  42. current_message = messages.get(parent_id) if parent_id else None
  43. return message_list
  44. def get_messages_content(messages: list[dict]) -> str:
  45. return "\n".join(
  46. [
  47. f"{message['role'].upper()}: {get_content_from_message(message)}"
  48. for message in messages
  49. ]
  50. )
  51. def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
  52. for message in reversed(messages):
  53. if message["role"] == "user":
  54. return message
  55. return None
  56. def get_content_from_message(message: dict) -> Optional[str]:
  57. if isinstance(message.get("content"), list):
  58. for item in message["content"]:
  59. if item["type"] == "text":
  60. return item["text"]
  61. else:
  62. return message.get("content")
  63. return None
  64. def get_last_user_message(messages: list[dict]) -> Optional[str]:
  65. message = get_last_user_message_item(messages)
  66. if message is None:
  67. return None
  68. return get_content_from_message(message)
  69. def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
  70. for message in reversed(messages):
  71. if message["role"] == "assistant":
  72. return message
  73. return None
  74. def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
  75. for message in reversed(messages):
  76. if message["role"] == "assistant":
  77. return get_content_from_message(message)
  78. return None
  79. def get_system_message(messages: list[dict]) -> Optional[dict]:
  80. for message in messages:
  81. if message["role"] == "system":
  82. return message
  83. return None
  84. def remove_system_message(messages: list[dict]) -> list[dict]:
  85. return [message for message in messages if message["role"] != "system"]
  86. def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
  87. return get_system_message(messages), remove_system_message(messages)
  88. def prepend_to_first_user_message_content(
  89. content: str, messages: list[dict]
  90. ) -> list[dict]:
  91. for message in messages:
  92. if message["role"] == "user":
  93. if isinstance(message["content"], list):
  94. for item in message["content"]:
  95. if item["type"] == "text":
  96. item["text"] = f"{content}\n{item['text']}"
  97. else:
  98. message["content"] = f"{content}\n{message['content']}"
  99. break
  100. return messages
  101. def add_or_update_system_message(
  102. content: str, messages: list[dict], append: bool = False
  103. ):
  104. """
  105. Adds a new system message at the beginning of the messages list
  106. or updates the existing system message at the beginning.
  107. :param msg: The message to be added or appended.
  108. :param messages: The list of message dictionaries.
  109. :return: The updated list of message dictionaries.
  110. """
  111. if messages and messages[0].get("role") == "system":
  112. if append:
  113. messages[0]["content"] = f"{messages[0]['content']}\n{content}"
  114. else:
  115. messages[0]["content"] = f"{content}\n{messages[0]['content']}"
  116. else:
  117. # Insert at the beginning
  118. messages.insert(0, {"role": "system", "content": content})
  119. return messages
  120. def add_or_update_user_message(content: str, messages: list[dict]):
  121. """
  122. Adds a new user message at the end of the messages list
  123. or updates the existing user message at the end.
  124. :param msg: The message to be added or appended.
  125. :param messages: The list of message dictionaries.
  126. :return: The updated list of message dictionaries.
  127. """
  128. if messages and messages[-1].get("role") == "user":
  129. messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
  130. else:
  131. # Insert at the end
  132. messages.append({"role": "user", "content": content})
  133. return messages
  134. def append_or_update_assistant_message(content: str, messages: list[dict]):
  135. """
  136. Adds a new assistant message at the end of the messages list
  137. or updates the existing assistant message at the end.
  138. :param msg: The message to be added or appended.
  139. :param messages: The list of message dictionaries.
  140. :return: The updated list of message dictionaries.
  141. """
  142. if messages and messages[-1].get("role") == "assistant":
  143. messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
  144. else:
  145. # Insert at the end
  146. messages.append({"role": "assistant", "content": content})
  147. return messages
  148. def openai_chat_message_template(model: str):
  149. return {
  150. "id": f"{model}-{str(uuid.uuid4())}",
  151. "created": int(time.time()),
  152. "model": model,
  153. "choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
  154. }
  155. def openai_chat_chunk_message_template(
  156. model: str,
  157. content: Optional[str] = None,
  158. tool_calls: Optional[list[dict]] = None,
  159. usage: Optional[dict] = None,
  160. ) -> dict:
  161. template = openai_chat_message_template(model)
  162. template["object"] = "chat.completion.chunk"
  163. template["choices"][0]["index"] = 0
  164. template["choices"][0]["delta"] = {}
  165. if content:
  166. template["choices"][0]["delta"]["content"] = content
  167. if tool_calls:
  168. template["choices"][0]["delta"]["tool_calls"] = tool_calls
  169. if not content and not tool_calls:
  170. template["choices"][0]["finish_reason"] = "stop"
  171. if usage:
  172. template["usage"] = usage
  173. return template
  174. def openai_chat_completion_message_template(
  175. model: str,
  176. message: Optional[str] = None,
  177. tool_calls: Optional[list[dict]] = None,
  178. usage: Optional[dict] = None,
  179. ) -> dict:
  180. template = openai_chat_message_template(model)
  181. template["object"] = "chat.completion"
  182. if message is not None:
  183. template["choices"][0]["message"] = {
  184. "content": message,
  185. "role": "assistant",
  186. **({"tool_calls": tool_calls} if tool_calls else {}),
  187. }
  188. template["choices"][0]["finish_reason"] = "stop"
  189. if usage:
  190. template["usage"] = usage
  191. return template
  192. def get_gravatar_url(email):
  193. # Trim leading and trailing whitespace from
  194. # an email address and force all characters
  195. # to lower case
  196. address = str(email).strip().lower()
  197. # Create a SHA256 hash of the final string
  198. hash_object = hashlib.sha256(address.encode())
  199. hash_hex = hash_object.hexdigest()
  200. # Grab the actual image URL
  201. return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
  202. def calculate_sha256(file_path, chunk_size):
  203. # Compute SHA-256 hash of a file efficiently in chunks
  204. sha256 = hashlib.sha256()
  205. with open(file_path, "rb") as f:
  206. while chunk := f.read(chunk_size):
  207. sha256.update(chunk)
  208. return sha256.hexdigest()
  209. def calculate_sha256_string(string):
  210. # Create a new SHA-256 hash object
  211. sha256_hash = hashlib.sha256()
  212. # Update the hash object with the bytes of the input string
  213. sha256_hash.update(string.encode("utf-8"))
  214. # Get the hexadecimal representation of the hash
  215. hashed_string = sha256_hash.hexdigest()
  216. return hashed_string
  217. def validate_email_format(email: str) -> bool:
  218. if email.endswith("@localhost"):
  219. return True
  220. return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
  221. def sanitize_filename(file_name):
  222. # Convert to lowercase
  223. lower_case_file_name = file_name.lower()
  224. # Remove special characters using regular expression
  225. sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
  226. # Replace spaces with dashes
  227. final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
  228. return final_file_name
  229. def extract_folders_after_data_docs(path):
  230. # Convert the path to a Path object if it's not already
  231. path = Path(path)
  232. # Extract parts of the path
  233. parts = path.parts
  234. # Find the index of '/data/docs' in the path
  235. try:
  236. index_data_docs = parts.index("data") + 1
  237. index_docs = parts.index("docs", index_data_docs) + 1
  238. except ValueError:
  239. return []
  240. # Exclude the filename and accumulate folder names
  241. tags = []
  242. folders = parts[index_docs:-1]
  243. for idx, _ in enumerate(folders):
  244. tags.append("/".join(folders[: idx + 1]))
  245. return tags
  246. def parse_duration(duration: str) -> Optional[timedelta]:
  247. if duration == "-1" or duration == "0":
  248. return None
  249. # Regular expression to find number and unit pairs
  250. pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
  251. matches = re.findall(pattern, duration)
  252. if not matches:
  253. raise ValueError("Invalid duration string")
  254. total_duration = timedelta()
  255. for number, _, unit in matches:
  256. number = float(number)
  257. if unit == "ms":
  258. total_duration += timedelta(milliseconds=number)
  259. elif unit == "s":
  260. total_duration += timedelta(seconds=number)
  261. elif unit == "m":
  262. total_duration += timedelta(minutes=number)
  263. elif unit == "h":
  264. total_duration += timedelta(hours=number)
  265. elif unit == "d":
  266. total_duration += timedelta(days=number)
  267. elif unit == "w":
  268. total_duration += timedelta(weeks=number)
  269. return total_duration
  270. def parse_ollama_modelfile(model_text):
  271. parameters_meta = {
  272. "mirostat": int,
  273. "mirostat_eta": float,
  274. "mirostat_tau": float,
  275. "num_ctx": int,
  276. "repeat_last_n": int,
  277. "repeat_penalty": float,
  278. "temperature": float,
  279. "seed": int,
  280. "tfs_z": float,
  281. "num_predict": int,
  282. "top_k": int,
  283. "top_p": float,
  284. "num_keep": int,
  285. "typical_p": float,
  286. "presence_penalty": float,
  287. "frequency_penalty": float,
  288. "penalize_newline": bool,
  289. "numa": bool,
  290. "num_batch": int,
  291. "num_gpu": int,
  292. "main_gpu": int,
  293. "low_vram": bool,
  294. "f16_kv": bool,
  295. "vocab_only": bool,
  296. "use_mmap": bool,
  297. "use_mlock": bool,
  298. "num_thread": int,
  299. }
  300. data = {"base_model_id": None, "params": {}}
  301. # Parse base model
  302. base_model_match = re.search(
  303. r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
  304. )
  305. if base_model_match:
  306. data["base_model_id"] = base_model_match.group(1)
  307. # Parse template
  308. template_match = re.search(
  309. r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  310. )
  311. if template_match:
  312. data["params"] = {"template": template_match.group(1).strip()}
  313. # Parse stops
  314. stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
  315. if stops:
  316. data["params"]["stop"] = stops
  317. # Parse other parameters from the provided list
  318. for param, param_type in parameters_meta.items():
  319. param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
  320. if param_match:
  321. value = param_match.group(1)
  322. try:
  323. if param_type is int:
  324. value = int(value)
  325. elif param_type is float:
  326. value = float(value)
  327. elif param_type is bool:
  328. value = value.lower() == "true"
  329. except Exception as e:
  330. log.exception(f"Failed to parse parameter {param}: {e}")
  331. continue
  332. data["params"][param] = value
  333. # Parse adapter
  334. adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
  335. if adapter_match:
  336. data["params"]["adapter"] = adapter_match.group(1)
  337. # Parse system description
  338. system_desc_match = re.search(
  339. r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  340. )
  341. system_desc_match_single = re.search(
  342. r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
  343. )
  344. if system_desc_match:
  345. data["params"]["system"] = system_desc_match.group(1).strip()
  346. elif system_desc_match_single:
  347. data["params"]["system"] = system_desc_match_single.group(1).strip()
  348. # Parse messages
  349. messages = []
  350. message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
  351. for role, content in message_matches:
  352. messages.append({"role": role, "content": content})
  353. if messages:
  354. data["params"]["messages"] = messages
  355. return data
  356. def convert_logit_bias_input_to_json(user_input):
  357. logit_bias_pairs = user_input.split(",")
  358. logit_bias_json = {}
  359. for pair in logit_bias_pairs:
  360. token, bias = pair.split(":")
  361. token = str(token.strip())
  362. bias = int(bias.strip())
  363. bias = 100 if bias > 100 else -100 if bias < -100 else bias
  364. logit_bias_json[token] = bias
  365. return json.dumps(logit_bias_json)