-1

I tried to check if an inctance of a NamedTuple "Transition" is equal to any object in the list "self.memory".

Here is the code I tried to run:

from typing import NamedTuple
import random
import torch as t

Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)


class ReplayMemory:

    def __init__(self, capacity):
        self.memory = []
        self.capacity = capacity
        self.position = 0

    def store(self, *args):
        print(self.memory == Transition(*args))
        if Transition(*args) in self.memory:
            return
    if len(self.memory) < self.capacity:
        self.memory.append(None)
    self.memory[self.position] = Transition(*args)
    ...

And here is the output:

False
False

And the error I got:

   ...
        if Transition(*args) in self.memory:
    RuntimeError: bool value of Tensor with more than one value is ambiguous

This seems weird to me because the print is telling me that the "==" operation returns a boolean.

How could this be done correctly?

Thank you

Edit:

*args is a tuple that consists of

torch.Size([16, 12])
int
int
torch.Size([16, 12])
int
torch.Size([4])
Luke C
  • 1
  • 1
  • what does `Transition(*args)` return? – Kenan Jan 09 '20 at 21:58
  • Can you show us the values in `*args`? – John Gordon Jan 09 '20 at 22:00
  • @Kenan it returns a NamedTuple as specified in line 5 of the first code block – Luke C Jan 09 '20 at 22:07
  • 1
    Your problem is equality of the values *inside* the named tuple, not the tuple itself. The `in` operator compares the single tuple against each element in `self.memiry`. In contrast, the `==` operator compares the tuple against the list, which is trivially `False` because the types are not equal; no elements are compared in this case. – MisterMiyagi Jan 09 '20 at 22:20
  • @MisterMiyagi how is it possible that the assertion prints a single boolean though? – Luke C Jan 09 '20 at 22:21
  • 1
    @LukeC The print runs an entirely different comparison. It compares the tuple against the list, not its elements. – MisterMiyagi Jan 09 '20 at 22:23
  • 1
    @LukeC something to note when using `==`, `a=5;b=[];a==b` -> `False` and `a!=b` -> True, however the comparison doesn't make sense right – Kenan Jan 10 '20 at 00:16

1 Answers1

1

I believe the you should explicitly define equality.

from typing import NamedTuple
import random
import torch as t


class Sample(NamedTuple):
    state: t.Tensor
    action: int

    def __eq__(self, other):
        return bool(t.all(self.state == other.state)) and self.action == other.action