The max
and exp
operations are fundamentally different; exp
(and other operations like addition, sin
, etc.) is an elementwise operation that is embarrassingly parallelizable, while max
requires a parallel-processing scan algorithm that basically builds up a tree of pairwise comparisons over an array. It's not impossible to speed up max
, but it's not as easy as exp
.
Anyway, the theano
implementation of max
basically consists of the following lines (in theano/tensor/basic.py):
try:
out = max_and_argmax(x, axis)[0]
except Exception:
out = CAReduce(scal.maximum, axis)(x)
where max_and_argmax
is a bunch of custom code that, to my eye, implements a max+argmax operation using numpy
, and CAReduce
is a generic GPU-accelerated scan operation used as a fallback (which, according to the comments, doesn't support grad
etc.). You could try using the fallback directly and see whether that is faster, maybe something like this:
from theano.tensor.elemwise import CAReduce
from theano.scalar import maximum
def mymax(X, axis=None):
CAReduce(maximum, axis)(X)