Browse Source

fix test_inference_engine

Alex Cheema 1 year ago
parent
commit
30ab126c08
1 changed files with 2 additions and 2 deletions
  1. 2 2
      exo/inference/test_inference_engine.py

+ 2 - 2
exo/inference/test_inference_engine.py

@@ -7,7 +7,7 @@ import numpy as np
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
-    prompt = "In a single word only, what is the capital of Japan? "
+    prompt = "In a single word only, what is the last name of the current president of the USA?"
     resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
     next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
 
@@ -33,5 +33,5 @@ asyncio.run(test_inference_engine(
 asyncio.run(test_inference_engine(
     TinygradDynamicShardInferenceEngine(),
     TinygradDynamicShardInferenceEngine(),
-    "/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
+    "llama3-8b-sfr",
 ))