Skip to content

Commit

Permalink
fix develop bug function:view to reshape (PaddlePaddle#3633)
Browse files Browse the repository at this point in the history
  • Loading branch information
luyao-cv authored and luotao1 committed Jun 11, 2024
1 parent 44f1626 commit e2d1e9d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def forward_qkv(self,
"""
n_batch = query.shape[0]

q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = self.linear_q(query).reshape([n_batch, -1, self.h, self.d_k])
k = self.linear_k(key).reshape([n_batch, -1, self.h, self.d_k])
v = self.linear_v(value).reshape([n_batch, -1, self.h, self.d_k])

q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
Expand Down Expand Up @@ -129,8 +129,8 @@ def forward_attention(

p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).view(n_batch, -1, self.h *
self.d_k) # (batch, time1, d_model)
x = x.transpose([0, 2, 1, 3]).reshape([n_batch, -1, self.h *
self.d_k]) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

Expand Down Expand Up @@ -349,7 +349,7 @@ def forward(self,
new_cache = paddle.concat((k, v), axis=-1)

n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
Expand Down

0 comments on commit e2d1e9d

Please sign in to comment.