Browse Source

handle is_finished

Alex Cheema 1 year ago
parent
commit
e6f387a690
3 changed files with 15 additions and 5 deletions
  1. 6 0
      example_user_2.py
  2. 1 1
      networking/grpc/grpc_server.py
  3. 8 4
      orchestration/standard_node.py

+ 6 - 0
example_user_2.py

@@ -50,20 +50,26 @@ async def run_prompt(prompt: str):
         print(e)
         print(e)
 
 
     import sys
     import sys
+    import time
     # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
     # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
     previous_length = 0
     previous_length = 0
+    n_tokens = 0
+    start_time = time.perf_counter()
     while True:
     while True:
         result, is_finished = await peer2.get_inference_result("request-id-1")
         result, is_finished = await peer2.get_inference_result("request-id-1")
         await asyncio.sleep(0.1)
         await asyncio.sleep(0.1)
 
 
         # Print the updated string in place
         # Print the updated string in place
         updated_string = tokenizer.decode(result)
         updated_string = tokenizer.decode(result)
+        n_tokens = len(result)
         print(updated_string[previous_length:], end='', flush=True)
         print(updated_string[previous_length:], end='', flush=True)
         previous_length = len(updated_string)
         previous_length = len(updated_string)
 
 
         if is_finished:
         if is_finished:
             print("\nDone")
             print("\nDone")
             break
             break
+    end_time = time.perf_counter()
+    print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Run prompt")
     parser = argparse.ArgumentParser(description="Run prompt")

+ 1 - 1
networking/grpc/grpc_server.py

@@ -55,7 +55,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         request_id = request.request_id
         request_id = request.request_id
         result = await self.node.get_inference_result(request_id)
         result = await self.node.get_inference_result(request_id)
         tensor_data = result[0].tobytes() if result[0] is not None else None
         tensor_data = result[0].tobytes() if result[0] is not None else None
-        return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype))) if result[0] is not None else node_service_pb2.InferenceResult()
+        return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), is_finished=result[1]) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
 
 
     async def ResetShard(self, request, context):
     async def ResetShard(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)

+ 8 - 4
orchestration/standard_node.py

@@ -44,7 +44,9 @@ class StandardNode(Node):
 
 
         print(f"[{request_id}] process prompt: {shard}, {prompt}")
         print(f"[{request_id}] process prompt: {shard}, {prompt}")
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
-        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
+        is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
+        if is_finished:
+            self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
 
 
         if result.size == 1:
         if result.size == 1:
             self.buffered_token_output[request_id][0].append(result.item())
             self.buffered_token_output[request_id][0].append(result.item())
@@ -52,7 +54,7 @@ class StandardNode(Node):
 
 
         print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
         print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
-        if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
+        if not is_finished:
             asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
             asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
 
 
         return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
         return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
@@ -66,14 +68,16 @@ class StandardNode(Node):
         try:
         try:
             print(f"[{request_id}] process_tensor: {shard}, {tensor}")
             print(f"[{request_id}] process_tensor: {shard}, {tensor}")
             result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
             result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
-            self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
+            is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
+            if is_finished:
+                self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
 
 
             if result.size == 1:  # we got a new token out
             if result.size == 1:  # we got a new token out
                 self.buffered_token_output[request_id][0].append(result.item())
                 self.buffered_token_output[request_id][0].append(result.item())
                 self.on_token(self.buffered_token_output[request_id][0])
                 self.on_token(self.buffered_token_output[request_id][0])
             print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
             print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
-            if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
+            if not is_finished:
                 asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
                 asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
 
 
             return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
             return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None