0

Let's say c = a + b, but a and b are ndarrays, whose shapes are not necessarily the same. That is, they could be any two arrays that follow the general broadcasting rules.

I have the deriviative of some output dl/dc, and I'd like to compute dl/da. If a and b were of the same shape, then dl/da = dl/db = dl/dc. However, I might have some addition like this where a.shape == (3,) and b.shape == (2,3), so c[i][j] = a[j] + b[i][j]. Which means that dl/da[j] = sum_i c[i][j]. In general, dl/da is the sum of dl/dc over all axes that were broadcast in a.

To compute the chain rule derivatives of a and b in general, I wrote the following function, but I feel it's not very pythonic, and could probably be done more efficiently:

def addition_derivatives(x, y, d):
    flip = False
    if x.ndim < y.ndim:  # x should have higher ndim
        flip = True
        x, y = y, x

    S = x.shape  # shape of array with higher ndim
    s = y.shape  # shape of array with lower ndim

    # figure out which axes will be broadcast in which arrays
    n = len(S)
    # impute missing ones in the shape of the smaller array as per:
    # https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
    s = tuple(1 if i < len(S) - len(s) else s[i - (len(S) - len(s))] for i in range(n))
    axis_x = []
    axis_y = []
    for i in range(n):
        assert s[i] == S[i] or s[i] == 1 or S[i] == 1
        if S[i] == 1 and s[i] != 1:
            axis_x.append(i)
        if s[i] == 1 and S[i] != 1:
            axis_y.append(i)
    axis_x, axis_y = map(tuple, (axis_x, axis_y))

    # compute the derivatives
    dx = np.sum(d, axis=axis_x).reshape(x.shape)
    dy = np.sum(d, axis=axis_y).reshape(y.shape)
    if flip:
        dx, dy = dy, dx

    return dx, dy
michaelsnowden
  • 6,031
  • 2
  • 38
  • 83

1 Answers1

0

I actually ended up finding a sort of hack to do this using np.broadcast_arrays and np.strides. I'm not sure this will work in all cases, but it has worked so far because np.strides returns 0 for all axes with dimension 1.

def addition_derivatives(x, y, d):
    bx, by = np.broadcast_arrays(x, y)
    ax = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx == 0 and dy != 0)
    ay = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx != 0 and dy == 0)
    dx = np.sum(d, ax).reshape(x.shape)
    dy = np.sum(d, ay).reshape(y.shape)
    return dx, dy
michaelsnowden
  • 6,031
  • 2
  • 38
  • 83