You need to "implement your own" in this case - and the problem is that you will need an auxiliar dictionary with stand-alone keys - so that when the time comes to del one the paired keys, you are able to find them back.
The implementation of WeakrefDicts themselves are pretty simple and straightforward, using collection.abc helpers for mappings - you could even pick the code from there and evolve it- but I think a minimal one can be done from scratch like bellow.
To be clear: this is a fresh implementation of weak-key dicts, doing exactly what you asked in the question: the keys should e any sequence of weak-referenceable objects, and when any object from the sequence is destroyed, the item is cleared in the dictionary. This is done using the callback mechanism of low-level weakref.ref
objects.
import weakref
from collections.abc import MutableMapping
class MultiWeakKeyDict(MutableMapping):
def __init__(self, **kw):
self.data = {}
self.helpers = {}
self.update(**kw)
def _remove(self, wref):
for data_key in self.helpers.pop(wref, ()):
try:
del self.data[data_key]
except KeyError:
pass
def _build_key(self, keys):
return tuple(weakref.ref(item, self._remove) for item in keys)
def __setitem__(self, keys, value):
weakrefs = self._build_key(keys)
for item in weakrefs:
self.helpers.setdefault(item, set()).add(weakrefs)
self.data[weakrefs] = value
def __getitem__(self, keys):
return self.data[self._build_key(keys)]
def __delitem__(self, keys):
del self.data[self._build_key(keys)]
def __iter__(self):
for key in self.data:
yield tuple(item() for item in key)
def __len__(self):
return len(self.data)
def __repr__(self):
return f"{self.__class__.__name__}({', '.join('{!r}:{!r}'.format(k, v) for k, v in self.items())})"
And working:
In [142]: class A:
...: def __repr__(s): return "A obj"
...:
In [143]: a, b, c = [A() for _ in (1,2,3)]
In [144]: d = MultiWeakKeyDict()
In [145]: d[a,b] = 1
In [146]: d[b,c] = 2
In [147]: d
Out[147]: MultiWeakKeyDict((A obj, A obj):1.(A obj, A obj):2)
In [148]: len(d)
Out[148]: 2
In [149]: del b
In [150]: len(d)
Out[150]: 0
In [151]: d
Out[151]: MultiWeakKeyDict()