Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS]StarGANv2 VC fix some trainer bugs, add add reset_parameters #3182

Merged
merged 4 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions examples/vctk/vc3/conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ discriminator_params:
dim_in: 64 # same as dim_in in generator_params
num_domains: 20 # same as num_domains in mapping_network_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
n_repeat: 4
repeat_num: 4
asr_params:
input_dim: 80
hidden_dim: 256
Expand Down Expand Up @@ -77,54 +77,55 @@ loss_params:
###########################################################
batch_size: 5 # Batch size.
num_workers: 2 # Number of workers in DataLoader.
max_mel_length: 192

###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
generator_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4
style_encoder_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
style_encoder_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4
mapping_network_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
mapping_network_scheduler_params:
max_learning_rate: 2e-6
max_learning_rate: 2.0e-6
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-6
end_learning_rate: 2.0e-6
discriminator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
discriminator_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4

###########################################################
# TRAINING SETTING #
Expand Down
4 changes: 1 addition & 3 deletions examples/vctk/vc3/local/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,4 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
--ngpu=1
23 changes: 12 additions & 11 deletions paddlespeech/t2s/datasets/am_batch_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,20 +852,21 @@ def starganv2_vc_batch_fn(self, examples):
# (B,)
label = paddle.to_tensor(label)
ref_label = paddle.to_tensor(ref_label)
# [B, 80, T] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel)
ref_mel = paddle.to_tensor(ref_mel)
ref_mel_2 = paddle.to_tensor(ref_mel_2)
# [B, T, 80] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel = paddle.to_tensor(ref_mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel_2 = paddle.to_tensor(ref_mel_2).transpose(
[0, 2, 1]).unsqueeze(1)

z_trg = paddle.randn(batch_size, self.latent_dim)
z_trg2 = paddle.randn(batch_size, self.latent_dim)
z_trg = paddle.randn([batch_size, self.latent_dim])
z_trg2 = paddle.randn([batch_size, self.latent_dim])

batch = {
"x_real": mels,
"y_org": labels,
"x_ref": ref_mels,
"x_ref2": ref_mels_2,
"y_trg": ref_labels,
"x_real": mel,
"y_org": label,
"x_ref": ref_mel,
"x_ref2": ref_mel_2,
"y_trg": ref_label,
"z_trg": z_trg,
"z_trg2": z_trg2
}
Expand Down
22 changes: 11 additions & 11 deletions paddlespeech/t2s/exps/starganv2_vc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
from paddle.optimizer.lr import OneCycleLR
from yacs.config import CfgNode

from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.resource.pretrained_models import StarGANv2VC_source
from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn
from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable
from paddlespeech.t2s.models.starganv2_vc import ASRCNN
from paddlespeech.t2s.models.starganv2_vc import Discriminator
from paddlespeech.t2s.models.starganv2_vc import Generator
from paddlespeech.t2s.models.starganv2_vc import JDCNet
from paddlespeech.t2s.models.starganv2_vc import MappingNetwork
Expand Down Expand Up @@ -66,24 +69,20 @@ def train_sp(args, config):
fields = ["speech", "speech_lengths"]
converters = {"speech": np.load}

collate_fn = starganv2_vc_batch_fn
collate_fn = build_starganv2_vc_collate_fn(
latent_dim=config['mapping_network_params']['latent_dim'],
max_mel_length=config['max_mel_length'])

# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True

# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
train_dataset = StarGANv2VCDataTable(data=train_metadata)
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )
dev_dataset = StarGANv2VCDataTable(data=dev_metadata)

# collate function and dataloader
train_sampler = DistributedBatchSampler(
Expand Down Expand Up @@ -118,6 +117,7 @@ def train_sp(args, config):
generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
style_encoder = StyleEncoder(**config['style_encoder_params'])
discriminator = Discriminator(**config['discriminator_params'])

# load pretrained model
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
Expand Down
13 changes: 11 additions & 2 deletions paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .layers import ConvNorm
from .layers import LinearNorm
from .layers import MFCC
from paddlespeech.t2s.modules.nets_utils import _reset_parameters
from paddlespeech.utils.initialize import uniform_


Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
hidden_dim=hidden_dim // 2,
n_token=n_token)

self.reset_parameters()
self.asr_s2s.reset_parameters()

def forward(self,
x: paddle.Tensor,
src_key_padding_mask: paddle.Tensor=None,
Expand Down Expand Up @@ -108,6 +112,9 @@ def get_future_mask(self, out_length: int, unmask_future_steps: int=0):
index_tensor.T + unmask_future_steps)
return mask

def reset_parameters(self):
self.apply(_reset_parameters)


class ASRS2S(nn.Layer):
def __init__(self,
Expand All @@ -118,8 +125,7 @@ def __init__(self,
n_token: int=40):
super().__init__()
self.embedding = nn.Embedding(n_token, embedding_dim)
val_range = math.sqrt(6 / hidden_dim)
uniform_(self.embedding.weight, -val_range, val_range)
self.val_range = math.sqrt(6 / hidden_dim)

self.decoder_rnn_dim = hidden_dim
self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
Expand Down Expand Up @@ -236,3 +242,6 @@ def parse_decoder_outputs(self,
hidden = paddle.stack(hidden).transpose([1, 0, 2])

return hidden, logit, alignments

def reset_parameters(self):
uniform_(self.embedding.weight, -self.val_range, self.val_range)
7 changes: 3 additions & 4 deletions paddlespeech/t2s/models/starganv2_vc/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,16 @@ def compute_d_loss(nets: Dict[str, Any],
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg=True,
use_adv_cls=False,
use_con_reg=False,
use_r1_reg: bool=True,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):

assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False

out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)

Expand Down
32 changes: 32 additions & 0 deletions paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import paddle.nn.functional as F
from paddle import nn

from paddlespeech.t2s.modules.nets_utils import _reset_parameters


class DownSample(nn.Layer):
def __init__(self, layer_type: str):
Expand Down Expand Up @@ -355,6 +357,8 @@ def __init__(self,
if w_hpf > 0:
self.hpf = HighPass(w_hpf)

self.reset_parameters()

def forward(self,
x: paddle.Tensor,
s: paddle.Tensor,
Expand Down Expand Up @@ -399,6 +403,9 @@ def forward(self,
out = self.to_out(x)
return out

def reset_parameters(self):
self.apply(_reset_parameters)


class MappingNetwork(nn.Layer):
def __init__(self,
Expand Down Expand Up @@ -427,6 +434,8 @@ def __init__(self,
nn.ReLU(), nn.Linear(hidden_dim, style_dim))
])

self.reset_parameters()

def forward(self, z: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
Expand All @@ -449,6 +458,9 @@ def forward(self, z: paddle.Tensor, y: paddle.Tensor):
s = out[idx, y]
return s

def reset_parameters(self):
self.apply(_reset_parameters)


class StyleEncoder(nn.Layer):
def __init__(self,
Expand Down Expand Up @@ -490,6 +502,8 @@ def __init__(self,
for _ in range(num_domains):
self.unshared.append(nn.Linear(dim_out, style_dim))

self.reset_parameters()

def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
Expand All @@ -513,6 +527,9 @@ def forward(self, x: paddle.Tensor, y: paddle.Tensor):
s = out[idx, y]
return s

def reset_parameters(self):
self.apply(_reset_parameters)


class Discriminator(nn.Layer):
def __init__(self,
Expand All @@ -535,14 +552,29 @@ def __init__(self,
repeat_num=repeat_num)
self.num_domains = num_domains

self.reset_parameters()

def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, 1, 80, T).
y(Tensor(float32)):
Shape (B, ).
Returns:
Tensor:
Shape (B, )
"""
out = self.dis(x, y)
return out

def classifier(self, x: paddle.Tensor):
out = self.cls.get_feature(x)
return out

def reset_parameters(self):
self.apply(_reset_parameters)


class Discriminator2D(nn.Layer):
def __init__(self,
Expand Down
11 changes: 7 additions & 4 deletions paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler

from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss
from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState

logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
Expand Down Expand Up @@ -62,10 +65,10 @@ def __init__(self,
self.models = models

self.optimizers = optimizers
self.optimizer_g = optimizers['optimizer_g']
self.optimizer_s = optimizers['optimizer_s']
self.optimizer_m = optimizers['optimizer_m']
self.optimizer_d = optimizers['optimizer_d']
self.optimizer_g = optimizers['generator']
self.optimizer_s = optimizers['style_encoder']
self.optimizer_m = optimizers['mapping_network']
self.optimizer_d = optimizers['discriminator']

self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
Expand Down
Loading