1

I have some model in pytorch, whose updatable weights I want to access and change manually.

How would that be done correctly?

Ideally, I would like a tensor of those weights.

It seems to me that

for parameter in model.parameters():
    do_something_to_parameter(parameter)

wouldn't be the right way to go, because

  1. It doesn't utilize GPU, and is not able to
  2. It doesn't even utilize low level implementation

What is the correct way of accessing a model's weights manually (not through loss.backward and optimizer.step)?

Gulzar
  • 23,452
  • 27
  • 113
  • 201
  • What do you mean it doesn't utilize GPU? Any tensor operations you do to tensors that are on GPU utilize GPU. That includes the model parameters. – Coolness Feb 18 '19 at 14:15
  • @Coolness I mean any operation that goes over parameters using a python for loop is not going through any GPU paralelism. For any parallelism to happen, some function has to be passed, which is clearly not what's going on here. For instance, loss.backward() can be calculated on GPU. – Gulzar Feb 18 '19 at 14:18
  • But the sets of parameters have distinct shapes and meanings. It's not possible to have them as a single tensor, unless you flatten them to a vector, in which case they lose their meaning. What kind of operation would you like to perform on them? – Coolness Feb 18 '19 at 14:20
  • I would expect some way to update all parameters together using some vector operation. A python for look can't possibly be the way to go. More concretely, I am trying to implement https://stackoverflow.com/questions/54734556/pytorch-how-to-create-an-update-rule-that-doesnt-come-from-derivatives and I am afraid a large network would never finish optimization if implemented like this. – Gulzar Feb 18 '19 at 14:20
  • @Coolness Also no luck in datascience.stackexchange: https://datascience.stackexchange.com/questions/45718/pytorch-how-to-create-an-update-rule-the-doesnt-come-from-derivatives – Gulzar Feb 18 '19 at 14:22
  • There is no way to update all parameters using some vector operation, because different sets of parameters are located in different tensors. Have a look at the [source code](https://pytorch.org/docs/0.3.1/_modules/torch/optim/adam.html) for the Adam optimizer. It literally loops through the different parameter sets and applies tensor operations on each set of parameters. – Coolness Feb 18 '19 at 14:24
  • So... Why not calculate the relevant grads, put them all in one big vector, then `params = params + lr * grads` (or adam equivalent) in the optimizer? Why would the parameters lose their meaning? Anyway, they do use some vector operations, and not operating on each parameter separately. I think looking at that optimizer code is what I really needed, thanks. I wish I knew a better way to search around the pytorch docs... – Gulzar Feb 18 '19 at 14:34
  • 1
    Because "putting them in a big vector" would create a copy of all the parameters. You need to store them separately to maintain the shapes, and e.g keep biases separated from weights (think of the linear layer, how would you do `A x + b` with a single concatenated vector?). The overhead in the for loop is completely dominated by the actual forward / backward passes of the model. – Coolness Feb 18 '19 at 14:40

1 Answers1

0

here's my method, you can generally input any model here and it will return a list of all torch.nn.* things, just add a wrap around it to return not module but it's weights

def flatten_model(modules):
    def flatten_list(_2d_list):
        flat_list = []
        # Iterate through the outer list
        for element in _2d_list:
            if type(element) is list:
                # If the element is of type list, iterate through the sublist
                for item in element:
                    flat_list.append(item)
            else:
                flat_list.append(element)
        return flat_list

    ret = []
    try:
        for _, n in modules:
            ret.append(loopthrough(n))
    except:
        try:
            if str(modules._modules.items()) == "odict_items([])":
                ret.append(modules)
            else:
                for _, n in modules._modules.items():
                    ret.append(loopthrough(n))
        except:
            ret.append(modules)
    return flatten_list(ret)