0

I'm trying to implement neural style transfer using the pre-trained vgg19 model in google colab. I'm getting an error while running this section of code. Content loss is printing the correct value but I'm not sure what is wrong with style loss.

def content_loss(target_conv4_2,content_conv4_2):
  loss=torch.mean((target_conv4_2-content_conv4_2)**2)
  return loss

style_grams = {layer : gram_matrix(style_f[layer]) for layer in style_f}

def style_loss(style_weights,target_features,style_grams):
  loss = 0
  for layer in style_weights:
    target_f = target_features[layer]
    target_gram = gram_matrix[target_f]
    style_gram = style_grams[layer]
    b,c,h,w = target_f.shape
    layer.loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
    loss += layer_loss/(c*h*w)

  return loss

target = content_p.clone().requires_grad_(True).to(device)
target_f = get_features(target,vgg)
print("Content Loss: ",content_loss(target_f['conv4_2'],content_f['conv4_2']))
print("Style Loss: ",style_loss(style_weights,target_f,style_grams))

This is the error:

Content Loss:  tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-101-0b8ad8e6d456> in <module>()
      3 #style_grams = {layer : gram_matrix(style_f[layer]) for layer in style_f}
      4 print("Content Loss: ",content_loss(target_frittata['conv4_2'],content_f['conv4_2']))
----> 5 print("Style Loss: ",style_loss(style_weights,target_frittata,style_grams))

<ipython-input-98-25679c9fd886> in style_loss(style_weights, target_features, style_grams)
      5   for layer in style_weights:
      6     target_f = target_features[layer]
----> 7     target_gram = gram_matrix[target_f]
      8     style_gram = style_grams[layer]
      9     b,c,h,w = target_f.shape

TypeError: 'function' object is not subscriptable

According to this answer, its because of 2 objects with the same name, but I have no idea where's the error.

  • 2
    `gram_matrix` is an array/list here: `target_gram = gram_matrix[target_f]` and a function here: `gram_matrix(style_f[layer])`. This is of course doesn't look right! – hesham_EE Jan 12 '22 at 02:26
  • 1
    `gram_matrix` is a function, but it seems like you're expecting it to be a dictionary. – John Gordon Jan 12 '22 at 02:28

0 Answers0