My question is simple:
>>> isinstance(x, jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray, numpy.ndarray)
True
>>> isinstance(x, numpy.ndarray)
False
?
And now I will ramble so SE will accept my reasonable question.
The reason this is the case is because jax.numpy.ndarray
overrides instance checks with a metaclass:
class _ArrayMeta(type(np.ndarray)): # type: ignore
"""Metaclass for overriding ndarray isinstance checks."""
def __instancecheck__(self, instance):
try:
return isinstance(instance.aval, _arraylike_types)
except AttributeError:
return isinstance(instance, _arraylike_types)
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
The reason your code returns what it does is because you have an x
value which is not an instance of numpy.ndarray
, but for which this __instancecheck__
method returns true.
Why this kind of subterfuge in JAX? Well, for the purpose of JIT compilation, auto-differentiation, and other transforms, JAX uses stand-in objects called tracers that are meant to look and act like an array, despite not actually being an array. This overriding of instance checks is one of the tricks JAX uses to make such tracing work.