Browse Source

remove prints and fix download progress for SD

Pranav Veldurthi 5 months ago
parent
commit
9986fb86d4

+ 3 - 3
exo/api/chatgpt_api.py

@@ -463,9 +463,9 @@ class ChatGPTAPI:
     model = data.get("model", "")
     model = data.get("model", "")
     prompt = data.get("prompt", "")
     prompt = data.get("prompt", "")
     image_url = data.get("image_url", "")
     image_url = data.get("image_url", "")
-    print(f"model: {model}, prompt: {prompt}, stream: {stream}")
+    if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
     shard = build_base_shard(model, self.inference_engine_classname)
     shard = build_base_shard(model, self.inference_engine_classname)
-    print(f"shard: {shard}")
+    if DEBUG >= 2: print(f"shard: {shard}")
     if not shard:
     if not shard:
         return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
         return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
 
 
@@ -683,7 +683,7 @@ class ChatGPTAPI:
     img = Image.open(BytesIO(image_data))
     img = Image.open(BytesIO(image_data))
     W, H = (dim - dim % 64 for dim in (img.width, img.height))
     W, H = (dim - dim % 64 for dim in (img.width, img.height))
     if W != img.width or H != img.height:
     if W != img.width or H != img.height:
-        print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+        if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
         img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
         img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
     img = mx.array(np.array(img))
     img = mx.array(np.array(img))
     img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
     img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1

+ 11 - 7
exo/download/hf/hf_shard_download.py

@@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader):
           print(f"No snapshot directory found for {self.current_repo_id}")
           print(f"No snapshot directory found for {self.current_repo_id}")
         return None
         return None
 
 
+      if not await aios.path.exists(snapshot_dir/"model_index.json"):
       # Get the weight map to know what files we need
       # Get the weight map to know what files we need
-      weight_map = await get_weight_map(self.current_repo_id, self.revision)
-      if not weight_map:
-        if DEBUG >= 2:
-          print(f"No weight map found for {self.current_repo_id}")
-        return None
+        weight_map = await get_weight_map(self.current_repo_id, self.revision)
+        if not weight_map:
+          if DEBUG >= 2:
+            print(f"No weight map found for {self.current_repo_id}")
+          return None
+
+        # Get all files needed for this shard
+        patterns = get_allow_patterns(weight_map, self.current_shard)
+      else:
+        patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
 
 
-      # Get all files needed for this shard
-      patterns = get_allow_patterns(weight_map, self.current_shard)
 
 
       # Check download status for all relevant files
       # Check download status for all relevant files
       status = {}
       status = {}

+ 0 - 2
exo/inference/mlx/sharded_utils.py

@@ -254,6 +254,4 @@ def load_model_index(model_path: Path, model_index_path: Path):
         m = {}
         m = {}
         m[model] = model_config
         m[model] = model_config
         models_config.update(m)
         models_config.update(m)
-  models_config = json.dumps(models_config)
-  models_config = json.loads(models_config)
   return models_config
   return models_config