|
@@ -13,32 +13,31 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
|
|
_tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
|
|
|
|
|
|
prompt = "In a single word only, what is the last name of the president of the United States? "
|
|
|
- resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
|
|
|
- next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
|
|
|
+ resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
|
|
|
+ token_full = await inference_engine_1.sample(resp_full)
|
|
|
+
|
|
|
+ next_resp_full = await inference_engine_1.infer_tensor(
|
|
|
"A",
|
|
|
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
|
|
|
- input_data=resp_full,
|
|
|
- inference_state=inference_state_full,
|
|
|
+ input_data=token_full,
|
|
|
)
|
|
|
|
|
|
- resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
|
|
|
- resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
|
|
|
+ resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
|
|
|
+ resp2 = await inference_engine_2.infer_tensor(
|
|
|
"B",
|
|
|
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
|
|
input_data=resp1,
|
|
|
- inference_state=inference_state_1,
|
|
|
)
|
|
|
- resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
|
|
|
+ token2 = await inference_engine_2.sample(resp2)
|
|
|
+ resp3 = await inference_engine_1.infer_tensor(
|
|
|
"B",
|
|
|
shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
|
|
|
- input_data=resp2,
|
|
|
- inference_state=inference_state_2,
|
|
|
+ input_data=token2,
|
|
|
)
|
|
|
- resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
|
|
|
+ resp4 = await inference_engine_2.infer_tensor(
|
|
|
"B",
|
|
|
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
|
|
input_data=resp3,
|
|
|
- inference_state=inference_state_3,
|
|
|
)
|
|
|
|
|
|
print(f"{resp2=}")
|