First:
def compose(f, g):
def wrapper(x):
return f(g(x))
wrapper.__name__ = f'compose({f.__name__}, {g.__name__})'
return wrapper
def ntimes(n):
def wrap(func):
if n == 1: return func
return compose(func, ntimes(n-1)(func))
return wrap
That should be obvious, right? ntimes(3)
is a function that composes any function with itself 3 times, so ntimes(3)(func)(x)
is func(func(func(x)))
.
And now, we just need to call ntimes
on ntimes
with the same n
at both levels. I could write an nntimes
function that does that the same way ntimes
did, but for variety, let's make it flatter:
def nntimes(n, func, arg):
f = ntimes(n)
return f(f)(func)(arg)
So nntimes(n, func, arg)
calls ntimes(n)
on ntimes(n)
, which gives you a function that composes its argument n**n
times, and then calls that function on arg
.
And now we just need a function to pass in. print
doesn't quite work, because it returns None
, so you can't compose it with itself. So:
def printret(x):
print(x, end=' ')
return x
And now we just call it:
>>> nntimes(2, printret, 'Hi')
hi hi hi hi
>>> nntimes(3, printret, 'Hi')
hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi hi
If you still can't understand what's happening, maybe this will help. Let's do something a bit simpler than the general nntimes
and just hardcode three, and then print out the composition:
>>> thrice = ntimes(3)
>>> print(thrice(thrice)(printret).__name__)
compose(compose(compose(printret, compose(printret, printret)), compose(compose(printret, compose(printret, printret)), compose(printret, compose(printret, printret)))), compose(compose(compose(printret, compose(printret, printret)), compose(compose(printret, compose(printret, printret)), compose(printret, compose(printret, printret)))), compose(compose(printret, compose(printret, printret)), compose(compose(printret, compose(printret, printret)), compose(printret, compose(printret, printret))))))
All those parentheses! It's like I've died and gone to Lisp!
If you read up on Church numerals, you'll see that I've sort of cheated here. Write up the trivial functions to Church-encode a number and to exponentiate two Church numerals, then compare it to what my code does. So, have I really avoided calculating the value of n**n
?
Of course you can do this a whole lot more simply with a simple flat recursion and no higher-order functions, or with itertools
(well, you're not allowed to use builtins, but everything in itertools comes with source and/or or a "roughly equivalent" function in the docs, so you can just copy that). But what's the fun in that? After all, if you actually wanted a Pythonic, or simple, or efficient version, you'd just loop over range(n**n)
. I assume the point of this interview question is to force you to think outside the Pythonic box.