Browse Source

Merge branch 'main' into package-exo-fixes

josh 5 tháng trước cách đây
mục cha
commit
0996bcc3b6

+ 4 - 4
.circleci/config.yml

@@ -27,11 +27,11 @@ commands:
             fi
 
             # Start first instance
-            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
+            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output1.log &
             PID1=$!
 
             # Start second instance
-            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
+            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output2.log &
             PID2=$!
 
             # Wait for discovery
@@ -149,9 +149,9 @@ jobs:
           name: Run discovery integration test
           command: |
             source env/bin/activate
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --disable-tui > output1.log 2>&1 &
             PID1=$!
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --disable-tui > output2.log 2>&1 &
             PID2=$!
             sleep 10
             kill $PID1 $PID2

+ 2 - 1
exo/inference/inference_engine.py

@@ -26,7 +26,8 @@ class InferenceEngine(ABC):
   
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
-    output_data = await self.infer_tensor(request_id, shard, tokens)
+    x = tokens.reshape(1, -1)
+    output_data = await self.infer_tensor(request_id, shard, x)
     return output_data 
 
 inference_engine_classes = {

+ 1 - 1
exo/inference/mlx/sharded_utils.py

@@ -137,7 +137,7 @@ def load_model_shard(
       self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
 
     def __call__(self, x, *args, **kwargs):
-      y = super().__call__(x[None] if self.shard.is_first_layer() else x, *args, **kwargs)
+      y = super().__call__(x, *args, **kwargs)
       return y
 
   model_args = model_args_class.from_dict(config)

+ 2 - 0
exo/inference/test_inference_engine.py

@@ -13,6 +13,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
+  token_full = token_full.reshape(1, -1)
   next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
@@ -27,6 +28,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     input_data=resp1,
   )
   tokens2 = await inference_engine_1.sample(resp2)
+  tokens2 = tokens2.reshape(1, -1)
   resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),

+ 8 - 0
exo/main.py

@@ -1,5 +1,6 @@
 import argparse
 import asyncio
+import atexit
 import signal
 import json
 import logging
@@ -221,6 +222,13 @@ async def main():
     except Exception as e:
       print(f"Error moving models to .cache/huggingface: {e}")
 
+  def restore_cursor():
+    if platform.system() != "Windows":
+        os.system("tput cnorm")  # Show cursor
+
+  # Restore the cursor when the program exits
+  atexit.register(restore_cursor)
+
   # Use a more direct approach to handle signals
   def handle_exit():
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))

+ 56 - 56
exo/orchestration/standard_node.py

@@ -102,11 +102,7 @@ class StandardNode(Node):
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
   
-  async def encode_prompt(self, shard: Shard, prompt):
-    toks = await self.inference_engine.encode(shard, prompt)
-    return toks
-  
-  async def process_result(
+  async def process_inference_result(
     self,
     shard,
     result: np.ndarray,
@@ -114,32 +110,25 @@ class StandardNode(Node):
   ):
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
-    
-    if request_id not in self.buffered_logits:
-      self.buffered_logits[request_id] = []
-
-    self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
-
-    if shard.is_last_layer():
-      result = await self.inference_engine.sample(result)
-    
-    await self.inference_engine.ensure_shard(shard)
-    is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:  # we got a new token out
-      self.buffered_token_output[request_id][0].append(result.item())
+    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if shard.is_last_layer() and not is_finished:
+      token = await self.inference_engine.sample(result)
+      await self.inference_engine.ensure_shard(shard)
+      self.buffered_token_output[request_id][0].append(token.item())
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      forward = token.reshape(1, -1)
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-    
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
+    else:
+      forward = result
 
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
     else:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
 
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    return np.array(self.buffered_token_output[request_id][0])
 
   async def process_prompt(
     self,
@@ -190,13 +179,13 @@ class StandardNode(Node):
     shard = self.get_current_shard(base_shard)
 
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
-    if shard.start_layer != 0:
+    if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      await self.forward_to_next_shard(shard, prompt, request_id)
+      resp = await self.forward_prompt(shard, prompt, request_id, 0)
       return None
     else:
       result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return result
 
   async def process_tensor(
@@ -255,46 +244,57 @@ class StandardNode(Node):
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return ret
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       return None
 
-  async def forward_to_next_shard(
+  async def forward_prompt(
     self,
     base_shard: Shard,
-    tensor_or_prompt: Union[np.ndarray, str],
+    prompt: str,
     request_id: str,
+    target_index: int,
   ) -> None:
-    if not self.partitioning_strategy:
-      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return
-
-    next_partition_index = self.get_partition_index(offset = 1)
-    if DEBUG >= 1: print(f"Next partition index: {next_partition_index}")
-    if next_partition_index is not None:
-      target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
-      next_shard = self.get_current_shard(base_shard, next_partition_index)
-      if DEBUG >= 2: print(f"Computed next from: {base_shard} {next_partition_index}, {self.topology}. Next shard: {next_shard}")
-      is_tensor = isinstance(tensor_or_prompt, np.ndarray)
-      if target_id == self.id:
-        if is_tensor:
-          await self.process_tensor(next_shard, tensor_or_prompt, request_id)
-        else:
-          await self.process_prompt(next_shard, tensor_or_prompt, request_id)
-      else:
-        target_peer = next((p for p in self.peers if p.id() == target_id), None)
-        if not target_peer:
-          raise ValueError(f"Peer for {next_partition_index} not found")
-        if is_tensor:
-          if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
-          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id)
-        else:
-          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id)
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
 
   def get_partition_index(self, offset: int = 0):
+    if not self.partitioning_strategy:
+      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+      return None
     partitions = self.partitioning_strategy.partition(self.topology)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if current_partition_index is None:

+ 40 - 51
exo/tinychat/index.css

@@ -1,31 +1,11 @@
 /* define colors */
 :root {
-  --primary-color: #a52e4d;
-  --primary-color-transparent: #a52e4d66;
-  --secondary-color: #228039;
-  --secondary-color-transparent: #22803966;
-
+  --primary-color: #fff;
+  --secondary-color: #2a2a2a;
+  --secondary-color-transparent: #ffffff66;
+  --primary-bg-color: #1a1a1a;
+  --foreground-color: #f0f0f0;
   --red-color: #a52e4d;
-  --green-color: #228039;
-  --silver-color: #88808e;
-}
-@media(prefers-color-scheme: light) {
-  :root {
-    --primary-bg-color: #f0f0f0;
-    --secondary-bg-color: #eeeeee;
-    --tertiary-bg-color: #dddddd;
-    --foreground-color: #111111;
-    --accent-color: #000000;
-  }
-}
-@media(prefers-color-scheme: dark) {
-  :root {
-    --primary-bg-color: #111111;
-    --secondary-bg-color: #131313;
-    --tertiary-bg-color: #232323;
-    --foreground-color: #f0f0f0;
-    --accent-color: #aaaaaa;
-  }
 }
 
 main {
@@ -81,7 +61,11 @@ main {
   top: 0;
   position: absolute;
 
-  background: linear-gradient(180deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    180deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 .histories-end {
   height: 3rem;
@@ -91,7 +75,11 @@ main {
   bottom: 0;
   position: absolute;
 
-  background: linear-gradient(0deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 
 .history {
@@ -99,7 +87,7 @@ main {
   width: 100%;
   max-width: 40rem;
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   border-radius: 10px;
   border-left: 2px solid var(--primary-color);
 
@@ -109,7 +97,7 @@ main {
   opacity: var(--opacity, 1);
 }
 .history:hover {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
 }
 
 .history-delete-button {
@@ -120,14 +108,14 @@ main {
   margin: 0;
   outline: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   cursor: pointer;
   transition: 0.2s;
 }
 .history-delete-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
 }
 
@@ -135,6 +123,7 @@ main {
   overflow-y: auto;
   height: 100%;
   width: 100%;
+  max-width: 1200px;
 
   display: flex;
   flex-direction: column;
@@ -145,24 +134,19 @@ main {
 }
 
 .message {
-  width: 96%;
-  max-width: 80rem;
-
-  display: grid;
-
-  background-color: var(--secondary-bg-color);
+  max-width: 75%;
   padding: 0.5rem 1rem;
-  border-radius: 10px;
+  border-radius: 20px;
 }
 .message-role-assistant {
-  border-bottom: 2px solid var(--primary-color);
-  border-left: 2px solid var(--primary-color);
-  box-shadow: -10px 10px 20px 2px var(--primary-color-transparent);
+  background-color: var(--secondary-color);
+  margin-right: auto;
+  color: #fff;
 }
 .message-role-user {
-  border-bottom: 2px solid var(--secondary-color);
-  border-right: 2px solid var(--secondary-color);
-  box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent);
+  margin-left: auto;
+  background-color: var(--primary-color);
+  color: #000;
 }
 .download-progress {
   margin-bottom: 12em;
@@ -275,14 +259,14 @@ main {
   margin: 0;
   outline: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   cursor: pointer;
   transition: 0.2s;
 }
 .clipboard-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
 }
 
@@ -291,9 +275,14 @@ main {
   bottom: 0;
 
   /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(0deg, var(--primary-bg-color) 55%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
 
   width: 100%;
+  max-width: 1200px;
   display: flex;
   flex-direction: column;
   justify-content: center;
@@ -340,7 +329,7 @@ main {
   min-height: 3rem;
   max-height: 8rem;
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   border-radius: 10px;
   border: none;
@@ -352,8 +341,8 @@ main {
   height: 3rem;
   width: 4rem;
 
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
+  background-color: var(--primary-color);
+  color: var(--secondary-color);
   border-radius: 10px;
   padding: 0.5rem;
   cursor: pointer;
@@ -362,7 +351,7 @@ main {
   background-color: var(--secondary-color-transparent);
 }
 .input-button:disabled {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   cursor: not-allowed;
 }
 

+ 2 - 2
scripts/compile_grpc.sh

@@ -1,7 +1,7 @@
 #!/bin/bash
 source ./install.sh
 pushd exo/networking/grpc
-python3.12 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
-sed -i "s/import node_service_pb2/from . &/" node_service_pb2_grpc.py
+python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
+sed -i "s/import\ node_service_pb2/from . &/" node_service_pb2_grpc.py
 popd
 

+ 1 - 1
setup.py

@@ -17,7 +17,7 @@ install_requires = [
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
-  "protobuf==5.27.1",
+  "protobuf==5.28.1",
   "psutil==6.0.0",
   "pydantic==2.9.2",
   "requests==2.32.3",