0

I converted the following code from Keras to Pytorch. The main challenge here for me is to make multi-inputs and multi-outputs model similar to keras.models.Model. As how to implement the following code, in Pytorch, to accept the multi input and outputs.

from tensorflow import keras as k
import tensorflow as tf
class NetworkKeys:
    NUM_UNITS = "num_units"
    ACTIVATION = "activation"
    L2_REG_FACT = "l2_reg_fact"
    DROP_PROB = "drop_prob"
    BATCH_NORM = "batch_norm"


def build_dense_network(input_dim, output_dim,
                        output_activation, params, with_output_layer=True):

    model = k.models.Sequential()

    activation = params.get(NetworkKeys.ACTIVATION, "relu")
    l2_reg_fact = params.get(NetworkKeys.L2_REG_FACT, 0.0)
    regularizer = k.regularizers.l2(l2_reg_fact) if l2_reg_fact > 0 else None
    drop_prob = params.get(NetworkKeys.DROP_PROB, 0.0)
    batch_norm = params.get(NetworkKeys.BATCH_NORM, False)

    last_dim = input_dim
    for i in range(len(params[NetworkKeys.NUM_UNITS])):
        model.add(k.layers.Dense(units=params[NetworkKeys.NUM_UNITS][i],
                                    kernel_regularizer=regularizer,
                                    input_dim=last_dim))
        if batch_norm:
            model.add(k.layers.BatchNormalization())
        model.add(k.layers.Activation(activation))
        last_dim = params[NetworkKeys.NUM_UNITS][i]

        if drop_prob > 0.0:
            model.add(k.layers.Dropout(rate=drop_prob))
    if with_output_layer:
        model.add(k.layers.Dense(units=output_dim, activation=output_activation))
    return model

ldre_net = build_dense_network(input_dim=input_dim, output_dim=1,
                               output_activation=k.activations.linear,
                               params=hidden_params)

p_samples = k.layers.Input(shape=(input_dim,))
q_samples = k.layers.Input(shape=(input_dim,))

train_model = k.models.Model(inputs=[p_samples, q_samples],
                             outputs=[ldre_net(p_samples),ldre_net(q_samples)])

Here is my attempt to convert the above code to Pytorch code:

def l2_penalty(model, l2_lambda=0.001):
    """Returns the L2 penalty of the params."""
    l2_norm = sum(p.pow(2).sum() for p in model.parameters())
    return l2_lambda*l2_norm
    
def build_dense_network(input_dim, output_dim,
                        output_activation, params, with_output_layer=True):
    activation = params.get(NetworkKeys.ACTIVATION, "relu")
    l2_reg_fact = params.get(NetworkKeys.L2_REG_FACT, 0.0)
    drop_prob = params.get(NetworkKeys.DROP_PROB, 0.0)
    batch_norm = params.get(NetworkKeys.BATCH_NORM, False)
    layers=[]
    last_dim = input_dim
    for i in range(len(params[NetworkKeys.NUM_UNITS])):
        layers.append(nn.Linear(last_dim,params[NetworkKeys.NUM_UNITS][i]))
        if batch_norm:
            layers.append(torch.nn.BatchNorm1d(params[NetworkKeys.NUM_UNITS][i]))

        if activation=="relu":
            layers.append(nn.ReLU())
        elif activation=="LeakyRelu":
            layers.append(nn.LeakyReLU(0.1,inplace=True))
        else:
            pass

        last_dim = params[NetworkKeys.NUM_UNITS][i]

        if drop_prob > 0.0:
            layers.append(torch.nn.Dropout(p=drop_prob))

    if with_output_layer:

        layers.append(nn.Linear(params[NetworkKeys.NUM_UNITS][-1],output_dim))
    model = nn.Sequential(*layers)
    regularizer = l2_penalty(model, l2_lambda=0.001) if l2_reg_fact > 0 else None
    return model, regularizer
    
class Split(torch.nn.Module):
    def __init__(self, module, n_parts: int, dim=1):
        super().__init__()
        self._n_parts = n_parts
        self._dim = dim
        self._module = module

    def forward(self, inputs):
        output = self._module(inputs)
        chunk_size = output.shape[self._dim] // self._n_parts
        return torch.split(output, chunk_size, dim=self._dim)

class Net(nn.Module):
    def __init__(self, hidden_params, input_dim):
        self._ldre_net, ldre_regularizer = build_dense_network(input_dim=input_dim,
                    output_dim=1,output_activation="linear", params=hidden_params)

        self._p_samples = nn.Linear(input_dim,input_dim)
        self._q_samples = nn.Linear(input_dim,input_dim)
        self._split_layers = Split(
                self._ldre_net,
                n_parts=2,
                dim = 0
            )
    def forward(self, x, inTrain=True):
        if inTrain:
            p = self._p_samples(x)
            q = self._q_samples(x)
            p = x[:, 0, :]
            q = x[:, 1, :]
            combined = torch.cat((p.view(p.size(0), -1),
                                q.view(q.size(0), -1)), dim=0)

            p_output, q_output =self._split_layers(combined)
            return p_output, q_output
        else:
            return self._ldre_net(x)

I am wondering whether my implementation in the Net class is correct or not?

Warkaz
  • 845
  • 6
  • 18
Dalek
  • 4,168
  • 11
  • 48
  • 100

1 Answers1

0

TLDR You control the number of inputs and outputs in PyTorch, in the form of a tensor (or a number of variables). Missing super initialization and the order of operations should be fixed. Also don't particularly like the way arguments are passed, recommend using *args and **kwargs.

Explanation

There were a few things for me to make it run, namely the parameters NetworkKeys are used to access the dictionary that is passed through. Seems like an overly complicated way to do things, as you tried to make default values, but in the end, it threw exceptions if there are none (namely num_units). Recommend just using args and kwargs and passing the dictionary as a parameter. Tried with the following example:

values = {NetworkKeys.BATCH_NORM: False,
    NetworkKeys.L2_REG_FACT: 0.0,
    NetworkKeys.DROP_PROB: 0.0,
    NetworkKeys.ACTIVATION: "relu",
    NetworkKeys.NUM_UNITS: [10, 10]
}
print(values)
Net(values, 10)

There were a few things to fix in the Net class

  • Needs initialization of super (e.g. super(Net, self).__init__())
  • Order of the forward pass didn't make sense, you are overriding the output of the linear layer, see that we are doing self_p_samples(p) now which is one of the dimensions p = x[:, 0, :].
class Net(nn.Module):
  def __init__(self, hidden_params, input_dim):
    super(Net, self).__init__()
    self._ldre_net, ldre_regularizer = build_dense_network(input_dim=input_dim, 
                output_dim=1,output_activation="linear", params=hidden_params)

    self._p_samples = nn.Linear(input_dim,input_dim)
    self._q_samples = nn.Linear(input_dim,input_dim)
    self._split_layers = Split(
            self._ldre_net,
            n_parts=2,
            dim = 0
        )
  def forward(self, x, inTrain=True):
    if inTrain:
      p = x[:, 0, :]
      q = x[:, 1, :]
      p = self._p_samples(p)
      q = self._q_samples(q)
      combined = torch.cat((p.view(p.size(0), -1),
                            q.view(q.size(0), -1)), dim=0)

      p_output, q_output =self._split_layers(combined)
      return p_output, q_output
    else:
      return self._ldre_net(x)

While displaying the network got with a successful forward pass with input size of torch.randn((1,2,10)):

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Net                                      --                        --
├─Linear: 1-1                            [1, 10]                   110
├─Linear: 1-2                            [1, 10]                   110
├─Split: 1-3                             [1, 1]                    --
├─Sequential: 1-4                        [2, 1]                    --
│    └─Linear: 2-1                       [2, 10]                   110
│    └─ReLU: 2-2                         [2, 10]                   --
│    └─Linear: 2-3                       [2, 10]                   110
│    └─ReLU: 2-4                         [2, 10]                   --
│    └─Linear: 2-5                       [2, 1]                    11
==========================================================================================
Total params: 451
Trainable params: 451
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================

Example output, will be in the form of:

(tensor([[-0.0699]], grad_fn=<SplitBackward0>),
 tensor([[0.0394]], grad_fn=<SplitBackward0>))

Note: I didn't try to overfit this model (which you should do) to validate that it indeed can learn what you want.

Also a side note, if you really wanted multiple outputs for auxiliary which aren't part of tensor and you have to compute, you can just do return x,y in the forward pass

Warkaz
  • 845
  • 6
  • 18
  • Swapping the order of splitting x completely messed up my code and its output doesn't make sense, you are right about overriding the output though. – Dalek Sep 27 '22 at 15:54
  • @Dalek Don't know your exact use case, I am only suggesting it from a model-building perspective. But if you return it to the original, it just means you are skipping the `self._p_sample` and `self._q_sample` linear layers. – Warkaz Sep 27 '22 at 15:57
  • @Dalek also would you mind clarifying `output doesn't make sense`, of course, if you change the behavior of the model you would have to retrain the model. As I assume those 2 linear layers were never trained – Warkaz Sep 27 '22 at 16:06
  • It is a big model but at the end I compute a KL divergence and the order of the KL term was `~1e-2` or `0.1` and it was more and less the same order as the tensorflow code even with this mistake but now with your suggestion its value is `~50`. – Dalek Sep 27 '22 at 16:08
  • @Dalek Did you do retraining or just changed and used a pre-trained one? If so, this is expected behavior as noted before (Never used those layers before). Also, the change shouldn't change the expressivity of the model, because `Relu` should be actually applied before sending further down the layers (meaning it actually collapses into the next layer without the non-linearity) – Warkaz Sep 27 '22 at 16:13
  • I tried to run it by removing the splitting part and just keeping `p = self._p_samples(x)` and `q = self._q_samples(x)`. The code crashed. – Dalek Sep 27 '22 at 16:19
  • Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/248387/discussion-between-warkaz-and-dalek). – Warkaz Sep 27 '22 at 16:21