前言 最近,OpenAI推出的ChatGPT展現(xiàn)出了卓越的性能,引發(fā)了大規(guī)模語(yǔ)言模型(Large Language Model, LLM)的研究熱潮。大規(guī)模語(yǔ)言模型的“大”體現(xiàn)在兩個(gè)方面:模型參數(shù)規(guī)模大,訓(xùn)練數(shù)據(jù)規(guī)模大。以GPT3為例,GPT3的參數(shù)量為1750億,訓(xùn)練數(shù)據(jù)量達(dá)到了570GB。進(jìn)而,訓(xùn)練大規(guī)模語(yǔ)言模型面臨兩個(gè)主要挑戰(zhàn):顯存效率和計(jì)算效率。 現(xiàn)在業(yè)界的大語(yǔ)言模型都是基于transformer模型的,模型結(jié)構(gòu)主要有兩大類:encoder-decoder(代表模型是T5)和decoder-only,具體的,decoder-only結(jié)構(gòu)又可以分為Causal LM(代表模型是GPT系列)和Prefix LM(代表模型是GLM)。歸因于GPT系列取得的巨大成功,大多數(shù)的主流大語(yǔ)言模型都采用Causal LM結(jié)構(gòu)。因此,針對(duì)decoder-only框架,為了更好地理解訓(xùn)練訓(xùn)練大語(yǔ)言模型的顯存效率和計(jì)算效率,本文分析采用decoder-only框架transformer模型的模型參數(shù)量、計(jì)算量、中間激活值、KV cache。 為了方便分析,先定義好一些數(shù)學(xué)符號(hào)。記transformer模型的層數(shù)為 ,隱藏層維度為 ,注意力頭數(shù)為 。詞表大小為 ,訓(xùn)練數(shù)據(jù)的批次大小為,序列長(zhǎng)度為。 模型參數(shù)量 transformer模型由個(gè)相同的層組成,每個(gè)層分為兩個(gè)部分:self-attention和MLP(各層包含layer normalization層) Self-attention Self-attention模塊參數(shù)包含的權(quán)重矩陣、輸出及偏置Bias,4個(gè)權(quán)重矩陣形狀為,4個(gè)偏置形狀為, Self-attention參數(shù)量為。 class MultiHeadAttention(nn.Module): def __init__(self): super(MultiHeadAttention, self).__init__() self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False) self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) def forward(self, input_Q, input_K, input_V, attn_mask): # input_Q: [batch_size, len_q, d_model] # input_K: [batch_size, len_k, d_model] # input_V: [batch_size, len_v(=len_k), d_model] # attn_mask: [batch_size, seq_len, seq_len] residual, batch_size = input_Q, input_Q.size(0) Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k] K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k] V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v] attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len] context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask) # context: [batch_size, n_heads, len_q, d_v] # attn: [batch_size, n_heads, len_q, len_k] context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v] output = self.fc(context) # [batch_size, len_q, d_model] return nn.LayerNorm(d_model).cuda()(output + residual), attn MLP MLP模塊由2個(gè)線性層組成,一般地,第一個(gè)線性層先將維度從映射到,第二個(gè)線性層再將維度從映射到。第一個(gè)線性層權(quán)重的權(quán)重矩陣的形狀為,偏置的形狀為,第二個(gè)線性層權(quán)重矩陣的形狀為,偏置形狀為,MLP模塊參數(shù)量為。 class PoswiseFeedForwardNet(nn.Module): def __init__(self): super(PoswiseFeedForwardNet, self).__init__() self.fc = nn.Sequential( nn.Linear(d_model, d_ff, bias=False), nn.ReLU(), nn.Linear(d_ff, d_model, bias=False)) def forward(self, inputs): # inputs: [batch_size, seq_len, d_model] residual = inputs output = self.fc(inputs) return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model] LayerNorm Self-attention和MLP各有一個(gè)layer normalization,包含2個(gè)可訓(xùn)練參數(shù):縮放參數(shù)和平移參數(shù),形狀都是,2個(gè)layer normalization的參數(shù)量為 class LayerNorm(nn.Module): def __init__(self, d_model, eps=1e-12): super(LayerNorm, self).__init__() self.gamma = nn.Parameter(torch.ones(d_model)) self.beta = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) var = x.var(-1, unbiased=False, keepdim=True) # '-1' means last dimension. out = (x - mean) / torch.sqrt(var + self.eps) out = self.gamma * out + self.beta return out 總之,每個(gè)transformer層的參數(shù)量為,除此之外,詞嵌入矩陣的參數(shù)量也較多,詞向量維度通常等于隱藏層維度,詞嵌入矩陣的參數(shù)量為;關(guān)于位置編碼,如果采用可訓(xùn)練式的位置編碼,會(huì)有一些可訓(xùn)練模型參數(shù),數(shù)量比較少。如果采用相對(duì)位置編碼,例如RoPE和ALiBi,則不包含可訓(xùn)練的模型參數(shù)。我們忽略這部分參數(shù)。 綜上所述,層transformer模型可訓(xùn)練參數(shù)量為,當(dāng)隱藏層維度較大時(shí),可忽略一次項(xiàng),模型參數(shù)量近似為。 因此可估算不同版本LLama模型參數(shù)量,如下表所示: 計(jì)算量FLOPs估計(jì) FLOPs,floating point operations,表示浮點(diǎn)數(shù)運(yùn)算次數(shù),衡量了計(jì)算量的大小。如何計(jì)算矩陣乘法的FLOPs呢?對(duì)于,計(jì)算AB需要進(jìn)行n乘法運(yùn)算和n次假發(fā)運(yùn)算,共計(jì)2n次浮點(diǎn)運(yùn)算,需要2n的FLOPS;對(duì)于,計(jì)算AB需要的浮點(diǎn)運(yùn)算次數(shù)為2mnk Input 在一次訓(xùn)練迭代中,假設(shè)輸入數(shù)據(jù)的形狀為,經(jīng)embedding層得,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。 計(jì)算,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。 矩陣乘法的輸入和輸出形狀為 ,計(jì)算量為 。 計(jì)算在V上的加權(quán),矩陣乘法的輸入和輸出形狀為 ,計(jì)算量為 。 attention后的線性映射,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。 MLP MLP計(jì)算公式如下 1. 第一個(gè)線性層,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。2. 第二個(gè)線性層,矩陣乘法的輸入和輸出形狀為,計(jì)算量為 將Self-attention和MLP計(jì)算量相加,得到每個(gè)transformer層的計(jì)算量大約為。 Output 另一個(gè)計(jì)算量的大頭是logits的計(jì)算,將隱藏向量映射為詞表大小,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。 因此,對(duì)于一個(gè)層的transformer模型,輸入數(shù)據(jù)形狀為的情況下,一次訓(xùn)練迭代計(jì)算量為 計(jì)算量與參數(shù)量關(guān)系 當(dāng)隱藏維度比較大,且遠(yuǎn)大于序列長(zhǎng)度時(shí),我們可以忽略一次項(xiàng),計(jì)算量可以近似為;前面提到當(dāng)模型參數(shù)量為,輸入的tokens數(shù)為,存在等式。我們可近似認(rèn)為在一次前向傳遞中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行2次浮點(diǎn)數(shù)運(yùn)算,即一次乘法勻運(yùn)算和一次加法運(yùn)算。 一次訓(xùn)練迭代包含了前向傳遞和后向傳遞,后向傳遞的計(jì)算量近似是前向傳遞的2倍(后向傳播除了計(jì)算梯度之外,還需要存儲(chǔ)梯度并進(jìn)行參數(shù)更新)。因此,前向傳遞 + 后向傳遞的系數(shù) =1+2=3 。一次訓(xùn)練迭代中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行 2?3=6 次浮點(diǎn)數(shù)運(yùn)算。 接下來(lái),我們可以估計(jì)訓(xùn)練GPT3-175B所需要的計(jì)算量。對(duì)于GPT3,每個(gè)token,每個(gè)參數(shù)進(jìn)行了6次浮點(diǎn)數(shù)運(yùn)算,再乘以參數(shù)量和總tokens數(shù)就得到了總的計(jì)算量。GPT3的模型參數(shù)量為 174600M,訓(xùn)練數(shù)據(jù)量為 300B tokens。 訓(xùn)練時(shí)間估計(jì) 模型參數(shù)量和訓(xùn)練總tokens數(shù)決定了訓(xùn)練transformer模型需要的計(jì)算量。給定硬件GPU類型的情況下,可以估計(jì)所需要的訓(xùn)練時(shí)間。給定計(jì)算量,訓(xùn)練時(shí)間(也就是GPU算完這么多flops的計(jì)算時(shí)間)不僅跟GPU類型有關(guān),還與GPU利用率有關(guān)。計(jì)算端到端訓(xùn)練的GPU利用率時(shí),不僅要考慮前向傳遞和后向傳遞的計(jì)算時(shí)間,還要考慮CPU加載數(shù)據(jù)、優(yōu)化器更新、多卡通信和記錄日志的時(shí)間。一般來(lái)講,GPU利用率一般在 0.3~0.55 之間。 上文講到一次前向傳遞中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),進(jìn)行2次浮點(diǎn)數(shù)計(jì)算。使用激活重計(jì)算技術(shù)來(lái)減少中間激活顯存(下文會(huì)詳細(xì)介紹)需要進(jìn)行一次額外的前向傳遞,因此前向傳遞 + 后向傳遞 + 激活重計(jì)算的系數(shù)=1+2+1=4。使用激活重計(jì)算的一次訓(xùn)練迭代中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行 2?4=8 次浮點(diǎn)數(shù)運(yùn)算。在給定訓(xùn)練tokens數(shù)、硬件環(huán)境配置的情況下,訓(xùn)練transformer模型的計(jì)算時(shí)間為: 以GPT3-175B為例,在1024張40GB顯存的A100上,在300B tokens的數(shù)據(jù)上訓(xùn)練175B參數(shù)量的GPT3。40GB顯存A100的峰值性能為312TFLOPS,設(shè)GPU利用率為0.45,則所需要的訓(xùn)練時(shí)間為34天, 這與相關(guān)文獻(xiàn)中的訓(xùn)練時(shí)間吻合, (ref: https:///pdf/2104.04473.pdf) 以LLaMA-65B為例,在2048張80GB顯存的A100上,在1.4TB tokens的數(shù)據(jù)上訓(xùn)練了65B參數(shù)量的模型。80GB顯存A100的峰值性能為624TFLOPS,設(shè)GPU利用率為0.3,則所需要的訓(xùn)練時(shí)間為21天, 這與相關(guān)文獻(xiàn)中的訓(xùn)練時(shí)間吻合, (ref: https:///pdf/2302.13971.pdf) 不同階段顯存占用 訓(xùn)練階段 在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的過(guò)程中,占用顯存的大頭主要分為四部分:模型參數(shù)、前向計(jì)算過(guò)程中產(chǎn)生的中間激活、后向傳遞計(jì)算得到的梯度、優(yōu)化器狀態(tài)。這里著重分析參數(shù)、梯度和優(yōu)化器狀態(tài)的顯存占用,中間激活的顯存占用后面會(huì)詳細(xì)介紹。訓(xùn)練大模型時(shí)通常會(huì)采用AdamW優(yōu)化器,并用混合精度訓(xùn)練來(lái)加速訓(xùn)練,基于這個(gè)前提分析顯存占用。 在一次訓(xùn)練迭代中,每個(gè)可訓(xùn)練模型參數(shù)都會(huì)對(duì)應(yīng)1個(gè)梯度,并對(duì)應(yīng)2個(gè)優(yōu)化器狀態(tài)(Adam優(yōu)化器梯度的一階動(dòng)量和二階動(dòng)量)。設(shè)模型參數(shù)量為 ,那么梯度的元素?cái)?shù)量為 ,AdamW優(yōu)化器的元素?cái)?shù)量為 。float16數(shù)據(jù)類型的元素占2個(gè)bytes,float32數(shù)據(jù)類型的元素占4個(gè)bytes。在混合精度訓(xùn)練中,會(huì)使用float16的模型參數(shù)進(jìn)行前向傳遞和后向傳遞,計(jì)算得到float16的梯度;在優(yōu)化器更新模型參數(shù)時(shí),會(huì)使用float32的優(yōu)化器狀態(tài)、float32的梯度、float32的模型參數(shù)來(lái)更新模型參數(shù)。因此,對(duì)于每個(gè)可訓(xùn)練模型參數(shù),占用了(2+4)+(2+4)+(4+4) = 24bytes,使用AdamW優(yōu)化器和混合精度訓(xùn)練來(lái)訓(xùn)練參數(shù)量為的大模型,模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小為bytes。 (ref: https:///pdf/2201.11990.pdf) 推理階段 在神經(jīng)網(wǎng)絡(luò)的推理階段,沒(méi)有優(yōu)化器狀態(tài)和梯度,也不需要保存中間激活。少了梯度、優(yōu)化器狀態(tài)、中間激活,模型推理階段占用的顯存要遠(yuǎn)小于訓(xùn)練階段。模型推理階段,占用顯存的大頭主要是模型參數(shù),如果使用float16來(lái)進(jìn)行推理,推理階段模型參數(shù)占用的顯存大概是bytes 。如果使用KV cache來(lái)加速推理過(guò)程,KV cache也需要占用顯存,KV cache占用的顯存下文會(huì)詳細(xì)介紹。此外,輸入數(shù)據(jù)也需要放到GPU上,還有一些中間結(jié)果(推理過(guò)程中的中間結(jié)果用完會(huì)盡快釋放掉),不過(guò)這部分占用的顯存是很小的,可以忽略。 中間激活值顯存分析 除了模型參數(shù)、梯度、優(yōu)化器狀態(tài)外,占用顯存的大頭就是前向傳遞過(guò)程中計(jì)算得到的中間激活值了,需要保存中間激活以便在后向傳遞計(jì)算梯度時(shí)使用。這里的激活(activations)指的是:前向傳遞過(guò)程中計(jì)算得到的,并在后向傳遞過(guò)程中需要用到的所有張量。這里的激活不包含模型參數(shù)和優(yōu)化器狀態(tài),但包含了dropout操作需要用到的mask矩陣 在分析中間激活的顯存占用時(shí),只考慮激活占用顯存的大頭,忽略掉一些小的buffers。比如,對(duì)于layer normalization,計(jì)算梯度時(shí)需要用到層的輸入、輸入的均值和方差,輸入包含了 個(gè)元素,而輸入的均值和方差分別包含了個(gè)元素。由于 ? 通常是比較大的(千數(shù)量級(jí)),有 ,因此,對(duì)于layer normalization,中間激活近似估計(jì)為,而不是。 大模型在訓(xùn)練過(guò)程中通常采用混合精度訓(xùn)練,中間激活值一般是float16或者bfloat16數(shù)據(jù)類型的。在分析中間激活的顯存占用時(shí),假設(shè)中間激活值是以float16或bfloat16數(shù)據(jù)格式來(lái)保存的,每個(gè)元素占了2個(gè)bytes。唯一例外的是,dropout操作的mask矩陣,每個(gè)元素只占1個(gè)bytes。在下面的分析中,單位是bytes,而不是元素個(gè)數(shù)。每個(gè)transformer層包含了一個(gè)self-attention塊和MLP塊,并分別對(duì)應(yīng)了一個(gè)layer normalization連接。 對(duì)于,需要保存它們共同的輸入,這就是中間激活。輸入的形狀為 ,元素個(gè)數(shù)為 ,占用顯存大小為。 對(duì)于矩陣乘法,需要保存中間激活,兩個(gè)張量的形狀都是,占用顯存大小合計(jì)為。 對(duì)于函數(shù),需要保存函數(shù)的輸入,占用顯存大小為,其中為注意力頭數(shù) 的形狀: 的形狀: Q的形狀:,元素個(gè)數(shù),占用顯存大小。 計(jì)算完后,會(huì)進(jìn)行dropout操作。需要保存一個(gè)mask矩陣,mask矩陣的形狀與相同,占用顯存大小。 計(jì)算在V上的attention,即,需要保存,顯存大小為;以及V的顯存大小為,共占用顯存為。 計(jì)算輸出映射以及一個(gè)dropout操作。輸入映射需要保存其輸入,大小為;dropout需要保存mask矩陣,大小為,二者占用顯存大小合計(jì)為。 因此,將上述中間激活相加得到,self-attention塊的中間激活占用顯存大小為。 第一個(gè)線性層需要保存其輸入,占用顯存大小為 激活函數(shù)需要保存其輸入,占用顯存大小為 第二個(gè)線性層需要保存其輸入,占用顯存大小為 最后有一個(gè)dropout操作,需要保存mask矩陣,占用顯存大小為 因此,對(duì)于MLP塊,需要保存的中間激活值為另外,self-attention塊和MLP塊分別對(duì)應(yīng)了一個(gè)layer normalization。每個(gè)layer norm需要保存其輸入,大小為,2個(gè)layer norm需要保存的中間激活為。 綜上,每個(gè)transformer層需要保存的中間激活占用顯存大小為,對(duì)于層transformer模型,還有embedding層、最后的輸出層。embedding層不需要中間激活??偟亩?,當(dāng)隱藏維度較大,層數(shù)較深時(shí),這部分的中間激活是很少的,可以忽略。因此,對(duì)于層transformer模型,中間激活占用的顯存大小可以近似為。 中間激活與模型參數(shù)的顯存占用對(duì)比 為什么可通過(guò)減小批次大小效緩解模型訓(xùn)練中顯存不足(OOM)的問(wèn)題? 在一次訓(xùn)練迭代中,模型參數(shù)或梯度占用顯存大小只與模型參數(shù)量和參數(shù)數(shù)據(jù)類型有關(guān),與輸入數(shù)據(jù)的大小是沒(méi)有關(guān)系的;優(yōu)化器狀態(tài)占用的顯存大小與優(yōu)化器類型有關(guān),與模型參數(shù)量有關(guān),但與輸入數(shù)據(jù)的大小無(wú)關(guān)。而中間激活值與輸入數(shù)據(jù)的大?。ㄅ未笮?和序列長(zhǎng)度)是成正相關(guān)的,隨著批次大小 和序列長(zhǎng)度的增大,中間激活占用的顯存會(huì)同步增大。當(dāng)我們訓(xùn)練神經(jīng)網(wǎng)絡(luò)遇到顯存不足OOM(Out Of Memory)問(wèn)題時(shí),通常會(huì)嘗試減小批次大小來(lái)避免顯存不足的問(wèn)題,這種方式減少的其實(shí)是中間激活占用的顯存,而不是模型參數(shù)、梯度和優(yōu)化器的顯存。 以GPT3-175B為例,我們來(lái)直觀地對(duì)比下模型參數(shù)與中間激活的顯存大小。GPT3的模型配置如下。我們假設(shè)采用混合精度訓(xùn)練,模型參數(shù)和中間激活都采用float16數(shù)據(jù)類型,每個(gè)元素占2個(gè)bytes。 GPT3的序列長(zhǎng)度為2048,對(duì)比下不同批次大小中間激活層的顯存占用: 大約是模型參數(shù)顯存的0.79倍。 大約是模型參數(shù)顯存的50倍。 大約是模型參數(shù)顯存的101倍。 可以看到隨著批次大小的增大,中間激活占用的顯存遠(yuǎn)遠(yuǎn)超過(guò)了模型參數(shù)顯存。通常會(huì)采用激活重計(jì)算技術(shù)來(lái)減少中間激活,理論上可以將中間激活顯存從 減少到,代價(jià)是增加了一次額外前向計(jì)算的時(shí)間,本質(zhì)上是“時(shí)間換空間”。 KV cache 在LLM推斷階段,需要認(rèn)識(shí)到: 推理性能的最大瓶頸在于顯存; 自回歸模型的 keys 和 values 通常被稱為 KV cache,這些 tensors 會(huì)存在 GPU 的顯存中,用于生成下一個(gè) token; 這些 KV cache 都很大,并且大小是動(dòng)態(tài)變化難以預(yù)測(cè),已有系統(tǒng)中,由于顯存碎片和過(guò)度預(yù)留,浪費(fèi)了60%-80%的顯存。transformer模型推理加速的一個(gè)常用策略就是優(yōu)化 KV cache,一個(gè)典型的大模型生成式推斷包含了兩個(gè)階段: 預(yù)填充階段:輸入一個(gè)prompt序列,為每個(gè)transformer層生成 key cache和value cache,即KV cache。 解碼階段:使用并更新KV cache, 一個(gè)接一個(gè)地生成詞,當(dāng)前生成的詞依賴于之前已經(jīng)生成的詞。 第個(gè)transformer層的權(quán)重矩陣為 ,其中self-attention的4個(gè)權(quán)重矩陣 ,MLP的2個(gè)權(quán)重矩陣 。 預(yù)填充階段 假設(shè)第個(gè)transformer層的輸入為,self-attention塊的key、value、query和output表示為 ,key cache和value cache計(jì)算過(guò)程: 第個(gè)transformer層剩余的計(jì)算過(guò)程: 解碼階段 給定當(dāng)前生成詞在第個(gè)transformer層的向量表示為,推斷計(jì)算分兩部分,更新KV cache 和 計(jì)算第個(gè)transformer層的輸出。更新key cache和value cache的計(jì)算過(guò)程如下: 第個(gè)transformer層剩余的計(jì)算過(guò)程為: KV cache 顯存占用分析 假設(shè)輸入序列長(zhǎng)度為,輸出序列長(zhǎng)度為,以float16來(lái)保存KV cache,那么KV cache的峰值顯存占用大小為 ,其中第一個(gè)2表示K/V cache,第二個(gè)2表示float16占2個(gè)bytes。 以GPT3為例,對(duì)比KV cache與模型參數(shù)占用顯存的大小。GPT3模型占用顯存大小為350GB。假設(shè)批次大小 ,輸入序列長(zhǎng)度,輸出序列長(zhǎng)度 ,則KV cache占用顯存為 大約是模型參數(shù)顯存的0.5倍。 總結(jié) 本文首先介紹了如何計(jì)算transformer模型的參數(shù)量,基于參數(shù)量可以進(jìn)一步估計(jì)模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小。接著,本文估計(jì)了訓(xùn)練迭代中,在給定訓(xùn)練tokens數(shù)的情況下transformer模型的計(jì)算量,給予計(jì)算量和顯卡性能可以進(jìn)一步估計(jì)訓(xùn)練迭代的計(jì)算耗時(shí)。然后,本文分析了transformer模型前向計(jì)算過(guò)程中產(chǎn)生的中間激活值的顯存大小,中間激活的顯存大小與輸入數(shù)據(jù)大小正相關(guān),甚至?xí)h(yuǎn)超過(guò)模型參數(shù)占用的顯存。最后,本文介紹了transformer模型推理過(guò)程常用的加速策略:使用KV cache??偟膩?lái)說(shuō),分析transformer模型的參數(shù)量、計(jì)算量、中間激活和KV cache,有助于理解大模型訓(xùn)練和推斷過(guò)程中的顯存效率和計(jì)算效率。 特此聲明,此文主體參考知乎文章https://zhuanlan.zhihu.com/p/624740065(在此感該作者“回旋托馬斯x”的辛苦付出),本文重點(diǎn)對(duì)該文章進(jìn)行計(jì)算驗(yàn)證、計(jì)算量FLOPs估計(jì)的邏輯修正、部分符號(hào)和表述修正、部分內(nèi)容代碼增加及重新排版。 參考 [1] https:///pdf/1706.03762.pdf [2] https:///pdf/2302.13971.pdf [3] https:///pdf/2104.04473.pdf [4] https://zhuanlan.zhihu.com/p/624740065 |
|