본문 바로가기
프로젝트

Transformer - Translation

by ornni 2024. 7. 31.
728x90
반응형

문제 분석(목적)

Attention is All You Need의 논문에서 배운 모델과 코드를 이용하기

단순한 영어를 프랑스어로 번역하는 코드 작성하기


데이터 수집 및 전처리

직접 작성

 

# 영어와 프랑스어 단어 사전
eng_vocab = {"I": 0, "am": 1, "a": 2, "student": 3, "hello": 4, "world": 5, "<sos>": 6, "<eos>": 7}
fra_vocab = {'je': 0, 'suis': 1, 'un': 2, 'étudiant': 3, 'bonjour': 4, 'monde': 5, '<sos>': 6, '<eos>': 7}

 

# 임의의 영어-프랑스어 번역 데이터셋
data = [
    ([6, 0, 1, 2, 3, 7], [6, 0, 1, 2, 3, 7]),  # "<sos> I am a student <eos>" -> "<sos> je suis un étudiant <eos>"
    ([6, 4, 5, 7], [6, 4, 5, 7]),  # "<sos> hello world <eos>" -> "<sos> bonjour monde <eos>"
    ([6, 4, 0, 1, 2, 3, 7], [6, 4, 0, 1, 2, 3, 7]),  # "<sos> hello I am a student <eos>" -> "<sos> bonjour je suis un étudiant <eos>"
]


AI 모델

Transformer의 Self-Attention 모델 사용 (" Attention is All You Need "참고)

https://ornni.tistory.com/337

 

# 모델 인스턴스 생성
encoder = TransformerEncoder(input_dim, d_model, num_layers, num_heads, d_ff, dropout)
decoder = TransformerDecoder(output_dim, d_model, num_layers, num_heads, d_ff, dropout)


하이퍼파라미터 설정

input_dim = len(eng_vocab) + 1
output_dim = len(fra_vocab) + 1
d_model = 32
num_layers = 2
num_heads = 4
d_ff = 64
dropout = 0.1


학습

 

# 학습 설정
num_epochs = 100
learning_rate = 0.001

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss(ignore_index = tgt_pad_idx)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr = learning_rate)

 

학습

...

Epoch 95/100, Loss: 1.3815
Epoch 96/100, Loss: 1.3818
Epoch 97/100, Loss: 1.3821
Epoch 98/100, Loss: 1.3804
Epoch 99/100, Loss: 1.3812
Epoch 100/100, Loss: 1.3841


결과

 

# 번역 테스트 1
sentence = ["<sos>", "I", "am", "a", "student"]
translated = translate_sentence(sentence, encoder, decoder, eng_vocab, fra_vocab_rev)
print(" ".join(translated))

# 번역 테스트 2
sentence = ["<sos>", "hello", "I", "am", "a", "student"]
translated = translate_sentence(sentence, encoder, decoder, eng_vocab, fra_vocab_rev)
print(" ".join(translated))

# 번역 테스트 3
sentence = ["<sos>", "hello", "student"]
translated = translate_sentence(sentence, encoder, decoder, eng_vocab, fra_vocab_rev)
print(" ".join(translated))

 

# 결과

je suis un étudiant <eos>
bonjour je suis un étudiant <eos>
bonjour monde <eos>
# 학습이 많이 되지 않아 틀리지만 결과를 확인할 수 있다


링크

https://github.com/ornni/Projects/tree/main/Transformer_Translation

 

Projects/Transformer_Translation at main · ornni/Projects

Deep Learning Algorithm. Contribute to ornni/Projects development by creating an account on GitHub.

github.com

 

반응형

'프로젝트' 카테고리의 다른 글

GRU - Electric Production Estimation  (0) 2022.09.02