misc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. import hashlib
  2. import re
  3. import threading
  4. import time
  5. import uuid
  6. import logging
  7. from datetime import timedelta
  8. from pathlib import Path
  9. from typing import Callable, Optional
  10. import json
  11. import aiohttp
  12. import collections.abc
  13. from open_webui.env import SRC_LOG_LEVELS, CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
  14. log = logging.getLogger(__name__)
  15. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  16. def deep_update(d, u):
  17. for k, v in u.items():
  18. if isinstance(v, collections.abc.Mapping):
  19. d[k] = deep_update(d.get(k, {}), v)
  20. else:
  21. d[k] = v
  22. return d
  23. def get_message_list(messages_map, message_id):
  24. """
  25. Reconstructs a list of messages in order up to the specified message_id.
  26. :param message_id: ID of the message to reconstruct the chain
  27. :param messages: Message history dict containing all messages
  28. :return: List of ordered messages starting from the root to the given message
  29. """
  30. # Handle case where messages is None
  31. if not messages_map:
  32. return [] # Return empty list instead of None to prevent iteration errors
  33. # Find the message by its id
  34. current_message = messages_map.get(message_id)
  35. if not current_message:
  36. return [] # Return empty list instead of None to prevent iteration errors
  37. # Reconstruct the chain by following the parentId links
  38. message_list = []
  39. while current_message:
  40. message_list.insert(
  41. 0, current_message
  42. ) # Insert the message at the beginning of the list
  43. parent_id = current_message.get("parentId") # Use .get() for safety
  44. current_message = messages_map.get(parent_id) if parent_id else None
  45. return message_list
  46. def get_messages_content(messages: list[dict]) -> str:
  47. return "\n".join(
  48. [
  49. f"{message['role'].upper()}: {get_content_from_message(message)}"
  50. for message in messages
  51. ]
  52. )
  53. def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
  54. for message in reversed(messages):
  55. if message["role"] == "user":
  56. return message
  57. return None
  58. def get_content_from_message(message: dict) -> Optional[str]:
  59. if isinstance(message.get("content"), list):
  60. for item in message["content"]:
  61. if item["type"] == "text":
  62. return item["text"]
  63. else:
  64. return message.get("content")
  65. return None
  66. def get_last_user_message(messages: list[dict]) -> Optional[str]:
  67. message = get_last_user_message_item(messages)
  68. if message is None:
  69. return None
  70. return get_content_from_message(message)
  71. def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
  72. for message in reversed(messages):
  73. if message["role"] == "assistant":
  74. return message
  75. return None
  76. def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
  77. for message in reversed(messages):
  78. if message["role"] == "assistant":
  79. return get_content_from_message(message)
  80. return None
  81. def get_system_message(messages: list[dict]) -> Optional[dict]:
  82. for message in messages:
  83. if message["role"] == "system":
  84. return message
  85. return None
  86. def remove_system_message(messages: list[dict]) -> list[dict]:
  87. return [message for message in messages if message["role"] != "system"]
  88. def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
  89. return get_system_message(messages), remove_system_message(messages)
  90. def update_message_content(message: dict, content: str, append: bool = True) -> dict:
  91. if isinstance(message["content"], list):
  92. for item in message["content"]:
  93. if item["type"] == "text":
  94. if append:
  95. item["text"] = f"{item['text']}\n{content}"
  96. else:
  97. item["text"] = f"{content}\n{item['text']}"
  98. else:
  99. if append:
  100. message["content"] = f"{message['content']}\n{content}"
  101. else:
  102. message["content"] = f"{content}\n{message['content']}"
  103. return message
  104. def replace_system_message_content(content: str, messages: list[dict]) -> dict:
  105. for message in messages:
  106. if message["role"] == "system":
  107. message["content"] = content
  108. break
  109. return messages
  110. def add_or_update_system_message(
  111. content: str, messages: list[dict], append: bool = False
  112. ):
  113. """
  114. Adds a new system message at the beginning of the messages list
  115. or updates the existing system message at the beginning.
  116. :param msg: The message to be added or appended.
  117. :param messages: The list of message dictionaries.
  118. :return: The updated list of message dictionaries.
  119. """
  120. if messages and messages[0].get("role") == "system":
  121. messages[0] = update_message_content(messages[0], content, append)
  122. else:
  123. # Insert at the beginning
  124. messages.insert(0, {"role": "system", "content": content})
  125. return messages
  126. def add_or_update_user_message(content: str, messages: list[dict], append: bool = True):
  127. """
  128. Adds a new user message at the end of the messages list
  129. or updates the existing user message at the end.
  130. :param msg: The message to be added or appended.
  131. :param messages: The list of message dictionaries.
  132. :return: The updated list of message dictionaries.
  133. """
  134. if messages and messages[-1].get("role") == "user":
  135. messages[-1] = update_message_content(messages[-1], content, append)
  136. else:
  137. # Insert at the end
  138. messages.append({"role": "user", "content": content})
  139. return messages
  140. def prepend_to_first_user_message_content(
  141. content: str, messages: list[dict]
  142. ) -> list[dict]:
  143. for message in messages:
  144. if message["role"] == "user":
  145. message = update_message_content(message, content, append=False)
  146. break
  147. return messages
  148. def append_or_update_assistant_message(content: str, messages: list[dict]):
  149. """
  150. Adds a new assistant message at the end of the messages list
  151. or updates the existing assistant message at the end.
  152. :param msg: The message to be added or appended.
  153. :param messages: The list of message dictionaries.
  154. :return: The updated list of message dictionaries.
  155. """
  156. if messages and messages[-1].get("role") == "assistant":
  157. messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
  158. else:
  159. # Insert at the end
  160. messages.append({"role": "assistant", "content": content})
  161. return messages
  162. def openai_chat_message_template(model: str):
  163. return {
  164. "id": f"{model}-{str(uuid.uuid4())}",
  165. "created": int(time.time()),
  166. "model": model,
  167. "choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
  168. }
  169. def openai_chat_chunk_message_template(
  170. model: str,
  171. content: Optional[str] = None,
  172. reasoning_content: Optional[str] = None,
  173. tool_calls: Optional[list[dict]] = None,
  174. usage: Optional[dict] = None,
  175. ) -> dict:
  176. template = openai_chat_message_template(model)
  177. template["object"] = "chat.completion.chunk"
  178. template["choices"][0]["index"] = 0
  179. template["choices"][0]["delta"] = {}
  180. if content:
  181. template["choices"][0]["delta"]["content"] = content
  182. if reasoning_content:
  183. template["choices"][0]["delta"]["reasoning_content"] = reasoning_content
  184. if tool_calls:
  185. template["choices"][0]["delta"]["tool_calls"] = tool_calls
  186. if not content and not reasoning_content and not tool_calls:
  187. template["choices"][0]["finish_reason"] = "stop"
  188. if usage:
  189. template["usage"] = usage
  190. return template
  191. def openai_chat_completion_message_template(
  192. model: str,
  193. message: Optional[str] = None,
  194. reasoning_content: Optional[str] = None,
  195. tool_calls: Optional[list[dict]] = None,
  196. usage: Optional[dict] = None,
  197. ) -> dict:
  198. template = openai_chat_message_template(model)
  199. template["object"] = "chat.completion"
  200. if message is not None:
  201. template["choices"][0]["message"] = {
  202. "role": "assistant",
  203. "content": message,
  204. **({"reasoning_content": reasoning_content} if reasoning_content else {}),
  205. **({"tool_calls": tool_calls} if tool_calls else {}),
  206. }
  207. template["choices"][0]["finish_reason"] = "stop"
  208. if usage:
  209. template["usage"] = usage
  210. return template
  211. def get_gravatar_url(email):
  212. # Trim leading and trailing whitespace from
  213. # an email address and force all characters
  214. # to lower case
  215. address = str(email).strip().lower()
  216. # Create a SHA256 hash of the final string
  217. hash_object = hashlib.sha256(address.encode())
  218. hash_hex = hash_object.hexdigest()
  219. # Grab the actual image URL
  220. return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
  221. def calculate_sha256(file_path, chunk_size):
  222. # Compute SHA-256 hash of a file efficiently in chunks
  223. sha256 = hashlib.sha256()
  224. with open(file_path, "rb") as f:
  225. while chunk := f.read(chunk_size):
  226. sha256.update(chunk)
  227. return sha256.hexdigest()
  228. def calculate_sha256_string(string):
  229. # Create a new SHA-256 hash object
  230. sha256_hash = hashlib.sha256()
  231. # Update the hash object with the bytes of the input string
  232. sha256_hash.update(string.encode("utf-8"))
  233. # Get the hexadecimal representation of the hash
  234. hashed_string = sha256_hash.hexdigest()
  235. return hashed_string
  236. def validate_email_format(email: str) -> bool:
  237. if email.endswith("@localhost"):
  238. return True
  239. return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
  240. def sanitize_filename(file_name):
  241. # Convert to lowercase
  242. lower_case_file_name = file_name.lower()
  243. # Remove special characters using regular expression
  244. sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
  245. # Replace spaces with dashes
  246. final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
  247. return final_file_name
  248. def extract_folders_after_data_docs(path):
  249. # Convert the path to a Path object if it's not already
  250. path = Path(path)
  251. # Extract parts of the path
  252. parts = path.parts
  253. # Find the index of '/data/docs' in the path
  254. try:
  255. index_data_docs = parts.index("data") + 1
  256. index_docs = parts.index("docs", index_data_docs) + 1
  257. except ValueError:
  258. return []
  259. # Exclude the filename and accumulate folder names
  260. tags = []
  261. folders = parts[index_docs:-1]
  262. for idx, _ in enumerate(folders):
  263. tags.append("/".join(folders[: idx + 1]))
  264. return tags
  265. def parse_duration(duration: str) -> Optional[timedelta]:
  266. if duration == "-1" or duration == "0":
  267. return None
  268. # Regular expression to find number and unit pairs
  269. pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
  270. matches = re.findall(pattern, duration)
  271. if not matches:
  272. raise ValueError("Invalid duration string")
  273. total_duration = timedelta()
  274. for number, _, unit in matches:
  275. number = float(number)
  276. if unit == "ms":
  277. total_duration += timedelta(milliseconds=number)
  278. elif unit == "s":
  279. total_duration += timedelta(seconds=number)
  280. elif unit == "m":
  281. total_duration += timedelta(minutes=number)
  282. elif unit == "h":
  283. total_duration += timedelta(hours=number)
  284. elif unit == "d":
  285. total_duration += timedelta(days=number)
  286. elif unit == "w":
  287. total_duration += timedelta(weeks=number)
  288. return total_duration
  289. def parse_ollama_modelfile(model_text):
  290. parameters_meta = {
  291. "mirostat": int,
  292. "mirostat_eta": float,
  293. "mirostat_tau": float,
  294. "num_ctx": int,
  295. "repeat_last_n": int,
  296. "repeat_penalty": float,
  297. "temperature": float,
  298. "seed": int,
  299. "tfs_z": float,
  300. "num_predict": int,
  301. "top_k": int,
  302. "top_p": float,
  303. "num_keep": int,
  304. "presence_penalty": float,
  305. "frequency_penalty": float,
  306. "num_batch": int,
  307. "num_gpu": int,
  308. "use_mmap": bool,
  309. "use_mlock": bool,
  310. "num_thread": int,
  311. }
  312. data = {"base_model_id": None, "params": {}}
  313. # Parse base model
  314. base_model_match = re.search(
  315. r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
  316. )
  317. if base_model_match:
  318. data["base_model_id"] = base_model_match.group(1)
  319. # Parse template
  320. template_match = re.search(
  321. r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  322. )
  323. if template_match:
  324. data["params"] = {"template": template_match.group(1).strip()}
  325. # Parse stops
  326. stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
  327. if stops:
  328. data["params"]["stop"] = stops
  329. # Parse other parameters from the provided list
  330. for param, param_type in parameters_meta.items():
  331. param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
  332. if param_match:
  333. value = param_match.group(1)
  334. try:
  335. if param_type is int:
  336. value = int(value)
  337. elif param_type is float:
  338. value = float(value)
  339. elif param_type is bool:
  340. value = value.lower() == "true"
  341. except Exception as e:
  342. log.exception(f"Failed to parse parameter {param}: {e}")
  343. continue
  344. data["params"][param] = value
  345. # Parse adapter
  346. adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
  347. if adapter_match:
  348. data["params"]["adapter"] = adapter_match.group(1)
  349. # Parse system description
  350. system_desc_match = re.search(
  351. r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  352. )
  353. system_desc_match_single = re.search(
  354. r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
  355. )
  356. if system_desc_match:
  357. data["params"]["system"] = system_desc_match.group(1).strip()
  358. elif system_desc_match_single:
  359. data["params"]["system"] = system_desc_match_single.group(1).strip()
  360. # Parse messages
  361. messages = []
  362. message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
  363. for role, content in message_matches:
  364. messages.append({"role": role, "content": content})
  365. if messages:
  366. data["params"]["messages"] = messages
  367. return data
  368. def convert_logit_bias_input_to_json(user_input):
  369. logit_bias_pairs = user_input.split(",")
  370. logit_bias_json = {}
  371. for pair in logit_bias_pairs:
  372. token, bias = pair.split(":")
  373. token = str(token.strip())
  374. bias = int(bias.strip())
  375. bias = 100 if bias > 100 else -100 if bias < -100 else bias
  376. logit_bias_json[token] = bias
  377. return json.dumps(logit_bias_json)
  378. def freeze(value):
  379. """
  380. Freeze a value to make it hashable.
  381. """
  382. if isinstance(value, dict):
  383. return frozenset((k, freeze(v)) for k, v in value.items())
  384. elif isinstance(value, list):
  385. return tuple(freeze(v) for v in value)
  386. return value
  387. def throttle(interval: float = 10.0):
  388. """
  389. Decorator to prevent a function from being called more than once within a specified duration.
  390. If the function is called again within the duration, it returns None. To avoid returning
  391. different types, the return type of the function should be Optional[T].
  392. :param interval: Duration in seconds to wait before allowing the function to be called again.
  393. """
  394. def decorator(func):
  395. last_calls = {}
  396. lock = threading.Lock()
  397. def wrapper(*args, **kwargs):
  398. if interval is None:
  399. return func(*args, **kwargs)
  400. key = (args, freeze(kwargs))
  401. now = time.time()
  402. if now - last_calls.get(key, 0) < interval:
  403. return None
  404. with lock:
  405. if now - last_calls.get(key, 0) < interval:
  406. return None
  407. last_calls[key] = now
  408. return func(*args, **kwargs)
  409. return wrapper
  410. return decorator
  411. def extract_urls(text: str) -> list[str]:
  412. # Regex pattern to match URLs
  413. url_pattern = re.compile(
  414. r"(https?://[^\s]+)", re.IGNORECASE
  415. ) # Matches http and https URLs
  416. return url_pattern.findall(text)
  417. def stream_chunks_handler(stream: aiohttp.StreamReader):
  418. """
  419. Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit.
  420. When a single line exceeds max_buffer_size, returns an empty JSON string {} and skips subsequent data
  421. until encountering normally sized data.
  422. :param stream: The stream reader to handle.
  423. :return: An async generator that yields the stream data.
  424. """
  425. max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
  426. if max_buffer_size is None or max_buffer_size <= 0:
  427. return stream
  428. async def yield_safe_stream_chunks():
  429. buffer = b""
  430. skip_mode = False
  431. async for data, _ in stream.iter_chunks():
  432. if not data:
  433. continue
  434. # In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
  435. if skip_mode and len(buffer) > max_buffer_size:
  436. buffer = b""
  437. lines = (buffer + data).split(b"\n")
  438. # Process complete lines (except the last possibly incomplete fragment)
  439. for i in range(len(lines) - 1):
  440. line = lines[i]
  441. if skip_mode:
  442. # Skip mode: check if current line is small enough to exit skip mode
  443. if len(line) <= max_buffer_size:
  444. skip_mode = False
  445. yield line
  446. else:
  447. yield b"data: {}"
  448. else:
  449. # Normal mode: check if line exceeds limit
  450. if len(line) > max_buffer_size:
  451. skip_mode = True
  452. yield b"data: {}"
  453. log.info(f"Skip mode triggered, line size: {len(line)}")
  454. else:
  455. yield line
  456. # Save the last incomplete fragment
  457. buffer = lines[-1]
  458. # Check if buffer exceeds limit
  459. if not skip_mode and len(buffer) > max_buffer_size:
  460. skip_mode = True
  461. log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
  462. # Clear oversized buffer to prevent unlimited growth
  463. buffer = b""
  464. # Process remaining buffer data
  465. if buffer and not skip_mode:
  466. yield buffer
  467. return yield_safe_stream_chunks()