NASA recently released the first images taken by the James Webb Telescope, and the world was amazed. The orbiting telescope was first launched on December 25, 2021, and images were released to the public on July 12, 2022. It is the most advanced telescope launched by humans to date, taking photographs of the furthest galaxies ever seen along with beautifully detailed images of distant nebulae and stars.
But who needs that when you can make up those kinds of pictures with AI? Kidding, of course, but it would be magical to be able to create your own small snapshot of space. That's where the GAN comes in.
The Generative Adversarial Network (GAN) is a deep-learning neural network that can be trained to generate data, in our case images, that do not actually exist. The way a GAN works is by pitting two models against each other: one to generate synthetic data, and the other to discriminate between the synthetic data and real data. Together, these to "adversaries" working agaist each other result in a generator that can produce synthetic data that looks real enough to trick the discriminator and perhaps even a human.
For our model, the real data will be 64x64 subsets of the images taken by the James Webb Telescope.
import os
import math
import numpy as np
import time
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
data_path = '/Users/nigelstory/Documents/jw_training_images/'
snippets_path = '/Users/nigelstory/Documents/jw_image_snippets/'
model_path = '/Users/nigelstory/Documents/jw_gan_models/'
full_image_files = [x for x in os.listdir(data_path) if x.endswith('.tif')]
full_image_files
The full-sized images from the telescope are quite large -- over 4000x4000 -- which is too big to processed all at once, so we will scan each image and extract 512x512 squares from each, which is big enough to capture some interesting features of the photos, and then the images will be downsized to 64x64.
img = Image.open(data_path + full_image_files[0])
img_arr = np.array(img)
img_arr.shape
stride = 512
snippets = np.lib.stride_tricks.sliding_window_view(img_arr, (stride, stride, 3))
The shape of the resulting array conveys the following meaning:
So overall, the shape of snippets
represents the following:
(number of vertical steps, number of horizontal steps, number of images produced per step, vertical dimension of output image, horizontal dimension of output image, color channels)
snippets.shape
tst = snippets[0][0][0]
Below, we can see one of the 512x512 extracted "snippets." This will be downsized futher to 64x64.
plt.imshow(tst)
plt.show()
And below, we see the same example at its final 64x64 resolution.
small_tst = Image.fromarray(tst).resize((64, 64), Image.ANTIALIAS)
plt.imshow(small_tst)
plt.show()
In order to loop through the snippet images, we need to loop through the indices of the first two dimensions of snippets
with the third dimension taken at 0. This will yield the arrays representing the (64, 64, 3) color image snippets.
def get_snippets(image_file, output_path, stride=512):
"""Create (stride, stride, 3) color image snippets from
larger image file.
Args:
image_file (str): Name of image file from which to create snippets.
output_path (str): Path to which snippets are to be saved.
stride (int): Shape of square snippet. Also governs the stride taken
between snippet recordings.
Returns:
None
"""
img_name_idx = len(os.listdir(output_path))
img = Image.open(data_path + image_file)
img_arr = np.array(img)
snippets = np.lib.stride_tricks.sliding_window_view(img_arr, (stride, stride, 3))
bounds = snippets.shape[:2]
for i in range(0, bounds[0], stride):
for j in range(0, bounds[1], stride):
snippet_arr = snippets[i][j][0]
out_img = Image.fromarray(snippet_arr).resize((64, 64), Image.ANTIALIAS)
out_img.save(output_path + f"{img_name_idx}.png")
img_name_idx += 1
We then save the final snippets to a training directory that we can use with PyTorch's image loader. This will provide us with 390 training images.
# for f in tqdm(full_image_files):
# get_snippets(f, snippets_path + 'train/1/')
img_size = 64
batch_size = 128
# create image transformations
transform = transforms.Compose([
transforms.ToTensor(),
# shift data range to (-1, 1)
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
dataset = torchvision.datasets.ImageFolder(
snippets_path,
transform=transform
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True
)
# set gpu or cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
The GAN model is composed of two component models: the generator and the discriminator. The better these two models perform, the better your generated data will be. For our models, we will use a Convolutional Neural Network as our discriminator and a decoding feed-forward neural network for our generator.
We will be experimenting between two different GAN architectures: one using a linear feed-forward network for image generation and the other using transpose convolutions.
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
nn.MaxPool2d(2)
)
self.conv_layers2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
nn.MaxPool2d(2),
nn.Flatten()
)
self.dense_layers = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256*8*8, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1)
)
def forward(self, x):
x = self.conv_layers(x)
x = self.conv_layers2(x)
x = self.dense_layers(x)
return x
The generator model takes as input random noise and outputs its generated data. The higher dimensionality used as input noise, the more granular your generated data will be. That is because the noise is representative of the latent space of the model, or the space from which the model can select its synthetic data. The trade-off is that the higher the dimension you use for the latent space, the larger your generator neural network will have to be and the longer the model will take to converge. After experimentation, 100 dimensions yields the best results for our uses.
The architecture of our generator is a feed-forward network, taking 100 features as input (noise) and outputting a vector the same length as a flattened image from our training images (64*64*3). This works similarly to the back half, or decoding layers, of an autoencoder.
class GAN(nn.Module):
def __init__(self):
super(GAN, self).__init__()
# discriminator
self.D = CNN()
# generator
self.latent_dim = 100
self.G = nn.Sequential(
nn.Linear(self.latent_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256, momentum=0.7),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512, momentum=0.7),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024, momentum=0.7),
nn.Linear(1024, 2048),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(2048, momentum=0.7),
nn.Linear(2048, 4096),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(4096, momentum=0.7),
nn.Linear(4096, img_size*img_size*3),
nn.Tanh()
)
def forward(self):
pass
Now it's time to train the model. We will start with 150 epochs, with batch sizes of 128.
The model will first train the discriminator with a batch of real images and a batch of noise generated by the untrained generator model. The discriminator will then undergo backpropagation, which will reinforce its ability to distinguish between generated data and the real images.
Next we train the generator. After the discriminator has trained once, the generator will generate a new batch of images and feed them to the discriminator. The discriminator will produce predictions, identifying some of the generated images as fake. Since the generator wants all of its outputs to be identified as true, we calculate the loss and backpropagate as though we had true values for all of out prediction targets. This will iteratively train the generator to produce images that the discriminator will think are real images.
def scale_image(img):
# scale images back to (0, 1)
out = (img + 1) / 2
return out
def batch_train(model, criterion, optimizers, data_loader, batch_size, epochs=10):
# label placeholders
ones_ = torch.ones(batch_size).to(device)
zeros_ = torch.zeros(batch_size).to(device)
# optimizers
g_optimizer = optimizers[0]
d_optimizer = optimizers[1]
# losses (not necessary)
g_losses = []
d_losses = []
for epoch in range(epochs):
start_time = time.time()
for inputs,_ in data_loader:
inputs = inputs.to(device)
n = inputs.size(0)
ones = ones_[:n].unsqueeze(1)
zeros = zeros_[:n].unsqueeze(1)
g_optimizer.zero_grad()
d_optimizer.zero_grad()
### Train Discriminator
# real images
real_outputs = model.D(inputs)
d_loss_real = criterion(real_outputs, ones)
# fake images
noise = torch.randn(n, model.latent_dim).to(device)
fake_images = model.G(noise)
fake_images = fake_images.reshape(n, 3, img_size, img_size)
fake_outputs = model.D(fake_images)
d_loss_fake = criterion(fake_outputs, zeros)
d_loss = (d_loss_real + d_loss_fake) / 2
d_loss.backward()
d_optimizer.step()
### Train Generator (x2)
for _ in range(2):
d_optimizer.zero_grad()
g_optimizer.zero_grad()
noise = torch.randn(n, model.latent_dim).to(device)
fake_images = model.G(noise)
fake_images = fake_images.reshape(n, 3, img_size, img_size)
fake_outputs = model.D(fake_images)
g_loss = criterion(fake_outputs, ones) # fake images as real
g_loss.backward()
g_optimizer.step()
g_losses.append(g_loss.item())
d_losses.append(d_loss.item())
print(f"Epoch {epoch+1}/{epochs}: d_loss={d_loss.item():0.4f}, g_loss={g_loss.item():0.4f}," + \
f" elapsed_time={time.time()-start_time:0.2f}s")
fake_images = fake_images.reshape(-1, 3, img_size, img_size)
save_image(scale_image(fake_images), f"/Users/nigelstory/Documents/jw_generated_images/epoch_{epoch+1}.png")
return g_losses, d_losses
model = GAN()
model.to(device)
criterion = nn.BCEWithLogitsLoss()
g_optimizer = torch.optim.Adam(model.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(model.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
losses = batch_train(model, criterion, (g_optimizer, d_optimizer), data_loader, batch_size, epochs=150)
We trained the model for 386 epochs, ending with a model that produces fairly nice images. Early stopping criteria for GAN models is something that is still actively being researched, and is something I plan to explore in future projects, but for now, we simply stopped trainig when the output images looked good to the human eye. See below for an example synthetic image.
def generate_image(model):
"""Generate a random image from a GAN's generator model.
Args:
model (nn.Module): GAN model with which to generate image.
Returns:
PIL.Image: Generated image.
"""
model.G.eval()
rand_vec = torch.randn(1, model.latent_dim).to(device)
output_image = model.G(rand_vec)
output_image = output_image.reshape(3, 64, 64)
output_image = scale_image(output_image)
transform = transforms.ToPILImage()
output_image = transform(output_image)
output_image = output_image.resize((512, 512), Image.ANTIALIAS)
return output_image
img = generate_image(model)
img = np.array(img)
plt.figure(figsize=(4,4))
plt.imshow(img)
plt.show()
# torch.save(model.G, model_path + 'gan_2.torch')
The main drawback of this model is a lack of variance in the generated images produced by the network. The generated images have trouble producing unique start placements; either the stars are in the same positions but with new background colors, or the image is generated without stars. This is what we will try to remedy with the next model by using a transposed convolutional generator rather than a simple linear feed-forward generator.
With this model, we seek to product images with wider variance in features. The transposed convolutions, sometimes called inverse convolutions (though not technically correct), uses kernels to expand the dimensionality of inputs rather than reduce, like a convolution. The kernels can be trained via backpropagation and gradient descent, which allows for wide variance in the output tensors.
We will make a few other adjustments to our training as well. First, we've reduced the training dataset to exclude image snippets that are dominated by a single color, i.e. empty space, homogeneous selections of nebulae, etc. The new training set has 135 images, and this lets us reduce the batch size a little as well -- from 128 to 64. Also, we will add another dense layer to our discriminator to make sure it can compete with our updated generator.
batch_size = 64
class DeepCNN(nn.Module):
def __init__(self):
super(DeepCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
nn.MaxPool2d(2)
)
self.conv_layers2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
nn.MaxPool2d(2),
nn.Flatten()
)
self.dense_layers = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256*8*8, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
# added layer
nn.Dropout(0.2),
nn.Linear(512, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
def forward(self, x):
x = self.conv_layers(x)
x = self.conv_layers2(x)
x = self.dense_layers(x)
return x
class TCNN_GAN(nn.Module):
def __init__(self):
super(TCNN_GAN, self).__init__()
# discriminator
self.D = DeepCNN()
# generator
self.latent_dim = 100
self.G = nn.Sequential(
nn.ConvTranspose2d(self.latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# state size. 512 x 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# state size. 256 x 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# state size. 128 x 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# state size. 64 x 32 x 32
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
# final state size. 3 x 64 x 64
)
def forward(self):
pass
def tcnn_batch_train(model, criterion, optimizers, data_loader, batch_size, epochs=10):
# label placeholders
ones_ = torch.ones(batch_size).to(device)
zeros_ = torch.zeros(batch_size).to(device)
# optimizers
g_optimizer = optimizers[0]
d_optimizer = optimizers[1]
# losses (not necessary)
g_losses = []
d_losses = []
for epoch in range(epochs):
start_time = time.time()
for inputs,_ in data_loader:
inputs = inputs.to(device)
n = inputs.size(0)
ones = ones_[:n].unsqueeze(1)
zeros = zeros_[:n].unsqueeze(1)
g_optimizer.zero_grad()
d_optimizer.zero_grad()
### Train Discriminator
# real images
real_outputs = model.D(inputs)
d_loss_real = criterion(real_outputs, ones)
# fake images
noise = torch.randn(n, model.latent_dim, 1, 1).to(device)
fake_images = model.G(noise)
fake_images = fake_images.reshape(n, 3, img_size, img_size)
fake_outputs = model.D(fake_images)
d_loss_fake = criterion(fake_outputs, zeros)
d_loss = (d_loss_real + d_loss_fake) / 2
d_loss.backward()
d_optimizer.step()
### Train Generator (x2)
for _ in range(2):
d_optimizer.zero_grad()
g_optimizer.zero_grad()
noise = torch.randn(n, model.latent_dim, 1, 1).to(device)
fake_images = model.G(noise)
fake_images = fake_images.reshape(n, 3, img_size, img_size)
fake_outputs = model.D(fake_images)
g_loss = criterion(fake_outputs, ones) # fake images as real
g_loss.backward()
g_optimizer.step()
g_losses.append(g_loss.item())
d_losses.append(d_loss.item())
print(f"Epoch {epoch+1}/{epochs}: d_loss={d_loss.item():0.4f}, g_loss={g_loss.item():0.4f}," + \
f" elapsed_time={time.time()-start_time:0.2f}s")
fake_images = fake_images.reshape(-1, 3, img_size, img_size)
save_image(scale_image(fake_images), f"/Users/nigelstory/Documents/jw_generated_images/epoch_{epoch+1}.png")
return g_losses, d_losses
tcnn_model = tcnn_batch_train(model, criterion, optimizers, data_loader, batch_size, epochs=300)
We trained this model for 300 epochs, and achieved the wider variance in outputs. The images also look more believable and representative of the source images. However, we see remnants of a kind of tiling effect caused by the transposed convolutions. The kernels of these convolutions, being square, leave a square or checker-board pattern on some of the output images, but after sufficient training, this effect is minimized. Using smaller kernel sizes may lead to less noticeable tiling, but we will leave that for future explorations.
def tcnn_generate_image(model):
"""Generate a random image from a GAN's generator model.
Args:
model (nn.Module): GAN model with which to generate image.
Returns:
PIL.Image: Generated image.
"""
model.eval()
rand_vec = torch.randn(1, 100, 1, 1).to(device)
output_image = model(rand_vec)
output_image = output_image.reshape(3, 64, 64)
output_image = scale_image(output_image)
transform = transforms.ToPILImage()
output_image = transform(output_image)
output_image = output_image.resize((512, 512), Image.ANTIALIAS)
return output_image
tcnn_model = torch.load(model_path + 'tcnn_gan_1.torch')
img = tcnn_generate_image(tcnn_model)
img = np.array(img)
plt.figure(figsize=(4,4))
plt.imshow(img)
plt.show()
Above, we see that we are getting a much wider array of colors and star placements, making the generated images appear more realistic than the results from the linear feed-forward generator.
Our GANs produce some nice looking, if a bit grainy, synthetic images of space. The transposed convolutional generator was able to produce more realistic and more "creative" images than the simple linear generator, but it had the drawback of producing a checker-board pattern with many of its generations. I am suprised that the models did not pick up more on the classic hexagon artifacts that we see in the real images from the James Webb Telescope. I believe this is because in the larger training data, these artifacts were not as common as one would think.
I was also suprised by the number of feasible models were produced in passing during the training process, highlighting how difficult defining an early stopping citerion can be for GANs. In future projects, I plan to explore criteria, such as inception scores, that could provide more objectivity when it comes to terminating training.