Browse Source

dynamic halfway partition point in unit test

Alex Cheema 11 tháng trước cách đây
mục cha
commit
17065d879b
1 tập tin đã thay đổi với 1 bổ sung1 xóa
  1. 1 1
      exo/inference/test_inference_engine.py

+ 1 - 1
exo/inference/test_inference_engine.py

@@ -19,7 +19,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     inference_state=inference_state_full,
   )
 
-  pp = 15
+  pp = n_layers // 2
   resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",