Search code examples
javamemoizationgeneric-lambda

Java: Implement recursive cached from cached Function and cached BiFunction


TLDR: How to implement this function?

public static <T, R> Function<T, R> cachedRecursive(final BiFunction<T, Function<T,R>, R> bifunc) {
        
    }

I need to somehow extract the second argument from the BiFunction so I can return a proper result for the function.

This project is for learning purposes, although I'm stuck with the last part of my task.

First part of the task is to create a Cache class extended from the LinkedHashMap, and this is my Implementation:

public class Cache<K,V> extends LinkedHashMap<K, V> {


    private static int MaxSize;
    
    public Cache (int maxSize) {
        super(maxSize,1f,false);
        MaxSize = maxSize;
    }
    
    public Cache () {
        super();
    }
    
    public int getMaximalCacheSize () {
        return MaxSize;
    }

    @Override
    protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
        return size() > MaxSize;
    }
}

As for the second part, it is to create a class for which the function definitions will be added:

public class FunctionCache {
    
    private static class Pair<T, U> {
        private T stored_t;
        private U stored_u;
        
        public Pair(T t, U u) {
            stored_t = t;
            stored_u = u;
        }
        
        public boolean equals(Object t) {
            
            if (t == this) {
                return true;
            }
            
            return t == stored_t;
        }
        
        public int hashCode () {
            return stored_t.hashCode();
        }
        
        public T get_first() {
            return stored_t;
        }
        
        public U get_second() {
            return stored_u;
        }
    }
    
    private final static int DEFAULT_CACHE_SIZE = 10000;
    
    public static <T, R> Function<T, R> cached(final Function<T, R> func, int maximalCacheSize) {
        Cache<T, R> cache = new Cache<T,R>(maximalCacheSize);
        return input -> cache.computeIfAbsent(input, func);
    }

    public static <T, R> Function<T, R> cached(final Function<T, R> func) {
        Cache<T, R> cache = new Cache<T,R>(DEFAULT_CACHE_SIZE);
        return input -> cache.computeIfAbsent(input, func);
    }
    
    public static <T, U, R> BiFunction<T, U, R> cached(BiFunction<T, U, R> bifunc, int maximalCacheSize) {
        Cache<T, R> cache = new Cache<T, R>(maximalCacheSize);
        
        return (t, u) -> {
            Pair<T,U> pairKey = new Pair<T,U>(t,u);
            
            Function<Pair<T,U>, R> something = input -> {
                return bifunc.apply(input.get_first(), input.get_second());
            };
            
            if (!cache.containsKey(pairKey.get_first())) {
                R result = something.apply(pairKey);
                cache.put(pairKey.get_first(), result);
                
                return result;
            } else {
                return cache.get(pairKey.get_first());
            }
        };
    }
    
    public static <T, U, R> BiFunction<T, U, R> cached(BiFunction<T, U, R> bifunc) {
        Cache<T, R> cache = new Cache<T, R>(DEFAULT_CACHE_SIZE);
        
        return (t, u) -> {
            Pair<T,U> pairKey = new Pair<T,U>(t,u);
            
            Function<Pair<T,U>, R> something = input -> {
                return bifunc.apply(input.get_first(), input.get_second());
            };
            
            
            if (!cache.containsKey(pairKey.get_first())) {
                R result = something.apply(pairKey);
                cache.put(pairKey.get_first(), result);
                
                return result;
            } else {
                return cache.get(pairKey.get_first());
            }
        };
    }
    
    public static <T, R> Function<T, R> cachedRecursive(final BiFunction<T, Function<T,R>, R> bifunc) {
        
    }
}

This is my problem:

public static <T, R> Function<T, R> cachedRecursive(final BiFunction<T, Function<T,R>, R> bifunc) {
        
    }

I have absolutely no idea how to implement the cachedRecursive function, the previous functions are working with a simple fibonacci test perfectly, However the goal of this task is to implement the cachedRecursive function that takes a BiFunction with the first argument as the input and the second argument a function. Just to complete the code, this is the main class I used to test:

public class cachedFunction extends FunctionCache {


public static void main(String[] args) {
        
        @SuppressWarnings({ "rawtypes", "unchecked" })
        BiFunction<BigInteger, BiFunction, BigInteger> fibHelper = cached((n, f) -> {
            if (n.compareTo(BigInteger.TWO) <= 0) return BigInteger.ONE;
            
            return ((BigInteger) (f.apply(n.subtract(BigInteger.ONE), f)))
                    .add((BigInteger)f.apply(n.subtract(BigInteger.TWO), f));
        }, 50000);
        
        Function<BigInteger, BigInteger> fib = cached((n) -> fibHelper.apply(n,fibHelper));
        
        System.out.println(fib.apply(BigInteger.valueOf(1000L)));
    }
}

Solution

  • There are many drawbacks and mistakes in your code:

    • static size variables shared across different cache instances (therefore breaking it);
    • code duplication;
    • incorrect equals/hashCode contract implementation;
    • suppressing what should be fixed rather than suppressed;
    • the code is overly bloated;
    • and some minor ones (like _-containing lower-cased names, etc).

    If you don't mind, I simplify it:

    final class Functions {
    
        private Functions() {
        }
    
        // memoize a simple "unknown" function -- simply delegates to a private encapsulated method
        static <T, R> Function<T, R> memoize(final Function<? super T, ? extends R> f, final int maxSize) {
            return createCacheFunction(f, maxSize);
        }
    
        // memoize a recursive function
        // note that the bi-function can be converted to an unary function and vice versa
        static <T, R> Function<T, R> memoize(final BiFunction<? super T, ? super Function<? super T, ? extends R>, ? extends R> f, final int maxSize) {
            final Function<UnaryR<T, Function<T, R>>, R> memoizedF = memoize(unaryR -> f.apply(unaryR.t, unaryR.r), maxSize);
            return new Function<T, R>() {
                @Override
                public R apply(final T t) {
                    // this is the "magic"
                    return memoizedF.apply(new UnaryR<>(t, this));
                }
            };
        }
    
        private static <T, R> Function<T, R> createCacheFunction(final Function<? super T, ? extends R> f, final int maxSize) {
            final Map<T, R> cache = new LinkedHashMap<T, R>(maxSize, 1F, false) {
                @Override
                protected boolean removeEldestEntry(final Map.Entry eldest) {
                    return size() > maxSize;
                }
            };
            return t -> cache.computeIfAbsent(t, f);
        }
    
        // these annotations generate proper `equals` and `hashCode`, and a to-string implementation to simplify debugging
        @EqualsAndHashCode
        @ToString
        private static final class UnaryR<T, R> {
    
            @EqualsAndHashCode.Include
            private final T t;
    
            @EqualsAndHashCode.Exclude
            private final R r;
    
            private UnaryR(final T t, final R r) {
                this.t = t;
                this.r = r;
            }
    
        }
    
    }
    

    And the test that tests both results and the memoization contract ("no recalculation, if memoized"):

    public final class FunctionsTest {
    
        @Test
        public void testMemoizeRecursive() {
            final BiFunction<BigInteger, Function<? super BigInteger, ? extends BigInteger>, BigInteger> fib = (n, f) -> n.compareTo(BigInteger.valueOf(2)) <= 0 ? BigInteger.ONE : f.apply(n.subtract(BigInteger.ONE)).add(f.apply(n.subtract(BigInteger.valueOf(2))));
            @SuppressWarnings("unchecked")
            final BiFunction<BigInteger, Function<? super BigInteger, ? extends BigInteger>, BigInteger> mockedFib = Mockito.mock(BiFunction.class, AdditionalAnswers.delegatesTo(fib));
            final Function<BigInteger, BigInteger> memoizedFib = Functions.memoize(mockedFib, 1000);
            final BigInteger memoizedResult = memoizedFib.apply(BigInteger.valueOf(120));
            Mockito.verify(mockedFib, Mockito.times(120))
                    .apply(Matchers.any(), Matchers.any());
            Assertions.assertEquals("5358359254990966640871840", memoizedResult.toString());
            Assertions.assertEquals(memoizedResult, memoizedFib.apply(BigInteger.valueOf(120)));
            Mockito.verifyNoMoreInteractions(mockedFib);
        }
    
    }