Skip to content

Commit

Permalink
Removed dropout from ResidualBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
Sulabh Kumra committed Nov 21, 2020
1 parent 0a9eb40 commit 08efed7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
7 changes: 1 addition & 6 deletions inference/models/grasp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,15 @@ class ResidualBlock(nn.Module):
A residual block with dropout option
"""

def __init__(self, in_channels, out_channels, kernel_size=3, dropout=False, prob=0.0):
def __init__(self, in_channels, out_channels, kernel_size=3):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels)

self.dropout = dropout
self.dropout1 = nn.Dropout(p=prob)

def forward(self, x_in):
x = self.bn1(self.conv1(x_in))
x = F.relu(x)
if self.dropout:
x = self.dropout1(x)
x = self.bn2(self.conv2(x))
return x + x_in
10 changes: 5 additions & 5 deletions inference/models/grconvnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def __init__(self, input_channels=4, output_channels=1, channel_size=32, dropout
self.conv3 = nn.Conv2d(channel_size * 2, channel_size * 4, kernel_size=4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(channel_size * 4)

self.res1 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob)
self.res2 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob)
self.res3 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob)
self.res4 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob)
self.res5 = ResidualBlock(channel_size * 4, channel_size * 4, dropout=dropout, prob=prob)
self.res1 = ResidualBlock(channel_size * 4, channel_size * 4)
self.res2 = ResidualBlock(channel_size * 4, channel_size * 4)
self.res3 = ResidualBlock(channel_size * 4, channel_size * 4)
self.res4 = ResidualBlock(channel_size * 4, channel_size * 4)
self.res5 = ResidualBlock(channel_size * 4, channel_size * 4)

self.conv4 = nn.ConvTranspose2d(channel_size * 4, channel_size * 2, kernel_size=4, stride=2, padding=1,
output_padding=1)
Expand Down

0 comments on commit 08efed7

Please sign in to comment.