Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS]Fix losses of StarGAN v2 VC #3184

Merged
merged 7 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/vctk/vc3/local/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1
--ngpu=1 \
--speaker-dict=dump/speaker_id_map.txt
8 changes: 4 additions & 4 deletions paddlespeech/t2s/datasets/am_batch_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,12 +820,13 @@ def __init__(self, latent_dim: int=16, max_mel_length: int=192):
self.max_mel_length = max_mel_length

def random_clip(self, mel: np.array):
# [80, T]
mel_length = mel.shape[1]
# [T, 80]
mel_length = mel.shape[0]
if mel_length > self.max_mel_length:
random_start = np.random.randint(0,
mel_length - self.max_mel_length)
mel = mel[:, random_start:random_start + self.max_mel_length]

mel = mel[random_start:random_start + self.max_mel_length, :]
return mel

def __call__(self, exmaples):
Expand All @@ -843,7 +844,6 @@ def starganv2_vc_batch_fn(self, examples):
mel = [self.random_clip(item["mel"]) for item in examples]
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]

mel = batch_sequences(mel)
ref_mel = batch_sequences(ref_mel)
ref_mel_2 = batch_sequences(ref_mel_2)
Expand Down
17 changes: 16 additions & 1 deletion paddlespeech/t2s/exps/starganv2_vc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def train_sp(args, config):
model_version = '1.0'
uncompress_path = download_and_decompress(StarGANv2VC_source[model_version],
MODEL_HOME)
# 根据 speaker 的个数修改 num_domains
# 源码的预训练模型和 default.yaml 里面默认是 20
if args.speaker_dict is not None:
with open(args.speaker_dict, 'rt', encoding='utf-8') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
config['mapping_network_params']['num_domains'] = spk_num
config['style_encoder_params']['num_domains'] = spk_num
config['discriminator_params']['num_domains'] = spk_num

generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
Expand All @@ -123,7 +133,7 @@ def train_sp(args, config):
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
asr_model_dir = os.path.join(uncompress_path, 'asr.pdz')

F0_model = JDCNet(num_class=1, seq_len=192)
F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length'])
F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params'])
F0_model.eval()

Expand Down Expand Up @@ -234,6 +244,11 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")

args = parser.parse_args()

Expand Down
42 changes: 23 additions & 19 deletions paddlespeech/t2s/models/starganv2_vc/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,38 @@

from .transforms import build_transforms


# 这些都写到 updater 里
def compute_d_loss(nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg: bool=True,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):


def compute_d_loss(
nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
# TODO: should be True here, but r1_reg has some bug now
use_r1_reg: bool=False,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):

assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)

# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
# loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_reg = paddle.zeros([1])

# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_con_reg = paddle.zeros([1])
if use_con_reg:
t = build_transforms()
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
Expand Down Expand Up @@ -118,9 +121,10 @@ def compute_g_loss(nets: Dict[str, Any],
s_trg = nets['style_encoder'](x_ref, y_trg)

# compute ASR/F0 features (real)
with paddle.no_grad():
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# 源码没有用 .eval(), 使用了 no_grad()
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)

# adversarial loss
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)
Expand Down
4 changes: 2 additions & 2 deletions paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def evaluate_core(self, batch):
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)

Expand All @@ -269,7 +269,7 @@ def evaluate_core(self, batch):
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)

Expand Down