Ver código fonte

Fixed unit tests

Nel Nibcord 8 meses atrás
pai
commit
1cd3efbe4c

+ 2 - 6
exo/inference/test_dummy_inference_engine.py

@@ -36,21 +36,17 @@ async def test_dummy_inference_engine():
   shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
 
   # Test infer_prompt
-  output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
+  output = await engine.infer_prompt("test_id", shard, "Test prompt")
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"
-  assert isinstance(state, str), "State should be a string"
-  assert isinstance(is_finished, bool), "is_finished should be a boolean"
 
   # Test infer_tensor
   input_tensor = np.array([[1, 2, 3]])
-  output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
+  output = await engine.infer_tensor("test_id", shard, input_tensor)
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"
-  assert isinstance(state, str), "State should be a string"
-  assert isinstance(is_finished, bool), "is_finished should be a boolean"
 
   print("All tests passed!")
 

+ 6 - 6
exo/inference/test_inference_engine.py

@@ -11,8 +11,8 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp_full,
@@ -20,20 +20,20 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   )
 
   pp = n_layers // 2
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
     inference_state=inference_state_1,
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     input_data=resp2,
     inference_state=inference_state_2,
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,