Browse Source

Merge pull request #516 from exo-explore/imagegen_anim

Imagegen anim
Alex Cheema 7 months ago
parent
commit
3e7039bb66

+ 1 - 0
.gitattributes

@@ -1 +1,2 @@
 *.mp3 filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text

BIN
docs/exo-logo-transparent-black-text.png


BIN
docs/exo-logo-transparent.png


BIN
docs/exo-rounded.png


BIN
docs/exo-screenshot.png


BIN
docs/ring-topology.png


+ 45 - 5
exo/api/chatgpt_api.py

@@ -8,16 +8,16 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
-import os
 import signal
-import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
+from exo.apputil import create_animation_mp4
 from typing import Callable, Optional
+import tempfile
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -175,6 +175,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
 
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
@@ -183,7 +184,7 @@ class ChatGPTAPI:
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
-  
+
   async def handle_quit(self, request):
     if DEBUG>=1: print("Received quit signal")
     response = web.json_response({"detail": "Quit signal received"}, status=200)
@@ -216,11 +217,11 @@ class ChatGPTAPI:
   async def handle_model_support(self, request):
     return web.json_response({
       "model pool": {
-        model_name: pretty_name.get(model_name, model_name) 
+        model_name: pretty_name.get(model_name, model_name)
         for model_name in get_supported_models(self.node.topology_inference_engines_pool)
       }
     })
-  
+
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
@@ -370,6 +371,45 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
+  async def handle_create_animation(self, request):
+    try:
+      data = await request.json()
+      replacement_image_path = data.get("replacement_image_path")
+      device_name = data.get("device_name", "Local Device")
+      prompt_text = data.get("prompt", "")
+
+      if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
+
+      if not replacement_image_path:
+        return web.json_response({"error": "replacement_image_path is required"}, status=400)
+
+      # Create temp directory if it doesn't exist
+      tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
+      tmp_dir.mkdir(parents=True, exist_ok=True)
+
+      # Generate unique output filename in temp directory
+      output_filename = f"animation_{uuid.uuid4()}.mp4"
+      output_path = str(tmp_dir/output_filename)
+
+      if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
+
+      # Create the animation
+      create_animation_mp4(
+        replacement_image_path,
+        output_path,
+        device_name,
+        prompt_text
+      )
+
+      return web.json_response({
+        "status": "success",
+        "output_path": output_path
+      })
+
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()

+ 1 - 0
exo/apputil/__init__.py

@@ -0,0 +1 @@
+from exo.apputil.anim import create_animation_mp4

+ 139 - 0
exo/apputil/anim.py

@@ -0,0 +1,139 @@
+from PIL import Image, ImageDraw, ImageFont, ImageFilter
+import os
+import numpy as np
+import cv2
+
+def draw_rounded_rectangle(draw, coords, radius, fill):
+  left, top, right, bottom = coords
+  diameter = radius * 2
+  draw.rectangle([left + radius, top, right - radius, bottom], fill=fill)
+  draw.rectangle([left, top + radius, right, bottom - radius], fill=fill)
+  draw.pieslice([left, top, left + diameter, top + diameter], 180, 270, fill=fill)
+  draw.pieslice([right - diameter, top, right, top + diameter], 270, 360, fill=fill)
+  draw.pieslice([left, bottom - diameter, left + diameter, bottom], 90, 180, fill=fill)
+  draw.pieslice([right - diameter, bottom - diameter, right, bottom], 0, 90, fill=fill)
+
+def draw_centered_text_rounded(draw, text, font, rect_coords, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_width = rect_right - rect_left
+  rect_height = rect_bottom - rect_top
+  text_x = rect_left + (rect_width - text_width) // 2
+  text_y = rect_top + (rect_height - text_height) // 2
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_left_aligned_text_rounded(draw, text, font, rect_coords, padding_left=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_height = rect_bottom - rect_top
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = rect_left + padding_left
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_right_text_dynamic_width_rounded(draw, text, font, base_coords, padding=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  _, rect_top, rect_right, rect_bottom = base_coords
+  rect_height = rect_bottom - rect_top
+  new_rect_left = rect_right - (text_width + (padding * 2))
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = new_rect_left + padding
+  draw_rounded_rectangle(draw, (new_rect_left, rect_top, rect_right, rect_bottom), radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+  return new_rect_left
+
+def draw_progress_bar(draw, progress, coords, color="yellow", bg_color=(70, 70, 70)):
+  left, top, right, bottom = coords
+  total_width = right - left
+  draw.rectangle(coords, fill=bg_color)
+  progress_width = int(total_width * progress)
+  if progress_width > 0:
+    draw.rectangle((left, top, left + progress_width, bottom), fill=color)
+
+def crop_image(image, top_crop=70):
+  width, height = image.size
+  return image.crop((0, top_crop, width, height))
+
+def create_animation_mp4(
+  replacement_image_path,
+  output_path,
+  device_name,
+  prompt_text,
+  fps=30,
+  target_size=(512, 512),
+  target_position=(139, 755),
+  progress_coords=(139, 1285, 655, 1295),
+  device_coords=(1240, 370, 1640, 416),
+  prompt_coords=(332, 1702, 2662, 1745)
+):
+  frames = []
+  try:
+    font = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 20)
+    promptfont = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 24)
+  except:
+    font = ImageFont.load_default()
+    promptfont = ImageFont.load_default()
+
+  # Process first two frames
+  for i in range(1, 3):
+    base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", f"image{i}.png"))
+    draw = ImageDraw.Draw(base_img)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    if i == 2:
+      draw_left_aligned_text_rounded(draw, prompt_text, promptfont, prompt_coords)
+    frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
+
+  # Create blur sequence
+  replacement_img = Image.open(replacement_image_path)
+  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
+  blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
+
+  for i, blur_amount in enumerate(blur_steps):
+    new_frame = base_img.copy()
+    draw = ImageDraw.Draw(new_frame)
+
+    replacement_copy = replacement_img.copy()
+    replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+    if blur_amount > 0:
+      replacement_copy = replacement_copy.filter(ImageFilter.GaussianBlur(radius=blur_amount))
+
+    mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+    new_frame.paste(replacement_copy, target_position, mask)
+
+    draw_progress_bar(draw, (i + 1) / 9, progress_coords)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+    frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
+
+  # Create and add final frame (image4)
+  final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
+  draw = ImageDraw.Draw(final_base)
+
+  draw_centered_text_rounded(draw, device_name, font, device_coords)
+  draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+  replacement_copy = replacement_img.copy()
+  replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+  mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+  final_base.paste(replacement_copy, target_position, mask)
+
+  frames.extend([crop_image(final_base)] * 30)  # 1 second at 30fps
+
+  # Convert frames to video
+  if frames:
+    first_frame = np.array(frames[0])
+    height, width = first_frame.shape[:2]
+    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
+
+    for frame in frames:
+      out.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+    out.release()
+    print(f"Video saved successfully to {output_path}")

+ 3 - 0
exo/apputil/baseimages/image1.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:361fdadd67c277d45cd18b0bfc8c5ceea5fd89f2d65aef157fd915ce9cbb8599
+size 814460

+ 3 - 0
exo/apputil/baseimages/image2.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0e3891bc6b4f4dfa7444af53fcaa4b3ba06b0549546202be3243f08a0e6bd7e
+size 814235

+ 3 - 0
exo/apputil/baseimages/image3.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2dc5b3378aef397d60fd1252da8a1c578ad97e202a859590ffa416b49551d19
+size 146633

+ 3 - 0
exo/apputil/baseimages/image4.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd9a6328e609bc594a08fb9f0b43f541aadbd52c90ae18a2c699aafdf6ed3ccb
+size 647478

+ 1 - 0
setup.py

@@ -15,6 +15,7 @@ install_requires = [
   "numpy==2.0.0",
   "nuitka==2.5.1",
   "nvidia-ml-py==12.560.30",
+  "opencv-python==4.10.0.84",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "protobuf==5.28.1",