-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
102 lines (79 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from generator import Generator
from discriminator import Discriminator
import torch.nn as nn
import matplotlib.pyplot as plt
batch_size = 100
def main():
data = get_data()
train_model(data)
def train_model(data):
noise_dim = 64
image_dim = 28 * 28 * 1
num_steps = len(data['train'])
num_epochs = 50
learning_rate = 0.01
gen = Generator(noise_dim, image_dim)
disc = Discriminator(image_dim)
opt_gen = optim.SGD(gen.parameters(), lr=learning_rate)
opt_disc = optim.SGD(disc.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
for i, (real_data, _) in enumerate(data['train']):
real_data = real_data.view(-1, 784)
fixed_noise = torch.randn((batch_size, noise_dim))
fake_data = gen(fixed_noise)
# train discriminator
disc_fake = disc(fake_data).view(-1)
disc_real = disc(real_data).view(-1)
disc_loss = discriminator_loss(disc_real, disc_fake)
disc.zero_grad()
disc_loss.backward(retain_graph=True)
opt_disc.step()
# train generator
disc_output = disc(fake_data).view(-1)
gen_loss = generator_loss(disc_output)
gen.zero_grad()
gen_loss.backward()
opt_gen.step()
if (i + 1) % 100 == 0:
print(f'Epoch: [{epoch + 1}/{num_epochs}], Step: [{i + 1}/{num_steps}],\
Disct Loss: {disc_loss:.4f} ,Gen Loss: {gen_loss:.4f}')
show_result(fake_data)
def show_result(generated_data):
fake_data = generated_data.view(100, 1, 28, 28)
fake_data = fake_data.detach().numpy()
# display the first 10 images
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(fake_data[i][0], cmap='gray')
plt.axis('off')
plt.show()
def discriminator_loss(real_output, fake_output):
criterion = nn.BCELoss()
real_loss = criterion(real_output, torch.ones_like(real_output))
fake_loss = criterion(fake_output, torch.zeros_like(fake_output))
return real_loss + fake_loss
def generator_loss(fake_output):
criterion = nn.BCELoss()
return criterion(fake_output, torch.ones_like(fake_output))
def get_data():
train_data = datasets.MNIST(
root='data',
train=True,
transform=ToTensor(),
download=True
)
loaders = {
'train': DataLoader(
train_data,
batch_size=batch_size,
shuffle=True
),
}
return loaders
if __name__ == '__main__':
main()