-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ASR] Support Hubert, fintuned on the librispeech dataset #3088
Changes from 3 commits
fd61a61
ce75f8e
0a8d95c
c85d61c
3a31163
35c75fe
d036c2e
9c7c2ca
958bfe5
559627c
ac08645
7658e54
6f3585b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
############################################ | ||
# Network Architecture # | ||
############################################ | ||
freeze_hubert: True | ||
normalize_wav: True | ||
output_norm: True | ||
init_type: kaiming_uniform # !Warning: need to convergence | ||
enc: | ||
input_shape: 1024 | ||
dnn_blocks: 2 | ||
dnn_neurons: 1024 | ||
activation: True | ||
ctc: | ||
enc_n_units: 1024 | ||
blank_id: 0 | ||
dropout_rate: 0.0 | ||
hubert_params_path: "exp/hubert/pd_hubert.pdparams" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个模型是否可以给出下载链接? |
||
|
||
|
||
task_cfg: | ||
sample_rate: 16000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议能否添加pretrain/finetune的标签 |
||
|
||
model_cfg: | ||
dropout_input: 0.0 | ||
final_dropout: 0.0 | ||
dropout: 0.0 | ||
attention_dropout: 0.0 | ||
activation_dropout: 0.1 | ||
apply_mask: True | ||
mask_length: 10 | ||
mask_prob: 0.5 | ||
mask_selection: static | ||
mask_other: 0.0 | ||
no_mask_overlap: False | ||
mask_channel_length: 64 | ||
mask_channel_prob: 0.25 | ||
mask_channel_selection: static | ||
mask_channel_other: 0.0 | ||
no_mask_channel_overlap: False | ||
freeze_finetune_updates: 10000 | ||
feature_grad_mult: 0.0 | ||
layerdrop: 0.1 | ||
normalize: True | ||
fp16: True | ||
label_rate: 50 | ||
extractor_mode: layer_norm | ||
encoder_layers: 24 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是Large的配置?配置文件区分下吧 |
||
encoder_embed_dim: 1024 | ||
encoder_ffn_embed_dim: 4096 | ||
encoder_attention_heads: 16 | ||
activation_fn: gelu | ||
encoder_layerdrop: 0.1 | ||
dropout_features: 0.0 | ||
final_dim: 768 | ||
untie_final_proj: True | ||
layer_norm_first: True | ||
conv_feature_layers: "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" | ||
conv_bias: False | ||
logit_temp: 0.1 | ||
target_glu: False | ||
mask_min_space: 1 | ||
mask_channel_min_space: 1 | ||
conv_pos: 128 | ||
conv_pos_groups: 16 | ||
latent_temp: [2.0, 0.5, 0.999995] | ||
skip_masked: False | ||
skip_nomask: True | ||
|
||
########################################### | ||
# Data # | ||
########################################### | ||
train_manifest: data/manifest.train | ||
dev_manifest: data/manifest.dev | ||
test_manifest: data/manifest.test-clean | ||
|
||
########################################### | ||
# Dataloader # | ||
########################################### | ||
vocab_filepath: data/lang_char/vocab.txt | ||
unit_type: char | ||
mean_std_filepath: "" | ||
preprocess_config: conf/preprocess.yaml | ||
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs | ||
batch_size: 8 # Different batch_size may cause large differences in results | ||
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced | ||
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced | ||
minibatches: 0 # for debug | ||
batch_count: auto | ||
batch_bins: 0 | ||
batch_frames_in: 0 | ||
batch_frames_out: 0 | ||
batch_frames_inout: 0 | ||
num_workers: 0 | ||
subsampling_factor: 1 | ||
num_encs: 1 | ||
dist_sampler: True | ||
shortest_first: True | ||
return_lens_rate: True | ||
|
||
############################################ | ||
# Data Augmentation # | ||
############################################ | ||
audio_augment: # for raw audio | ||
sample_rate: 16000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么需要两个sample_rate参数 |
||
|
||
########################################### | ||
# Training # | ||
########################################### | ||
n_epoch: 1 | ||
accum_grad: 1 | ||
global_grad_clip: 5.0 | ||
model_optim: adadelta | ||
model_optim_conf: | ||
lr: 1.0 | ||
epsilon: 1.0e-6 | ||
rho: 0.95 | ||
model_scheduler: constantlr | ||
model_scheduler_conf: | ||
warmup_steps: 25000 | ||
lr_decay: 1.0 | ||
hubert_optim: adadelta | ||
hubert_optim_conf: | ||
lr: 0.9 | ||
epsilon: 1.0e-6 | ||
rho: 0.95 | ||
hubert_scheduler: constantlr | ||
hubert_scheduler_conf: | ||
warmup_steps: 25000 | ||
lr_decay: 1.0 | ||
log_interval: 1 | ||
checkpoint: | ||
kbest_n: 50 | ||
latest_n: 5 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
#!/bin/bash | ||
|
||
stage=-1 | ||
stage=0 | ||
stop_stage=100 | ||
|
||
unit_type=char | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} | |
|
||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ | ||
|
||
|
||
MODEL=wav2vec2 | ||
MODEL=$1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个不需要固定,不能用传参的方式。如果是和wav2vec一个asr目录的话就单开个吧。 |
||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,14 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
. ./path.sh || exit 1; | ||
MODEL=hubert | ||
. ./path.sh ${MODEL} || exit 1; | ||
. ./cmd.sh || exit 1; | ||
|
||
gpus=0 | ||
stage=0 | ||
stop_stage=0 | ||
conf_path=conf/wav2vec2ASR.yaml | ||
gpus=2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 记得够改回默认值。 |
||
stage=1 | ||
stop_stage=3 | ||
conf_path=conf/${MODEL}ASR.yaml | ||
ips= #xx.xx.xx.xx,xx.xx.xx.xx | ||
decode_conf_path=conf/tuning/decode.yaml | ||
avg_num=1 | ||
|
@@ -19,6 +20,7 @@ audio_file=data/demo_002_en.wav | |
|
||
avg_ckpt=avg_${avg_num} | ||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') | ||
ckpt=test6 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删一下 |
||
echo "checkpoint name ${ckpt}" | ||
|
||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Evaluation for hubert model.""" | ||
import cProfile | ||
|
||
from yacs.config import CfgNode | ||
|
||
from paddlespeech.s2t.exps.hubert.model import HubertASRTester as Tester | ||
from paddlespeech.s2t.training.cli import default_argument_parser | ||
from paddlespeech.s2t.utils.utility import print_arguments | ||
|
||
|
||
def main_sp(config, args): | ||
exp = Tester(config, args) | ||
with exp.eval(): | ||
exp.setup() | ||
exp.run_test() | ||
|
||
|
||
def main(config, args): | ||
main_sp(config, args) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = default_argument_parser() | ||
# save asr result to | ||
parser.add_argument( | ||
'--dict-path', type=str, default=None, help='dict path.') | ||
parser.add_argument( | ||
"--result_file", type=str, help="path of save the asr result") | ||
args = parser.parse_args() | ||
print_arguments(args, globals()) | ||
|
||
# https://yaml.org/type/float.html | ||
config = CfgNode(new_allowed=True) | ||
if args.config: | ||
config.merge_from_file(args.config) | ||
if args.decode_cfg: | ||
decode_confs = CfgNode(new_allowed=True) | ||
decode_confs.merge_from_file(args.decode_cfg) | ||
config.decode = decode_confs | ||
if args.opts: | ||
config.merge_from_list(args.opts) | ||
config.freeze() | ||
print(config) | ||
if args.dump_config: | ||
with open(args.dump_config, 'w') as f: | ||
print(config, file=f) | ||
|
||
# Setting for profiling | ||
pr = cProfile.Profile() | ||
pr.runcall(main, config, args) | ||
pr.dump_stats('test.profile') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要变?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原先的代码似乎和librispeech解压出的结果不太一致,本地已有librispeech数据集的情况下不太方便
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么意思?这里不是有的话就不下载了吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,按照之前的吧