0

I want to create custom activation function in TF2. The math is like this:

def sqrt_activation(x):
    if x >= 0:
        return tf.math.sqrt(x)
    else:
        return -tf.math.sqrt(-x)

The problem is that I can't compare x with 0 since x is a tensor. How to achieve this functionality?

David H. J.
  • 340
  • 2
  • 12

2 Answers2

3

You can skip the comparison by doing,

def sqrt_activation(x):
    return tf.math.sign(x)*tf.math.sqrt(tf.abs(x))
Vijay Mariappan
  • 16,921
  • 3
  • 40
  • 59
1

YOu need to use tf backend functions and convert your code as follows:

import tensorflow as tf
@tf.function
def sqrt_activation(x):
    zeros = tf.zeros_like(x)
    pos = tf.where(x >= 0, tf.math.sqrt(x), zeros)
    neg = tf.where(x < 0, -tf.math.sqrt(-x), zeros)
    return pos + neg

note that this function check all tensor to meet on those conditions ergo returning the pos + neg line

ahmet hamza emra
  • 580
  • 4
  • 15