|
@@ -292,7 +292,7 @@ class StandardNode(Node):
|
|
|
self.outstanding_requests[request_id] = "training"
|
|
|
loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
|
|
|
self.outstanding_requests.pop(request_id)
|
|
|
- return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
|
|
|
+ return loss.reshape(1, -1) if shard.is_first_layer() else grad
|
|
|
else:
|
|
|
if shard.is_last_layer():
|
|
|
self.outstanding_requests[request_id] = "evaluating"
|
|
@@ -303,7 +303,7 @@ class StandardNode(Node):
|
|
|
self.outstanding_requests[request_id] = "waiting"
|
|
|
loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
|
|
self.outstanding_requests.pop(request_id)
|
|
|
- return loss.reshape(example.shape[0], -1)
|
|
|
+ return loss.reshape(1, -1)
|
|
|
except Exception as e:
|
|
|
self.outstanding_requests.pop(request_id)
|
|
|
print(f"Error processing example for shard {shard}: {e}")
|