-
-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DCGAN model for Flux v0.10 #207
Conversation
DCGAN model for Flux v0.10
vision/mnist/dcgan.jl
Outdated
end | ||
|
||
cd(@__DIR__) | ||
train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newline here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I meant a newline at the very end of the file, so that we don't have github complaining with a red arrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry. I'll fix that
very nice. Is this the architecture from the original paper (or any other)?. If so, can you add a reference in a comment? |
you should add the model to the README |
Maybe this is the right moment to start reorganizing this repo. We could have a folder structure like this |
vision/mnist/dcgan.jl
Outdated
@info("Train step $(train_steps), Discriminator loss = $(loss_dscr), Generator loss = $(loss_gen)") | ||
# Save generated fake image | ||
output_image = create_output_image(gen, fixed_noise) | ||
save(@sprintf("dcgan_steps_%06d.png", train_steps), output_image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could also add to this PR those figures you produced, it's nice to have a reference
Could you also use https://github.com/JuliaML/MLDatasets.jl instead of Flux.Data.MNIST? It's more future proof since we are going to excise Julia.Data soon |
@CarloLucibello , thank you for reviewing my PR.
Basically the architecture follows the DCGAN tutorial for tensorflow (https://www.tensorflow.org/tutorials/generative/dcgan). I would change the hyperparams if necessary.
I'm on board. It will be easier to maintain models and update deps. |
I will replace Flux.Data.MNIST by MLDatasets |
vision/dcgan_mnist/dcgan.jl
Outdated
using Printf | ||
|
||
const BATCH_SIZE = 128 | ||
const NOISE_DIM = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe LATENT_DIM sounds better here
I very much like the overall style, this would be a nice template for the other models. |
I used global variables just because they are used in other vision models. So if it is time to change, I would switch them to |
|
vision/dcgan_mnist/dcgan.jl
Outdated
|
||
function train() | ||
# Model Parameters | ||
hparams = HyperParams() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here we can do:
function train(; kws...)
hparams = HyperParams(; kws...)
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
# Load MNIST dataset | ||
images, _ = MLDatasets.MNIST.traindata(Float32) | ||
# Normalize to [-1, 1] and convert it to WHCN | ||
image_tensor = permutedims(reshape(@.(2f0 * images - 1f0), 28, 28, 1, :), (2, 1, 3, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ouch, the permutedims is a bit annoying but I guess we cannot do much about it without breaking changes in MLDatasets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch anyways
could you commit the pngs as well? then I think we are good to go |
Co-Authored-By: Carlo Lucibello <[email protected]>
Sure, I will rerun a whole training process again to make sure everything works fine and add ouput images |
vision/dcgan_mnist/dcgan.jl
Outdated
generator_loss(fake_output) = mean(logitbinarycrossentropy.(fake_output, 1f0)) | ||
|
||
function train_discriminator!(gen, dscr, batch, opt_dscr, hparams) | ||
noise = randn(Float32, hparams.latent_dim, hparams.batch_size) |> gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not ideal, we should avoid data transfer and use CuArrays.randn
if training on gpu. I don't know what would be an elegant way to do that, maybe we can revisit in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might write something like
randn!(similar(batch, (hparams.latent_dim, hparams.batch_size)))
https://docs.julialang.org/en/v1/stdlib/Random/#Random.randn!, https://github.com/JuliaGPU/CuArrays.jl/blob/da2389f0db46d611a5b625c2ce6bba79aa20163f/src/rand/random.jl#L88-L95
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
vision/dcgan_mnist/dcgan.jl
Outdated
for batch in data | ||
# Update discriminator and generator | ||
loss_dscr = train_discriminator!(gen, dscr, batch, opt_dscr, hparams) | ||
loss_gen = train_generator!(gen, dscr, batch, opt_gen, hparams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here part of the computation done in train_discriminator! could be reused in train_generator!, i.e. in principle one need only a single forward of the generator.
This is also done here https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
I don't know how to do it properly with zygote, needs some thinking. If you have a simple solution in mind you could add this optimization, otherwise we can revisit in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately I don't know how to do it, too.
better move the images to a subfolder |
last tiny detail, consider renaming |
Thanks, I applied the changes |
terrific works, thanks! |
@CarloLucibello, thank you for your awesome detailed review! |
This is a DCGAN implementation for Flux v0.10. I know there are already pending pull requests for DCGAN (#47, #111), but they are incompatible with Zygote.
A linear layer is used as the last layer of the discriminator and losses are calculated using
logitbinarycrossentropy
. This is because the combination ofsigmoid
andbinarycrossentropy
may cause numerical issues (FluxML/Flux.jl#914).It ouputs generated digits for a fixed noise every 1000 iterations. I believe it is helpful to trace its training process :)
0 steps
3000 steps
6000 steps
final result (=9380 steps)