Skip to content

Commit

Permalink
[TTS]StarGANv2 VC fix some trainer bugs, add add reset_parameters (Pa…
Browse files Browse the repository at this point in the history
  • Loading branch information
yt605155624 authored and luotao1 committed Jun 11, 2024
1 parent 6bd7d14 commit 135d19e
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 52 deletions.
35 changes: 18 additions & 17 deletions examples/vctk/vc3/conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ discriminator_params:
dim_in: 64 # same as dim_in in generator_params
num_domains: 20 # same as num_domains in mapping_network_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
n_repeat: 4
repeat_num: 4
asr_params:
input_dim: 80
hidden_dim: 256
Expand Down Expand Up @@ -77,54 +77,55 @@ loss_params:
###########################################################
batch_size: 5 # Batch size.
num_workers: 2 # Number of workers in DataLoader.
max_mel_length: 192

###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
generator_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4
style_encoder_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
style_encoder_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4
mapping_network_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
mapping_network_scheduler_params:
max_learning_rate: 2e-6
max_learning_rate: 2.0e-6
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-6
end_learning_rate: 2.0e-6
discriminator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
weight_decay: 1.0e-4
epsilon: 1.0e-9
discriminator_scheduler_params:
max_learning_rate: 2e-4
max_learning_rate: 2.0e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
end_learning_rate: 2.0e-4

###########################################################
# TRAINING SETTING #
Expand Down
4 changes: 1 addition & 3 deletions examples/vctk/vc3/local/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,4 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
--ngpu=1
23 changes: 12 additions & 11 deletions paddlespeech/t2s/datasets/am_batch_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,20 +852,21 @@ def starganv2_vc_batch_fn(self, examples):
# (B,)
label = paddle.to_tensor(label)
ref_label = paddle.to_tensor(ref_label)
# [B, 80, T] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel)
ref_mel = paddle.to_tensor(ref_mel)
ref_mel_2 = paddle.to_tensor(ref_mel_2)
# [B, T, 80] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel = paddle.to_tensor(ref_mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel_2 = paddle.to_tensor(ref_mel_2).transpose(
[0, 2, 1]).unsqueeze(1)

z_trg = paddle.randn(batch_size, self.latent_dim)
z_trg2 = paddle.randn(batch_size, self.latent_dim)
z_trg = paddle.randn([batch_size, self.latent_dim])
z_trg2 = paddle.randn([batch_size, self.latent_dim])

batch = {
"x_real": mels,
"y_org": labels,
"x_ref": ref_mels,
"x_ref2": ref_mels_2,
"y_trg": ref_labels,
"x_real": mel,
"y_org": label,
"x_ref": ref_mel,
"x_ref2": ref_mel_2,
"y_trg": ref_label,
"z_trg": z_trg,
"z_trg2": z_trg2
}
Expand Down
22 changes: 11 additions & 11 deletions paddlespeech/t2s/exps/starganv2_vc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
from paddle.optimizer.lr import OneCycleLR
from yacs.config import CfgNode

from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.resource.pretrained_models import StarGANv2VC_source
from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn
from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable
from paddlespeech.t2s.models.starganv2_vc import ASRCNN
from paddlespeech.t2s.models.starganv2_vc import Discriminator
from paddlespeech.t2s.models.starganv2_vc import Generator
from paddlespeech.t2s.models.starganv2_vc import JDCNet
from paddlespeech.t2s.models.starganv2_vc import MappingNetwork
Expand Down Expand Up @@ -66,24 +69,20 @@ def train_sp(args, config):
fields = ["speech", "speech_lengths"]
converters = {"speech": np.load}

collate_fn = starganv2_vc_batch_fn
collate_fn = build_starganv2_vc_collate_fn(
latent_dim=config['mapping_network_params']['latent_dim'],
max_mel_length=config['max_mel_length'])

# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True

# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
train_dataset = StarGANv2VCDataTable(data=train_metadata)
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )
dev_dataset = StarGANv2VCDataTable(data=dev_metadata)

# collate function and dataloader
train_sampler = DistributedBatchSampler(
Expand Down Expand Up @@ -118,6 +117,7 @@ def train_sp(args, config):
generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
style_encoder = StyleEncoder(**config['style_encoder_params'])
discriminator = Discriminator(**config['discriminator_params'])

# load pretrained model
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
Expand Down
13 changes: 11 additions & 2 deletions paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .layers import ConvNorm
from .layers import LinearNorm
from .layers import MFCC
from paddlespeech.t2s.modules.nets_utils import _reset_parameters
from paddlespeech.utils.initialize import uniform_


Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
hidden_dim=hidden_dim // 2,
n_token=n_token)

self.reset_parameters()
self.asr_s2s.reset_parameters()

def forward(self,
x: paddle.Tensor,
src_key_padding_mask: paddle.Tensor=None,
Expand Down Expand Up @@ -108,6 +112,9 @@ def get_future_mask(self, out_length: int, unmask_future_steps: int=0):
index_tensor.T + unmask_future_steps)
return mask

def reset_parameters(self):
self.apply(_reset_parameters)


class ASRS2S(nn.Layer):
def __init__(self,
Expand All @@ -118,8 +125,7 @@ def __init__(self,
n_token: int=40):
super().__init__()
self.embedding = nn.Embedding(n_token, embedding_dim)
val_range = math.sqrt(6 / hidden_dim)
uniform_(self.embedding.weight, -val_range, val_range)
self.val_range = math.sqrt(6 / hidden_dim)

self.decoder_rnn_dim = hidden_dim
self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
Expand Down Expand Up @@ -236,3 +242,6 @@ def parse_decoder_outputs(self,
hidden = paddle.stack(hidden).transpose([1, 0, 2])

return hidden, logit, alignments

def reset_parameters(self):
uniform_(self.embedding.weight, -self.val_range, self.val_range)
7 changes: 3 additions & 4 deletions paddlespeech/t2s/models/starganv2_vc/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,16 @@ def compute_d_loss(nets: Dict[str, Any],
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg=True,
use_adv_cls=False,
use_con_reg=False,
use_r1_reg: bool=True,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):

assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False

out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)

Expand Down
32 changes: 32 additions & 0 deletions paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import paddle.nn.functional as F
from paddle import nn

from paddlespeech.t2s.modules.nets_utils import _reset_parameters


class DownSample(nn.Layer):
def __init__(self, layer_type: str):
Expand Down Expand Up @@ -355,6 +357,8 @@ def __init__(self,
if w_hpf > 0:
self.hpf = HighPass(w_hpf)

self.reset_parameters()

def forward(self,
x: paddle.Tensor,
s: paddle.Tensor,
Expand Down Expand Up @@ -399,6 +403,9 @@ def forward(self,
out = self.to_out(x)
return out

def reset_parameters(self):
self.apply(_reset_parameters)


class MappingNetwork(nn.Layer):
def __init__(self,
Expand Down Expand Up @@ -427,6 +434,8 @@ def __init__(self,
nn.ReLU(), nn.Linear(hidden_dim, style_dim))
])

self.reset_parameters()

def forward(self, z: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
Expand All @@ -449,6 +458,9 @@ def forward(self, z: paddle.Tensor, y: paddle.Tensor):
s = out[idx, y]
return s

def reset_parameters(self):
self.apply(_reset_parameters)


class StyleEncoder(nn.Layer):
def __init__(self,
Expand Down Expand Up @@ -490,6 +502,8 @@ def __init__(self,
for _ in range(num_domains):
self.unshared.append(nn.Linear(dim_out, style_dim))

self.reset_parameters()

def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
Expand All @@ -513,6 +527,9 @@ def forward(self, x: paddle.Tensor, y: paddle.Tensor):
s = out[idx, y]
return s

def reset_parameters(self):
self.apply(_reset_parameters)


class Discriminator(nn.Layer):
def __init__(self,
Expand All @@ -535,14 +552,29 @@ def __init__(self,
repeat_num=repeat_num)
self.num_domains = num_domains

self.reset_parameters()

def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, 1, 80, T).
y(Tensor(float32)):
Shape (B, ).
Returns:
Tensor:
Shape (B, )
"""
out = self.dis(x, y)
return out

def classifier(self, x: paddle.Tensor):
out = self.cls.get_feature(x)
return out

def reset_parameters(self):
self.apply(_reset_parameters)


class Discriminator2D(nn.Layer):
def __init__(self,
Expand Down
11 changes: 7 additions & 4 deletions paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler

from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss
from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState

logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
Expand Down Expand Up @@ -62,10 +65,10 @@ def __init__(self,
self.models = models

self.optimizers = optimizers
self.optimizer_g = optimizers['optimizer_g']
self.optimizer_s = optimizers['optimizer_s']
self.optimizer_m = optimizers['optimizer_m']
self.optimizer_d = optimizers['optimizer_d']
self.optimizer_g = optimizers['generator']
self.optimizer_s = optimizers['style_encoder']
self.optimizer_m = optimizers['mapping_network']
self.optimizer_d = optimizers['discriminator']

self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
Expand Down
Loading

0 comments on commit 135d19e

Please sign in to comment.