Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] Support Hubert, fintuned on the librispeech dataset #3088

Merged
merged 13 commits into from
May 4, 2023
2 changes: 1 addition & 1 deletion dataset/librispeech/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def create_manifest(data_dir, manifest_path):
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
"""
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要变?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原先的代码似乎和librispeech解压出的结果不太一致,本地已有librispeech数据集的情况下不太方便

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么意思?这里不是有的话就不下载了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,按照之前的吧

if not os.path.exists(os.path.join(target_dir)):
# download
filepath = download(url, md5sum, target_dir)
# unpack
Expand Down
133 changes: 133 additions & 0 deletions examples/librispeech/asr3/conf/hubertASR.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
############################################
# Network Architecture #
############################################
freeze_hubert: True
normalize_wav: True
output_norm: True
init_type: kaiming_uniform # !Warning: need to convergence
enc:
input_shape: 1024
dnn_blocks: 2
dnn_neurons: 1024
activation: True
ctc:
enc_n_units: 1024
blank_id: 0
dropout_rate: 0.0
hubert_params_path: "exp/hubert/pd_hubert.pdparams"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个模型是否可以给出下载链接?



task_cfg:
sample_rate: 16000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议能否添加pretrain/finetune的标签


model_cfg:
dropout_input: 0.0
final_dropout: 0.0
dropout: 0.0
attention_dropout: 0.0
activation_dropout: 0.1
apply_mask: True
mask_length: 10
mask_prob: 0.5
mask_selection: static
mask_other: 0.0
no_mask_overlap: False
mask_channel_length: 64
mask_channel_prob: 0.25
mask_channel_selection: static
mask_channel_other: 0.0
no_mask_channel_overlap: False
freeze_finetune_updates: 10000
feature_grad_mult: 0.0
layerdrop: 0.1
normalize: True
fp16: True
label_rate: 50
extractor_mode: layer_norm
encoder_layers: 24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是Large的配置?配置文件区分下吧

encoder_embed_dim: 1024
encoder_ffn_embed_dim: 4096
encoder_attention_heads: 16
activation_fn: gelu
encoder_layerdrop: 0.1
dropout_features: 0.0
final_dim: 768
untie_final_proj: True
layer_norm_first: True
conv_feature_layers: "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
conv_bias: False
logit_temp: 0.1
target_glu: False
mask_min_space: 1
mask_channel_min_space: 1
conv_pos: 128
conv_pos_groups: 16
latent_temp: [2.0, 0.5, 0.999995]
skip_masked: False
skip_nomask: True

###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean

###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
unit_type: char
mean_std_filepath: ""
preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs
batch_size: 8 # Different batch_size may cause large differences in results
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1
dist_sampler: True
shortest_first: True
return_lens_rate: True

############################################
# Data Augmentation #
############################################
audio_augment: # for raw audio
sample_rate: 16000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要两个sample_rate参数


###########################################
# Training #
###########################################
n_epoch: 1
accum_grad: 1
global_grad_clip: 5.0
model_optim: adadelta
model_optim_conf:
lr: 1.0
epsilon: 1.0e-6
rho: 0.95
model_scheduler: constantlr
model_scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
hubert_optim: adadelta
hubert_optim_conf:
lr: 0.9
epsilon: 1.0e-6
rho: 0.95
hubert_scheduler: constantlr
hubert_scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5
2 changes: 1 addition & 1 deletion examples/librispeech/asr3/local/data.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

stage=-1
stage=0
stop_stage=100

unit_type=char
Expand Down
2 changes: 1 addition & 1 deletion examples/librispeech/asr3/local/test.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ python3 utils/format_rsl.py \

for type in ctc_greedy_search; do
echo "decoding ${type}"
batch_size=16
batch_size=8
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
Expand Down
Empty file modified examples/librispeech/asr3/local/test_wav.sh
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion examples/librispeech/asr3/local/train.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ python3 -u ${BIN_DIR}/train.py \
--seed ${seed} \
--resume ${resume}
else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} --log_dir=exp/log/${ckpt_name} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
Expand Down
3 changes: 1 addition & 2 deletions examples/librispeech/asr3/path.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/


MODEL=wav2vec2
MODEL=$1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不需要固定,不能用传参的方式。如果是和wav2vec一个asr目录的话就单开个吧。

export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin
12 changes: 7 additions & 5 deletions examples/librispeech/asr3/run.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/bin/bash
set -e

. ./path.sh || exit 1;
MODEL=hubert
. ./path.sh ${MODEL} || exit 1;
. ./cmd.sh || exit 1;

gpus=0
stage=0
stop_stage=0
conf_path=conf/wav2vec2ASR.yaml
gpus=2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

记得够改回默认值。

stage=1
stop_stage=3
conf_path=conf/${MODEL}ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
Expand All @@ -19,6 +20,7 @@ audio_file=data/demo_002_en.wav

avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
ckpt=test6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删一下

echo "checkpoint name ${ckpt}"

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
Expand Down
13 changes: 13 additions & 0 deletions paddlespeech/s2t/exps/hubert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions paddlespeech/s2t/exps/hubert/bin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
64 changes: 64 additions & 0 deletions paddlespeech/s2t/exps/hubert/bin/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for hubert model."""
import cProfile

from yacs.config import CfgNode

from paddlespeech.s2t.exps.hubert.model import HubertASRTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments


def main_sp(config, args):
exp = Tester(config, args)
with exp.eval():
exp.setup()
exp.run_test()


def main(config, args):
main_sp(config, args)


if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)

# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')
Loading