Skip to content

Commit

Permalink
fix vits reduce_sum's input/output dtype, test=tts (PaddlePaddle#3028)
Browse files Browse the repository at this point in the history
  • Loading branch information
yt605155624 authored and luotao1 committed Jun 11, 2024
1 parent f057fc0 commit 4b17e83
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
14 changes: 6 additions & 8 deletions paddlespeech/t2s/models/vits/duration_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,19 @@ def forward(
z_u, z1 = paddle.split(z_q, [1, 1], 1)
u = F.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += paddle.sum(
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
logq = (paddle.sum(-0.5 *
(math.log(2 * math.pi) +
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)

tmp1 = (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask
logdet_tot_q += paddle.sum(tmp1, [1, 2])
tmp2 = -0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask
logq = (paddle.sum(tmp2, [1, 2]) - logdet_tot_q)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = paddle.concat([z0, z1], 1)
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
(z**2)) * x_mask, [1, 2]) - logdet_tot)
tmp3 = 0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask
nll = (paddle.sum(tmp3, [1, 2]) - logdet_tot)
# (B,)
return nll + logq
else:
Expand Down
13 changes: 8 additions & 5 deletions paddlespeech/t2s/models/vits/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,9 @@ def forward(
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
tmp1 = -0.5 * math.log(2 * math.pi) - logs_p
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
tmp1,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
Expand All @@ -384,8 +385,9 @@ def forward(
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
tmp2 = -0.5 * (m_p**2) * s_p_sq_r
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
tmp2,
[1],
keepdim=True, )
# (B, T_feats, T_text)
Expand All @@ -403,7 +405,6 @@ def forward(
w = attn.sum(2)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / paddle.sum(x_mask)

# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = paddle.matmul(attn.squeeze(1),
Expand Down Expand Up @@ -511,8 +512,9 @@ def inference(
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
tmp3 = -0.5 * math.log(2 * math.pi) - logs_p
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
tmp3,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
Expand All @@ -524,8 +526,9 @@ def inference(
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
tmp4 = -0.5 * (m_p**2) * s_p_sq_r
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
tmp4,
[1],
keepdim=True, )
# (B, T_feats, T_text)
Expand Down
11 changes: 9 additions & 2 deletions paddlespeech/t2s/models/vits/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def piecewise_rational_quadratic_transform(


def mask_preprocess(x, mask):
# bins.dtype = int32
B, C, T, bins = paddle.shape(x)
new_x = paddle.zeros([mask.sum(), bins])
mask_int = paddle.cast(mask, dtype='int64')
# paddle.sum 输入是 int32 或 bool 的时候,输出是 int64
# paddle.zeros (fill_constant) 的 shape 会被强制转成 int32 类型
new_x = paddle.zeros([paddle.sum(mask_int), bins])
for i in range(bins):
new_x[:, i] = x[:, :, :, i][mask]
return new_x
Expand Down Expand Up @@ -240,4 +244,7 @@ def rational_quadratic_spline(

def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1
mask = inputs[..., None] >= bin_locations
mask_int = paddle.cast(mask, 'int64')
out = paddle.sum(mask_int, axis=-1) - 1
return out

0 comments on commit 4b17e83

Please sign in to comment.