Nel Nibcord 6 месяцев назад
Родитель
Сommit
42172b2c39
1 измененных файлов с 11 добавлено и 12 удалено
  1. 11 12
      exo/inference/debug_inference_engine.py

+ 11 - 12
exo/inference/debug_inference_engine.py

@@ -13,32 +13,31 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    input_data=token_full,
   )
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
-    inference_state=inference_state_1,
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  token2 = await inference_engine_2.sample(resp2)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    input_data=token2,
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
-    inference_state=inference_state_3,
   )
 
   print(f"{resp2=}")