0

I am trying to make a Python object (called a Stroke) that can be used as elements in NumPy arrays and operated on by NumPy universal functions.

I wish to do the following comparison operation:

x = Stroke(...)
y = Stroke(...)
z = Stroke(...)

arr = np.array([x, y, z], dtype=object)

output = np.greater(arr, 5)

# OR

output = np.greater(arr, Stroke(...))

For my case, I would also like to custom define how the universal function, np.greater, works on arrays of Stroke's. I tried making this custom definition using the __numpy_ufunc__ method as described here.

Unfortunately, when calling np.greater, the Stroke.__gt__ dunder method is used to evaluate np.greater rather than my custom defined function. How can I force NumPy to use of the custom defined np.greater function instead of the dunder method without removing the dunder method?

A relevant snippet of the Stroke code is seen below:

from polare._numpy_ufunc_overrides import HANDLED_FUNCTIONS
import numpy as np


class Stroke:

    def __gt__(self, other):

        return self._binary_operation(np.greater, other)
    
    ...


    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

        if method == '__call__':

            if ufunc in HANDLED_FUNCTIONS:
                func = HANDLED_FUNCTIONS[ufunc]
                return self._handle_functions(func, *inputs)
            else:
                # Handle the ufunc in a different way.

        else:

            return NotImplemented

where the HANDLED_FUNCTIONS dictionary is populated with the following code:

import numpy as np


HANDLED_FUNCTIONS = {}


def implements(np_function):
    def decorator(func):
        HANDLED_FUNCTIONS[np_function] = func
        return func
    return decorator


@implements(np.greater)
def greater(a: np.ndarray, b: np.ndarray) -> list:
    ...
  • Needs a [mcve]. What do you mean by "used in arrays"? `np.array([Stroke(), Stroke()]` or something else? – hpaulj Feb 02 '22 at 02:07
  • Thanks for pointing this out, I have made edits. – Jai Willems Feb 02 '22 at 02:19
  • So `arr` is `object` dtype? Normally `np.greater(arr, 5)` would be the equivalent of `np.array([ x>5 for x in arr])`. I don't know what whether your added definitions change what's happening. `arr>5` would do the same. Most operators work with objects; functions like `np.exp(arr)` raise an error, saying that ` does not have a exp method`.. – hpaulj Feb 02 '22 at 02:26
  • The problem that arises, is when the `Stroke.__gt__` method is used it casts the elements of `arr` to `bool` types; I want to keep the elements as `Stroke` objects but modified internally. I left out the code but I have defined general handling of ufunc operations so I do not get the error you noted. – Jai Willems Feb 02 '22 at 02:38

0 Answers0