0

Trying to solve this problem with recursion and memoization but for input 7168 I'm getting wrong answer.

    public int numSquares(int n) {
        Map<Integer, Integer> memo = new HashMap();
        List<Integer> list = fillSquares(n, memo);
        if (list == null)
            return 1;
        return helper(list.size()-1, list, n, memo);
    }
    
    private int helper(int index, List<Integer> list, int left, Map<Integer, Integer> memo) {
        
        if (left == 0)
            return 0;
        if (left < 0 || index < 0)
            return Integer.MAX_VALUE-1;
        
        if (memo.containsKey(left)) {
            return memo.get(left);
        }
        
        int d1 = 1+helper(index, list, left-list.get(index), memo);
        int d2 = 1+helper(index-1, list, left-list.get(index),  memo);
        int d3 = helper(index-1, list, left, memo);
        
        int d = Math.min(Math.min(d1,d2), d3);
        memo.put(left, d);
        return d;
    }
    
    private List<Integer> fillSquares(int n, Map<Integer, Integer> memo) {
        int curr = 1;
        List<Integer> list = new ArrayList();
        int d = (int)Math.pow(curr, 2);
        while (d < n) {
            list.add(d);
            memo.put(d, 1);
            curr++;
            d = (int)Math.pow(curr, 2);
        }
        if (d == n)
            return null;
        return list;
    }

I'm calling like this:

numSquares(7168)

All test cases pass (even complex cases), but this one fails. I suspect something is wrong with my memoization but cannot pinpoint what exactly. Any help will be appreciated.

maxam
  • 159
  • 1
  • 9

2 Answers2

1

You have the memoization keyed by the value to be attained, but this does not take into account the value of index, which actually puts restrictions on which powers you can use to attain that value. That means that if (in the extreme case) index is 0, you can only reduce what is left with one square (1²), which rarely is the optimal way to form that number. So in a first instance memo.set() will register a non-optimal number of squares, which later will get updated by other recursive calls which are pending in the recursion tree.

If you add some conditional debugging code, you'll see that map.set is called for the same value of left multiple times, and with differing values. This is not good, because that means the if (memo.has(left)) block will execute for cases where that value is not guaranteed to be optimal (yet).

You could solve this by incorporating the index in your memoization key. This increases the space used for memoization, but it will work. I assume you can work this out.

But according to Lagrange's four square theorem every natural number can be written as the sum of at most four squares, so the returned value should never be 5 or more. You can shortcut the recursion when you get passed that number of terms. This reduces the benefit of using memoization.

Finally, there is a mistake in fillSquares: it should add n itself also when it is a perfect square, otherwise you'll not find solutions that should return 1.

trincot
  • 317,000
  • 35
  • 244
  • 286
  • Thank you. Indeed the issue is not taking into account the ```index``` I'm on. So storing a String of ```index + " " + left``` in the memo helped here. Although this solution is too slow. – maxam Nov 04 '20 at 16:49
0
  • Not sure about your bug, here is a short dynamic programming Solution:

Java

public class Solution {
    public static final int numSquares(
        final int n
    ) {
        int[] dp = new int[n + 1];
        Arrays.fill(dp, Integer.MAX_VALUE);
        dp[0] = 0;

        for (int i = 1; i <= n; i++) {
            int j = 1;
            int min = Integer.MAX_VALUE;

            while (i - j * j >= 0) {
                min = Math.min(min, dp[i - j * j] + 1);
                ++j;
            }

            dp[i] = min;
        }

        return dp[n];
    }
}

C++

// Most of headers are already included;
// Can be removed;
#include <iostream>
#include <cstdint>
#include <vector>
#include <algorithm>

// The following block might slightly improve the execution time;
// Can be removed;
static const auto __optimize__ = []() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);
    return 0;
}();


#define MAX INT_MAX

using ValueType = std::uint_fast32_t;

struct Solution {
    static const int numSquares(
        const int n
    ) {
        if (n < 1) {
            return 0;
        }

        static std::vector<ValueType> count_perfect_squares{0};

        while (std::size(count_perfect_squares) <= n) {
            const ValueType len = std::size(count_perfect_squares);
            ValueType count_squares = MAX;

            for (ValueType index = 1; index * index <= len; ++index) {
                count_squares = std::min(count_squares, 1 + count_perfect_squares[len - index * index]);
            }

            count_perfect_squares.emplace_back(count_squares);
        }

        return count_perfect_squares[n];
    }
};

int main() {
    std::cout <<  std::to_string(Solution().numSquares(12) == 3) << "\n";

    return 0;
}

Python

  • Here we can simply use lru_cache:
class Solution:
    dp = [0]
    @functools.lru_cache
    def numSquares(self, n):
        dp = self.dp
        while len(dp) <= n:
            dp += min(dp[-i * i] for i in range(1, int(len(dp) ** 0.5 + 1))) + 1, 
        return dp[n]

Here are LeetCode's official solutions with comments:

Java: DP

class Solution {

  public int numSquares(int n) {
    int dp[] = new int[n + 1];
    Arrays.fill(dp, Integer.MAX_VALUE);
    // bottom case
    dp[0] = 0;

    // pre-calculate the square numbers.
    int max_square_index = (int) Math.sqrt(n) + 1;
    int square_nums[] = new int[max_square_index];
    for (int i = 1; i < max_square_index; ++i) {
      square_nums[i] = i * i;
    }

    for (int i = 1; i <= n; ++i) {
      for (int s = 1; s < max_square_index; ++s) {
        if (i < square_nums[s])
          break;
        dp[i] = Math.min(dp[i], dp[i - square_nums[s]] + 1);
      }
    }
    return dp[n];
  }
}

Java: Greedy

class Solution {
  Set<Integer> square_nums = new HashSet<Integer>();

  protected boolean is_divided_by(int n, int count) {
    if (count == 1) {
      return square_nums.contains(n);
    }

    for (Integer square : square_nums) {
      if (is_divided_by(n - square, count - 1)) {
        return true;
      }
    }
    return false;
  }

  public int numSquares(int n) {
    this.square_nums.clear();

    for (int i = 1; i * i <= n; ++i) {
      this.square_nums.add(i * i);
    }

    int count = 1;
    for (; count <= n; ++count) {
      if (is_divided_by(n, count))
        return count;
    }
    return count;
  }
}

Java: Breadth First Search

class Solution {
  public int numSquares(int n) {

    ArrayList<Integer> square_nums = new ArrayList<Integer>();
    for (int i = 1; i * i <= n; ++i) {
      square_nums.add(i * i);
    }

    Set<Integer> queue = new HashSet<Integer>();
    queue.add(n);

    int level = 0;
    while (queue.size() > 0) {
      level += 1;
      Set<Integer> next_queue = new HashSet<Integer>();

      for (Integer remainder : queue) {
        for (Integer square : square_nums) {
          if (remainder.equals(square)) {
            return level;
          } else if (remainder < square) {
            break;
          } else {
            next_queue.add(remainder - square);
          }
        }
      }
      queue = next_queue;
    }
    return level;
  }
}

Java: Most efficient solution using math

  • Runtime: O(N ^ 0.5)
  • Memory: O(1)
class Solution {

  protected boolean isSquare(int n) {
    int sq = (int) Math.sqrt(n);
    return n == sq * sq;
  }

  public int numSquares(int n) {
    // four-square and three-square theorems.
    while (n % 4 == 0)
      n /= 4;
    if (n % 8 == 7)
      return 4;

    if (this.isSquare(n))
      return 1;
    // enumeration to check if the number can be decomposed into sum of two squares.
    for (int i = 1; i * i <= n; ++i) {
      if (this.isSquare(n - i * i))
        return 2;
    }
    // bottom case of three-square theorem.
    return 3;
  }
}
Emma
  • 27,428
  • 11
  • 44
  • 69