|
@@ -3,19 +3,15 @@ import asyncio
|
|
|
import atexit
|
|
|
import signal
|
|
|
import json
|
|
|
-import logging
|
|
|
import platform
|
|
|
import os
|
|
|
-import sys
|
|
|
import time
|
|
|
import traceback
|
|
|
import uuid
|
|
|
import numpy as np
|
|
|
-from functools import partial
|
|
|
from tqdm import tqdm
|
|
|
-from exo.train.dataset import load_dataset, iterate_batches, compose
|
|
|
+from exo.train.dataset import load_dataset, iterate_batches
|
|
|
from exo.networking.manual.manual_discovery import ManualDiscovery
|
|
|
-from exo.networking.manual.network_topology_config import NetworkTopology
|
|
|
from exo.orchestration.node import Node
|
|
|
from exo.networking.grpc.grpc_server import GRPCServer
|
|
|
from exo.networking.udp.udp_discovery import UDPDiscovery
|
|
@@ -28,12 +24,11 @@ from exo.download.download_progress import RepoProgressEvent
|
|
|
from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
|
|
|
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
|
|
|
from exo.inference.shard import Shard
|
|
|
-from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
+from exo.inference.inference_engine import get_inference_engine
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from exo.models import build_base_shard, get_repo
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
import uvloop
|
|
|
-from contextlib import asynccontextmanager
|
|
|
import concurrent.futures
|
|
|
import resource
|
|
|
import psutil
|
|
@@ -45,31 +40,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
|
|
# Configure uvloop for maximum performance
|
|
|
def configure_uvloop():
|
|
|
- # Install uvloop as event loop policy
|
|
|
uvloop.install()
|
|
|
-
|
|
|
- # Create new event loop
|
|
|
loop = asyncio.new_event_loop()
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
|
# Increase file descriptor limits on Unix systems
|
|
|
if not psutil.WINDOWS:
|
|
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
|
- try:
|
|
|
- resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
|
|
+ try: resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
|
|
except ValueError:
|
|
|
- try:
|
|
|
- resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
|
|
- except ValueError:
|
|
|
- pass
|
|
|
-
|
|
|
- # Configure thread pool for blocking operations
|
|
|
- loop.set_default_executor(
|
|
|
- concurrent.futures.ThreadPoolExecutor(
|
|
|
- max_workers=min(32, (os.cpu_count() or 1) * 4)
|
|
|
- )
|
|
|
- )
|
|
|
+ try: resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
|
|
+ except ValueError: pass
|
|
|
|
|
|
+ loop.set_default_executor(concurrent.futures.ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 4)))
|
|
|
return loop
|
|
|
|
|
|
# parse args
|
|
@@ -183,7 +166,7 @@ server = GRPCServer(node, args.node_host, args.node_port)
|
|
|
node.server = server
|
|
|
api = ChatGPTAPI(
|
|
|
node,
|
|
|
- inference_engine.__class__.__name__,
|
|
|
+ node.inference_engine.__class__.__name__,
|
|
|
response_timeout=args.chatgpt_api_response_timeout,
|
|
|
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
|
|
default_model=args.default_model,
|
|
@@ -192,21 +175,31 @@ api = ChatGPTAPI(
|
|
|
buffered_token_output = {}
|
|
|
def update_topology_viz(req_id, tokens, __):
|
|
|
if not topology_viz: return
|
|
|
- if not inference_engine.shard: return
|
|
|
- if inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
|
|
|
-
|
|
|
+ if not node.inference_engine.shard: return
|
|
|
+ if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
|
|
|
if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
|
|
|
else: buffered_token_output[req_id] = tokens
|
|
|
- topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
|
|
|
+ topology_viz.update_prompt_output(req_id, node.inference_engine.tokenizer.decode(buffered_token_output[req_id]))
|
|
|
node.on_token.register("update_topology_viz").on_next(update_topology_viz)
|
|
|
+def update_prompt_viz(request_id, opaque_status: str):
|
|
|
+ if not topology_viz: return
|
|
|
+ try:
|
|
|
+ status = json.loads(opaque_status)
|
|
|
+ if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
|
|
|
+ topology_viz.update_prompt(request_id, status.get("prompt", "corrupted prompt (this should never happen)"))
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"Failed to update prompt viz: {e}")
|
|
|
+ traceback.print_exc()
|
|
|
+node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz)
|
|
|
|
|
|
def preemptively_start_download(request_id: str, opaque_status: str):
|
|
|
try:
|
|
|
status = json.loads(opaque_status)
|
|
|
- if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
|
|
|
- current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
|
|
- if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
|
|
- asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
|
|
|
+ if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
|
|
|
+ current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
|
|
+ if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
|
|
+ asyncio.create_task(shard_downloader.ensure_shard(current_shard, node.inference_engine.__class__.__name__))
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2:
|
|
|
print(f"Failed to preemptively start download: {e}")
|
|
@@ -229,11 +222,11 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
|
|
|
|
|
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
|
|
|
|
|
-async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
|
|
- inference_class = inference_engine.__class__.__name__
|
|
|
+async def run_model_cli(node: Node, model_name: str, prompt: str):
|
|
|
+ inference_class = node.inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|
|
|
if not shard:
|
|
|
- print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
|
|
+ print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
|
|
return
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
|
|
request_id = str(uuid.uuid4())
|
|
@@ -284,11 +277,11 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
|
|
|
|
|
return total_loss, total_tokens
|
|
|
|
|
|
-async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
|
|
|
- inference_class = inference_engine.__class__.__name__
|
|
|
+async def eval_model_cli(node: Node, model_name, dataloader, batch_size, num_batches=-1):
|
|
|
+ inference_class = node.inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|
|
|
if not shard:
|
|
|
- print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
|
|
+ print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
|
|
return
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
|
|
train, val, test = dataloader(tokenizer.encode)
|
|
@@ -298,11 +291,11 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
|
|
|
print("Waiting for outstanding tasks")
|
|
|
await hold_outstanding(node)
|
|
|
|
|
|
-async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
|
|
|
- inference_class = inference_engine.__class__.__name__
|
|
|
+async def train_model_cli(node: Node, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
|
|
|
+ inference_class = node.inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|
|
|
if not shard:
|
|
|
- print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
|
|
+ print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
|
|
return
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
|
|
train, val, test = dataloader(tokenizer.encode)
|
|
@@ -362,7 +355,7 @@ async def main():
|
|
|
if not model_name:
|
|
|
print("Error: Model name is required when using 'run' command or --run-model")
|
|
|
return
|
|
|
- await run_model_cli(node, inference_engine, model_name, args.prompt)
|
|
|
+ await run_model_cli(node, model_name, args.prompt)
|
|
|
elif args.command == "eval" or args.command == 'train':
|
|
|
model_name = args.model_name
|
|
|
dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item)
|
|
@@ -371,12 +364,12 @@ async def main():
|
|
|
if not model_name:
|
|
|
print("Error: Much like a human, I can't evaluate anything without a model")
|
|
|
return
|
|
|
- await eval_model_cli(node, inference_engine, model_name, dataloader, args.batch_size)
|
|
|
+ await eval_model_cli(node, model_name, dataloader, args.batch_size)
|
|
|
else:
|
|
|
if not model_name:
|
|
|
print("Error: This train ain't leaving the station without a model")
|
|
|
return
|
|
|
- await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
|
|
+ await train_model_cli(node, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
|
|
|
|
|
else:
|
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
|
@@ -387,11 +380,6 @@ async def main():
|
|
|
for i in tqdm(range(50)):
|
|
|
await asyncio.sleep(.1)
|
|
|
|
|
|
-@asynccontextmanager
|
|
|
-async def setup_node(args):
|
|
|
- # Rest of setup_node implementation...
|
|
|
- pass
|
|
|
-
|
|
|
def run():
|
|
|
loop = None
|
|
|
try:
|
|
@@ -400,8 +388,7 @@ def run():
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nShutdown requested... exiting")
|
|
|
finally:
|
|
|
- if loop:
|
|
|
- loop.close()
|
|
|
+ if loop: loop.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run()
|