.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_translation_transformer.py: Language Translation with Transformer ===================================== This tutorial shows, how to train a translation model from scratch using Transformer. We will be using Multi30k dataset to train a German to English translation model. Data Processing --------------- torchtext has utilities for creating datasets that can be easily iterated through for the purposes of creating a language translation model. In this example, we show how to tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. To run this tutorial, first install spacy using pip or conda. Next, download the raw data for the English and German Spacy tokenizers from https://spacy.io/usage/models .. code-block:: default import math import torchtext import torch import torch.nn as nn from torchtext.data.utils import get_tokenizer from collections import Counter from torchtext.vocab import Vocab from torchtext.utils import download_from_url, extract_archive from torch import Tensor import io import time torch.manual_seed(0) torch.use_deterministic_algorithms(True) url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/' train_urls = ('train.de.gz', 'train.en.gz') val_urls = ('val.de.gz', 'val.en.gz') test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz') train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls] val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls] test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls] de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm') en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm') def build_vocab(filepath, tokenizer): counter = Counter() with io.open(filepath, encoding="utf8") as f: for string_ in f: counter.update(tokenizer(string_)) return Vocab(counter, specials=['', '', '', '']) de_vocab = build_vocab(train_filepaths[0], de_tokenizer) en_vocab = build_vocab(train_filepaths[1], en_tokenizer) def data_process(filepaths): raw_de_iter = iter(io.open(filepaths[0], encoding="utf8")) raw_en_iter = iter(io.open(filepaths[1], encoding="utf8")) data = [] for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter): de_tensor_ = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de.rstrip("\n"))], dtype=torch.long) en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en.rstrip("\n"))], dtype=torch.long) data.append((de_tensor_, en_tensor_)) return data train_data = data_process(train_filepaths) val_data = data_process(val_filepaths) test_data = data_process(test_filepaths) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') BATCH_SIZE = 128 PAD_IDX = de_vocab[''] BOS_IDX = de_vocab[''] EOS_IDX = de_vocab[''] DataLoader ---------- The last torch specific feature we’ll use is the DataLoader, which is easy to use since it takes the data as its first argument. Specifically, as the docs say: DataLoader combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning. Please pay attention to collate_fn (optional) that merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. .. code-block:: default from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader def generate_batch(data_batch): de_batch, en_batch = [], [] for (de_item, en_item) in data_batch: de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0)) en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0)) de_batch = pad_sequence(de_batch, padding_value=PAD_IDX) en_batch = pad_sequence(en_batch, padding_value=PAD_IDX) return de_batch, en_batch train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch) valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch) test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch) Transformer! ------------ Transformer is a Seq2Seq model introduced in `“Attention is all you need” `__ paper for solving machine translation task. Transformer model consists of an encoder and decoder block each containing fixed number of layers. Encoder processes the input sequence by propogating it, through a series of Multi-head Attention and Feed forward network layers. The output from the Encoder referred to as ``memory``, is fed to the decoder along with target tensors. Encoder and decoder are trained in an end-to-end fashion using teacher forcing technique. .. code-block:: default from torch.nn import (TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer) class Seq2SeqTransformer(nn.Module): def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, src_vocab_size: int, tgt_vocab_size: int, dim_feedforward:int = 512, dropout:float = 0.1): super(Seq2SeqTransformer, self).__init__() encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD, dim_feedforward=dim_feedforward) self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers) decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD, dim_feedforward=dim_feedforward) self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers) self.generator = nn.Linear(emb_size, tgt_vocab_size) self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout) def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor): src_emb = self.positional_encoding(self.src_tok_emb(src)) tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask) outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None, tgt_padding_mask, memory_key_padding_mask) return self.generator(outs) def encode(self, src: Tensor, src_mask: Tensor): return self.transformer_encoder(self.positional_encoding( self.src_tok_emb(src)), src_mask) def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): return self.transformer_decoder(self.positional_encoding( self.tgt_tok_emb(tgt)), memory, tgt_mask) Text tokens are represented by using token embeddings. Positional encoding is added to the token embedding to introduce a notion of word order. .. code-block:: default class PositionalEncoding(nn.Module): def __init__(self, emb_size: int, dropout, maxlen: int = 5000): super(PositionalEncoding, self).__init__() den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) pos = torch.arange(0, maxlen).reshape(maxlen, 1) pos_embedding = torch.zeros((maxlen, emb_size)) pos_embedding[:, 0::2] = torch.sin(pos * den) pos_embedding[:, 1::2] = torch.cos(pos * den) pos_embedding = pos_embedding.unsqueeze(-2) self.dropout = nn.Dropout(dropout) self.register_buffer('pos_embedding', pos_embedding) def forward(self, token_embedding: Tensor): return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:]) class TokenEmbedding(nn.Module): def __init__(self, vocab_size: int, emb_size): super(TokenEmbedding, self).__init__() self.embedding = nn.Embedding(vocab_size, emb_size) self.emb_size = emb_size def forward(self, tokens: Tensor): return self.embedding(tokens.long()) * math.sqrt(self.emb_size) We create a ``subsequent word`` mask to stop a target word from attending to its subsequent words. We also create masks, for masking source and target padding tokens .. code-block:: default def generate_square_subsequent_mask(sz): mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def create_mask(src, tgt): src_seq_len = src.shape[0] tgt_seq_len = tgt.shape[0] tgt_mask = generate_square_subsequent_mask(tgt_seq_len) src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool) src_padding_mask = (src == PAD_IDX).transpose(0, 1) tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1) return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask Define model parameters and instantiate model .. code-block:: default SRC_VOCAB_SIZE = len(de_vocab) TGT_VOCAB_SIZE = len(en_vocab) EMB_SIZE = 512 NHEAD = 8 FFN_HID_DIM = 512 BATCH_SIZE = 128 NUM_ENCODER_LAYERS = 3 NUM_DECODER_LAYERS = 3 NUM_EPOCHS = 16 DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM) for p in transformer.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) transformer = transformer.to(device) loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) optimizer = torch.optim.Adam( transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 ) .. code-block:: default def train_epoch(model, train_iter, optimizer): model.train() losses = 0 for idx, (src, tgt) in enumerate(train_iter): src = src.to(device) tgt = tgt.to(device) tgt_input = tgt[:-1, :] src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input) logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask) optimizer.zero_grad() tgt_out = tgt[1:,:] loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) loss.backward() optimizer.step() losses += loss.item() return losses / len(train_iter) def evaluate(model, val_iter): model.eval() losses = 0 for idx, (src, tgt) in (enumerate(valid_iter)): src = src.to(device) tgt = tgt.to(device) tgt_input = tgt[:-1, :] src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input) logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask) tgt_out = tgt[1:,:] loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) losses += loss.item() return losses / len(val_iter) Train model .. code-block:: default for epoch in range(1, NUM_EPOCHS+1): start_time = time.time() train_loss = train_epoch(transformer, train_iter, optimizer) end_time = time.time() val_loss = evaluate(transformer, valid_iter) print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, " f"Epoch time = {(end_time - start_time):.3f}s")) We get the following results during model training. :: Epoch: 1, Train loss: 5.316, Val loss: 4.065, Epoch time = 35.322s Epoch: 2, Train loss: 3.727, Val loss: 3.285, Epoch time = 36.283s Epoch: 3, Train loss: 3.131, Val loss: 2.881, Epoch time = 37.096s Epoch: 4, Train loss: 2.741, Val loss: 2.625, Epoch time = 37.714s Epoch: 5, Train loss: 2.454, Val loss: 2.428, Epoch time = 38.263s Epoch: 6, Train loss: 2.223, Val loss: 2.291, Epoch time = 38.415s Epoch: 7, Train loss: 2.030, Val loss: 2.191, Epoch time = 38.412s Epoch: 8, Train loss: 1.866, Val loss: 2.104, Epoch time = 38.511s Epoch: 9, Train loss: 1.724, Val loss: 2.044, Epoch time = 38.367s Epoch: 10, Train loss: 1.600, Val loss: 1.994, Epoch time = 38.491s Epoch: 11, Train loss: 1.488, Val loss: 1.969, Epoch time = 38.490s Epoch: 12, Train loss: 1.390, Val loss: 1.929, Epoch time = 38.194s Epoch: 13, Train loss: 1.299, Val loss: 1.898, Epoch time = 38.430s Epoch: 14, Train loss: 1.219, Val loss: 1.885, Epoch time = 38.406s Epoch: 15, Train loss: 1.141, Val loss: 1.890, Epoch time = 38.365s Epoch: 16, Train loss: 1.070, Val loss: 1.873, Epoch time = 38.439s The models trained using transformer architecture — train faster and converge to a lower validation loss compared to RNN models. .. code-block:: default def greedy_decode(model, src, src_mask, max_len, start_symbol): src = src.to(device) src_mask = src_mask.to(device) memory = model.encode(src, src_mask) ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device) for i in range(max_len-1): memory = memory.to(device) memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool) tgt_mask = (generate_square_subsequent_mask(ys.size(0)) .type(torch.bool)).to(device) out = model.decode(ys, memory, tgt_mask) out = out.transpose(0, 1) prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim = 1) next_word = next_word.item() ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) if next_word == EOS_IDX: break return ys def translate(model, src, src_vocab, tgt_vocab, src_tokenizer): model.eval() tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer(src)]+ [EOS_IDX] num_tokens = len(tokens) src = (torch.LongTensor(tokens).reshape(num_tokens, 1) ) src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten() return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("", "").replace("", "") .. code-block:: default translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu .", de_vocab, en_vocab, de_tokenizer) Output: `A group of people stand in front of an igloo .` References ---------- 1. Attention is all you need paper. https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 2. The annotated transformer. https://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_beginner_translation_transformer.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: translation_transformer.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: translation_transformer.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_