|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from .initialize import get_model_parallel_group |
|
from .initialize import get_model_parallel_rank |
|
from .initialize import get_model_parallel_world_size |
|
from .utils import VocabUtility |
|
|
|
|
|
class _VocabParallelCrossEntropy(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, vocab_parallel_logits, target): |
|
|
|
|
|
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] |
|
torch.distributed.all_reduce( |
|
logits_max, |
|
op=torch.distributed.ReduceOp.MAX, |
|
group=get_model_parallel_group(), |
|
) |
|
|
|
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) |
|
|
|
|
|
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size |
|
partition_vocab_size = vocab_parallel_logits.size()[-1] |
|
rank = get_model_parallel_rank() |
|
world_size = get_model_parallel_world_size() |
|
vocab_start_index, vocab_end_index = get_vocab_range( |
|
partition_vocab_size, rank, world_size |
|
) |
|
|
|
|
|
target_mask = (target < vocab_start_index) | (target >= vocab_end_index) |
|
masked_target = target.clone() - vocab_start_index |
|
masked_target[target_mask] = 0 |
|
|
|
|
|
|
|
|
|
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) |
|
masked_target_1d = masked_target.view(-1) |
|
arange_1d = torch.arange( |
|
start=0, end=logits_2d.size()[0], device=logits_2d.device |
|
) |
|
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] |
|
predicted_logits_1d = predicted_logits_1d.clone().contiguous() |
|
predicted_logits = predicted_logits_1d.view_as(target) |
|
predicted_logits[target_mask] = 0.0 |
|
|
|
torch.distributed.all_reduce( |
|
predicted_logits, |
|
op=torch.distributed.ReduceOp.SUM, |
|
group=get_model_parallel_group(), |
|
) |
|
|
|
|
|
exp_logits = vocab_parallel_logits |
|
torch.exp(vocab_parallel_logits, out=exp_logits) |
|
sum_exp_logits = exp_logits.sum(dim=-1) |
|
torch.distributed.all_reduce( |
|
sum_exp_logits, |
|
op=torch.distributed.ReduceOp.SUM, |
|
group=get_model_parallel_group(), |
|
) |
|
|
|
|
|
loss = torch.log(sum_exp_logits) - predicted_logits |
|
|
|
|
|
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) |
|
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) |
|
|
|
return loss |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
|
|
|
|
softmax, target_mask, masked_target_1d = ctx.saved_tensors |
|
|
|
|
|
grad_input = softmax |
|
|
|
partition_vocab_size = softmax.size()[-1] |
|
grad_2d = grad_input.view(-1, partition_vocab_size) |
|
|
|
|
|
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) |
|
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() |
|
|
|
|
|
grad_input.mul_(grad_output.unsqueeze(dim=-1)) |
|
|
|
return grad_input, None |
|
|
|
|
|
def vocab_parallel_cross_entropy(vocab_parallel_logits, target): |
|
"""Helper function for the cross entropy.""" |
|
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) |
|
|