Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] 添加了一些阅读代码的注释 #214

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import tarfile
import torchtext.data
import torchtext.datasets
from torchtext.datasets import TranslationDataset
#from torchtext.datasets import TranslationDataset
from torchtext.legacy.datasets import TranslationDataset
import transformer.Constants as Constants
from learn_bpe import learn_bpe
from apply_bpe import BPE
Expand Down Expand Up @@ -332,5 +333,5 @@ def filter_examples_with_length(x):


if __name__ == '__main__':
main_wo_bpe()
#main()
#main_wo_bpe()
main()
8 changes: 7 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def train_epoch(model, training_data, optimizer, opt, device, smoothing):
total_loss, n_word_total, n_word_correct = 0, 0, 0

desc = ' - (Training) '
# * 训练每一个batch
for batch in tqdm(training_data, mininterval=2, desc=desc, leave=False):

# prepare data
Expand All @@ -89,7 +90,8 @@ def train_epoch(model, training_data, optimizer, opt, device, smoothing):
# backward and update parameters
loss, n_correct, n_word = cal_performance(
pred, gold, opt.trg_pad_idx, smoothing=smoothing)
loss.backward()
loss.backward() # * 计算梯度
# * 更新参数
optimizer.step_and_update_lr()

# note keeping
Expand Down Expand Up @@ -162,6 +164,7 @@ def print_performances(header, ppl, accu, start_time, lr):
print('[ Epoch', epoch_i, ']')

start = time.time()
# * train 单个epoch
train_loss, train_accu = train_epoch(
model, training_data, optimizer, opt, device, smoothing=opt.label_smoothing)
train_ppl = math.exp(min(train_loss, 100))
Expand All @@ -170,6 +173,7 @@ def print_performances(header, ppl, accu, start_time, lr):
print_performances('Training', train_ppl, train_accu, start, lr)

start = time.time()
# * 使用验证集验证数据
valid_loss, valid_accu = eval_epoch(model, validation_data, device, opt)
valid_ppl = math.exp(min(valid_loss, 100))
print_performances('Validation', valid_ppl, valid_accu, start, lr)
Expand Down Expand Up @@ -278,6 +282,7 @@ def main():

print(opt)

# * 这里是设置transformer参数
transformer = Transformer(
opt.src_vocab_size,
opt.trg_vocab_size,
Expand All @@ -295,6 +300,7 @@ def main():
dropout=opt.dropout,
scale_emb_or_prj=opt.scale_emb_or_prj).to(device)

# * 设置优化器
optimizer = ScheduledOptim(
optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09),
opt.lr_mul, opt.d_model, opt.n_warmup_steps)
Expand Down
4 changes: 2 additions & 2 deletions transformer/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
__author__ = "Yu-Hsiang Huang"


class EncoderLayer(nn.Module):
class EncoderLayer(nn.Module):
''' Compose with two layers '''
# * 一个encode layer 包含一个MHA 以及 layernorm + feedforward

def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
Expand All @@ -21,7 +22,6 @@ def forward(self, enc_input, slf_attn_mask=None):
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn


class DecoderLayer(nn.Module):
''' Compose with three layers '''

Expand Down
13 changes: 8 additions & 5 deletions transformer/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def get_position_angle_vec(position):
return torch.FloatTensor(sinusoid_table).unsqueeze(0)

def forward(self, x):
# * 位置编码不需要梯度回传 所以会有detach
# * 这里将输入和pos+encoding 相加
return x + self.pos_table[:, :x.size(1)].clone().detach()


class Encoder(nn.Module):
''' A encoder model with self attention mechanism. '''

Expand All @@ -57,8 +58,9 @@ def __init__(
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
# * 看起来这里是 MHA的实现逻辑 多个Encodelayer 堆叠。encodelayer 内部包含MHA 以及layernorm
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.scale_emb = scale_emb
Expand All @@ -69,17 +71,18 @@ def forward(self, src_seq, src_mask, return_attns=False):
enc_slf_attn_list = []

# -- Forward
enc_output = self.src_word_emb(src_seq)
enc_output = self.src_word_emb(src_seq) # * src_seq 应该是词表索引 通过查找表获取对应的embedding
if self.scale_emb:
enc_output *= self.d_model ** 0.5
enc_output = self.dropout(self.position_enc(enc_output))
enc_output = self.layer_norm(enc_output)

for enc_layer in self.layer_stack:
# * 计算每个encoder layer 每个layer的输出是下一个layer的输入
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
enc_slf_attn_list += [enc_slf_attn] if return_attns else []

if return_attns:
if return_attns: # * 这里看起来是为了打印出权重
return enc_output, enc_slf_attn_list
return enc_output,

Expand Down Expand Up @@ -185,7 +188,7 @@ def __init__(


def forward(self, src_seq, trg_seq):

# ? 这里的mask都是怎么用的?
src_mask = get_pad_mask(src_seq, self.src_pad_idx)
trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq)

Expand Down
2 changes: 1 addition & 1 deletion transformer/Modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def forward(self, q, k, v, mask=None):

if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)

# * 看起来比原文多了个 dropout
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)

Expand Down
17 changes: 12 additions & 5 deletions transformer/SubLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
self.d_k = d_k
self.d_v = d_v

self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
# * 每个矩阵是concat起来的 d_model 是输入embedding的size
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) # 注意这里只是线性乘法 没有bias
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False) # * 压缩多个layer使用的

self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

Expand All @@ -28,9 +29,10 @@ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):


def forward(self, q, k, v, mask=None):

# ! 这里输入的qkv实际上都是同一个矩阵 batch_size * seq_length * word_emb_dim
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
# * sz_b 是batch_size的意思 len_q 是seq_length的意思

residual = q

Expand All @@ -45,13 +47,18 @@ def forward(self, q, k, v, mask=None):

if mask is not None:
mask = mask.unsqueeze(1) # For head axis broadcasting.

# * 计算qkv乘法
# ! qkv.size = b x n x lq x dv
# ! matmul([b x n x lq x dv], [b x n x lq x dv])= b x n x lq x lq # q * KT
# ! matmul([b x n x lq x lq], [b x n x lq x dv])= b x n x lq x dv # (q * KT) * V
# batch_size = 11;seq_len = 6;emb_sz = 4; n_head = 3
# torch.rand(batch_size, n_head, seq_len, emb_sz)
q, attn = self.attention(q, k, v, mask=mask)

# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q = self.dropout(self.fc(q))
q = self.dropout(self.fc(q)) # *
q += residual

q = self.layer_norm(q)
Expand Down
1 change: 1 addition & 0 deletions transformer_docker_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
docker run --rm -it -v /Users/bytedance/Desktop/ai_infra/attention-is-all-you-need-pytorch:/transformer ss4g/transformer_env_torch /bin/bash