Skip to content

Commit

Permalink
fix: image pre-processing again..., #3
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Oct 15, 2018
1 parent 26ccd30 commit f4c717a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
44 changes: 22 additions & 22 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self,
self.x_hr = tf.placeholder(tf.float32, shape=(None,) + self.hr_img_size, name='x-hr-img')

self.lr = tf.placeholder(tf.float32, name='learning_rate')
self.is_train = tf.placeholder(tf.bool, name='is_train')
# self.is_train = tf.placeholder(tf.bool, name='is_train')

# setting stuffs
self.setup()
Expand Down Expand Up @@ -122,11 +122,14 @@ def setup(self):

def image_processing(self, x, sign, name):
with tf.variable_scope(name):
rgb_mean = (sign * self.rgb_mean[0] * 255.,
sign * self.rgb_mean[1] * 255.,
sign * self.rgb_mean[2] * 255.)
x = tfutil.mean_shift(x, rgb_mean=rgb_mean)
return x
r, g, b = tf.split(x, num_or_size_splits=3, axis=-1)

# Sub/Add the mean value
rgb = tf.concat([r + sign * self.rgb_mean[0] * 255.,
g + sign * self.rgb_mean[1] * 255.,
b + sign * self.rgb_mean[2] * 255.], axis=-1)
# x = tfutil.mean_shift(x, rgb_mean) # for fast pre-processing
return rgb

def channel_attention(self, x, f, reduction, name):
"""
Expand All @@ -149,30 +152,29 @@ def channel_attention(self, x, f, reduction, name):
x = tf.nn.sigmoid(x)
return skip_conn * x

def residual_channel_attention_block(self, x, f, kernel_size, reduction, use_bn, name, is_train=True):
def residual_channel_attention_block(self, x, f, kernel_size, reduction, use_bn, name):
with tf.variable_scope("RCAB-%s" % name):
skip_conn = tf.identity(x, name='identity')

x = tfutil.conv2d(x, f=f, k=kernel_size, name="conv2d-1")
x = tf.layers.BatchNormalization(epsilon=self._eps, trainable=is_train, name="bn-1")(x) if use_bn else x
x = tf.layers.BatchNormalization(epsilon=self._eps, name="bn-1")(x) if use_bn else x
x = self.act(x)

x = tfutil.conv2d(x, f=f, k=kernel_size, name="conv2d-2")
x = tf.layers.BatchNormalization(epsilon=self._eps, trainable=is_train, name="bn-2")(x) if use_bn else x
x = tf.layers.BatchNormalization(epsilon=self._eps, name="bn-2")(x) if use_bn else x

x = self.channel_attention(x, f, reduction, name="RCAB-%s" % name)
return skip_conn + self.res_scale * x
return self.res_scale * x + skip_conn

def residual_group(self, x, f, kernel_size, reduction, use_bn, name, is_train=True):
def residual_group(self, x, f, kernel_size, reduction, use_bn, name):
with tf.variable_scope("RG-%s" % name):
skip_conn = tf.identity(x, name='identity')

for i in range(self.n_res_blocks):
x = self.residual_channel_attention_block(x, f, kernel_size, reduction, use_bn, name=str(i),
is_train=is_train)
x = self.residual_channel_attention_block(x, f, kernel_size, reduction, use_bn, name=str(i))

x = tfutil.conv2d(x, f=f, k=kernel_size)
return skip_conn + x
return x + skip_conn

def up_scaling(self, x, f, scale_factor, name):
"""
Expand All @@ -195,8 +197,7 @@ def up_scaling(self, x, f, scale_factor, name):
raise NotImplementedError("[-] Not supported scaling factor (%d)" % scale_factor)
return x

def residual_channel_attention_network(self, x, f, kernel_size, reduction, use_bn, scale,
is_train=True, reuse=False, gpu_idx=0):
def residual_channel_attention_network(self, x, f, kernel_size, reduction, use_bn, scale, reuse=False, gpu_idx=0):
with tf.variable_scope("Residual_Channel_Attention_Network-gpu%d" % gpu_idx, reuse=reuse):
x = self.image_processing(x, sign=-1, name='pre-processing')

Expand All @@ -206,7 +207,7 @@ def residual_channel_attention_network(self, x, f, kernel_size, reduction, use_b
# 2. body
x = head
for i in range(self.n_res_groups):
x = self.residual_group(x, f, kernel_size, reduction, use_bn, name=str(i), is_train=is_train)
x = self.residual_group(x, f, kernel_size, reduction, use_bn, name=str(i))

body = tfutil.conv2d(x, f=f, k=kernel_size, name="conv2d-body")
body += head
Expand All @@ -226,9 +227,8 @@ def build_model(self):
reduction=self.reduction,
use_bn=self.use_bn,
scale=self.img_scale,
is_train=self.is_train
)
self.output = tf.clip_by_value(self.output, 0, 255)
self.output = tf.cast(tf.clip_by_value(self.output, 0, 255), dtype=tf.uint8)

# l1 loss
self.loss = tf.reduce_mean(tf.abs(self.output - self.x_hr))
Expand All @@ -240,9 +240,9 @@ def build_model(self):
self.ssim = tf.reduce_mean(metric.ssim(self.output, self.x_hr, m_val=255))

# summaries
tf.summary.image('lr', self.x_lr)
tf.summary.image('hr', self.x_hr)
tf.summary.image('generated-hr', self.output)
tf.summary.image('lr', self.x_lr, max_outputs=self.batch_size)
tf.summary.image('hr', self.x_hr, max_outputs=self.batch_size)
tf.summary.image('generated-hr', self.output, max_outputs=self.batch_size)

tf.summary.scalar("loss/l1_loss", self.loss)
tf.summary.scalar("metric/psnr", self.psnr)
Expand Down
6 changes: 2 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def main():
rcan_model.global_step.assign(tf.constant(global_step))
start_epoch = global_step // (ds.n_images // config.batch_size)

best_loss = 1e2
best_loss = 2e2
for epoch in range(start_epoch, config.epochs):
for x_lr, x_hr in di.iterate():
# training
Expand All @@ -146,7 +146,6 @@ def main():
rcan_model.x_lr: x_lr,
rcan_model.x_hr: x_hr,
rcan_model.lr: lr,
rcan_model.is_train: True,
})

if global_step % config.logging_step == 0:
Expand All @@ -159,7 +158,6 @@ def main():
rcan_model.x_lr: x_lr,
rcan_model.x_hr: x_hr,
rcan_model.lr: lr,
rcan_model.is_train: False,
})
rcan_model.writer.add_summary(summary, global_step)

Expand All @@ -168,8 +166,8 @@ def main():
feed_dict={
rcan_model.x_lr: sample_lr,
rcan_model.lr: lr,
rcan_model.is_train: False,
})
print(output)

util.img_save(img=util.merge(output, (patch, patch)),
path=config.output_dir + "/%d.png" % global_step,
Expand Down

0 comments on commit f4c717a

Please sign in to comment.