-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
144 lines (111 loc) · 4.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import config
from data import Dataset
from modules import Transformer
from nltk.translate.bleu_score import sentence_bleu
from utils.experiment import Experiment
from torch.optim.lr_scheduler import ReduceLROnPlateau
print('[~] Training')
print(f' ~ Using device: {Transformer.device}')
# Download and preprocess data
dataset = Dataset(config.LANGUAGE_PAIR, batch_size=config.BATCH_SIZE)
# Initialize model
model = Transformer(
config.D_MODEL,
len(dataset.src_vocab),
len(dataset.trg_vocab),
dataset.src_vocab[dataset.pad_token],
dataset.trg_vocab[dataset.pad_token]
)
print(f' ~ Parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
# Set up experiment
experiment = Experiment(model, category='-'.join(config.LANGUAGE_PAIR))
# Optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.LEARNING_RATE,
betas=(config.BETA1, config.BETA2),
eps=config.EPS
)
# Lambda LR Scheduler as described in paper:
"""
import os
import seaborn as sns
from matplotlib import pyplot as plt
def get_lr(x):
x += 1 # x is originally zero-indexed
return (config.D_MODEL ** (-0.5)) * min(x ** (-0.5), x * (config.NUM_WARMUP ** (-1.5)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr)
ax = sns.lineplot(
x=range(config.NUM_EPOCHS),
y=[get_lr(x) for x in range(config.NUM_EPOCHS)]
)
ax.set(xlabel='Epoch', ylabel='Learning Rate', title='Learning Rate Schedule')
plt.savefig(os.path.join(experiment.path, 'lr_schedule.png'))
"""
# Instead, reducing LR by factor on loss plateau works much, much better
scheduler = ReduceLROnPlateau(optimizer, factor=config.LR_REDUCTION_FACTOR)
# Cross entropy loss
loss_function = torch.nn.CrossEntropyLoss()
# Train
def train(epoch):
model.train()
train_loss = 0
num_batches = 0 # Using DataPipe, cannot use len() to get number of batches
for data in dataset.train_loader:
src = data['source'].to(model.device)
trg = data['target'].to(model.device)
# Given the sequence length N, transformer tries to predict the N+1th token.
# Thus, transformer must take in trg[:-1] as input and predict trg[1:] as output.
optimizer.zero_grad()
predictions = model(src, trg[:, :-1])
# For CrossEntropyLoss, need to reshape input from (batch, seq_len, vocab_len)
# to (batch * seq_len, vocab_len). Also need to reshape ground truth from
# (batch, seq_len) to just (batch * seq_len)
loss = loss_function(
predictions.reshape(-1, predictions.size(-1)),
trg[:, 1:].reshape(-1)
)
loss.backward()
optimizer.step()
train_loss += loss.item()
num_batches += 1
del src, trg
experiment.add_scalar('loss/train', epoch, train_loss / num_batches)
validate(epoch)
# Evaluate against validation set and calculate BLEU
def validate(epoch):
with torch.no_grad():
model.eval()
valid_loss = 0
num_batches = 0
bleu_score = 0
for data in dataset.valid_loader:
src = data['source'].to(model.device)
trg = data['target'].to(model.device)
predictions = model(src, trg[:, :-1])
loss = loss_function(
predictions.reshape(-1, predictions.size(-1)),
trg[:, 1:].reshape(-1)
)
# Calculate BLEU score
batch_size = predictions.size(0)
batch_bleu = 0
p_indices = torch.argmax(predictions, dim=-1)
for i in range(batch_size):
p_tokens = dataset.trg_vocab.lookup_tokens(p_indices[i].tolist())
t_tokens = dataset.trg_vocab.lookup_tokens(trg[i, 1:].tolist())
# Filter out special tokens
p_tokens = list(filter(lambda x: '<' not in x, p_tokens))
t_tokens = list(filter(lambda x: '<' not in x, t_tokens))
if len(p_tokens) > 0 and len(t_tokens) > 0:
batch_bleu += sentence_bleu([t_tokens], p_tokens)
bleu_score += batch_bleu / batch_size
valid_loss += loss.item()
scheduler.step(loss.item())
num_batches += 1
del src, trg
experiment.add_scalar('loss/validation', epoch, valid_loss / num_batches)
experiment.add_scalar('bleu', epoch, bleu_score / num_batches)
experiment.add_scalar('lr', epoch, next(iter(optimizer.param_groups))['lr'])
experiment.loop(config.NUM_EPOCHS, train)