|
@@ -106,7 +106,11 @@ class Model(nn.Module):
|
|
|
shard_state_dict[key] = value
|
|
|
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
|
|
shard_state_dict[key] = value
|
|
|
- elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
|
|
|
+ elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
|
|
|
+ shard_state_dict[key] = value
|
|
|
+ elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
|
|
|
+ shard_state_dict[key] = value
|
|
|
+ elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
|
|
|
shard_state_dict[key] = value
|
|
|
|
|
|
return shard_state_dict
|