PyTorch Transformer Language Model Clarified

PyTorch provides a pretty thorough tutorial for building a complete pipeline for training and evaluating a Transformer-based language model (link) It provides sufficient amount of details as a tutorial for beginners, but there are several places I found that can be further clarified.

1. The batch_size in section "run the model" should be seq_len

First, let's look at the for loop in train(), where a batch is taken from train_data and processed into data and targets

for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
    data, targets = get_batch(train_data, i)
    batch_size = data.size(0)

What is the shape of data? We need to trace back to the function get_batch():

def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

in which data is a segment of source, whose 1st dimension is seq_len, and the 2nd dimension is batch_size. Thus, the data.size(0) statement should return the sequence length, instead of batch size.

2. Slight difference in using criterion() between train() and evaluate()

In train(), after loss is computed, it is directly added to total_loss:

total_loss += loss.item()

But in evaluate(), there is an additional multiplication by batch_size:

total_loss += batch_size * criterion(output_flat, targets).item()

Why is there such a difference?

First of all, the batch_size in evaluate() should also be renamed to seq_len, because of the same reason already explained before.

Let’s first look at the complete code of train():

The complete train() function
The complete train() function

loss is the returned from calling criterion(), which is the nn.CrossEntropyLoss() function. Reading the document for CrossEntropyLoss(), we can see that its default reduction method is `mean`. It means that if we have N output logits in output and N target words in targets, then the loss function will return the average cross entropy loss of the entire mini-batch.

Link: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
Link: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

In the if statement for printing training information, total_loss is divided by log_interval. We know that the latter is the number of batches, every which we print the information. So, cur_loss is a timely smoothed value of mean loss.

On the other hand, in evaluate(), we do not need to print every some batches, but we just need the average loss cross the entire testing/validation dataset. Therefore, the returned value is total_loss / (len(eval_data) - 1)

The complete evaluate() function
The complete evaluate() function

Remember that criterion() returns the average loss of all examples in a batch (of size N), so in order to make sure total_loss estimates the total amount of loss for ALL data examples in the entire evaluation dataset, we naturally need to accumulate N times the mean loss per batch.