0

Easiest thing might be for me to just post the numpy code that I'm trying to perform directly in Theano if it's possible:

tensor = shared(np.random.randn(7, 16, 16)).eval()

tensor2 = tensor[0,:,:].eval()
tensor2[tensor2 < 1] = 0.0
tensor2[tensor2 > 0] = 1.0

new_tensor = [tensor2]
for i in range(1, tensor.shape[0]):
    new_tensor.append(np.multiply(tensor2, tensor[i,:,:].eval()))

output = np.array(new_tensor).reshape(7,16,16)

If it's not immediately obvious, what I'm trying to do is use the values from one matrix of a tensor made up of 7 different matrices and apply that to the other matrices in the tensor.

Really, the problem I'm solving is doing conditional statements in an objective function for a fully convoltional network in Keras. Basically the loss for some of the feature map values is going to be calculated (and subsequently weighted) differently from others depending on some of the values in one of the feature maps.

Corey J. Nolet
  • 339
  • 1
  • 4
  • 13

1 Answers1

1

You can easily implement conditionals with switch statement.

Here would be the equivalent code:

import theano
from theano import tensor as T
import numpy as np


def _check_new(var):
    shape =  var.shape[0]
    t_1, t_2 = T.split(var, [1, shape-1], 2, axis=0)
    ones = T.ones_like(t_1)
    cond = T.gt(t_1, ones)
    mask = T.repeat(cond, t_2.shape[0], axis=0)
    out  = T.switch(mask, t_2, T.zeros_like(t_2))
    output = T.join(0, cond, out)
    return output

def _check_old(var):
    tensor = var.eval()

    tensor2 = tensor[0,:,:]
    tensor2[tensor2 < 1] = 0.0
    tensor2[tensor2 > 0] = 1.0
    new_tensor = [tensor2]

    for i in range(1, tensor.shape[0]):
        new_tensor.append(np.multiply(tensor2, tensor[i,:,:]))

    output = theano.shared(np.array(new_tensor).reshape(7,16,16))
    return output


tensor = theano.shared(np.random.randn(7, 16, 16))
out1 =  _check_new(tensor).eval() 
out2 =  _check_old(tensor).eval()
print out1
print '----------------'
print ((out1-out2) ** 2).mean()

Note: since your masking on the first filter, I needed to use split and join operations.

indraforyou
  • 8,969
  • 3
  • 49
  • 40