|
@@ -16,15 +16,11 @@ async def test_dummy_inference_specific():
|
|
|
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
|
|
test_prompt = "This is a test prompt"
|
|
|
|
|
|
- result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
|
|
|
+ result = await engine.infer_prompt("test_request", test_shard, test_prompt)
|
|
|
|
|
|
print(f"Inference result shape: {result.shape}")
|
|
|
- print(f"Inference state: {state}")
|
|
|
- print(f"Is finished: {is_finished}")
|
|
|
|
|
|
assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
|
|
|
- assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
|
|
|
- assert isinstance(is_finished, bool), "is_finished should be a boolean"
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|