Skip to content

Commit

Permalink
fix dtype diff of last expand_v2 op of VITS (PaddlePaddle#3041)
Browse files Browse the repository at this point in the history
  • Loading branch information
yt605155624 authored and luotao1 committed Jun 11, 2024
1 parent 1ccb3ce commit ba47b52
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions paddlespeech/t2s/models/vits/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,12 @@ def forward(
unnorm_widths = h[..., :self.bins] / denom
unnorm_heights = h[..., self.bins:2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins:]

xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inputs=xb,
unnormalized_widths=unnorm_widths,
unnormalized_heights=unnorm_heights,
unnormalized_derivatives=unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound, )
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/t2s/models/vits/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,6 @@ def rational_quadratic_spline(
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
mask = inputs[..., None] >= bin_locations
mask_int = paddle.cast(mask, 'int64')
mask_int = paddle.cast(mask, dtype='int64')
out = paddle.sum(mask_int, axis=-1) - 1
return out
4 changes: 2 additions & 2 deletions paddlespeech/t2s/modules/nets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,18 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):

bs = paddle.shape(lengths)[0]
if xs is None:
maxlen = lengths.max()
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
else:
maxlen = paddle.shape(xs)[length_dim]

seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
# VITS 最后一个 expand 的位置
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand.cast(seq_range_expand.dtype)

if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)

if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
Expand Down

0 comments on commit ba47b52

Please sign in to comment.