본문 바로가기
AI&ML/DL algorithms

Transformer 코드

by ornni 2024. 7. 24.
728x90
반응형

Transformer Self Attention

아래 코드는 positional embedding이 포함되지 않은 코드이다!!


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Scaled Dot Production Attention
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, mask = None):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype = torch.float32))
       #scores = [batch size, n heads, query len, key len]

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(scores, dim = -1)
        #attention = [batch size, n heads, query len, key len]
        
        output = torch.matmul(attn, V)
        #x = [batch size, n heads, query len, head dim]

        return output, attn

# Multi Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads

        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.fc = nn.Linear(d_model, d_model)

        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, Q, K, V, mask = None):
        batch_size = Q.size(0)

        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_v).transpose(1, 2)

        #Q = [batch size, n heads, query len, head dim]
        #K = [batch size, n heads, key len, head dim]
        #V = [batch size, n heads, value len, head dim]

        attn_output, attn = ScaledDotProductAttention()(Q, K, V, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
    
        #x = [batch size, query len, n heads, head dim]
        #x = [batch size, query len, hid dim]

        output = self.fc(attn_output)
        #x = [batch size, query len, hid dim]

        return output, attn
    
# Position-wise Feed-Forward Networks
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout = 0.1):
        super(PositionwiseFeedForward, self).__init__()

        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.fc2(self.dropout(F.relu(self.fc1(x))))
        #x = [batch size, seq len, hid dim]

# Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, mask = None):
        src1, _ = self.attention(src, src, src, mask)
        src = src + self.dropout1(src1)
        src = self.norm1(src)

        ff_output = self.feed_forward(src)
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)

        return src

# Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, d_model, num_layers, num_heads, d_ff, dropout = 0.1):
        super(TransformerEncoder, self).__init__()

        self.embedding = nn.Embedding(input_dim, d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, src, mask = None):
        src = self.embedding(src)
        src = self.dropout(src)

        for layer in self.layers:
            src = layer(src, mask)
        #src = [batch size, src len, hid dim]

        return self.norm(src)

# Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):
        super(DecoderLayer, self).__init__()

        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.encoder_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, src_mask = None, tgt_mask = None):
        tgt2, _ = self.self_attention(tgt, tgt, tgt, tgt_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        tgt2, attn = self.encoder_attention(tgt, memory, memory, src_mask)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        tgt2 = self.feed_forward(tgt)
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt, attn
    
# Transformer Decoder
class TransformerDecoder(nn.Module):
    def __init__(self, output_dim, d_model, num_layers, num_heads, d_ff, dropout = 0.1):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, tgt, memory, src_mask = None, tgt_mask = None):
        tgt = self.embedding(tgt)
        tgt = self.dropout(tgt)

        for layer in self.layers:
            tgt, attn = layer(tgt, memory, src_mask, tgt_mask)
        
        tgt = self.fc(tgt)
        tgt = F.softmax(tgt, dim=-1)

        return tgt, attn


관련 논문: Attention is All You Need

https://ornni.tistory.com/241

 

Attention Is All You Need 리뷰

목차0. Abstract1. Introduction2. Background3. Model Architecture3.1 Encoder and Decoder Stacks3.2 Attention3.2.1 Scaled Dot-Product Attention3.2.2 Multi-Head Attention3.2.3 Applications if Attention in our Model3.3 Position-wise Feed-Forward Networks3.4

ornni.tistory.com


링크

https://github.com/ornni/DL_algorithm/tree/main/Transformer

 

DL_algorithm/Transformer at main · ornni/DL_algorithm

deep learning algorithms. Contribute to ornni/DL_algorithm development by creating an account on GitHub.

github.com

 

반응형

'AI&ML > DL algorithms' 카테고리의 다른 글

AlexNet 코드  (0) 2024.08.14
VGGNET 코드  (0) 2024.07.10