Ver Fonte

await ensure_downloads_dir in brackets

Alex Cheema há 1 ano atrás
pai
commit
d17fdd223d
1 ficheiros alterados com 2 adições e 2 exclusões
  1. 2 2
      exo/download/new_shard_download.py

+ 2 - 2
exo/download/new_shard_download.py

@@ -50,7 +50,7 @@ async def ensure_downloads_dir() -> Path:
 
 
 async def delete_model(model_id: str, inference_engine_name: str) -> bool:
 async def delete_model(model_id: str, inference_engine_name: str) -> bool:
   repo_id = get_repo(model_id, inference_engine_name)
   repo_id = get_repo(model_id, inference_engine_name)
-  model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  model_dir = (await ensure_downloads_dir())/repo_id.replace("/", "--")
   if not await aios.path.exists(model_dir): return False
   if not await aios.path.exists(model_dir): return False
   await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
   await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
   return True
   return True
@@ -199,7 +199,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
   if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
   if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
   repo_id = get_repo(shard.model_id, inference_engine_classname)
   repo_id = get_repo(shard.model_id, inference_engine_classname)
   revision = "main"
   revision = "main"
-  target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  target_dir = (await ensure_downloads_dir())/repo_id.replace("/", "--")
   if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
   if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
 
 
   if repo_id is None:
   if repo_id is None: