0

I am trying to make a progressive autoencoder and I have thought a couple ways of growing my network during training. However, I am always stuck on this one part where I don't know if changing the input(encoder) and output(decoder) channel would affect my network. See the example below.

X = torch.randn( 8, 1, 4, 4,) # A batch of 8 grayscale images of 4x4 pixels in size

Encoder = nn.Sequential( Conv2D( 1, 16, 3, 1, 1 ), nn.ReLU() ) # starting setup 16 3x3 kernels

if I print the above weights from the network I would get a size of [ 1, 16, 3, 3 ], 16 kernels each of size 3x3 if I want to grow the network I would need to do save those weight because hopefully its already well trained on those 4x4 image inputs.

X = torch.randn( 8, 1, 8, 8) # increase the image size from 4x4 to 8x8

...
new_model = nn.Sequential()

then do...
# copy the previous layer and its weights from the original encoder 

# BTW My issue starts here.

# Add/grow the new_model with new layers concat with the old layer, also modify the input channela so they can link correctly

# Final result would be like something below.

new_model = nn.Sequential( Conv2D( **1**, 8, 3, 1, 1 ), nn.ReLU(), Conv2D( **8**, 16,  3, 1, 1 ), nn.ReLU()   )

Encoder = new_model

# Repeat process

Everything looks good however because I change the input channel the size of the weights changed as well, and this is the issue that I have been stuck on for a while now. You can simply check this by running,

foo_1 = nn.Conv2d(1, 1, 3, 1, 1) # You can think this as the starting Conv2D from the starting encoder

foo_2 = nn.Conv2d(3, 1, 3, 1, 1) # You can think this as the modfiied starting Conv2D with an outer layer outputting 3 channels connecting to it  

print(foo_1.weight.size()) # torch.Size([1, 1, 3, 3])

print(foo_2.weight.size()) # torch.Size([1, 3, 3, 3])

Initially, I thought foo_1 and foo_2 would both have the same weight size, as in both would only use one 3x3 kernel but it doesn't see to be the case. I hope you can see my dilemma now, after x amount of epochs I need to grow another convolution and I have to mess with the input size to make new layers chain properly but if I change the input size the shape of the weight is different and I don't know how pasting the old state would work.

I have been looking at pro gan implementations in pytorch and IMO they are not easy to read. How can I build more institutions on how to properly progressively grow your network?

halfer
  • 19,824
  • 17
  • 99
  • 186
Inkplay_
  • 561
  • 2
  • 5
  • 18

1 Answers1

1

By progressive autoencoder I assume you are referring to something like Pioneer Networks: Progressively Growing Generative Autoencoder which referred Progressive Growing of GANs for Improved Quality, Stability, and Variation.

First of all, don't use nn.Sequential. It is great for modeling simple and direct network structure, which is definitely not the case here. You should use simple nn.Conv2d, F.ReLU modules instead of building a nn.Sequential object.

Second, this isn't really about implementation but theory. You cannot magically convert a convolution layer from accepting 1 channels to 8 channels. There are numbers of ways to expand your convolution filter like appending random weights but I think it is not what you wanted.

From the second paper (It is a GAN, but the idea is the same), it does not expand any filters. Instead, the filters maintain their shape in the entire training process. Meaning that, you would have a

Conv2D(8, 16, 3, 1, 1)

from the very beginning (assuming you only have those two layers). An obvious problem pops out -- your grayscale image is a 1-channel input, but your convolution requires a 8-channel input in the first stage of training. In the second paper, it uses an extra 1x1 convolution layer to map RGB <-> feature maps. In your case, that would be

Conv2D(1, 8, 1)

which maps 1-channel input to 8-channel output. This can be thrown away after you are done with first stage.

There are other techniques like gradually fading in using a weight term stated in the paper. I suggest you read them, especially the second one.

hkchengrex
  • 4,361
  • 23
  • 33
  • Hi thank you so much I have read the pro gan paper already and finally understood what toRGB and fromRGB layers are now. Everything finally makes sense, so that's what they mean by fading in the layers. You have no idea how happy I am right now this was a major confusion and obstacles for me thanks for clearing this up! – Inkplay_ Jul 27 '18 at 12:12
  • @Inkplay_ Glad I could help! – hkchengrex Jul 27 '18 at 12:54