This is a living document that I’ll extend with more information as I go. As attention is dead-center when it comes to LLMs, it’s not weird that a lot has been written about it. Here I wanted to collect some of my own questions with their answers (for easy reference).
Parts of these notes are based on the new book “Deep Learning” by Bishop & Bishop1, which I recommend. They also have some comments from me (and, of course, all mistakes / typos are my own).
Let’s get started. We have a matrix \(X\in\mathbb{R}^{n\times m}\) which represents a sequence of \(n\) \(m\)-dimensional vectors. For the language modeling example, each one of those could be the embedding of a single token.
Let’s suppose that we are in the business of modeling, so we would like to map \(X\) to a \(Y\in \mathbb{R}^{n\times m}\) such that each m-dimensional vector in \(Y\) contains information from all \(X_j\) (\(X_j\) being the \(j\)-th column vector). Perhaps the simplest way is \(Y_i=\sum_jA_{ij}X_j,\) where we assume that \(A_{ij}\in [0,1]\) for all \(i,j\). Restricting \(A_{ij}\) such that \(\sum_{j}A_{ij}=1\) for all \(i\) has some nice properties, we can now think of \(Y_i\) as a weighted mean of the \(X_j\), and we just get to decide how much of each \(X_j\) to use.
Following this recipe further, we can pick \(A_{ij}\) according to how relevant each \(X_j\) is to every other. One way to capture relevance is through similarity, leading to “dot-product self-attention”2:
\(A_{ij}=\text{softmax}(XX^T)_{ij}=\frac{\exp(X_i^TX_j)}{\sum_{k}\exp(X_i^TX_k)}.\)
As \(X\in\mathbb{R}^{n\times m}\), \(XX^T\) has dimensions \(n \times n\), i.e., it is quadratic on sequence size.
To get \(Y\), we can just do \(Y=\text{softmax}(XX^T)X.\)
This is fine, but:
We can address these points by introducing a new matrix, \(U\in\mathbb{R}^{D\times D}\), with learnable parameters such that \(\tilde{X} = XU\), and so
\[Y=\text{softmax}(\tilde{X}\tilde{X}^T)\tilde{X}=\text{softmax}(XUU^TX^T)XU.\]Progress, but now \(\tilde{X}\tilde{X}^T\) is always a symmetric matrix regardless of \(U\), so we cannot capture asymmetric relationships. This motivates using different parameters for the parts of the attention matrix and the final mapping:
\[\begin{align} Q&=XW_Q,\ W_Q\in\mathbb{R}^{m\times D_K},\\ K&=XW_k,\ W_K\in\mathbb{R}^{m\times D_K},\\ V&=XW_V,\ W_V\in\mathbb{R}^{m\times D_V}. \end{align}\]Those are the celebrated query, key, and value matrices3, respectively, all learnable. Typically, \(D=D_K=D_V\) makes it easier to work things out. With those matrices, we adjust attention as \(Y=\text{softmax}(QK^T)V.\) Quick dimensionality check:
While \(Y=\text{softmax}(QK^T)V\) is very close to the usual self-attention, we are missing a scaling constant. Let’s derive this here.
Suppose that you have two \(D_K\)-dimensional vectors \(q,k\), each one with elements that have zero mean and unit variance and are independent. Then:
\[\text{Var}[(q,k)]=\sum_{i=1}^{D_K} \text{Var}[q_ik_i]=\sum_{i=1}^{D_K}1=D_K.\]We used independence to split up \(\text{Var}[(q,k)]\). Therefore, the standard deviation of \((q,k)\) is \(\sqrt{D_K}\). This is what we need to make sure that the parts of \(QK^{T}\) have unit variance. This helps us control how big the products are, which makes learning easier. So, finally \(Y=\text{softmax}(QK^T/\sqrt{D_K})V,\) which is the usual form of dot-product attention.
Zero mean and unit variance are a matter of pre-processing, but independence is not. In fact, you would hope a sequence would not have independent elements, as otherwise there is no information to use to predict the next element. I now see this scaling as a way to control the size of the terms and help with learning, but it’s important to remember this point as it not always explicitly stated4.
We need to calculate the matrix product \(QK^T\) which has computational cost \(O(nD^2)\) if we assume \(D=D_V=D_K\) and a sequence of length \(n\). Then, the matrix product \(QK^{T}V\) has cost \(O(n^2 D)\) (I’m ignoring the application of the softmax here and the scaling).
After this part, when dealing with a transformer block, we have an MLP layer that takes as input each output from the attention layer (\(n\) of them in total). This layer has cost \(O(n^2D)\). Therefore, the total cost is \(\max\{O(nD^2), O(n^2D)\}\).
\(D\) is fixed at the time the transformer is designed, whereas \(n\) is the length of the input sequence, so you can see which of the two is going to be a challenge during inference with large inputs.
Bishop, C.M. and Bishop, H., 2023. Deep learning: Foundations and concepts. Springer Nature. ↩
There are so many variants of attention now: grouped attention, linearised attention, etc. ↩
Query / Key / Value is a retrieval reference; see for example on cross-validated. ↩
Though, the authors of “Attention is all you need” do mention this assumption in the celebrated “footnote 4”. ↩