|
@@ -203,57 +203,6 @@ class ChatGPTAPI:
|
|
|
async def handle_root(self, request):
|
|
|
return web.FileResponse(self.static_dir/"index.html")
|
|
|
|
|
|
- def is_model_downloaded(self, model_name):
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f"\nChecking if model {model_name} is downloaded:")
|
|
|
-
|
|
|
- cache_dir = get_hf_home() / "hub"
|
|
|
- repo = get_repo(model_name, self.inference_engine_classname)
|
|
|
-
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f" Cache dir: {cache_dir}")
|
|
|
- print(f" Repo: {repo}")
|
|
|
- print(f" Engine: {self.inference_engine_classname}")
|
|
|
-
|
|
|
- if not repo:
|
|
|
- return False
|
|
|
-
|
|
|
- # Convert repo path (e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit")
|
|
|
- # to directory format (e.g. "models--mlx-community--Llama-3.2-1B-Instruct-4bit")
|
|
|
- repo_parts = repo.split('/')
|
|
|
- formatted_path = f"models--{repo_parts[0]}--{repo_parts[1]}"
|
|
|
- repo_path = cache_dir / formatted_path / "snapshots"
|
|
|
-
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f" Looking in: {repo_path}")
|
|
|
-
|
|
|
- if repo_path.exists():
|
|
|
- # Look for the most recent snapshot directory
|
|
|
- snapshots = list(repo_path.glob("*"))
|
|
|
- if snapshots:
|
|
|
- latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)
|
|
|
-
|
|
|
- # Check for model files and their index files
|
|
|
- model_files = (
|
|
|
- list(latest_snapshot.glob("model.safetensors")) +
|
|
|
- list(latest_snapshot.glob("model.safetensors.index.json")) +
|
|
|
- list(latest_snapshot.glob("*.mlx"))
|
|
|
- )
|
|
|
-
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f" Latest snapshot: {latest_snapshot}")
|
|
|
- print(f" Found files: {model_files}")
|
|
|
-
|
|
|
- # Model is considered downloaded if we find either:
|
|
|
- # 1. model.safetensors file
|
|
|
- # 2. model.safetensors.index.json file (for sharded models)
|
|
|
- # 3. *.mlx file
|
|
|
- return len(model_files) > 0
|
|
|
-
|
|
|
- if DEBUG >= 2:
|
|
|
- print(" No valid model files found")
|
|
|
- return False
|
|
|
-
|
|
|
async def handle_model_support(self, request):
|
|
|
try:
|
|
|
model_pool = {}
|
|
@@ -272,7 +221,6 @@ class ChatGPTAPI:
|
|
|
|
|
|
# Check if model supports required engines
|
|
|
if all(map(lambda engine: engine in model_info["repo"], required_engines)):
|
|
|
- # Create a shard for status checking
|
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
if shard:
|
|
|
downloader = HFShardDownloader()
|
|
@@ -293,7 +241,7 @@ class ChatGPTAPI:
|
|
|
|
|
|
model_pool[model_name] = {
|
|
|
"name": pretty,
|
|
|
- "downloaded": self.is_model_downloaded(model_name),
|
|
|
+ "downloaded": download_percentage == 100 if download_percentage is not None else False,
|
|
|
"download_percentage": download_percentage
|
|
|
}
|
|
|
|