Browse Source

Merge pull request #436 from blindcrone/unit-tests

Missed a spot
Alex Cheema 8 months ago
parent
commit
854a7c22ac

+ 1 - 5
exo/inference/test_dummy_inference_engine.py

@@ -16,15 +16,11 @@ async def test_dummy_inference_specific():
   test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
   test_prompt = "This is a test prompt"
 
-  result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
+  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
 
   print(f"Inference result shape: {result.shape}")
-  print(f"Inference state: {state}")
-  print(f"Is finished: {is_finished}")
 
   assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
-  assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
-  assert isinstance(is_finished, bool), "is_finished should be a boolean"
 
 
 @pytest.mark.asyncio

+ 0 - 4
exo/inference/test_inference_engine.py

@@ -16,7 +16,6 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp_full,
-    inference_state=inference_state_full,
   )
 
   pp = n_layers // 2
@@ -25,19 +24,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
-    inference_state=inference_state_1,
   )
   resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     input_data=resp2,
-    inference_state=inference_state_2,
   )
   resp4 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,
-    inference_state=inference_state_3,
   )
 
   assert np.array_equal(resp_full, resp2)