GQA
The point that nanochat use GQA
- GQA is in the
CausalSelfAttention.forward()ingpt.pyfile - The reasong of applying GQA is to make training and inference faster than MHA
Paper info
- Title: GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
- Authors: Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai
- Venue: EMNLP 2023
- URL: https://arxiv.org/pdf/2305.13245
- Length: 7 pages
0. Background Knowledge to know
- Concept of GPU memory bandwidth.
- What’s the effect of a large memory width and small memory bandwidth
- Computational complexity of Attention calculation
- KV-cache
- Matrix multiplication with expand
- Computational efficient()
1. Motivation
The memory bandwidth bottleneck has quite huge adverse-effect on autoregressive decoding process like GPT rather than the encoder like BERT.
- In the case of BERT classification model: we can run the forward process only one time to get classification result in CLS token.
- In the case of GPT autoregressive model: we should run the forward pass several times predicting next token in regressive way to get the full of sentence
- The num of running are more than 1, which means that we should read the new key/values with new predicted token from HBM to L2 cache several times So, reducing the inference time has several times benefit than the other models
In previous research, there was a trial to reduce a computational time such as Multi Query Attention (MQA)
- MQA uses multiple heads of Query with only one head of Key and Value. It reduced the inference time a lot comparing to the original Transformer Multi Head Attention(MHA), but also downgraded the performance a lot too.
- Need concrete quantity how much it’s reduced. Check the MQA paper.
2. Explanation
2-1. Goals
There are two goals in the research actually. 1-1. Uptraining the MHA based model into the MQA based model keeping the performance 1-2. Proposing GQA methodology which is the interpolation between MQA and MHA But, in nanochat, andrej didn’t use the first method, but used GQA methodology. So I’ll not mention the first one in this post.
The goal of this Grouped Query Attention is the followings:
- Keep the similar or almost same performance comparing to original Multi Head Attention that uses same num of Q/K/V each.
- At the same time, by reducing the K/V, it reduce the computation resource(the degration of speed of MHA caused by the memory bandwidth limitation) so speed up the inference and training time at to the similar speed with MQA.
GQA is the interpolation between MHA(slow inference speed, high quality result) and MQA(high inference speed, low quality result)
- MHA: One Key/Value head per One Query
- MQA: One Key/Value head per Every Query
- GQA: One Key/Value head per Grouped-Queries
2-2. Notation:
GQA-G refers to grouped-query with G groups
- If the G=1, GQA
-1, it refers to 1 group of query. It’s same with MQA - If the G=H, GQA
-H, it refers to H group of query which has same num with K/V. It’s same with MHA
3. GQA Methodology
3-1. Concept
Let’s suppose the GQA-(H/2), one K/V per two Queries.
Then the mapping relation is
$Q_1, Q_2 \rightarrow (K_1, V_1)$
$Q_3, Q_4 \rightarrow (K_2, V_2)$
$…$
Then the attention calculation is $V_1 * \frac{Q_1@K_1}{\sqrt(d_k)}$ $V_1 * \frac{Q_2@K_1}{\sqrt(d_k)}$ $V_2 * \frac{Q_3@K_2}{\sqrt(d_k)}$ $V_2 * \frac{Q_4@K_3}{\sqrt(d_k)}$ …
Comparing to MHA, just mapping relationship of Q and K/V is different, one K/V per n-Queries
3-2. Implementation trial
I tried to implement it, but failed. ChatGPT said it’s not possible to implement with pure Pytorch.
I’ll show you the code snippet that I tried first
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
num_of_expand = self.n_head//self.n_kv_head
k = k.view(B,T,self.n_kv_head, 1, self.head_dim).expand(B,T,self.n_kv_head, num_of_expand,self.head_dim).view(B, T, self.n_head, self.head_dim)
v = v.view(B,T,self.n_kv_head, 1, self.head_dim).expand(B,T,self.n_kv_head, num_of_expand,self.head_dim).view(B, T, self.n_head, self.head_dim)
The purpose of GQA is Reducing memory bandwidth overhead by reducing K/V head So as you can see above code, the k and v has smaller num of heads than q.
But to calculate every heads at once with matmul, we should have to match the shape of Q and K/V anyway.
Because .repeat() function copy(read) the data allocating new memory address, then there’s difference of using MHA which is also same with repeated result. - It just read additional data from HBM to L2 cache
Also we can’t do reshape which use .contiguous() that rearrange the memory address of data(creating new data), that also read additinoal data from HBM to L2 cache.
So I selected .expand() function that match the shape of q and k/v but not copy or re-arrange the memory address.
But after I expand the k/v, because of the stride:0 problem, I couldn’t use .view() function. But to solve the .view() problem, I can’t also use .contiguous() too, because there’s a issue that’s mention above.
So It’s stopped at that code snippet.
The part that I need to dig into
Ok, we know that MQA and GQA makes the decreasing of the inference time. But how much reducing is happened? Would it significantly reduce the time?
The real tracking of memory read/write
KV cache