1

I wonder if I want to implement dropout by myself, is something like the following sufficient (taken from Implementing dropout from scratch):

class MyDropout(nn.Module):
    def __init__(self, p: float = 0.5):
        super(MyDropout, self).__init__()
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
        self.p = p

    def forward(self, X):
        if self.training:
            binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
            return X * binomial.sample(X.size()) * (1.0/(1-self.p))
        return X

My concern is even if the unwanted weights are masked out (either through this way or by using a mask tensor), there can still be gradient flow through the 0 weights (https://discuss.pytorch.org/t/custom-connections-in-neural-network-layers/3027/9). Is my concern valid?

Sam
  • 199
  • 2
  • 9

1 Answers1

1

DropOut does not mask the weights - it masks the features.
For linear layers implementing y = <w, x> the gradient w.r.t the parameters w is x. Therefore, if you set entries in x to zero - it will amount to no update for the corresponding weight in the adjacent linear layer.

Shai
  • 111,146
  • 38
  • 238
  • 371
  • I see. So is there a convenient way of doing that? Apparently I need to do a conditional masking such that when x is multiplied with different rows of w (considering input to first hidden layer), for those unwanted forward pass I need to ask them. It seems this will break the simple expression like x = F.relu(self.fc(x)) (note: what I want is somewhat different from the standard dropout. I want to 0 mask certain specified forward pass deterministically. – Sam Aug 03 '21 at 07:56
  • @Zzy1130 the math of derivatives and chain rule does it for you. – Shai Aug 03 '21 at 07:58
  • 1
    Yes now I can see why with x being 0 when multiplying with certain row of w it wont update those parameters, but how to prevent x from multiplying with only certain rows of w? Math (matrix multiplication) itself dictates x to multiply with all rows of w. – Sam Aug 03 '21 at 08:02
  • 1
    after referring to the formula in the first link I cited, I see the masking is done afterwards. Thanks for ur reply. – Sam Aug 03 '21 at 08:12