Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS]fix elementwise_floordiv's fill_constant #3075

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions paddlespeech/t2s/modules/conformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,19 @@ def forward(self, x_input, mask, cache=None):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None

skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = paddle.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)

if skip_layer:
if cache is not None:
x = paddle.concat([cache, x], axis=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask

# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
Expand All @@ -138,7 +135,6 @@ def forward(self, x_input, mask, cache=None):
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)

# multi-headed self-attention module
residual = x
if self.normalize_before:
Expand Down
8 changes: 3 additions & 5 deletions paddlespeech/t2s/modules/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward_attention(self, value, scores, mask=None):
mask = paddle.logical_not(mask)
# assume scores.dtype==paddle.float32, we only use "float32" here
dtype = str(scores.dtype).split(".")[-1]
min_value = numpy.finfo(dtype).min
min_value = float(numpy.finfo(dtype).min)
scores = masked_fill(scores, mask, min_value)
# (batch, head, time1, time2)
self.attn = softmax(scores)
Expand Down Expand Up @@ -192,12 +192,11 @@ def rel_shift(self, x):
x_padded = paddle.concat([zero_pad, x], axis=-1)
x_padded = x_padded.reshape([b, h, t2 + 1, t1])
# only keep the positions from 0 to time2
x = x_padded[:, :, 1:].reshape([b, h, t1, t2])[:, :, :, :t2 // 2 + 1]

new_t = paddle.cast(paddle.floor(t2 / 2) + 1, dtype='int32')
x = x_padded[:, :, 1:].reshape([b, h, t1, t2])[:, :, :, :new_t]
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]

return x

def forward(self, query, key, value, pos_emb, mask):
Expand All @@ -221,7 +220,6 @@ def forward(self, query, key, value, pos_emb, mask):
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = q.transpose([0, 2, 1, 3])

n_batch_pos = paddle.shape(pos_emb)[0]
p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
Expand Down
3 changes: 2 additions & 1 deletion paddlespeech/t2s/modules/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def forward(self, x: paddle.Tensor):
x = x * self.xscale
T = paddle.shape(x)[1]
pe_size = paddle.shape(self.pe)
pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ]
tmp = paddle.cast(paddle.floor(pe_size[1] / 2), dtype='int32')
pos_emb = self.pe[:, tmp - T + 1:tmp + T, ]
return self.dropout(x), self.dropout(pos_emb)


Expand Down
4 changes: 2 additions & 2 deletions paddlespeech/t2s/modules/transformer/multi_layer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def forward(self, x):
Tensor: Batch of output tensors (B, T, in_chans).
"""
x = self.relu(self.w_1(x.transpose([0, 2, 1]))).transpose([0, 2, 1])
return self.w_2(self.dropout(x).transpose([0, 2, 1])).transpose(
[0, 2, 1])
out = self.w_2(self.dropout(x).transpose([0, 2, 1])).transpose([0, 2, 1])
return out


class Conv1dLinear(nn.Layer):
Expand Down