3

I want to write a utility for general memoization in Java, I want the code be able to look like this:

Util.memoize(() -> longCalculation(1));

where

private Integer longCalculation(Integer x) {
    try { 
        Thread.sleep(1000);
    } catch (InterruptedException ignored) {}
    return x * 2;
}

To do this, I was thinking I could do something like this:

public class Util{
    private static final Map<Object, Object> cache = new ConcurrentHashMap<>();
    public interface Operator<T> {
        T op();
    }

    public static<T> T memoize(Operator<T> o) {
        ConcurrentHashMap<Object, T> memo = cache.containsKey(o.getClass()) ? (ConcurrentHashMap<Object, T>) cache.get(o.getClass()) : new ConcurrentHashMap<>();
        if (memo.containsKey(o)) {
            return memo.get(o);
        } else {
            T val = o.op();
            memo.put(o, val);
            return val;
        }
    }
}

I was expecting this to work, but I see no memoization being done. I have tracked it down to the o.getClass() being different for each invocation. I was thinking that I could try to get the run-time type of T but I cannot figure out a way of doing that.

ccoutinho
  • 3,308
  • 5
  • 39
  • 47
saidaspen
  • 560
  • 1
  • 5
  • 13

3 Answers3

4

The answer by Lino points out a couple of flaws in the code, but doesn't work if not reusing the same lambda.

This is because o.getClass() does not return the class of what is returned by the lambda, but the class of the lambda itself. As such, the below code returns two different classes:

Util.memoize(() -> longCalculation(1));
Util.memoize(() -> longCalculation(1));

I don't think there is a good way to find out the class of the returned type without actually executing the potentially long running code, which of course is what you want to avoid.

With this in mind I would suggest passing the class as a second parameter to memoize(). This would give you:

@SuppressWarnings("unchecked")
public static <T> T memoize(Operator<T> o, Class<T> clazz) {
  return (T) cache.computeIfAbsent(clazz, k -> o.op());
}

This is based on that you change the type of cache to:

private static final Map<Class<?>, Object> cache = new ConcurrentHashMap<>();

Unfortunately, you have to downcast the Object to a T, but you can guarantee that it is safe with the @SuppressWarnings("unchecked") annotation. After all, you are in control of the code and know that the class of the value will be the same as the key in the map.

An alternative would be to use Guavas ClassToInstanceMap:

private static final ClassToInstanceMap<Object> cache = MutableClassToInstanceMap.create(new ConcurrentHashMap<>());

This, however, doesn't allow you to use computeIfAbsent() without casting, since it returns an Object, so the code would become a bit more verbose:

public static <T> T memoize(Operator<T> o, Class<T> clazz) {
  T cachedCalculation = cache.getInstance(clazz);
  if (cachedCalculation != null) {
    return cachedCalculation;
  }
  T calculation = o.op();
  cache.put(clazz, calculation);
  return calculation;
}

As a final side note, you don't need to specify your own functional interface, but you can use the Supplier interface:

@SuppressWarnings("unchecked")
public static <T> T memoize(Supplier<T> o, Class<T> clazz) {
  return (T) cache.computeIfAbsent(clazz, k -> o.get());
}
Magnilex
  • 11,584
  • 9
  • 62
  • 84
2

The problem you have is in the line:

ConcurrentHashMap<Object, T> memo = cache.containsKey(o.getClass()) ? (ConcurrentHashMap<Object, T>) cache.get(o.getClass()) : new ConcurrentHashMap<>();

You check whether an entry with the key o.getClass() exists. If yes, you get() it else you use a newly initialized ConcurrentHashMap. The problem now with that is, you don't save this newly created map, back in the cache.

So either:

  • Place cache.put(o.getClass(), memo); after the line above
  • Or even better use the computeIfAbsent() method:

    ConcurrentHashMap<Object, T> memo = cache.computeIfAbsent(o.getClass(), 
                                                              k -> new ConcurrentHashMap<>());
    

Also because you know the structure of your cache you can make it more typesafe, so that you don't have to cast everywhere:

private static final Map<Object, Map<Operator<?>, Object>> cache = new ConcurrentHashMap<>();

Also you can shorten your method even more by using the earlier mentioned computeIfAbsent():

public static <T> T memoize(Operator<T> o) {
    return (T) cache
        .computeIfAbsent(o.getClass(), k -> new ConcurrentHashMap<>())
        .computeIfAbsent(o, k -> o.op());
}
  1. (T): simply casts the unknown return type of Object to the required output type T
  2. .computeIfAbsent(o.getClass(), k -> new ConcurrentHashMap<>()): invokes the provided lambda k -> new ConcurrentHashMap<>() when there is no mapping for the key o.getClass() in cache
  3. .computeIfAbsent(o, k -> o.op());: this is invoked on the returned value from the computeIfAbsent call of 2.. If o doesn't exist in the nested map then execute the lambda k -> o.op() the return value is then stored in the map and returned.
Lino
  • 19,604
  • 6
  • 47
  • 65
  • Thanks! This works great in the case were I save the lambda as named like so: `Operator someOp = () -> longCalculation(1);` but not in this case: `memoize(() -> longCalculation(1);`. It seems `o.getClass()` is different each time – saidaspen May 07 '20 at 08:10
  • 2
    `o.getClass()` does not return `Integer.class`, but something like: _class some.package$$Lambda$1/1107024580_. This means that the class is not the same for two invocations. So I don't think this solution works, but it points out some flaws in the original code. If you use the same lambda twice, as your first example, it should work. – Magnilex May 07 '20 at 08:18
  • 1
    @saidaspen sadly as mentioned by Magnilex two lambdas of the form `() -> longCalculation(1)` will never be the same, even when they have the same body. You can only use this `memoize` method if you save the lambda you're invoking into a variable – Lino May 07 '20 at 08:26
0

Old question, I know, but I didn't see this answer (for Functions):

public static <T, R> Function<T, R> memoize(Function<T, R> f) {
    var cache = new ConcurrentHashMap<T, R>();
    return t -> cache.computeIfAbsent(t, f::apply);
}

cache, in other words, doesn't have to be a static field of some other class; you can just close over it.

For memoizing Suppliers, see https://stackoverflow.com/a/35335467/208288.

Laird Nelson
  • 15,321
  • 19
  • 73
  • 127