Skip to content

Commit

Permalink
Fixed expand dims bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fbcotter committed Oct 3, 2019
1 parent 3711ec5 commit 3eb0779
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions pytorch_wavelets/dtcwt/lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def prep_filt(h, c, transpose=False):
""" Prepares an array to be of the correct format for pytorch.
Can also specify whether to make it a row filter (set tranpose=True)"""
h = _as_col_vector(h)[::-1]
#h = np.reshape(h, [1, 1, *h.shape])
h = np.expand_dims(h, (0,1))
h = h[None, None, :]
h = np.repeat(h, repeats=c, axis=0)
if transpose:
h = h.transpose((0,1,3,2))
Expand Down

0 comments on commit 3eb0779

Please sign in to comment.