Firstly lets solve this for just two sets. This is known as the 'two sum' problem. You have two sets a
and b
that add to l
. Since a + b = l
we know that l - a = b
. This is important as we can determine if l - a
is in b
in O(1) time. Rather than looping through b
to find it in O(b) time. This means we can solve the 2 sum problem in O(a) time.
Note: For brevity the provided code only produces one solution. However changing two_sum
to a generator function can return them all.
def two_sum(l, a, b):
for i in a:
if l - i in b:
return i, l - i
raise ValueError('No solution found')
Next we can solve the 'four sum' problem. This time we have four sets c
, d
, e
and f
. By combining c
and d
into a
, and e
and f
into b
we can use two_sum
to solve the problem in O(cd + ef) space and time. To combine the sets we can just use a cartesian product, adding the results together.
Note: To get all results perform a cartesian product on all resulting a[i]
and b[j]
.
import itertools
def combine(*sets):
result = {}
for keys in itertools.product(*sets):
results.setdefault(sum(keys), []).append(keys)
return results
def four_sum(l, c, d, e, f):
a = combine(c, d)
b = combine(e, f)
i, j = two_sum(l, a, b)
return (*a[i][0], *b[j][0])
It should be apparent that the 'three sum' problem is just a simplified version of the 'four sum' problem. The difference is that we're given a
at the start rather than being asked to calculate it. This runs in O(a + ef) time and O(ef) space.
def three_sum(l, a, e, f):
b = combine(e, f)
i, j = two_sum(l, a, b)
return (i, *b[j][0])
Now we have enough information to solve the 'six sum' problem. The question comes down to how do we divide all these sets?
- If we decide to pair them together then we can use the 'three sum' solution to get what we want. But this may not run in the best time, as it runs in O(ab + bcde), or O(n^4) time if they're all the same size.
- If we decide to put them in trios then we can use the 'two sum' to get what we want. This runs in O(abc + def), or O(n^3) if they're all the same size.
At this point we should have all the information to make a generic version that runs in O(n^⌈s/2⌉) time and space. Where s is the amount of sets entered into the function.
def n_sum(l, *sets):
midpoint = len(sets) // 2
a = combine(*sets[:midpoint])
b = combine(*sets[midpoint:])
i, j = two_sum(l, a, b)
return (*a[i][0], *b[j][0])
You can further optimize the code. The size of both sides of the two sum matter quite a lot.
To exemplify this you can imagine 4 sets of 1 number on one side and 4 sets of 1000 numbers on the other. This will run in O(1^4 + 1000^4) time. Which is obviously really bad. Instead you can balance both sides of the two sum to make it much smaller. By having 2 sets of 1 number and 2 sets of 1000 numbers on both sides of the equation the performance increases; O(1^2×1000^2 + 1^2×1000^2) or simply O(1000^2). Which is far smaller than O(1000^4).
Expanding on the previous point if you have 3 sets of 1000 numbers and 3 sets of 10 numbers then the best solution is to put two 1000s on one side and everything else on the other side:
- 1000^2 + 10^3×1000 = 2_000_000
- Interlaced sorted and same size either side (10, 1000, 10), (1000, 10, 1000)
10^2×1000 + 10×1000^2 = 10_100_000
Additionally if there is an even amount of each set provided then you can cut the time it takes to run in half by only calling combine
once. For example if the input is n_sum(l, a, b, c, a, b, c)
(without the above optimizations) it should be apparent that the second call to combine
is only a waste of time and space.