Skip to content

Commit

Permalink
update: mean_shift implementation - v0.1, #3
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Oct 10, 2018
1 parent e20850b commit bf5b1bb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
21 changes: 11 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,21 @@ def setup(self):
# Optimizer
if self.optimizer == 'adam':
self.opt = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2)
beta1=self.beta1, beta2=self.beta2, epsilon=self.opt_eps)
elif self.optimizer == 'sgd': # gonna use mm opt actually
self.opt = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=self.momentum)
# self.opt = tf.train.GradientDescentOptimizer(learning_rate=self.lr)
else:
raise NotImplementedError("[-] Not supported optimizer (%s)" % self.optimizer)

def image_process(self, x, sign=-1):
r, g, b = tf.split(x, 3, 3)
# Sub/Add the mean value
rgb = tf.concat([r + sign * self.rgb_mean[0],
g + sign * self.rgb_mean[1],
b + sign * self.rgb_mean[2]], axis=3)
return rgb
def image_processing(self, x, sign, name):
with tf.variable_scope(name):
rgb_mean = tf.convert_to_tensor((sign * self.rgb_mean[0] * 255.,
sign * self.rgb_mean[1] * 255.,
sign * self.rgb_mean[2] * 255.),
dtype=tf.float32)
x = tfutil.mean_shift(x, rgb_mean=rgb_mean)
return x

def channel_attention(self, x, f, reduction, name):
"""
Expand Down Expand Up @@ -198,7 +199,7 @@ def up_scaling(self, x, f, scale_factor, name):
def residual_channel_attention_network(self, x, f, kernel_size, reduction, use_bn, scale,
is_train=True, reuse=False, gpu_idx=0):
with tf.variable_scope("Residual_Channel_Attention_Network-gpu%d" % gpu_idx, reuse=reuse):
x = self.image_process(x, sign=-1)
x = self.image_processing(x, sign=-1, name='pre-processing')

# 1. head
head = tfutil.conv2d(x, f=f, k=kernel_size, name="conv2d-head")
Expand All @@ -215,7 +216,7 @@ def residual_channel_attention_network(self, x, f, kernel_size, reduction, use_b
x = self.up_scaling(body, f, scale, name='up-scaling')
tail = tfutil.conv2d(x, f=self.n_channel, k=kernel_size, name="conv2d-tail") # (-1, 384, 384, 3)

x = self.image_process(tail, sign=1)
x = self.image_processing(tail, sign=1, name='post-processing')
return x

def build_model(self):
Expand Down
15 changes: 15 additions & 0 deletions tfutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ def pixel_shuffle(x, scaling_factor):
return x


def mean_shift(x, rgb_mean, f=3, k=1, s=1, pad='SAME', name='mean_shift'):
with tf.variable_scope(name):
weight_shape = [k, k, f, f]
weight = tf.get_variable(shape=weight_shape, trainable=False, name='ms_weight')
weight.assign(tf.reshape(tf.eye(f), weight_shape))

bias_shape = [k, k, k, f]
bias = tf.get_variable(shape=bias_shape, trainable=False, name='ms_bias')
bias.assign(tf.reshape(rgb_mean, bias_shape))

x = tf.nn.conv2d(x, weight, strides=s, padding=pad, name='ms_conv2d')
x = tf.nn.bias_add(x, bias)
return x


# ---------------------------------------------------------------------------------------------
# Gradients (for supporting multi-gpu in tensorflow)

Expand Down

0 comments on commit bf5b1bb

Please sign in to comment.