3

I made some implementations of MaxPool2d(Running correctly, comparing with a pytorch). When testing this on a mnist dataset, this function(updateOutput) takes a very long time to complete. How to optimize this code using numpy?

class MaxPool2d(Module):
    def __init__(self, kernel_size):
        super(MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.gradInput = None

    def updateOutput(self, input):
        #print("MaxPool updateOutput")
        #start_time = time.time()
        kernel = self.kernel_size
        poolH = input.shape[2] // kernel
        poolW = input.shape[3] // kernel
        self.output = np.zeros((input.shape[0], 
                                input.shape[1], 
                                poolH,
                                poolW))
        self.index = np.zeros((input.shape[0],
                                    input.shape[1],
                                    poolH,
                                    poolW,
                                    2), 
                                    dtype='int32')

        for i in range(input.shape[0]):
            for j in range(input.shape[1]):
                for k in range(0, input.shape[2] - kernel+1, kernel):
                    for m in range(0, input.shape[3] - kernel+1, kernel):
                        M = input[i, j, k : k+kernel, m : m+kernel]
                        self.output[i, j, k // kernel, m // kernel] = M.max()
                        self.index[i, j, k // kernel, m // kernel] = np.array(np.unravel_index(M.argmax(), M.shape)) + np.array((k, m))

        #print(f"time: {time.time() - start_time:.3f}s")
        return self.output

input shape = (batch_size, n_input_channels, h, w)

output shape = (batch_size, n_output_channels, h // kern_size, w // kern_size)

annaFerdsf
  • 329
  • 2
  • 14

1 Answers1

2

For clarity I've simplified your example by removing batch size and channels dimensions. Most of time is spent on calculation of M.max(). I've created benchmark function update_output_b to do this loop with constant array of ones.

import time
import numpy as np

def timeit(cycles):
    def timed(func):
        def wrapper(*args, **kwargs):
            start_t = time.time()
            for _ in range(cycles):
                func(*args, **kwargs)
            t = (time.time() - start_t) / cycles
            print(f'{func.__name__} mean execution time: {t:.3f}s')

        return wrapper
    return timed

@timeit(100)
def update_output_b(input, kernel):
    ones = np.ones((kernel, kernel))

    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    output = np.zeros((pool_h, pool_w))

    for i in range(0, input.shape[0] - kernel + 1, kernel):
        for j in range(0, input.shape[1] - kernel + 1, kernel):
            output[i // kernel, j // kernel] = ones.max()

    return output

in_arr = np.random.rand(3001, 200)
update_output_b(in_arr, 3)

Its output is update_output_b mean execution time: 0.277s as it doesn't use numpy fully vectorized operations. When it is possible, you should always prefere native numpy functions over loops.

In addition, using slices of input array slow execution as access to continuous memory is in most cases faster.

@timeit(100)
def update_output_1(input, kernel):
    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    output = np.zeros((pool_h, pool_w))

    for i in range(0, input.shape[0] - kernel + 1, kernel):
        for j in range(0, input.shape[1] - kernel + 1, kernel):
            M = input[i : i + kernel, j : j + kernel]
            output[i // kernel, j // kernel] = M.max()

    return output

update_output_1(in_arr, 3)

Code returns update_output_1 mean execution time: 0.332s (+55ms comparing to previous one)

I've added vectorized code bellow. It works ~20x faster (update_output_2 mean execution time: 0.015s), however it is probably far from optimal.

@timeit(100)
def update_output_2(input, kernel):
    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    input_h = pool_h * kernel
    input_w = pool_w * kernel

    # crop input
    output = input[:input_h, :input_w]
    # calculate max along second axis
    output = output.reshape((-1, kernel))
    output = output.max(axis=1)
    # calculate max along first axis
    output = output.reshape((pool_h, kernel, pool_w))
    output = output.max(axis=1)

    return output

update_output_2(in_arr, 3)

It generates output in 3 steps:

  • Cropping input to size divisible by kernel
  • Calculating max along second axis (it reduce offsets between slices in first axis)
  • Calculating max along first axis

Edit:

I've added modifications for retrieving indexes of max values. However, you should check index arithmetics as I've only tested it on a random array.

It calculate output_indices along second axis in ech window and then uses output_indices_selector to select maximum along second one.

def update_output_3(input, kernel):
    pool_h = input.shape[0] // kernel
    pool_w = input.shape[1] // kernel
    input_h = pool_h * kernel
    input_w = pool_w * kernel

    # crop input
    output = input[:input_h, :input_w]

    # calculate max along second axis
    output_tmp = output.reshape((-1, kernel))
    output_indices = output_tmp.argmax(axis=1)
    output_indices += np.arange(output_indices.shape[0]) * kernel
    output_indices = np.unravel_index(output_indices, output.shape)
    output_tmp = output[output_indices]

    # calculate max along first axis
    output_tmp = output_tmp.reshape((pool_h, kernel, pool_w))
    output_indices_selector = (kernel * pool_w * np.arange(pool_h).reshape(pool_h, 1))
    output_indices_selector = output_indices_selector.repeat(pool_w, axis=1)
    output_indices_selector += pool_w * output_tmp.argmax(axis=1)
    output_indices_selector += np.arange(pool_w)
    output_indices_selector = output_indices_selector.flatten()

    output_indices = (output_indices[0][output_indices_selector],
                      output_indices[1][output_indices_selector])
    output = output[output_indices].reshape(pool_h, pool_w)

    return output, output_indices
  • Thanks, your code works great. But how can I keep the indexes of the maximum elements? I need them for backward – annaFerdsf Jun 18 '20 at 17:29
  • Could you tell me how to better find the indices of the maximum elements? I try using np.unravel_index , but nothing comes out :( – annaFerdsf Jun 19 '20 at 09:27
  • 1
    You can check the update I've posted. It's rather obscure solution so best way to understand it is whiteboard debugging on a simple example. – Jakub Gąsiewski Jun 19 '20 at 14:48