Skip to content

Commit

Permalink
[TTS]Fix losses of StarGAN v2 VC (PaddlePaddle#3184)
Browse files Browse the repository at this point in the history
  • Loading branch information
yt605155624 authored and luotao1 committed Jun 11, 2024
1 parent d911d5a commit 14d4c89
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 27 deletions.
3 changes: 2 additions & 1 deletion examples/vctk/vc3/local/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1
--ngpu=1 \
--speaker-dict=dump/speaker_id_map.txt
8 changes: 4 additions & 4 deletions paddlespeech/t2s/datasets/am_batch_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,12 +820,13 @@ def __init__(self, latent_dim: int=16, max_mel_length: int=192):
self.max_mel_length = max_mel_length

def random_clip(self, mel: np.array):
# [80, T]
mel_length = mel.shape[1]
# [T, 80]
mel_length = mel.shape[0]
if mel_length > self.max_mel_length:
random_start = np.random.randint(0,
mel_length - self.max_mel_length)
mel = mel[:, random_start:random_start + self.max_mel_length]

mel = mel[random_start:random_start + self.max_mel_length, :]
return mel

def __call__(self, exmaples):
Expand All @@ -843,7 +844,6 @@ def starganv2_vc_batch_fn(self, examples):
mel = [self.random_clip(item["mel"]) for item in examples]
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]

mel = batch_sequences(mel)
ref_mel = batch_sequences(ref_mel)
ref_mel_2 = batch_sequences(ref_mel_2)
Expand Down
17 changes: 16 additions & 1 deletion paddlespeech/t2s/exps/starganv2_vc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def train_sp(args, config):
model_version = '1.0'
uncompress_path = download_and_decompress(StarGANv2VC_source[model_version],
MODEL_HOME)
# 根据 speaker 的个数修改 num_domains
# 源码的预训练模型和 default.yaml 里面默认是 20
if args.speaker_dict is not None:
with open(args.speaker_dict, 'rt', encoding='utf-8') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
config['mapping_network_params']['num_domains'] = spk_num
config['style_encoder_params']['num_domains'] = spk_num
config['discriminator_params']['num_domains'] = spk_num

generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
Expand All @@ -123,7 +133,7 @@ def train_sp(args, config):
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
asr_model_dir = os.path.join(uncompress_path, 'asr.pdz')

F0_model = JDCNet(num_class=1, seq_len=192)
F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length'])
F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params'])
F0_model.eval()

Expand Down Expand Up @@ -234,6 +244,11 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")

args = parser.parse_args()

Expand Down
42 changes: 23 additions & 19 deletions paddlespeech/t2s/models/starganv2_vc/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,38 @@

from .transforms import build_transforms


# 这些都写到 updater 里
def compute_d_loss(nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg: bool=True,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):


def compute_d_loss(
nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
# TODO: should be True here, but r1_reg has some bug now
use_r1_reg: bool=False,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):

assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)

# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
# loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_reg = paddle.zeros([1])

# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_con_reg = paddle.zeros([1])
if use_con_reg:
t = build_transforms()
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
Expand Down Expand Up @@ -118,9 +121,10 @@ def compute_g_loss(nets: Dict[str, Any],
s_trg = nets['style_encoder'](x_ref, y_trg)

# compute ASR/F0 features (real)
with paddle.no_grad():
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# 源码没有用 .eval(), 使用了 no_grad()
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)

# adversarial loss
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)
Expand Down
4 changes: 2 additions & 2 deletions paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def evaluate_core(self, batch):
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)

Expand All @@ -269,7 +269,7 @@ def evaluate_core(self, batch):
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)

Expand Down

0 comments on commit 14d4c89

Please sign in to comment.