0

I am using keras for a while now, but usually I don't have to use customized layers or perform some more complex flow control, so I'm struggling trying to understand somethings.

I am modeling a neural network with a customized layer on the top. This customized layer calls another function (search_sigma) and inside this function I execute tf.while_loop and inside of tf.while_loop I execute tf.cond.

I cannot understand why the conditions are not working.

  • tf.while_loop stops even though the condition (l1) still true
  • tf.cond executes both f1 and f2 (callables true_fn and false_fn)

Could someone help me understand what I am missing?

I already tried to change both tf.cond and tf.while_loop conditions for true tensors, just to see what would happen. The behavior (exactly same errors) remained the same.

I also tried to write this code without implementing a class (using just functions). Nothing changed.

I tried to find solutions looking at tensorflow documentation, other stack overflow doubts and websites talking about tf.while_loop and tf.cond.

I left some print()s in the body of the code to try to track what was happening.

class find_sigma:
    
    def __init__ (self, t_inputs,  inputs,  expected_perp=10. ):       
        self.sigma, self.cluster = t_inputs
        self.inputs = inputs
        self.expected_perp = expected_perp
        self.min_sigma=tf.constant([0.01],tf.float32)
        self.max_sigma=tf.constant([50.],tf.float32)
 

    def search_sigma(self):

        
        def cond(s,sigma_not_found): return sigma_not_found


        def body(s,sigma_not_found):   

            print('loop')
            pi = K.exp( - K.sum( (K.expand_dims(self.inputs, axis=1) - self.cluster)**2, axis=2  )/(2*s**2) )        
            pi = pi / K.sum(pi)
            MACHINE_EPSILON = np.finfo(np.double).eps
            pi = K.maximum(pi, MACHINE_EPSILON)
            H = - K.sum ( pi*(K.log(pi)/K.log(2.)) , axis=0 )
            perp = 2**H

            print('0')

            l1 = tf.logical_and (tf.less(perp , self.expected_perp), tf.less(0.01, self.max_sigma-s))
            l2 = tf.logical_and (tf.less(  self.expected_perp , perp) , tf.less(0.01, s-self.min_sigma) )
    
            def f1():
                print('f1')
                self.min_sigma = s 
                s2 = (s+self.max_sigma)/2 
                return  [s2, tf.constant([True])]
                

            def f2(l2): 
                tf.cond( l2, true_fn=f3 , false_fn = f4)

            def f3(): 
                print('f3')
                self.max_sigma = s 
                s2 = (s+self.min_sigma)/2
                return [s2, tf.constant([True])]

            def f4(): 
                print('f4')
                return [s, tf.constant([False])]
            
            output = tf.cond( l1, f1 ,  f4 ) #colocar f2 no lugar de f4

            s, sigma_not_found = output
            
            print('sigma_not_found = ',sigma_not_found)
            return [s,sigma_not_found]

        print('01')

        sigma_not_found = tf.constant([True])

        new_sigma,sigma_not_found=sigma_not_found = tf.while_loop(
            cond , body, loop_vars=[self.sigma,sigma_not_found]
        )

        print('saiu')
        
        print(new_sigma)

        return new_sigma

The piece of code that calls the above code is:

self.sigma = tf.map_fn(fn=lambda t: find_sigma(t,  inputs).search_sigma() , elems=(self.sigma,self.clusters), dtype=tf.float32)

'inputs' is a (None, 10) size tensor

'self.sigma' is a (10,) size tensor

'self.clusters' is a (N, 10) size tensor

Red
  • 26,798
  • 7
  • 36
  • 58

1 Answers1

0

First of all, your first question was outstanding ! Lots of info !

tf.while_loop is very confusing and that is one of the reasons why tf moved to eager execution. You don't need to do that anymore.

Anyway, back to your 2 questions. The answer is the same for both, you are never executing your graph, you are just building it. While building the execution graph, tensorflow needs to trace your python code and that is why you think tf.conf is running f1 and f2. It is "sort of running" because it needs to go inside to figure out what tensors/operations will be added to the graph.

Same applies to your question regarding tf.while_loop. It is never executing that.

I recommend a small change which might help you to understand what I'm saying and also fix your problem. Remove that tf.while_loop from inside the body method. Create another method, let's say run() and move the loop there. Sort of like this

def run(self):
   out = tf.while_loop(cond, body, loop_vars)

Then, invoke run(). It will force the graph to get executed.

CrazyBrazilian
  • 1,030
  • 1
  • 11
  • 15