6

I'm trying to use overloading to make the return type of a variadic function depend on the type of its arguments in a certain way. Specifically, I want the return type to be X if and only if any of its arguments is of type X.

Consider the following minimal example:

from typing import overload

class Safe:
    pass

class Dangerous:
    pass

@overload
def combine(*args: Safe) -> Safe: ...

@overload
def combine(*args: Safe | Dangerous) -> Safe | Dangerous: ...

def combine(*args: Safe | Dangerous) -> Safe | Dangerous:
    if all(isinstance(arg, Safe) for arg in args):
        return Safe()
    else:
        return Dangerous()

reveal_type(combine())
reveal_type(combine(Safe()))
reveal_type(combine(Dangerous()))
reveal_type(combine(Safe(), Safe()))
reveal_type(combine(Safe(), Dangerous()))

This outputs

example.py:21: note: Revealed type is "example.Safe"
example.py:22: note: Revealed type is "example.Safe"
example.py:23: note: Revealed type is "Union[example.Safe, example.Dangerous]"
example.py:24: note: Revealed type is "example.Safe"
example.py:25: note: Revealed type is "Union[example.Safe, example.Dangerous]"
Success: no issues found in 1 source file

I want to set things up so that the inferred types of combine(Dangerous()) and combine(Safe(), Dangerous()), for example, are Dangerous rather than Safe | Dangerous. Changing the return type of the second overload to just Dangerous yields an error:

example.py:10: error: Overloaded function signatures 1 and 2 overlap with incompatible return types  [misc]
example.py:21: note: Revealed type is "example.Safe"
example.py:22: note: Revealed type is "example.Safe"
example.py:23: note: Revealed type is "example.Dangerous"
example.py:24: note: Revealed type is "example.Safe"
example.py:25: note: Revealed type is "example.Dangerous"
Found 1 error in 1 file (checked 1 source file)

Thus it seems that I need a way to annotate the second overload to explicitly state that at least one of its arguments is Dangerous. Is there a way to do this?

It occurs to me that the desired type for the argument sequence is Sequence[Safe | Dangerous] - Sequence[Safe], but I don't think type subtraction is supported yet.

user76284
  • 1,269
  • 14
  • 29
  • 1
    Hmm, I wonder if using [`*args` as a Type Variable Tuple](https://peps.python.org/pep-0646/#args-as-a-type-variable-tuple) from PEP 646 will be useful here. Though I don't know if mypy has support for it yet. – Brian61354270 Mar 26 '23 at 20:58
  • @Aran-Fey Added more cases for clarity. – user76284 Mar 26 '23 at 21:24
  • FWIW, pyright accepts the solution of removing the `Safe |` from the 2nd overload. – Aran-Fey Mar 26 '23 at 21:25
  • 1
    See my answer [here](https://stackoverflow.com/questions/60222982/type-annotation-for-overloads-that-exclude-types-something-vs-everything-else/74567241#74567241) for a very similar question (incompatible overloads on X and `Any \ X`), which explains the error in detail. – STerliakov Mar 26 '23 at 21:45

1 Answers1

2

Using PyRight, your solution actually works. And mypy can also be forced to accept it with a simple # type: ignore:

@overload
def combine(*args: Safe) -> Safe: ...  # type: ignore

@overload
def combine(*args: Safe | Dangerous) -> Dangerous: ...

def combine(*args: Safe | Dangerous) -> Safe | Dangerous:
    if all(isinstance(arg, Safe) for arg in args):
        return Safe()
    else:
        return Dangerous()


reveal_type(combine())  # Safe
reveal_type(combine(Safe()))  # Safe
reveal_type(combine(Dangerous()))  # Dangerous
reveal_type(combine(Safe(), Safe()))  # Safe
reveal_type(combine(Safe(), Dangerous()))  # Dangerous

But beware that this can lead to discrepancies between static and dynamic types if the input type isn't exactly right:

arg: Safe | Dangerous = random.choice([Safe(), Dangerous()])
result = combine(arg)
reveal_type(result)  # Dangerous
print(type(result))  # Can be Safe or Dangerous
Aran-Fey
  • 39,665
  • 11
  • 104
  • 149
  • 2
    The reason why this is (correctly) forbidden by `mypy` is the following: `x: tuple[Safe | Dangerous, ...] = (Safe(),); reveal_type(combine(*x))` (guess what). Here's [the playground link](https://mypy-play.net/?mypy=master&python=3.10&flags=strict&gist=01868d5182b71871c71be7d80ea6fc51). Would be great to add this "beware" to the answer. – STerliakov Mar 26 '23 at 21:42
  • @SUTerliakov I don't see the problem? `Safe | Dangerous` goes in, `Dangerous` comes out. That's how it should be, isn't it? If you expected to get `Safe` as output, you shouldn't have declared `x` as `Safe | Dangerous`. Garbage in, garbage out. – Aran-Fey Mar 26 '23 at 22:08
  • 2
    Yes, but this is why the overloads are not strictly safe. If the subtype gives result incompatible with supertype, you're violating LSP. It should be true that for any subtype of `T` `combine` produces result compatible with `combine` result on `T` itself. Upcasting should never produce results incompatible with strict type (for example, because it breaks inference). – STerliakov Mar 26 '23 at 22:23
  • E.g. if the second overload returns `Safe|Dangerous` and first - only `Safe`, it's alright, because with upcasting you'll get just less specific type (union instead of plain Safe). – STerliakov Mar 26 '23 at 22:26
  • @SUTerliakov Ah, now I see what you mean. That's a fair point. I've added a warning. – Aran-Fey Mar 26 '23 at 22:46