Browse Source

fix test_inference_engine unittest reshape token output tensor

Alex Cheema 7 months ago
parent
commit
8a741485df
1 changed files with 2 additions and 0 deletions
  1. 2 0
      exo/inference/test_inference_engine.py

+ 2 - 0
exo/inference/test_inference_engine.py

@@ -13,6 +13,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
   token_full = await inference_engine_1.sample(resp_full)
+  token_full = token_full.reshape(1, -1)
   next_resp_full = await inference_engine_1.infer_tensor(
   next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
@@ -27,6 +28,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     input_data=resp1,
     input_data=resp1,
   )
   )
   tokens2 = await inference_engine_1.sample(resp2)
   tokens2 = await inference_engine_1.sample(resp2)
+  tokens2 = tokens2.reshape(1, -1)
   resp3 = await inference_engine_1.infer_tensor(
   resp3 = await inference_engine_1.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),