Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2t] fix cli args to config #3194

Merged
merged 2 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddlespeech/dataset/s2t/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def build_vocab(manifest_paths="",
spm_vocab_size=0,
spm_model_prefix="",
spm_character_coverage=0.9995):
manifest_paths = [manifest_paths] if isinstance(manifest_paths,
str) else manifest_paths

fout = open(vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1
Expand Down
1 change: 1 addition & 0 deletions paddlespeech/dataset/s2t/format_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def format_data(
unit_type="char",
vocab_path="examples/librispeech/data/vocab.txt",
spm_model_prefix=""):
manifest_paths = [manifest_paths] if isinstance(manifest_paths, str) else manifest_paths

fout = open(output_path, 'w', encoding='utf-8')

Expand Down
24 changes: 4 additions & 20 deletions paddlespeech/s2t/exps/u2/bin/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alignment for U2 model."""
from yacs.config import CfgNode

from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments


Expand All @@ -32,26 +32,10 @@ def main(config, args):

if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)

maybe_dump_config(args.dump_config, config)
main(config, args)
20 changes: 4 additions & 16 deletions paddlespeech/s2t/exps/u2/bin/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for U2 model."""
from yacs.config import CfgNode

from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments


Expand All @@ -32,22 +32,10 @@ def main(config, args):

if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())

# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)

maybe_dump_config(args.dump_config, config)
main(config, args)
30 changes: 3 additions & 27 deletions paddlespeech/s2t/exps/u2/bin/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import paddle
from kaldiio import ReadHelper
from paddleslim import PTQ
from yacs.config import CfgNode

from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig

logger = Log(__name__).getlog()


Expand Down Expand Up @@ -173,32 +174,7 @@ def main(config, args):

if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_scp", type=str, help="path of the input audio file")
parser.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
parser.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the input audio file")
args = parser.parse_args()

config = CfgNode(new_allowed=True)

if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
main(config, args)
23 changes: 4 additions & 19 deletions paddlespeech/s2t/exps/u2/bin/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
"""Evaluation for U2 model."""
import cProfile

from yacs.config import CfgNode

from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments


Expand All @@ -34,27 +34,12 @@ def main(config, args):

if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_config, config)

# Setting for profiling
pr = cProfile.Profile()
Expand Down
25 changes: 2 additions & 23 deletions paddlespeech/s2t/exps/u2/bin/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
import sys
from pathlib import Path

import distutils
import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode

from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
Expand Down Expand Up @@ -125,27 +124,7 @@ def main(config, args):

if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args()

config = CfgNode(new_allowed=True)

if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
main(config, args)
18 changes: 4 additions & 14 deletions paddlespeech/s2t/exps/u2/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@
import cProfile
import os

from yacs.config import CfgNode

from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments

# from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer


def main_sp(config, args):
exp = Trainer(config, args)
Expand All @@ -39,17 +37,9 @@ def main(config, args):
args = parser.parse_args()
print_arguments(args, globals())

# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_path, config)

# Setting for profiling
pr = cProfile.Profile()
Expand Down
59 changes: 58 additions & 1 deletion paddlespeech/s2t/training/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
import argparse

import distutils
from yacs.config import CfgNode


class ExtendAction(argparse.Action):
"""
Expand Down Expand Up @@ -68,7 +71,15 @@ def default_argument_parser(parser=None):
parser.register('action', 'extend', ExtendAction)
parser.add_argument(
'--conf', type=open, action=LoadFromFile, help="config file.")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="logging with debug mode.")
parser.add_argument(
"--dump_path", type=str, default=None, help="path to dump config file.")

# train group
train_group = parser.add_argument_group(
title='Train Options', description=None)
train_group.add_argument(
Expand Down Expand Up @@ -103,14 +114,35 @@ def default_argument_parser(parser=None):
train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.")

# test group
test_group = parser.add_argument_group(
title='Test Options', description=None)

test_group.add_argument(
"--decode_cfg",
metavar="DECODE_CONFIG_FILE",
help="decode config file.")
test_group.add_argument(
"--result_file", type=str, help="path of save the asr result")
test_group.add_argument(
"--audio_file", type=str, help="path of the input audio file")

# quant & export
quant_group = parser.add_argument_group(
title='Quant Options', description=None)
quant_group.add_argument(
"--audio_scp", type=str, help="path of the input audio scp file")
quant_group.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
quant_group.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the jit model to save")

# profile group
profile_group = parser.add_argument_group(
title='Benchmark Options', description=None)
profile_group.add_argument(
Expand All @@ -131,3 +163,28 @@ def default_argument_parser(parser=None):
help='max iteration for benchmark.')

return parser


def config_from_args(args):
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)

if args.config:
config.merge_from_file(args.config)

if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs

if args.opts:
config.merge_from_list(args.opts)
config.freeze()
return config


def maybe_dump_config(dump_path, config):
if dump_path:
with open(dump_path, 'w') as f:
print(config, file=f)
print(f"save config to {dump_path}")