0

I'm learning the concept of neural networks. I decided to try making the neuron class by myself. What is the best way to implement different activation functions in my code? Now it uses only the binary step function. It's my first try in coding neural networks so if you have any suggestions about my code, or it is completely dumb, please let me know.

Here is my code:

public class Neuron {

// properties
    private ArrayList<Neuron> input;
    private ArrayList<Float> weight;
    private float pot, bias, sense, out;
    private boolean checked;

// methods
    public float fire(){
        pot = 0f;
        if (input != null) {
            for (Neuron n : input){
                if (!n.getChecked()){
                    pot += n.fire()*weight.get(input.indexOf(n));
                } else {
                        pot += n.getOut()*weight.get(input.indexOf(n));
                } // end of condition (checked)
            } // end of loop (for input)
        } // end of condition (input exists)
        checked = true;
        pot -= bias;
        pot += sense;
        out = actFunc(pot);
        return out;
    } // end of fire()

    // getting properties
    public float getPot(){return pot;}
    public boolean getChecked(){return checked;}
    public float getOut(){return out;}

    // setting properties
    public void stimulate(float f){sense = f;}
    public void setBias(float b){bias = b;}
    public void setChecked(boolean c){checked = c;}
    public void setOut(float o){out = o;}

    // connection
    public void connect(Neuron n, float w){
        input.add(n);
        weight.add(w);
        }
    public void deconnect(Neuron n){
        weight.remove(input.indexOf(n));
        input.remove(n);
    }

    // activation function
        private float actFunc(float x){
            if (x < 0) {
                return 0f;
            } else {
                return 1f;
            }
        }

// constructor
    public Neuron(Neuron[] ns, float[] ws, float b, float o){
        if (ns != null){
            input = new ArrayList<Neuron>();
            weight = new ArrayList<Float>();
            for (Neuron n : ns) input.add(n);
            for (int i = 0; i < ws.length; i++) weight.add(ws[i]);
        } else {
            input = null;
            weight = null;
        }
        bias = b;
        out = o;
    }

    public Neuron(Neuron[] ns){
        if (ns != null){
            input = new ArrayList<Neuron>();
            weight = new ArrayList<Float>();
            for (Neuron n : ns) input.add(n);
            for (int i = 0; i < input.size(); i++) weight.add((float)Math.random()*2f-1f);
        } else {
            input = null;
            weight = null;
        }
        bias = (float)Math.random();
        out = (float)Math.random();
    }

}

1 Answers1

0

First, define interface of any activation function:

public interface ActivationFunction {
    float get(float f);
}

Then write some implementations:

public class StepFunction implements ActivationFunction {
    @Override
    public float get() {return (x < 0) ? 0f : 1f;}
}

public class SigmoidFunction implements ActivationFunction {
    @Override
    public float get() {return StrictMath.tanh(h);}
}

Finally, set some implementation to your Neuron:

public class Neuron {
    private final ActivationFunction actFunc;
    // other fields...

    public Neuron(ActivationFunction actFunc) {
        this.actFunc = actFunc;
    }

    public float fire(){
        // ...
        out = actFunc.get(pot);
        return out;
    } 
}

as following:

Neuron n = new Neuron(new SigmoidFunction());

Note, neural netoworks are using signal propagation through neurons, where weights are produced. Computing of weight depends also on first derivative of an activation function. Therefore, I would extend ActivationFunction by method, which will return first derivative at specified point x:

public interface ActivationFunction {
    float get(float f);
    float firstDerivative(float x);
}

So the implemenations will look like:

public class StepFunction implements ActivationFunction {
    @Override
    public float get(float x) {return (x < 0) ? 0f : 1f;}

    @Override
    public float firstDerivative(float x) {return 1;}
}

public class SigmoidFunction implements ActivationFunction {
    @Override
    public float get(float x) {return StrictMath.tanh(x);}

    // derivative_of tanh(x) = (4*e^(2x))/(e^(2x) + 1)^2 == 1-tanh(x)^2 
    @Override
    public float firstDerivative(float x) {return 1 - Math.pow(StrictMath.tanh(x), 2);}
}

Then, use actFunction.firstDerivative(x); in fire() method where weight is being computed.

matoni
  • 2,479
  • 21
  • 39