A Systematic Explanation of Transformers
Comes with free intuition, clarity and, comprehensiveness
In recent years, NLP has experienced remarkable progress, largely attributed to the introduction of the transformer model, which made its debut in literature in 2017. Transformers serve as the foundation for numerous breakthroughs in deep learning, including the development of large language models such as ChatGPT and powerful speech models such as Whisper.
One remark that could be helpful in understanding the transformer architecture is the realization that the architecture is composed of an encoder and a decoder each of which makes use of six main blocks: embedding blocks, positional encoding blocks, attention blocks, feed-forward blocks, Add & Norm and Softmax blocks. Thus, a strategy for studying how it works could be to understand:
- How each of these five blocks functions individually
- How these five blocks come together to form the encoder
- How training/inference data flows through the encoder and its purpose
- How these blocks come together to form the decoder
- How training/inference data flows through the decoder and its purpose
- How the encoder and decoder come together form the full transformer model and its purpose
The rest of this story will implement this strategy to help you deeply understand the transformer.
Table of Contents
· Transformer Blocks
∘ Embedding Block
∘ Positional Encoding Block
∘ Multi-head Attention Block
∘ Position-wise Feedforward Network
∘ Add & Norm Block
∘ Softmax Block
· Transformer Encoder
∘ Composition
∘ Purpose
∘ Data Flow
· Transformer Decoder
∘ Composition
∘ Purpose
∘ Data Flow
· Full Transformer
∘ Composition
∘ Purpose
∘ Data Flow
Transformer Blocks
Let’s start by understanding the different blocks used in the transformer architecture.
Embedding Block
Input:
This block assumes that each token in those that have appeared in the dataset (i.e., the vocabulary) corresponds to a unique integer. The input it expects is a list of integers [x₁,x₂,…,xₜₜ] representing the input sequence (or a batch thereof).
Purpose:
The purpose of this block (actually just a layer) is to learn to assign each token (e.g., word) represented by an integer in the given sequence (e.g., sentence) with a numerical vector.
The basic property that we want the assigned token vectors to possess is that similar tokens (in meaning) should have similar vectors (numerically) so that, among other reasons, the model gives similar outputs for a given sequence regardless to whether a word or one of its synonyms is used for one of the tokens.
Hyperparameters:
The dimensionality of the numerical vector representation used for each token (dₑ). Set to 512 in the original paper.
Output:
A sequence of vectors, each of size dₑ, where the ith vector represents a numerical representation for the ith token: [xₑ₁,xₑ₂,…,xₑₜ].
Operation:
Suppose the vocabulary has dimensionality |v| then randomly initialize a matrix Wₑ of dimensions (|v|, dₑ). For each token in the input sequence, if that token is at index i return the vector Wₑ[i,:]. Update the weights of Wₑ such as to minimize the loss function over the training data; eventually, it should automatically yield vectors that satisfy the expected embedding properties mentioned above (as justified in prior works).
Observe that an equivalent operation would be to train a feedforward layer of dimensionality dₑ with no bias where the input is a one-hot row vector that is 1 at the token’s position.
Learnable Parameters:
The weight matrix of dimensions (|v|, dₑ).
Side Note:
This layer is frequently used in even older architectures for the same purposes. In many instances, instead of learning Wₑ, a pre-trained version is used (e.g., Word2Vec/Skipgram).
Positional Encoding Block
Input:
A sequence of embedding vectors xₑ₁,xₑ₂,…,xₑₜₜ, each of dimensionality dₑ (or a batch thereof) that were output from the embedding layer.
Operation:
Given that dₑ=512, for each input embedding vector located at position pos add the vector:
Observe that if we define i∈{0, 1, 2, …, int((dₑ-1)/2)}, then even locations in the vector are at indices 2i, and odd locations of the vector are at indices 2i+1.
Thus, more generally, given a sequence of embedding vectors xₑ₁, xₑ₂, …, xₑₜ (pretend this is zero-indexed), where, pos∈{0,1,2,…,t-1} and with each vector having dimensionality dₑ, add to all values that occur at index 2i (even index) within the embedding vector at position pos the value:
and add to all values that occur at index 2i+1 (odd index) within the embedding vector as position pos the value:
An alternative operation is to let the model learn the sequence of vector to add on the sequence of embeddings to inject position information. This has been shown to yield equivalent performance in the transformer paper; however, it can limit the ability to deal with inference sequences that are much longer than those seen in training.
Output:
A sequence of embedding vectors xₑ₁,xₑ₂,…,xₑₜₜ, each of dimensionality dₑ (or a batch thereof) where position information has been injected to each embedding vector by adding the aforementioned vector to each token.
Purpose:
The purpose of this block is to inject position information, that can be recognized by later blocks in the architecture, into the given sequence of embedding vectors. This is necessary because the key block responsible for learning relationships between different parts of the given sequence in the transformer architecture (attention block) yields the same output for each given vector representation in the input regardless to the order of such representations. In other words, it can’t by itself make use of position information which can greatly help the model in its learning task.
To see how this indeed encodes position information we can look at a plot of the values added for different values of pos and i:
The first row would be added to the first embedding vector, the second row would be added to the second embedding vector and so on. You see the same pattern if you write 1,2,3,4,5,6,7,8,.. in binary; the column for the first bit flip rapidly as 0101010101…, the second bit less rapidly as 00110011.. and so on. Despite this, it can be still shown that this is more optimal compared to literally adding 1s, 2s, 3s, 4s, etc. to each position in the sequence.
Hyperparameters:
None exist by default.
Learnable parameters:
None exist by default (but as we mentioned we can let the transformer learn the vector to add as an alternative). Notice that the values added by this layer do not depend on the specific input values.
Sidenote:
Authors mention using dropout with P=0.1 on the resulting sequence of vectors meaning that 10% of the values in the sequence are randomly set to zero. This can help prevent overfitting.
Multi-head Attention Block
Input:
A sequence of query vectors q₁,q₂,…,qₜ and a sequence of key vectors k₁,k₂,…,kₙ and a sequence of value vectors v₁,v₂,…,vₙₙ. We assume here that they are row vectors.
For now, you can just think of them as arbitrary vectors. The multi-head attention block gets a sequence of vectors of length t and two sequences of vectors of length n.
Output:
Sequence of context vectors C₁,C₂,…,Cₜ corresponding to q₁,q₂,…,qₜ respectively (where q₁,q₂,…,qₜ correspond to the original input tokens [x₁,x₂,…,xₜ]).
Purpose:
Compute context vectors C₁,C₂,…,Cₜ corresponding to q₁,q₂,…,qₜ where the vector Cᵢ is a weighted combination of the input value vectors v₁,v₂,…,vₙ.
Suppose t=5 and n=5 with C₃=0.4v₁+0.01v₂+0.02v₃+0.03v₄+0.56v₅ then we say that the input token corresponding to q₃ “attends” to the input tokens corresponding to v₁ and v₅ (as they have the highest weights).
qᵢ always corresponds to the input token xᵢ given to the parent block (encoder/decoder); meanwhile, vᵢ corresponds to the input token xᵢ in the self-attention case and to an input token yᵢ given to another parent block in the cross-attention case.
In other words, the vectors v₁,v₂,…,vₙₙ most relevant for qᵢ for the model’s learning task should contribute the most in forming Cᵢ. It’s mostly the case that if some input token corresponding to vⱼ is related in a meaningful way to the token corresponding to qᵢ then it will contribute in forming Cᵢ. This gets clearer below.
There are two cases for this:
- Suppose (q₁,q₂,…,qₜ)=(k₁,k₂,…,kₙₙ)=(v₁,v₂,…,vₙₙ) (t=n) and that (q₁,q₂,…,qₜ) (and hence, the two others) are some vector representation of the model’s input sequence. In this case, this operation finds a vector Cᵢ for each input token xᵢ (e.g., word) that captures features that relate it to other tokens the sequence. This is called self-attention.
- Suppose (q₁,q₂,…,qₜ) are some vector representations for sequence X and (k₁,k₂,…,kₙₙ)=(v₁,v₂,…,vₙₙ) are some vector representation for a sequence Y then this would find a vector a vector Cᵢ for each vector representation in X and that representation would capture useful features that relate it to other tokens in the sequence Y for the current task. This is called cross-attention.
Operation:
Map each of the query, key and value vectors to a new space via weight matrices Wq, Wk, Wv:
Return the sequence of vectors C₁,C₂,…,Cₜₜ where each Cᵢ is a weighted sum of V₁,V₂,…,Vₙₙ that depends on the similarity of Qᵢ to K₁,K₂,…,Kₙₙ respectively:
The similarity function f used is:
The dot product is a well-known similarity measure. The division by the dimensionality of the Key dₖ (which is the same as that of the Query) is to make the dot product smaller when dₖ is large as otherwise, gradient problems arise. The Softmax ensures that the weights sum to 1 Σⱼf(Qᵢ,Kⱼ)=1.
This quantity is referred to as the attention score of Query Qᵢ towards Key Kⱼ. It ranges from 0 to 1. A higher score indicates a greater influence of Value Vⱼ in forming a context vector for Query Qᵢ and equivalently that the token corresponding to Qᵢ (ith token) is likely meaningfully related (for the learning task) to the token corresponding to Kⱼ and Vⱼ(jth token). Key and value always correspond to the same sequence of tokens in both self and cross attention.
We can also define a matrix A such that A[i,j] is the attention score of Qᵢ towards Key Kⱼ. We will extend this operation later.
Equivalent Operation:
Suppose we define
then we have that
by which you can show that it holds that
which implies that
In other words, all the vector and index operations we described above can be implemented by matrix operations (which are much more concise).
Example:
Consider
Q = [ .1 .2 .3 #Q1
.4 .5 .6 #Q2
.7 .8 .9 #Q3
]
K = [.2 .2 .1 #K1
.3 .4 .1 #K2
.7 .2 .6] #K3
V = [.3 .3 .1 #V1
.2 .1 .5 #V2
.6 .2 .3] #V3
Compute A and C using both the vector-based and matrix-based operations above. Confirm that the answer is
A = [
0.304918 0.288354 0.251379
0.332502 0.33121 0.325958
0.362581 0.380436 0.422663
]
C = [
0.299973 0.170586 0.250082 #C1
0.361567 0.198063 0.296643 #C2
0.438459 0.23135 0.353275 #C3
]
Operation Continued:
It can be shown that the attention matrix that the model learns to produce the context vector can vary to capture different types of relationships and dependencies between the query and key sequences depending on the initialization of the weights Wq, Wk, Wv; just like how different initializations of a kernel in a CNN lead it to learning different types of features from the image.
Similar to how this was exploited in CNNs to learn multiple features (performing multiple convolutions at the same layer), the attention operation is generally performed h times, each with its own initialization of Wq, Wk, Wv. h is called the number of heads.
This leads to h sequences of context vectors (C₁,C₂,…,Cₜₜ)₁, (C₁,C₂,…,Cₜₜ)₂,…,(C₁,C₂,…,Cₜₜ)ₕ these are transformed to one set of context vectors (C₁,C₂,…,Cₜₜ) by concatenating each Cᵢ from different attention heads and then multiplying them by a weight matrix Wₒ that learns to summarize them into something that of size of one context vector of size dₒ.
Equivalently if C¹, C²,…,Cʰ are the final h context vectors matrices resulting from the h attention heads, then a final context vectors matrix C is produced by:
Clearly, Wₒ should be of dimensionality ( h*dim(C),dₒ ) so that the output from these h row context vectors is one context vector.
Hyperparameters:
- The number of columns in Wq, Wk, Wv. The first two must be the same for the dot product (called dₖ above).; typically, all of them use one
head_dim
hyperparameter. By this, after the transformation by the three matrices, all Queries, Keys and Values will be ofhead_dim
dimensionality and the same will be for each of the context vectors at the output which are weighted combinations of the value vectors. - The number of heads h
Learnable Parameters:
Wq, Wk, Wv and Wₒ. Notice that the first dimensionality of the first three matches that of one query, key and value respectively. The second dimensionality is the hyperparameter head_dim
. By the matrix operation above, Wₒ is (h*head_dim, dₒ) where the second dimensionality is set as the dimensionality of the input query q, so the resulting context vectors are of the same size which will be later assumed by the Add & Norm block.
Side Note:
Unlike sequential models such as RNN/LSTM, attention does not have to worry about long-term dependencies (it evaluates all of them regardless to position) and can be perfectly parallelized. This is a serious edge in both performance and speed. Attention has been also used in older sequential models to help with issues related to long-term dependencies; however, no one was aware that it was all we needed and that they can replace the sequential RNN/LSTM components.
We later look at the masked case of self-attention. Just make sure you understand these two cases very well.
Position-wise Feedforward Network
Input:
A sequence of vectors x₁,x₂,…,xₜₜ (or a batch thereof).
Operation:
Pass each vector given by two feedforward layers. The first has a ReLU activation and the second has a linear activation (i.e., no activation). The same weights are used regardless of the vector’s position (hence, position-wise).
Output:
A sequence of transformed vectors z₁,z₂,…,zₜₜ (or a batch thereof).
Purpose:
Similar to feedforward neural networks, this projects the data to a space where it’s easier to accomplish the model’s learning task or equivalently, learn more complex and useful features for each vector in the given sequence.
Hyperparameters:
The dimensionality of the first layer dₚ. The second one is set such that the dimensionality is the same as that of the input because that is assumed by the Add & Norm block. dₚ is set to 2048 in the original paper.
Learnable Parameters:
The weight matrices of the first and second layer; that is., W₁ and W₂ respectively. If the input vectors have dimensionality dₓ (as row vectors) then the first has dimensionality (dₓ,dₚ) and the second has dimensionality (dₚ,dₓ).
Add & Norm Block
Input:
A sequence of vectors x₁,x₂,…,xₜₜ that have been just produced by the previous block (or a batch thereof).
Operation:
Let x’₁, x’₂, …, xₜ’ₜ be the sequence of vectors input to the previous block then this would produce x’₁+x₁, x₂’+x₂, …, xₜ’+xₜₜ then perform layer normalization on it. Layer normalization is performed by computing the following for each value in each in each vector in the sequence:
Where γ and β are learnable hyperparameters and the ϵ is a hyperparameter to prevent dividing by zero.
Given a single sequence of vectors such as [0, 2, 3], [4, 4, 6], [4, 2, -1] layer normalization finds scalars μ and σ by computing them over all features in all tokens the sequence. For instance, here we would have μ=27/9=3. Given a batch of sequences, it performs this operation independently on each sequence.
Output:
A sequence of vectors z₁,z₂,…,zₜ or (a batch thereof) where information from the input of the preceding block has been added.
Purpose:
Similar to the application in older architectures, the addition helps the gradient flow (not vanish) through the network while training; if the gradient of z may vanish after passing by f(z) because f’(z) will be too small (e.g., due to chain rule) then surely z+f(z) maintains the gradient z’.
Similar to batch normalization (which find μ and σ using the whole batch), layer normalization helps reduce training time. It was introduced for sequential models because it was harder to apply batch normalization for them and it has been shown to succeed in achieving its objective.
Hyperparameters:
No significant hyperparameters just the ϵ hyperparameter to prevent dividing by zero.
Learnable Parameters:
γ and β are the only learnable parameters they have the same shape as the input sequence of vectors.
Side Note:
It’s always used after any attention or position-wise feedforward block in the network. Authors mention also using dropout on the output just before adding it with P=0.1; this corresponds to randomly setting 10% of the values in the output to zero and helps prevent overfitting.
Softmax Block
Input:
A sequence of vectors x₁,x₂,…,xₜₜ that have been just produced by the last block in the architecture(or a batch thereof).
Operation:
The linear layer projects each vector to a dimensionality equal to the number of classes. The Softmax converts each of such vectors into a vector or probabilities (one for each class) that sum to 1 (i.e., a probability distribution).
Purpose:
Produce the predictions of the transformer for each token given in the input.
Output:
A sequence of vectors y₁,y₂,…,yₜₜ each a probability distribution of dimension equal to the number of classes for the task the model is learning. One way to convert these into t predictions is to consider the class of highest probability in each vector.
Transformer Encoder
Composition
The composition of the encoder is pretty simple:
It starts with an embedding block and a positional encoder block to assign the input tokens with meaningful position-aware vectors and then follows that up with N encoder layers (where N is a hyperparameter set to 6 in the original paper). Each encoder layer contains a multi-head attention block (functioning as self-attention) and a position-wise feedforward block which transform each vector given in the sequence into a more comprehensive and rich vector representation. The representation is comprehensive in the sense that, for any token it captures relationships with all other input tokens as well (thanks to self-attention). The last Softmax block is not always present; we’ll delve into that shortly.
Purpose
Because the encoder assigns a rich and comprehensive representation to each input token in its output, it can function as a standalone model; typically, for classification tasks that operate on each token in a sequence or the entire sequence. Consider tasks like named-entity or part-of-speech tagging; in both, the classification of any word as a specific entity or tag depends on its position in the sentence and its interaction with other words which is captured in each vector in the encoder’s representation. This is where the Softmax block comes into play. We need a block that transforms each rich and comprehensive vector to a vector of probabilities of size equal to the number of classes. We can then perform prediction for each word by taking the class of maximum probability. If one classification is needed for the whole sequence, the vectors input to Softmax can be first averaged/max-pooled or concatenated before their single output goes into the linear Softmax layer.
Pre-trained Models
There are encoder models, that are readily available, such as BERT and many of its variations, that have been exposed to giant English corpora in pre-training and can thus, produce decent representations for new arbitrary English sequence inputs.
Bidirectional Encoder Representations from Transformers. It does not modify the architecture of the encoder (except for adding a special type of embedding signals for its task). The transformer encoder is of course bidirectional because self-attention considers the relationship with all tokens in the sequence and not only those before it; hence, bidirectional.
Given a dataset corresponding to a classification task, it can be sufficient to only train the weights in the Softmax block given the weights of a pretrained BERT model for the res of the architecture to achieve decent classification performance. This is one form of fine-tuning.
Data Flow
The encoder takes as input a batch of B sequences where each sequence has S tokens coded as integers. Padding with some fixed token (e.g., represented by integer 0) is used if sequences are not of equal length. For instance, when [B,S] = [2, 3] then the batch has two sequences and each has three tokens (e.g., [[1, 77, 4],[3, 22, 11]].
It then passes the sequence of integers through the embedding layer to transform each integer into an embedding vector of dimensions dₑ. For Add & Norm purposes, this dimensionality is maintained for the rest of the model; hence, it is called d_model and is set to 512 in the original paper.
Only the final block then transforms each vector into one that assigns a probability to each of the k classes involved (assuming the task is token classification; sequence classification was discussed above).
This data flow is the same whether the encoder is being trained or used for inference. Of course, in training, we also have a backward pass, each time a batch is fed to the model to update all the learnable parameters such as to minimize the loss (improve classification accuracy). This requires being able to differentiate the loss function with respect to each learnable parameter in the model as required by optimization algorithms, but this is easily and automatically handled by the deep learning framework used to implement the transformer.
Transformer Decoder
Composition
Like the encoder, the decoder can be also used as a standalone model. In that case, the multi-head attention block using outputs from the encoder (i.e., cross-attention block) is ignored; it’s part of neither the encoder only nor the decoder only (requires input from both) and should be best plotted in between them like older architectures.
It then becomes obvious that the decoder as a standalone model is identical to the encoder model except that it uses “Masked Multi-head Attention” for self-attention instead of ordinary attention. We will explain this shortly.
Purpose
The purpose of the decoder model is primarily to function as a language model. That is, to repeatedly predict the next word. In more general terms, it functions to yield comprehensive and rich representations of the input (like the encoder) but constraints that no representation for a token at any position uses relationships with other tokens that come after it in the sequence. In other words, it’s just like the encoder but its unidirectional instead of bidirectional. The only allowed direction is for tokens to look back.
Analogy
We can draw a parallel between the encoder/decoder transformer models and RNNs/LSTMs. These models (and transformers) use an encoder-decoder model only for tasks where the input and output sequences are not of the same length (e.g., translation).
Otherwise, for tasks where the output sequence has the same length as the input sequence (e.g., word classification) or the there is a single output (e.g., sentence classification), an RNN (similar to a transformer encoder/decoder) can accomplish the task as a standalone model.
The difference between encoder and decoder transformer models is that the former is analogous to a bidirectional RNN and the latter is analogous to a unidirectional RNN. Bidirectional RNNs are helpful when we know that we will have access to the entire sequence in inference (e.g., token/sequence classification) and it is helpful for each token to learn something about those after it as well. Meanwhile, unidirectional RNNs must be used if in inference we won’t have access to the entire sequence (e.g., as it should be generated iteratively by the model).
Pre-trained Models
There are decoder models, that are readily available, such as GPT-2 and many of its successors and peers, that have been exposed to giant English corpora in pre-training and can thus, produce plausible continuations of text.
It turns out, that when very large language models are trained (where large refers to the amount of learnable parameters), they learn to accomplish many tasks just by predicting what to say based on what they have seen in the training data; a language model may not be trained on translation but still perform well on it because in many instances in the training data it was predicting the next word in conversations or text that involved translation. This holds for a very wide range of tasks making them behave as general-purpose machine learning models.
We can thus say that for any NLP task, the “transformer decoder” is all you need; provided it’s trained on sufficiently large corpora and is large enough. Thus, next time you think of ChatGPT or other LLMs(which is based on GPT 3.5 where GPT means Generative Pretrained Transformer because it’s a lanugage model that generates text pretrained on massive data), don’t think of the entire transformer architecture. It’s just the decoder.
Data Flow
Let’s see why an encoder model can’t be used for a task such as language modeling and derive the decoder model with intuition. Let’s start by acknowledging that the data flow for the decoder during training is similar to the encoder model:
For language modeling, k (number of classes) would be set to the number of words in the vocabulary.
Given a training sentence “<sos> be the sunshine on a cloudy day <eos>”
to train the encoder model we give it the sequence: “[<sos>, be, the, sunshine, on, a, cloudy, day]”
(in integer form) and expect it to predict the eight classes “[be, the, sunshine, on, a, cloudy, day, <eos>]”
and penalize it for each wrong prediction (by updating weights). We do this for each training sentence until the model can predict the next word with decent accuracy.
Problem
The problem with this approach is that to predict the word “sunshine” the model was able to make use of its relationship with subsequent words such as [on, a, cloudy, day]
, due to the nature of self-attention. In inference time, the model could be given only [be, the, sunshine]
and will be required to predict the next word. It won’t be able to make use of future words it has not generated yet to make the prediction. This is why training the encoder for purposes of language modeling fails; it gives it the ability to make use of features it won’t have access to in inference time.
Remedy
The remedy for this is pretty simple. Modify self-attention so that the computation of the context vector, for the ith token, Cᵢ, does not consider any token with index j>i. Recall that Cᵢ is computed as
We want aᵢⱼto be zero whenever j>i . That is, when computing Cᵢ never make use of Kⱼor Vⱼ corresponding to tokens that come after i.
Recall that this quantity is also A[i,j] when A is computed as
Let E=QKᵀ and suppose we add -∞ to any E[i,j] where j>i.
In this case, it’s easy to show that the Softmax will yield zero at any such position because exp(-∞)=0, implying that A[i,j]=0 wherever j>i and that the context vector at Cᵢ now only makes use of relationships where j≤i. You can as well easily verify that this is equivalent to adding a zero matrix with -∞ above the main diagonal to E:
# Definitions
∞ = Inf
Softmax(x) = exp.(x) ./ sum(exp.(x))
# Suppose
E = [ .1 .2 .3 #E1
.4 .5 .6 #E2
.7 .8 .9] #E3
# We first add
Mask = [ 0 -∞ -∞
0 0 -∞
0 0 0 ]
Em = E + Mask
A = hcat([Softmax(col) for col in eachcol(Em)]...)
# yields
A = [ 0.239694 0.0 0.0
0.323554 0.425557 0.0
0.436752 0.574443 1.0]
# Now when C is computed it doesn't use invalid relationships.
With this modification applied, we have masked self-attention.
Revisit the Flow
Now given a training sentence “<sos> be the sunshine on a cloudy day <eos>”
to train the decoder model we give it the sequence: “[<sos>, be, the, sunshine, on, a, cloudy, day]”
(in integer form) and expect it to predict the classes “[be, the, sunshine, on, a, cloudy, day, <eos>]”
and penalize it for each wrong prediction. We do this for each training sentence until the model can predict the next word with decent accuracy.
With this training scheme, the model learns to predict the next word without using any invalid relationships. It generalizes to unseen data with no problem.
Inference Time
It’s obvious that the decoder is parallelizable in the training time because for any sequence of length t we can input the first t-1 tokens and expect it to predict the last t-1 tokens (the next token of each of the first t-1 tokens). In inference time, we don’t even have the original target so we must rather feed the decoder word by word.
Suppose we want the model to generate words starting with the prompt Believe you can
(in reference to the quote “Believe you can and you are halfway there”).
In this case, you will input these three tokens to the decoder and (ideally) it will make three classifications and give you can and
(because it was trained to do this, right?).
Now you input Believe you can and
and it will perform four classification and (ideally) it will give you can and you
.
Now you input Believe you can and you
and it will perform five classifications and it (ideally) it will give you can and you are
.
And so on until the transformer predicts an <eos>
.
By that time it will have predicted all of and you are halfway there
. This is why every other tutorial you have seen says that the transformer decoder is autoregressive (sequential). The fact that it’s parallelizable in training is frequently overlooked (we don’t feed the transformer it’s previous outputs, we feed it the real outputs and penalize is for each mistake).
It’s also possible to train it sequentially in this fashion if we want the model to use its own output instead of the true one while predicting the next word but it may not be worth its salt for how long it will take; hence, this is usually avoided.
It should be obvious that some context vectors and their transformations are being redundantly recomputed in each iteration. In an efficient implementation, they are cached once first computed and directly used in the further iterations. We can’t stop inputting older tokens because attention over them will be helpful when computing the next one, but we can realize that once Qᵢ,Kᵢ,Vᵢ,Cᵢ are computed for a token, we don’t need to compute them again because the dependence is only on past tokens. It should be intuitive that if we do this, we can neglect the masking in inference because the cached version of Cᵢ surely never looked into the future as it corresponded to the last input token when it was computed.
Full Transformer
Composition
The encoder has the same composition except that now its sole purpose is to provide the rich and comprehensive representations to help the decoder generate the new sequence (i.e., will have no Softmax layer):
Meanwhile, the decoder uses an extra cross-attention block where the queries are derived from the decoder inputs and the keys and values come from the encoder outputs:
Purpose
As we have hinted earlier, the full encoder-decoder transformer is primarily used for tasks where the input and output have different sequence lengths. This includes translation, summarization and question-answering.
In this, the decoder’s job is to generate the entire new sequence (e.g., translated sentence) after the encoder has provided it with a rich and comprehensive representation of the input sequence (e.g., original sentence). Cross-attention allows the queries (representing the tokens input to the decoder) to attend to the values (which come from the encoder). This way, the representation the decoder gives to each token in its input takes important relationships with tokens in the entire original sequence into account aside from relationships with the previous tokens in the decoder.
Data Flow
It pretty much follows from how each of the encoder and decoder behave as single models:
Training
- Want to translate
Let us think logically
to german - Feed into the encoder integer tokens for
[<sos>, let, us, think, logically, <eos>]
. - Encoder prepares a corresponding rich and comprehensive representation
[v1, v2, v3, v4, v5, v6]
- Decoder inputs
[<sos>, denken, wir, logish, nach]
- Performs masked self-attention over the decoder inputs then as well cross-attention which makes use of the encoder representations and then provides an output.
- Compares with the real expected output
[denken, wir, logish, nach, <eos>]
and updates weights for the whole network accordingly.
Inference
- Want to translate
Let us think logically
- Feed into the encoder integer tokens for
[<sos>, let, us, think, logically, <eos>]
. - Encoder prepares a corresponding rich and comprehensive representation
[v1, v2, v3, v4, v5, v6]
- Decoder inputs
[<sos>]
, performs masked self-attention, cross-attention with encoder representations[v1, v2, v3, v4, v5, v6]
, these help it predict the next worddenken
- Decoder inputs
[<sos>, denken]
, performs masked self-attention, cross-attention with encoder representation, these help it predict the next word[denken, wir]
. - Decoder inputs
[<sos>, denken, wir]
, performs masked self-attention, cross-attention with encoder representation, these help it predict the next wordlogisch
- And son on until an
<eos>
is predicted.
Congratulations for making it this far. I hope this story has helped you deeply understand the different components founds in the transformer and how it functions as an encoder, decoder and encoder-decoder model. Till next time, au revoir.