Browse Source

move everything under exo module

Alex Cheema 1 year ago
parent
commit
5bbde22a23
42 changed files with 56 additions and 56 deletions
  1. 5 5
      example_user.py
  2. 5 5
      example_user_2.py
  3. 0 0
      exo/__init__.py
  4. 0 0
      exo/inference/__init__.py
  5. 0 0
      exo/inference/inference_engine.py
  6. 0 0
      exo/inference/mlx/__init__.py
  7. 0 0
      exo/inference/mlx/models/__init__.py
  8. 0 0
      exo/inference/mlx/models/sharded_llama.py
  9. 0 0
      exo/inference/mlx/sharded_inference_engine.py
  10. 0 0
      exo/inference/mlx/sharded_model.py
  11. 0 0
      exo/inference/mlx/sharded_utils.py
  12. 3 3
      exo/inference/mlx/test_sharded_llama.py
  13. 2 2
      exo/inference/mlx/test_sharded_model.py
  14. 0 0
      exo/inference/shard.py
  15. 4 4
      exo/inference/test_inference_engine.py
  16. 3 3
      exo/inference/tinygrad/inference.py
  17. 0 0
      exo/inference/tinygrad/models/llama.py
  18. 0 0
      exo/networking/__init__.py
  19. 0 0
      exo/networking/discovery.py
  20. 0 0
      exo/networking/grpc/__init__.py
  21. 1 1
      exo/networking/grpc/grpc_discovery.py
  22. 3 3
      exo/networking/grpc/grpc_peer_handle.py
  23. 2 2
      exo/networking/grpc/grpc_server.py
  24. 0 0
      exo/networking/grpc/node_service.proto
  25. 0 0
      exo/networking/grpc/node_service_pb2.py
  26. 0 0
      exo/networking/grpc/node_service_pb2_grpc.py
  27. 0 0
      exo/networking/grpc/test_grpc_discovery.py
  28. 3 3
      exo/networking/peer_handle.py
  29. 0 0
      exo/networking/server.py
  30. 0 0
      exo/orchestration/__init__.py
  31. 2 2
      exo/orchestration/node.py
  32. 6 6
      exo/orchestration/standard_node.py
  33. 1 1
      exo/orchestration/test_node.py
  34. 0 0
      exo/topology/__init__.py
  35. 0 0
      exo/topology/device_capabilities.py
  36. 2 2
      exo/topology/partitioning_strategy.py
  37. 1 1
      exo/topology/ring_memory_weighted_partitioning_strategy.py
  38. 1 1
      exo/topology/test_device_capabilities.py
  39. 0 0
      exo/topology/test_ring_memory_weighted_partitioning_strategy.py
  40. 0 0
      exo/topology/topology.py
  41. 6 6
      main.py
  42. 6 6
      main_dynamic.py

+ 5 - 5
example_user.py

@@ -2,11 +2,11 @@
 # They are prompting the cluster to generate a response to a question.
 # They are prompting the cluster to generate a response to a question.
 # The cluster is given the question, and the user is given the response.
 # The cluster is given the question, and the user is given the response.
 
 
-from inference.mlx.sharded_utils import get_model_path, load_tokenizer
-from inference.shard import Shard
-from networking.peer_handle import PeerHandle
-from networking.grpc.grpc_peer_handle import GRPCPeerHandle
-from topology.device_capabilities import DeviceCapabilities
+from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
+from exo.inference.shard import Shard
+from exo.networking.peer_handle import PeerHandle
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.topology.device_capabilities import DeviceCapabilities
 from typing import List
 from typing import List
 import asyncio
 import asyncio
 import argparse
 import argparse

+ 5 - 5
example_user_2.py

@@ -2,11 +2,11 @@
 # They are prompting the cluster to generate a response to a question.
 # They are prompting the cluster to generate a response to a question.
 # The cluster is given the question, and the user is given the response.
 # The cluster is given the question, and the user is given the response.
 
 
-from inference.mlx.sharded_utils import get_model_path, load_tokenizer
-from inference.shard import Shard
-from networking.peer_handle import PeerHandle
-from networking.grpc.grpc_peer_handle import GRPCPeerHandle
-from topology.device_capabilities import DeviceCapabilities
+from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
+from exo.inference.shard import Shard
+from exo.networking.peer_handle import PeerHandle
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.topology.device_capabilities import DeviceCapabilities
 from typing import List
 from typing import List
 import asyncio
 import asyncio
 import argparse
 import argparse

+ 0 - 0
inference/__init__.py → exo/__init__.py


+ 0 - 0
inference/mlx/__init__.py → exo/inference/__init__.py


+ 0 - 0
inference/inference_engine.py → exo/inference/inference_engine.py


+ 0 - 0
inference/mlx/models/__init__.py → exo/inference/mlx/__init__.py


+ 0 - 0
networking/grpc/__init__.py → exo/inference/mlx/models/__init__.py


+ 0 - 0
inference/mlx/models/sharded_llama.py → exo/inference/mlx/models/sharded_llama.py


+ 0 - 0
inference/mlx/sharded_inference_engine.py → exo/inference/mlx/sharded_inference_engine.py


+ 0 - 0
inference/mlx/sharded_model.py → exo/inference/mlx/sharded_model.py


+ 0 - 0
inference/mlx/sharded_utils.py → exo/inference/mlx/sharded_utils.py


+ 3 - 3
inference/mlx/test_sharded_llama.py → exo/inference/mlx/test_sharded_llama.py

@@ -1,7 +1,7 @@
 import mlx.core as mx
 import mlx.core as mx
-from inference.mlx.sharded_model import StatefulShardedModel
-from inference.mlx.sharded_utils import load_shard
-from inference.shard import Shard
+from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.sharded_utils import load_shard
+from exo.inference.shard import Shard
 
 
 shard_full = Shard("llama", 0, 31, 32)
 shard_full = Shard("llama", 0, 31, 32)
 shard1 = Shard("llama", 0, 12, 32)
 shard1 = Shard("llama", 0, 12, 32)

+ 2 - 2
inference/mlx/test_sharded_model.py → exo/inference/mlx/test_sharded_model.py

@@ -1,5 +1,5 @@
-from inference.shard import Shard
-from inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.shard import Shard
+from exo.inference.mlx.sharded_model import StatefulShardedModel
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from typing import Optional
 from typing import Optional

+ 0 - 0
inference/shard.py → exo/inference/shard.py


+ 4 - 4
inference/test_inference_engine.py → exo/inference/test_inference_engine.py

@@ -1,7 +1,7 @@
-from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-from inference.inference_engine import InferenceEngine
-from inference.shard import Shard
-from inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import numpy as np
 import numpy as np
 
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.

+ 3 - 3
inference/tinygrad/inference.py → exo/inference/tinygrad/inference.py

@@ -4,12 +4,12 @@ from typing import List
 import json, argparse, random, time
 import json, argparse, random, time
 import tiktoken
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from tiktoken.load import load_tiktoken_bpe
-from inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
 from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
 from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
 from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
 from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
-from inference.shard import Shard
-from inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+from exo.inference.inference_engine import InferenceEngine
 import numpy as np
 import numpy as np
 
 
 MODEL_PARAMS = {
 MODEL_PARAMS = {

+ 0 - 0
inference/tinygrad/models/llama.py → exo/inference/tinygrad/models/llama.py


+ 0 - 0
networking/__init__.py → exo/networking/__init__.py


+ 0 - 0
networking/discovery.py → exo/networking/discovery.py


+ 0 - 0
topology/__init__.py → exo/networking/grpc/__init__.py


+ 1 - 1
networking/grpc/grpc_discovery.py → exo/networking/grpc/grpc_discovery.py

@@ -6,7 +6,7 @@ from typing import List, Dict
 from ..discovery import Discovery
 from ..discovery import Discovery
 from ..peer_handle import PeerHandle
 from ..peer_handle import PeerHandle
 from .grpc_peer_handle import GRPCPeerHandle
 from .grpc_peer_handle import GRPCPeerHandle
-from topology.device_capabilities import DeviceCapabilities, device_capabilities
+from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
 
 
 class GRPCDiscovery(Discovery):
 class GRPCDiscovery(Discovery):
     def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
     def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):

+ 3 - 3
networking/grpc/grpc_peer_handle.py → exo/networking/grpc/grpc_peer_handle.py

@@ -7,9 +7,9 @@ from . import node_service_pb2
 from . import node_service_pb2_grpc
 from . import node_service_pb2_grpc
 
 
 from ..peer_handle import PeerHandle
 from ..peer_handle import PeerHandle
-from inference.shard import Shard
-from topology.topology import Topology
-from topology.device_capabilities import DeviceCapabilities
+from exo.inference.shard import Shard
+from exo.topology.topology import Topology
+from exo.topology.device_capabilities import DeviceCapabilities
 
 
 class GRPCPeerHandle(PeerHandle):
 class GRPCPeerHandle(PeerHandle):
     def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
     def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):

+ 2 - 2
networking/grpc/grpc_server.py → exo/networking/grpc/grpc_server.py

@@ -4,9 +4,9 @@ import numpy as np
 
 
 from . import node_service_pb2
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 from . import node_service_pb2_grpc
-from inference.shard import Shard
+from exo.inference.shard import Shard
 
 
-from orchestration import Node
+from exo.orchestration import Node
 
 
 import uuid
 import uuid
 
 

+ 0 - 0
networking/grpc/node_service.proto → exo/networking/grpc/node_service.proto


+ 0 - 0
networking/grpc/node_service_pb2.py → exo/networking/grpc/node_service_pb2.py


+ 0 - 0
networking/grpc/node_service_pb2_grpc.py → exo/networking/grpc/node_service_pb2_grpc.py


+ 0 - 0
networking/grpc/test_grpc_discovery.py → exo/networking/grpc/test_grpc_discovery.py


+ 3 - 3
networking/peer_handle.py → exo/networking/peer_handle.py

@@ -1,9 +1,9 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 import numpy as np
 import numpy as np
-from inference.shard import Shard
-from topology.device_capabilities import DeviceCapabilities
-from topology.topology import Topology
+from exo.inference.shard import Shard
+from exo.topology.device_capabilities import DeviceCapabilities
+from exo.topology.topology import Topology
 
 
 class PeerHandle(ABC):
 class PeerHandle(ABC):
     @abstractmethod
     @abstractmethod

+ 0 - 0
networking/server.py → exo/networking/server.py


+ 0 - 0
orchestration/__init__.py → exo/orchestration/__init__.py


+ 2 - 2
orchestration/node.py → exo/orchestration/node.py

@@ -1,8 +1,8 @@
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from inference.shard import Shard
-from topology.topology import Topology
+from exo.inference.shard import Shard
+from exo.topology.topology import Topology
 
 
 class Node(ABC):
 class Node(ABC):
     @abstractmethod
     @abstractmethod

+ 6 - 6
orchestration/standard_node.py → exo/orchestration/standard_node.py

@@ -1,12 +1,12 @@
 from typing import List, Dict, Optional, Callable, Tuple
 from typing import List, Dict, Optional, Callable, Tuple
 import numpy as np
 import numpy as np
-from networking import Discovery, PeerHandle, Server
-from inference.inference_engine import InferenceEngine, Shard
+from exo.networking import Discovery, PeerHandle, Server
+from exo.inference.inference_engine import InferenceEngine, Shard
 from .node import Node
 from .node import Node
-from topology.topology import Topology
-from topology.device_capabilities import device_capabilities
-from topology.partitioning_strategy import PartitioningStrategy
-from topology.partitioning_strategy import Partition
+from exo.topology.topology import Topology
+from exo.topology.device_capabilities import device_capabilities
+from exo.topology.partitioning_strategy import PartitioningStrategy
+from exo.topology.partitioning_strategy import Partition
 import asyncio
 import asyncio
 import uuid
 import uuid
 
 

+ 1 - 1
orchestration/test_node.py → exo/orchestration/test_node.py

@@ -3,7 +3,7 @@ from unittest.mock import Mock, AsyncMock
 import numpy as np
 import numpy as np
 
 
 from .standard_node import StandardNode
 from .standard_node import StandardNode
-from networking.peer_handle import PeerHandle
+from exo.networking.peer_handle import PeerHandle
 
 
 class TestNode(unittest.IsolatedAsyncioTestCase):
 class TestNode(unittest.IsolatedAsyncioTestCase):
     def setUp(self):
     def setUp(self):

+ 0 - 0
exo/topology/__init__.py


+ 0 - 0
topology/device_capabilities.py → exo/topology/device_capabilities.py


+ 2 - 2
topology/partitioning_strategy.py → exo/topology/partitioning_strategy.py

@@ -1,8 +1,8 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import List, Optional
 from typing import List, Optional
 from dataclasses import dataclass
 from dataclasses import dataclass
-from inference.shard import Shard
-from networking.peer_handle import PeerHandle
+from exo.inference.shard import Shard
+from exo.networking.peer_handle import PeerHandle
 from .topology import Topology
 from .topology import Topology
 
 
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1

+ 1 - 1
topology/ring_memory_weighted_partitioning_strategy.py → exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -1,6 +1,6 @@
 from typing import List
 from typing import List
 from .partitioning_strategy import PartitioningStrategy
 from .partitioning_strategy import PartitioningStrategy
-from inference.shard import Shard
+from exo.inference.shard import Shard
 from .topology import Topology
 from .topology import Topology
 from .partitioning_strategy import Partition
 from .partitioning_strategy import Partition
 
 

+ 1 - 1
topology/test_device_capabilities.py → exo/topology/test_device_capabilities.py

@@ -1,6 +1,6 @@
 import unittest
 import unittest
 from unittest.mock import patch
 from unittest.mock import patch
-from topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
+from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
 
 
 class TestMacDeviceCapabilities(unittest.TestCase):
 class TestMacDeviceCapabilities(unittest.TestCase):
     @patch('subprocess.check_output')
     @patch('subprocess.check_output')

+ 0 - 0
topology/test_ring_memory_weighted_partitioning_strategy.py → exo/topology/test_ring_memory_weighted_partitioning_strategy.py


+ 0 - 0
topology/topology.py → exo/topology/topology.py


+ 6 - 6
main.py

@@ -3,12 +3,12 @@ import asyncio
 import signal
 import signal
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from orchestration.standard_node import StandardNode
-from networking.grpc.grpc_server import GRPCServer
-from inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
-from inference.shard import Shard
-from networking.grpc.grpc_discovery import GRPCDiscovery
-from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
+from exo.orchestration.standard_node import StandardNode
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
+from exo.inference.shard import Shard
+from exo.networking.grpc.grpc_discovery import GRPCDiscovery
+from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")

+ 6 - 6
main_dynamic.py

@@ -4,12 +4,12 @@ import signal
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from typing import List
 from typing import List
-from orchestration.standard_node import StandardNode
-from networking.grpc.grpc_server import GRPCServer
-from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-from inference.shard import Shard
-from networking.grpc.grpc_discovery import GRPCDiscovery
-from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
+from exo.orchestration.standard_node import StandardNode
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.inference.shard import Shard
+from exo.networking.grpc.grpc_discovery import GRPCDiscovery
+from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")