Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR]add squeezeformer model #2755

Merged
merged 9 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions examples/aishell/asr1/conf/chunk_squeezeformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
encoder_dim: 256 # dimension of attention
output_size: 256 # dimension of output
attention_heads: 4
num_blocks: 12 # the number of encoder blocks
reduce_idx: 5
recover_idx: 11
feed_forward_expansion_factor: 4
input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
attention_dropout_rate: 0.1
adaptive_scale: true
cnn_module_kernel: 31
normalize_before: false
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv2d'
causal: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence

###########################################
# Data #
###########################################

train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test


###########################################
# Dataloader #
###########################################

vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1

###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
93 changes: 93 additions & 0 deletions examples/aishell/asr1/conf/squeezeformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: squeezeformer
encoder_conf:
encoder_dim: 256 # dimension of attention
output_size: 256 # dimension of output
attention_heads: 4
num_blocks: 12 # the number of encoder blocks
reduce_idx: 5
recover_idx: 11
feed_forward_expansion_factor: 4
input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
attention_dropout_rate: 0.1
adaptive_scale: true
cnn_module_kernel: 31
normalize_before: false
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv2d'

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence

###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test

###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1

###########################################
# Training #
###########################################
n_epoch: 150
accum_grad: 8
global_grad_clip: 5.0
dist_sampler: False
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
4 changes: 4 additions & 0 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from paddlespeech.s2t.modules.decoder import BiTransformerDecoder
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import SqueezeformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
Expand Down Expand Up @@ -905,6 +906,9 @@ def _init_from_config(cls, configs: dict):
elif encoder_type == 'conformer':
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
elif encoder_type == 'squeezeformer':
encoder = SqueezeformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
raise ValueError(f"not support encoder type:{encoder_type}")

Expand Down
164 changes: 163 additions & 1 deletion paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

logger = Log(__name__).getlog()

__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"]
__all__ = [
"MultiHeadedAttention", "RelPositionMultiHeadedAttention",
"RelPositionMultiHeadedAttention2"
]

# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
Expand Down Expand Up @@ -330,3 +333,162 @@ def forward(self,
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache


class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""

def __init__(self,
n_head,
n_feat,
dropout_rate,
do_rel_shift=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是多了后三个参数,默认都是false。可以和其他的一样看是否能合并成一个类。

adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = Linear(n_feat, n_feat)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.do_rel_shift = do_rel_shift
pos_bias_u = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_u', pos_bias_u)
pos_bias_v = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)
self.adaptive_scale = adaptive_scale
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
ada_scale = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
if init_weights:
self.init_weights()

def init_weights(self):
input_max = (self.h * self.d_k)**-0.5
self.linear_q._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_q._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_k._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_k._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_v._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_v._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_pos._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_pos._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_out._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_out._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)

def rel_shift(self, x, zero_triu: bool=False):
"""Compute relative positinal encoding.
Args:
x (paddle.Tensor): Input tensor (batch, head, time1, time1).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
[x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype)
x_padded = paddle.concat([zero_pad, x], axis=-1)

x_padded = x_padded.reshape(
[x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]])
x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1]

if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]

return x

def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros(
(0, 0, 0, 0))) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
if self.adaptive_scale:
query = self.ada_scale * query + self.ada_bias
key = self.ada_scale * key + self.ada_bias
value = self.ada_scale * value + self.ada_bias

q, k, v = self.forward_qkv(query, key, value)
if cache.shape[0] > 0:
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)

n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
# q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k)
# q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
# matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True)

# compute matrix b and matrix d
# (batch, head, time1, time2)
# matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
if self.do_rel_shift:
matrix_bd = self.rel_shift(matrix_bd)

scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache
Loading