Quellcode durchsuchen

fix inference engine for inference state

Pranav Veldurthi vor 4 Monaten
Ursprung
Commit
fff8a1a690

+ 3 - 1
exo/inference/inference_engine.py

@@ -43,9 +43,11 @@ class InferenceEngine(ABC):
     tokens = await self.encode(shard, prompt)
     if shard.model_id != 'stable-diffusion-2-1-base':
       x = tokens.reshape(1, -1)
+      output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
     else:
       x = tokens
-    output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
+      output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
+
     return output_data, inference_state
 
 inference_engine_classes = {

+ 4 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -82,7 +82,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
-    output_data,inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
+    if self.model.model_type != 'StableDiffusionPipeline':
+      output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
+    else:
+      output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
     output_data = np.array(output_data)
     return output_data, inference_state
 

Datei-Diff unterdrückt, da er zu groß ist
+ 2 - 2
exo/networking/grpc/node_service_pb2.py


+ 50 - 50
exo/networking/grpc/node_service_pb2_grpc.py

@@ -3,7 +3,7 @@
 import grpc
 import warnings
 
-from exo.networking.grpc import node_service_pb2 as node__service__pb2
+from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
 
 GRPC_GENERATED_VERSION = '1.68.0'
 GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
 if _version_not_supported:
     raise RuntimeError(
         f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -36,43 +36,43 @@ class NodeServiceStub(object):
         """
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
-                request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Loss.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
                 _registered_method=True)
         self.GetInferenceResult = channel.unary_unary(
                 '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
                 _registered_method=True)
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
                 _registered_method=True)
         self.SendResult = channel.unary_unary(
                 '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendOpaqueStatus = channel.unary_unary(
                 '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.HealthCheck = channel.unary_unary(
                 '/node_service.NodeService/HealthCheck',
-                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
                 _registered_method=True)
 
 
@@ -132,43 +132,43 @@ def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
             ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
-                    request_deserializer=node__service__pb2.ExampleRequest.FromString,
-                    response_serializer=node__service__pb2.Loss.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
             ),
             'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
                     servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
             ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
             ),
             'SendResult': grpc.unary_unary_rpc_method_handler(
                     servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
             ),
             'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
                     servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
             ),
             'HealthCheck': grpc.unary_unary_rpc_method_handler(
                     servicer.HealthCheck,
-                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
             ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
@@ -196,8 +196,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -223,8 +223,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -250,8 +250,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendExample',
-            node__service__pb2.ExampleRequest.SerializeToString,
-            node__service__pb2.Loss.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
             options,
             channel_credentials,
             insecure,
@@ -277,8 +277,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
             options,
             channel_credentials,
             insecure,
@@ -304,8 +304,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
             options,
             channel_credentials,
             insecure,
@@ -331,8 +331,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -358,8 +358,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -385,8 +385,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/HealthCheck',
-            node__service__pb2.HealthCheckRequest.SerializeToString,
-            node__service__pb2.HealthCheckResponse.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
             options,
             channel_credentials,
             insecure,

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.