Search code examples
pythonpytorchconv-neural-networkencodertext2image

How to implement Text2Image with CNNs and Transposed CNNs


I wanna implement text2image neural networks like the image below: Please see the image ![Image2text](https://drive.google.com/file/d/1A82iC29omu2yQrKEJrv1ropNtaL3urfH/view?usp=share_link) using CNNs and Transposed CNNs with Embedding layer

import torch
from torch import nn

Input text :

text = "A cat wearing glasses and playing the guitar "

# Simple preprocessing the text
word_to_ix = {"A": 0, "cat": 1, "wearing": 2, "glasses": 3, "and": 4, "playing": 5, "the": 6, "guitar":7}
lookup_tensor = torch.tensor(list(word_to_ix.values()), dtype = torch.long) # a tensor representing words by integers

vocab_size = len(lookup_tensor)

architecture implementation :

class TextToImage(nn.Module):
    def __init__(self, vocab_size):
        super(TextToImage, self).__init__()
        
        self.vocab_size = vocab_size
        self.noise = torch.rand((56,64))
        
        # DEFINE the layers
        # Embedding
        self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim = 64)
        
        # Conv
        self.conv2d_1 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding='valid')
        self.conv2d_2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding='valid')
        
        # Transposed CNNs
        self.conv2dTran_1 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=1)
        self.conv2dTran_2 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding=0)
        self.conv2dTran_3 = nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(4, 4), stride=(2, 2), padding=0)
        
        self.relu    = torch.nn.ReLU(inplace=False)
        self.dropout = torch.nn.Dropout(0.4)
        

    def forward(self, text_tensor):
        #SEND the input text tensor to the embedding layer
        emb = self.embed(text_tensor)
        
        #COMBINE the embedding with the noise tensor. Make it have 3 dimensions
        combine1 = torch.cat((emb, self.noise), dim=1, out=None)

        #SEND the noisy embedding to the convolutional and transposed convolutional layers
        conv2d_1 = self.conv2d_1(combine1)
        conv2d_2 = self.conv2d_2(conv2d_1)
        dropout  = self.dropout(conv2d_2)
                                               
        conv2dTran_1 = self.conv2dTran_1(dropout)
        conv2dTran_2 = self.conv2dTran_2(conv2dTran_1)
                                               
        #COMBINE the outputs having a skip connection in the image of the architecture
        combine2 = torch.cat((conv2d_1, conv2dTran_2), dim=1, out=None)
        conv2dTran_3 = self.conv2dTran_3(combine2)

        #SEND the combined outputs to the final layer. Please name your final output variable as "image" so you that it can be returned
        image = self.relu(conv2dTran_3)

        return image

Expected output torch.Size( [3, 64, 64] )

texttoimage = TextToImage(vocab_size=vocab_size)

output = texttoimage(lookup_tensor)

output.size()

Generated random noisy image :

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(np.moveaxis(output.detach().numpy(), 0,-1))

The error I got :

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 8 but got size 56 for tensor number 1 in the list.

Does anyone how to solve this issue I think it from concatenate nosey with embedding


Solution

  • After changing dim = 0 and expand to 3 dim In addition there was issue in Input channel for first Conv_1 where changed from 64 to 1

    class TextToImage(nn.Module):
        def __init__(self, vocab_size):
            super(TextToImage, self).__init__()
            
            self.vocab_size = vocab_size
            self.noise = torch.rand((56,64))
            
            # DEFINE the layers
            # Embedding
            self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim = 64)
            
            # Conv
            self.conv2d_1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding='valid')
            self.conv2d_2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding='valid')
            
            # Transposed CNNs
            self.conv2dTran_1 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=1)
            self.conv2dTran_2 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding=0)
            self.conv2dTran_3 = nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(4, 4), stride=(2, 2), padding=0)
            
            self.relu    = torch.nn.ReLU(inplace=False)
            self.dropout = torch.nn.Dropout(0.4)
            
    
        def forward(self, text_tensor):
            #SEND the input text tensor to the embedding layer
            emb = self.embed(text_tensor)
            
            #COMBINE the embedding with the noise tensor. Make it have 3 dimensions
            combined = torch.cat((emb, self.noise), dim=0) #, out=None
            print(combined.shape)
            combined_3d = combined[None, :]
            print(combined_3d.shape)   
    
            # SEND the noisy embedding to the convolutional and transposed convolutional layers
            conv2d_1 = self.conv2d_1(combined_3d)
            conv2d_2 = self.conv2d_2(conv2d_1)
            dropout  = self.dropout(conv2d_2)
                                                   
            conv2dTran_1 = self.conv2dTran_1(dropout)
            conv2dTran_2 = self.conv2dTran_2(conv2dTran_1)
                                                   
            #COMBINE the outputs having a skip connection in the image of the architecture
            combined_2 = torch.cat((conv2d_1, conv2dTran_2),axis=0) #dim=1, out=None
            conv2dTran_3 = self.conv2dTran_3(combined_2)
    
            #SEND the combined outputs to the final layer. Please name your final output variable as "image" so you that it can be returned
            image = self.relu(conv2dTran_3)
    
            return image