|
@@ -13,7 +13,6 @@ import uuid
|
|
|
import numpy as np
|
|
|
from functools import partial
|
|
|
from tqdm import tqdm
|
|
|
-from tqdm.asyncio import tqdm_asyncio
|
|
|
from exo.train.dataset import load_dataset, iterate_batches, compose
|
|
|
from exo.networking.manual.manual_discovery import ManualDiscovery
|
|
|
from exo.networking.manual.network_topology_config import NetworkTopology
|
|
@@ -33,6 +32,41 @@ from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from exo.models import build_base_shard, get_repo
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
|
|
|
+import uvloop
|
|
|
+from contextlib import asynccontextmanager
|
|
|
+import concurrent.futures
|
|
|
+import socket
|
|
|
+import resource
|
|
|
+import psutil
|
|
|
+
|
|
|
+# Configure uvloop for maximum performance
|
|
|
+def configure_uvloop():
|
|
|
+ # Install uvloop as event loop policy
|
|
|
+ uvloop.install()
|
|
|
+
|
|
|
+ # Create new event loop
|
|
|
+ loop = asyncio.new_event_loop()
|
|
|
+ asyncio.set_event_loop(loop)
|
|
|
+
|
|
|
+ # Increase file descriptor limits on Unix systems
|
|
|
+ if not psutil.WINDOWS:
|
|
|
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
|
+ try:
|
|
|
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
|
|
+ except ValueError:
|
|
|
+ try:
|
|
|
+ resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Configure thread pool for blocking operations
|
|
|
+ loop.set_default_executor(
|
|
|
+ concurrent.futures.ThreadPoolExecutor(
|
|
|
+ max_workers=min(32, (os.cpu_count() or 1) * 4)
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ return loop
|
|
|
|
|
|
# parse args
|
|
|
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
|
@@ -223,7 +257,7 @@ def clean_path(path):
|
|
|
async def hold_outstanding(node: Node):
|
|
|
while node.outstanding_requests:
|
|
|
await asyncio.sleep(.5)
|
|
|
- return
|
|
|
+ return
|
|
|
|
|
|
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
|
|
losses = []
|
|
@@ -234,7 +268,7 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
|
|
tokens.append(np.sum(lengths))
|
|
|
total_tokens = np.sum(tokens)
|
|
|
total_loss = np.sum(losses) / total_tokens
|
|
|
-
|
|
|
+
|
|
|
return total_loss, total_tokens
|
|
|
|
|
|
async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
|
|
@@ -270,7 +304,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
|
|
|
await hold_outstanding(node)
|
|
|
await hold_outstanding(node)
|
|
|
|
|
|
-
|
|
|
+
|
|
|
async def main():
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
|
@@ -285,7 +319,7 @@ async def main():
|
|
|
{"❌ No read access" if not has_read else ""}
|
|
|
{"❌ No write access" if not has_write else ""}
|
|
|
""")
|
|
|
-
|
|
|
+
|
|
|
if not args.models_seed_dir is None:
|
|
|
try:
|
|
|
models_seed_dir = clean_path(args.models_seed_dir)
|
|
@@ -330,29 +364,31 @@ async def main():
|
|
|
print("Error: This train ain't leaving the station without a model")
|
|
|
return
|
|
|
await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
|
|
-
|
|
|
+
|
|
|
else:
|
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
|
|
await asyncio.Event().wait()
|
|
|
-
|
|
|
+
|
|
|
if args.wait_for_peers > 0:
|
|
|
print("Cooldown to allow peers to exit gracefully")
|
|
|
for i in tqdm(range(50)):
|
|
|
await asyncio.sleep(.1)
|
|
|
|
|
|
+@asynccontextmanager
|
|
|
+async def setup_node(args):
|
|
|
+ # Rest of setup_node implementation...
|
|
|
+ pass
|
|
|
|
|
|
def run():
|
|
|
- loop = asyncio.new_event_loop()
|
|
|
- asyncio.set_event_loop(loop)
|
|
|
- try:
|
|
|
- loop.run_until_complete(main())
|
|
|
-
|
|
|
- except KeyboardInterrupt:
|
|
|
- print("Received keyboard interrupt. Shutting down...")
|
|
|
- finally:
|
|
|
- loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
|
|
|
- loop.close()
|
|
|
-
|
|
|
+ loop = None
|
|
|
+ try:
|
|
|
+ loop = configure_uvloop()
|
|
|
+ loop.run_until_complete(main())
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ print("\nShutdown requested... exiting")
|
|
|
+ finally:
|
|
|
+ if loop:
|
|
|
+ loop.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run()
|