免费高清特黄a大片,九一h片在线免费看,a免费国产一级特黄aa大,国产精品国产主播在线观看,成人精品一区久久久久,一级特黄aa大片,俄罗斯无遮挡一级毛片

分享

PEFT | Transformer參數(shù)量、計(jì)算量、顯存占用分析

 520jefferson 2023-09-06 發(fā)布于北京

前言

最近,OpenAI推出的ChatGPT展現(xiàn)出了卓越的性能,引發(fā)了大規(guī)模語(yǔ)言模型(Large Language Model, LLM)的研究熱潮。大規(guī)模語(yǔ)言模型的“大”體現(xiàn)在兩個(gè)方面:模型參數(shù)規(guī)模大,訓(xùn)練數(shù)據(jù)規(guī)模大。以GPT3為例,GPT3的參數(shù)量為1750億,訓(xùn)練數(shù)據(jù)量達(dá)到了570GB。進(jìn)而,訓(xùn)練大規(guī)模語(yǔ)言模型面臨兩個(gè)主要挑戰(zhàn):顯存效率和計(jì)算效率。

現(xiàn)在業(yè)界的大語(yǔ)言模型都是基于transformer模型的,模型結(jié)構(gòu)主要有兩大類:encoder-decoder(代表模型是T5)和decoder-only,具體的,decoder-only結(jié)構(gòu)又可以分為Causal LM(代表模型是GPT系列)和Prefix LM(代表模型是GLM)。歸因于GPT系列取得的巨大成功,大多數(shù)的主流大語(yǔ)言模型都采用Causal LM結(jié)構(gòu)。因此,針對(duì)decoder-only框架,為了更好地理解訓(xùn)練訓(xùn)練大語(yǔ)言模型的顯存效率和計(jì)算效率,本文分析采用decoder-only框架transformer模型的模型參數(shù)量、計(jì)算量、中間激活值、KV cache。

為了方便分析,先定義好一些數(shù)學(xué)符號(hào)。記transformer模型的層數(shù)為 ,隱藏層維度為 ,注意力頭數(shù)為 。詞表大小為 ,訓(xùn)練數(shù)據(jù)的批次大小為,序列長(zhǎng)度為。

圖片

模型參數(shù)量

transformer模型由個(gè)相同的層組成,每個(gè)層分為兩個(gè)部分:self-attention和MLP(各層包含layer normalization層)

Self-attention

Self-attention模塊參數(shù)包含的權(quán)重矩陣、輸出及偏置Bias,4個(gè)權(quán)重矩陣形狀為,4個(gè)偏置形狀為, Self-attention參數(shù)量為。

圖片

class MultiHeadAttention(nn.Module):

def __init__(self):

super(MultiHeadAttention, self).__init__()

self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)

self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)

self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)

self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

def forward(self, input_Q, input_K, input_V, attn_mask):    # input_Q: [batch_size, len_q, d_model]

# input_K: [batch_size, len_k, d_model]

# input_V: [batch_size, len_v(=len_k), d_model]

# attn_mask: [batch_size, seq_len, seq_len]

residual, batch_size = input_Q, input_Q.size(0)

Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]

K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]

V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)              # attn_mask : [batch_size, n_heads, seq_len, seq_len]

context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)          # context: [batch_size, n_heads, len_q, d_v]

# attn: [batch_size, n_heads, len_q, len_k]

context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]

output = self.fc(context)                                                # [batch_size, len_q, d_model]

return nn.LayerNorm(d_model).cuda()(output + residual), attn

MLP

MLP模塊由2個(gè)線性層組成,一般地,第一個(gè)線性層先將維度從映射到,第二個(gè)線性層再將維度從映射到。第一個(gè)線性層權(quán)重的權(quán)重矩陣的形狀為,偏置的形狀為,第二個(gè)線性層權(quán)重矩陣的形狀為,偏置形狀為,MLP模塊參數(shù)量為。

圖片

class PoswiseFeedForwardNet(nn.Module):

def __init__(self):

super(PoswiseFeedForwardNet, self).__init__()

self.fc = nn.Sequential(

nn.Linear(d_model, d_ff, bias=False),

nn.ReLU(),

nn.Linear(d_ff, d_model, bias=False))

def forward(self, inputs): # inputs: [batch_size, seq_len, d_model]

residual = inputs

output = self.fc(inputs)

return nn.LayerNorm(d_model).cuda()(output + residual)   # [batch_size, seq_len, d_model]  

LayerNorm

Self-attention和MLP各有一個(gè)layer normalization,包含2個(gè)可訓(xùn)練參數(shù):縮放參數(shù)和平移參數(shù),形狀都是,2個(gè)layer normalization的參數(shù)量為

圖片

class LayerNorm(nn.Module):

def __init__(self, d_model, eps=1e-12):

super(LayerNorm, self).__init__()

self.gamma = nn.Parameter(torch.ones(d_model))

self.beta = nn.Parameter(torch.zeros(d_model))

self.eps = eps

def forward(self, x):

mean = x.mean(-1, keepdim=True)

var = x.var(-1, unbiased=False, keepdim=True)

# '-1' means last dimension. 

out = (x - mean) / torch.sqrt(var + self.eps)

out = self.gamma * out + self.beta

return out

總之,每個(gè)transformer層的參數(shù)量為,除此之外,詞嵌入矩陣的參數(shù)量也較多,詞向量維度通常等于隱藏層維度,詞嵌入矩陣的參數(shù)量為;關(guān)于位置編碼,如果采用可訓(xùn)練式的位置編碼,會(huì)有一些可訓(xùn)練模型參數(shù),數(shù)量比較少。如果采用相對(duì)位置編碼,例如RoPE和ALiBi,則不包含可訓(xùn)練的模型參數(shù)。我們忽略這部分參數(shù)。

綜上所述,層transformer模型可訓(xùn)練參數(shù)量為,當(dāng)隱藏層維度較大時(shí),可忽略一次項(xiàng),模型參數(shù)量近似為。

因此可估算不同版本LLama模型參數(shù)量,如下表所示:

圖片

計(jì)算量FLOPs估計(jì)

FLOPs,floating point operations,表示浮點(diǎn)數(shù)運(yùn)算次數(shù),衡量了計(jì)算量的大小。如何計(jì)算矩陣乘法的FLOPs呢?對(duì)于,計(jì)算AB需要進(jìn)行n乘法運(yùn)算和n次假發(fā)運(yùn)算,共計(jì)2n次浮點(diǎn)運(yùn)算,需要2n的FLOPS;對(duì)于,計(jì)算AB需要的浮點(diǎn)運(yùn)算次數(shù)為2mnk

Input

在一次訓(xùn)練迭代中,假設(shè)輸入數(shù)據(jù)的形狀為,經(jīng)embedding層得,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。

Self-attention

圖片

計(jì)算,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。

矩陣乘法的輸入和輸出形狀為

,計(jì)算量為

。

計(jì)算在V上的加權(quán),矩陣乘法的輸入和輸出形狀為

,計(jì)算量為

。

attention后的線性映射,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。

MLP

MLP計(jì)算公式如下

1. 第一個(gè)線性層,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。2. 第二個(gè)線性層,矩陣乘法的輸入和輸出形狀為,計(jì)算量為

將Self-attention和MLP計(jì)算量相加,得到每個(gè)transformer層的計(jì)算量大約為。

Output

另一個(gè)計(jì)算量的大頭是logits的計(jì)算,將隱藏向量映射為詞表大小,矩陣乘法的輸入和輸出形狀為,計(jì)算量為。

因此,對(duì)于一個(gè)層的transformer模型,輸入數(shù)據(jù)形狀為的情況下,一次訓(xùn)練迭代計(jì)算量為

計(jì)算量與參數(shù)量關(guān)系

當(dāng)隱藏維度比較大,且遠(yuǎn)大于序列長(zhǎng)度時(shí),我們可以忽略一次項(xiàng),計(jì)算量可以近似為;前面提到當(dāng)模型參數(shù)量為,輸入的tokens數(shù)為,存在等式。我們可近似認(rèn)為在一次前向傳遞中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行2次浮點(diǎn)數(shù)運(yùn)算,即一次乘法勻運(yùn)算和一次加法運(yùn)算。

一次訓(xùn)練迭代包含了前向傳遞和后向傳遞,后向傳遞的計(jì)算量近似是前向傳遞的2倍(后向傳播除了計(jì)算梯度之外,還需要存儲(chǔ)梯度并進(jìn)行參數(shù)更新)。因此,前向傳遞 + 后向傳遞的系數(shù) =1+2=3 。一次訓(xùn)練迭代中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行 2?3=6 次浮點(diǎn)數(shù)運(yùn)算。

接下來(lái),我們可以估計(jì)訓(xùn)練GPT3-175B所需要的計(jì)算量。對(duì)于GPT3,每個(gè)token,每個(gè)參數(shù)進(jìn)行了6次浮點(diǎn)數(shù)運(yùn)算,再乘以參數(shù)量和總tokens數(shù)就得到了總的計(jì)算量。GPT3的模型參數(shù)量為 174600M,訓(xùn)練數(shù)據(jù)量為 300B tokens。

訓(xùn)練時(shí)間估計(jì)

模型參數(shù)量和訓(xùn)練總tokens數(shù)決定了訓(xùn)練transformer模型需要的計(jì)算量。給定硬件GPU類型的情況下,可以估計(jì)所需要的訓(xùn)練時(shí)間。給定計(jì)算量,訓(xùn)練時(shí)間(也就是GPU算完這么多flops的計(jì)算時(shí)間)不僅跟GPU類型有關(guān),還與GPU利用率有關(guān)。計(jì)算端到端訓(xùn)練的GPU利用率時(shí),不僅要考慮前向傳遞和后向傳遞的計(jì)算時(shí)間,還要考慮CPU加載數(shù)據(jù)、優(yōu)化器更新、多卡通信和記錄日志的時(shí)間。一般來(lái)講,GPU利用率一般在 0.3~0.55 之間。

上文講到一次前向傳遞中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),進(jìn)行2次浮點(diǎn)數(shù)計(jì)算。使用激活重計(jì)算技術(shù)來(lái)減少中間激活顯存(下文會(huì)詳細(xì)介紹)需要進(jìn)行一次額外的前向傳遞,因此前向傳遞 + 后向傳遞 + 激活重計(jì)算的系數(shù)=1+2+1=4。使用激活重計(jì)算的一次訓(xùn)練迭代中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行 2?4=8 次浮點(diǎn)數(shù)運(yùn)算。在給定訓(xùn)練tokens數(shù)、硬件環(huán)境配置的情況下,訓(xùn)練transformer模型的計(jì)算時(shí)間為:

圖片

以GPT3-175B為例,在1024張40GB顯存的A100上,在300B tokens的數(shù)據(jù)上訓(xùn)練175B參數(shù)量的GPT3。40GB顯存A100的峰值性能為312TFLOPS,設(shè)GPU利用率為0.45,則所需要的訓(xùn)練時(shí)間為34天,

圖片

這與相關(guān)文獻(xiàn)中的訓(xùn)練時(shí)間吻合,

圖片

(ref: https:///pdf/2104.04473.pdf)

以LLaMA-65B為例,在2048張80GB顯存的A100上,在1.4TB tokens的數(shù)據(jù)上訓(xùn)練了65B參數(shù)量的模型。80GB顯存A100的峰值性能為624TFLOPS,設(shè)GPU利用率為0.3,則所需要的訓(xùn)練時(shí)間為21天,

圖片

這與相關(guān)文獻(xiàn)中的訓(xùn)練時(shí)間吻合,

圖片

(ref: https:///pdf/2302.13971.pdf)

不同階段顯存占用

訓(xùn)練階段

在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的過(guò)程中,占用顯存的大頭主要分為四部分:模型參數(shù)、前向計(jì)算過(guò)程中產(chǎn)生的中間激活、后向傳遞計(jì)算得到的梯度、優(yōu)化器狀態(tài)。這里著重分析參數(shù)、梯度和優(yōu)化器狀態(tài)的顯存占用,中間激活的顯存占用后面會(huì)詳細(xì)介紹。訓(xùn)練大模型時(shí)通常會(huì)采用AdamW優(yōu)化器,并用混合精度訓(xùn)練來(lái)加速訓(xùn)練,基于這個(gè)前提分析顯存占用。

在一次訓(xùn)練迭代中,每個(gè)可訓(xùn)練模型參數(shù)都會(huì)對(duì)應(yīng)1個(gè)梯度,并對(duì)應(yīng)2個(gè)優(yōu)化器狀態(tài)(Adam優(yōu)化器梯度的一階動(dòng)量和二階動(dòng)量)。設(shè)模型參數(shù)量為 ,那么梯度的元素?cái)?shù)量為 ,AdamW優(yōu)化器的元素?cái)?shù)量為 。float16數(shù)據(jù)類型的元素占2個(gè)bytes,float32數(shù)據(jù)類型的元素占4個(gè)bytes。在混合精度訓(xùn)練中,會(huì)使用float16的模型參數(shù)進(jìn)行前向傳遞和后向傳遞,計(jì)算得到float16的梯度;在優(yōu)化器更新模型參數(shù)時(shí),會(huì)使用float32的優(yōu)化器狀態(tài)、float32的梯度、float32的模型參數(shù)來(lái)更新模型參數(shù)。因此,對(duì)于每個(gè)可訓(xùn)練模型參數(shù),占用了(2+4)+(2+4)+(4+4) = 24bytes,使用AdamW優(yōu)化器和混合精度訓(xùn)練來(lái)訓(xùn)練參數(shù)量為的大模型,模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小為bytes。

圖片

(ref: https:///pdf/2201.11990.pdf)

推理階段

在神經(jīng)網(wǎng)絡(luò)的推理階段,沒(méi)有優(yōu)化器狀態(tài)和梯度,也不需要保存中間激活。少了梯度、優(yōu)化器狀態(tài)、中間激活,模型推理階段占用的顯存要遠(yuǎn)小于訓(xùn)練階段。模型推理階段,占用顯存的大頭主要是模型參數(shù),如果使用float16來(lái)進(jìn)行推理,推理階段模型參數(shù)占用的顯存大概是bytes 。如果使用KV cache來(lái)加速推理過(guò)程,KV cache也需要占用顯存,KV cache占用的顯存下文會(huì)詳細(xì)介紹。此外,輸入數(shù)據(jù)也需要放到GPU上,還有一些中間結(jié)果(推理過(guò)程中的中間結(jié)果用完會(huì)盡快釋放掉),不過(guò)這部分占用的顯存是很小的,可以忽略。

中間激活值顯存分析

除了模型參數(shù)、梯度、優(yōu)化器狀態(tài)外,占用顯存的大頭就是前向傳遞過(guò)程中計(jì)算得到的中間激活值了,需要保存中間激活以便在后向傳遞計(jì)算梯度時(shí)使用。這里的激活(activations)指的是:前向傳遞過(guò)程中計(jì)算得到的,并在后向傳遞過(guò)程中需要用到的所有張量。這里的激活不包含模型參數(shù)和優(yōu)化器狀態(tài),但包含了dropout操作需要用到的mask矩陣

圖片

在分析中間激活的顯存占用時(shí),只考慮激活占用顯存的大頭,忽略掉一些小的buffers。比如,對(duì)于layer normalization,計(jì)算梯度時(shí)需要用到層的輸入、輸入的均值和方差,輸入包含了 個(gè)元素,而輸入的均值和方差分別包含了個(gè)元素。由于 ? 通常是比較大的(千數(shù)量級(jí)),有 ,因此,對(duì)于layer normalization,中間激活近似估計(jì)為,而不是。

大模型在訓(xùn)練過(guò)程中通常采用混合精度訓(xùn)練,中間激活值一般是float16或者bfloat16數(shù)據(jù)類型的。在分析中間激活的顯存占用時(shí),假設(shè)中間激活值是以float16或bfloat16數(shù)據(jù)格式來(lái)保存的,每個(gè)元素占了2個(gè)bytes。唯一例外的是,dropout操作的mask矩陣,每個(gè)元素只占1個(gè)bytes。在下面的分析中,單位是bytes,而不是元素個(gè)數(shù)。每個(gè)transformer層包含了一個(gè)self-attention塊和MLP塊,并分別對(duì)應(yīng)了一個(gè)layer normalization連接。

Self-attention的中間激活

對(duì)于,需要保存它們共同的輸入,這就是中間激活。輸入的形狀為 ,元素個(gè)數(shù)為 ,占用顯存大小為。

對(duì)于矩陣乘法,需要保存中間激活,兩個(gè)張量的形狀都是,占用顯存大小合計(jì)為。

對(duì)于函數(shù),需要保存函數(shù)的輸入,占用顯存大小為,其中為注意力頭數(shù)

的形狀:

的形狀:

Q的形狀:,元素個(gè)數(shù),占用顯存大小。

計(jì)算完后,會(huì)進(jìn)行dropout操作。需要保存一個(gè)mask矩陣,mask矩陣的形狀與相同,占用顯存大小。

計(jì)算在V上的attention,即,需要保存,顯存大小為;以及V的顯存大小為,共占用顯存為。

計(jì)算輸出映射以及一個(gè)dropout操作。輸入映射需要保存其輸入,大小為;dropout需要保存mask矩陣,大小為,二者占用顯存大小合計(jì)為。

因此,將上述中間激活相加得到,self-attention塊的中間激活占用顯存大小為。

MLP的中間激活

第一個(gè)線性層需要保存其輸入,占用顯存大小為

激活函數(shù)需要保存其輸入,占用顯存大小為

第二個(gè)線性層需要保存其輸入,占用顯存大小為

最后有一個(gè)dropout操作,需要保存mask矩陣,占用顯存大小為

因此,對(duì)于MLP塊,需要保存的中間激活值為另外,self-attention塊和MLP塊分別對(duì)應(yīng)了一個(gè)layer normalization。每個(gè)layer norm需要保存其輸入,大小為,2個(gè)layer norm需要保存的中間激活為。

綜上,每個(gè)transformer層需要保存的中間激活占用顯存大小為,對(duì)于層transformer模型,還有embedding層、最后的輸出層。embedding層不需要中間激活??偟亩?,當(dāng)隱藏維度較大,層數(shù)較深時(shí),這部分的中間激活是很少的,可以忽略。因此,對(duì)于層transformer模型,中間激活占用的顯存大小可以近似為。

中間激活與模型參數(shù)的顯存占用對(duì)比

為什么可通過(guò)減小批次大小效緩解模型訓(xùn)練中顯存不足(OOM)的問(wèn)題?

在一次訓(xùn)練迭代中,模型參數(shù)或梯度占用顯存大小只與模型參數(shù)量和參數(shù)數(shù)據(jù)類型有關(guān),與輸入數(shù)據(jù)的大小是沒(méi)有關(guān)系的;優(yōu)化器狀態(tài)占用的顯存大小與優(yōu)化器類型有關(guān),與模型參數(shù)量有關(guān),但與輸入數(shù)據(jù)的大小無(wú)關(guān)。而中間激活值與輸入數(shù)據(jù)的大?。ㄅ未笮?和序列長(zhǎng)度)是成正相關(guān)的,隨著批次大小 和序列長(zhǎng)度的增大,中間激活占用的顯存會(huì)同步增大。當(dāng)我們訓(xùn)練神經(jīng)網(wǎng)絡(luò)遇到顯存不足OOM(Out Of Memory)問(wèn)題時(shí),通常會(huì)嘗試減小批次大小來(lái)避免顯存不足的問(wèn)題,這種方式減少的其實(shí)是中間激活占用的顯存,而不是模型參數(shù)、梯度和優(yōu)化器的顯存。

以GPT3-175B為例,我們來(lái)直觀地對(duì)比下模型參數(shù)與中間激活的顯存大小。GPT3的模型配置如下。我們假設(shè)采用混合精度訓(xùn)練,模型參數(shù)和中間激活都采用float16數(shù)據(jù)類型,每個(gè)元素占2個(gè)bytes。

圖片

GPT3的序列長(zhǎng)度為2048,對(duì)比下不同批次大小中間激活層的顯存占用:

大約是模型參數(shù)顯存的0.79倍。

大約是模型參數(shù)顯存的50倍。

大約是模型參數(shù)顯存的101倍。

可以看到隨著批次大小的增大,中間激活占用的顯存遠(yuǎn)遠(yuǎn)超過(guò)了模型參數(shù)顯存。通常會(huì)采用激活重計(jì)算技術(shù)來(lái)減少中間激活,理論上可以將中間激活顯存從 減少到,代價(jià)是增加了一次額外前向計(jì)算的時(shí)間,本質(zhì)上是“時(shí)間換空間”。

KV cache

在LLM推斷階段,需要認(rèn)識(shí)到:

推理性能的最大瓶頸在于顯存;

自回歸模型的 keys 和 values 通常被稱為 KV cache,這些 tensors 會(huì)存在 GPU 的顯存中,用于生成下一個(gè) token;

這些 KV cache 都很大,并且大小是動(dòng)態(tài)變化難以預(yù)測(cè),已有系統(tǒng)中,由于顯存碎片和過(guò)度預(yù)留,浪費(fèi)了60%-80%的顯存。transformer模型推理加速的一個(gè)常用策略就是優(yōu)化 KV cache,一個(gè)典型的大模型生成式推斷包含了兩個(gè)階段:

預(yù)填充階段:輸入一個(gè)prompt序列,為每個(gè)transformer層生成 key cache和value cache,即KV cache。

解碼階段:使用并更新KV cache, 一個(gè)接一個(gè)地生成詞,當(dāng)前生成的詞依賴于之前已經(jīng)生成的詞。

第個(gè)transformer層的權(quán)重矩陣為

,其中self-attention的4個(gè)權(quán)重矩陣

,MLP的2個(gè)權(quán)重矩陣

。

預(yù)填充階段

假設(shè)第個(gè)transformer層的輸入為,self-attention塊的key、value、query和output表示為

,key cache和value cache計(jì)算過(guò)程:

第個(gè)transformer層剩余的計(jì)算過(guò)程:

解碼階段

給定當(dāng)前生成詞在第個(gè)transformer層的向量表示為,推斷計(jì)算分兩部分,更新KV cache 和 計(jì)算第個(gè)transformer層的輸出。更新key cache和value cache的計(jì)算過(guò)程如下:

第個(gè)transformer層剩余的計(jì)算過(guò)程為:

KV cache 顯存占用分析

假設(shè)輸入序列長(zhǎng)度為,輸出序列長(zhǎng)度為,以float16來(lái)保存KV cache,那么KV cache的峰值顯存占用大小為

,其中第一個(gè)2表示K/V cache,第二個(gè)2表示float16占2個(gè)bytes。

以GPT3為例,對(duì)比KV cache與模型參數(shù)占用顯存的大小。GPT3模型占用顯存大小為350GB。假設(shè)批次大小 ,輸入序列長(zhǎng)度,輸出序列長(zhǎng)度 ,則KV cache占用顯存為

大約是模型參數(shù)顯存的0.5倍。

總結(jié)

本文首先介紹了如何計(jì)算transformer模型的參數(shù)量,基于參數(shù)量可以進(jìn)一步估計(jì)模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小。接著,本文估計(jì)了訓(xùn)練迭代中,在給定訓(xùn)練tokens數(shù)的情況下transformer模型的計(jì)算量,給予計(jì)算量和顯卡性能可以進(jìn)一步估計(jì)訓(xùn)練迭代的計(jì)算耗時(shí)。然后,本文分析了transformer模型前向計(jì)算過(guò)程中產(chǎn)生的中間激活值的顯存大小,中間激活的顯存大小與輸入數(shù)據(jù)大小正相關(guān),甚至?xí)h(yuǎn)超過(guò)模型參數(shù)占用的顯存。最后,本文介紹了transformer模型推理過(guò)程常用的加速策略:使用KV cache??偟膩?lái)說(shuō),分析transformer模型的參數(shù)量、計(jì)算量、中間激活和KV cache,有助于理解大模型訓(xùn)練和推斷過(guò)程中的顯存效率和計(jì)算效率。

特此聲明,此文主體參考知乎文章https://zhuanlan.zhihu.com/p/624740065(在此感該作者“回旋托馬斯x”的辛苦付出),本文重點(diǎn)對(duì)該文章進(jìn)行計(jì)算驗(yàn)證、計(jì)算量FLOPs估計(jì)的邏輯修正、部分符號(hào)和表述修正、部分內(nèi)容代碼增加及重新排版。

參考

[1] https:///pdf/1706.03762.pdf

[2] https:///pdf/2302.13971.pdf

[3] https:///pdf/2104.04473.pdf

[4] https://zhuanlan.zhihu.com/p/624740065

    本站是提供個(gè)人知識(shí)管理的網(wǎng)絡(luò)存儲(chǔ)空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點(diǎn)。請(qǐng)注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購(gòu)買等信息,謹(jǐn)防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請(qǐng)點(diǎn)擊一鍵舉報(bào)。
    轉(zhuǎn)藏 分享 獻(xiàn)花(0

    0條評(píng)論

    發(fā)表

    請(qǐng)遵守用戶 評(píng)論公約

    類似文章 更多