I'm using strict type checks via pyright.
When I have a method that returns a pytorch DataLoader
, then pyright complains about my type definition:
Declared return type, "DataLoader[Unknown]", is partially unknown Pyright (reportUnknownVariableType)
Taking a look at the type stub from pytorch's DataLoader
(reduced to the important parts):
class DataLoader(Generic[T_co]):
dataset: Dataset[T_co]
@overload
def __init__(self, dataset: Dataset[T_co], ...
As far as I can see, the generic type T_co
of the DataLoader
should be defined by the __init__
dataset parameter.
Pyright also complains about my Dataset
type definition:
Type of parameter "dataset" is partially unknown Parameter type is "Dataset[Unknown]" Pyright (reportUnknownParameterType)
Taking a look at the Dataset
type stub:
class Dataset(Generic[T_co]):
def __getitem__(self, index: int) -> T_co: ...
shows to me that the type should be inferred by the return type of __getitem__
.
My dataset's type signature of __getitem__
looks like this:
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
Based on this I would expect Dataset
and DataLoader
to be inferred as Dataset[Tuple[Tensor, Tensor]]
and DataLoader[Tuple[Tensor, Tensor]]
but that is not the case.
My guess is that pyright fails to statically infer the types here.
I thought I could define the type signature my self like this:
Dataset[Tuple[Tensor, Tensor]]
but that actually results in my python script crashing with:
TypeError: 'type' object is not subscriptable
How can I properly define the type for Dataset
and DataLoader
?