|
@@ -16,7 +16,6 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
|
|
"A",
|
|
|
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
|
|
|
input_data=resp_full,
|
|
|
- inference_state=inference_state_full,
|
|
|
)
|
|
|
|
|
|
pp = n_layers // 2
|
|
@@ -25,19 +24,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
|
|
"B",
|
|
|
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
|
|
|
input_data=resp1,
|
|
|
- inference_state=inference_state_1,
|
|
|
)
|
|
|
resp3 = await inference_engine_1.infer_tensor(
|
|
|
"B",
|
|
|
shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
|
|
|
input_data=resp2,
|
|
|
- inference_state=inference_state_2,
|
|
|
)
|
|
|
resp4 = await inference_engine_2.infer_tensor(
|
|
|
"B",
|
|
|
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
|
|
|
input_data=resp3,
|
|
|
- inference_state=inference_state_3,
|
|
|
)
|
|
|
|
|
|
assert np.array_equal(resp_full, resp2)
|