-2

Given a list of numbers, say [4 5 2 3], I need to maximize the sum obtained according to the following set of rules:

  1. I need to select a number from the list and that number will be removed. Eg. selecting 2 will have the list as [4 5 3].
  2. If the number to be removed has two neighbours then I should get the result of this selection as the product of the currently selected number with one of its neighbours and this product summed up with the other neighbour. eg.: if I select 2 then I can have the result of this selction as 2 * 5 + 3.
  3. If I select a number with only one neighbour then the result is the product of the selected number with its neighbour.
  4. When their is only one number left then it is just added to the result till now.

Following these rules, I need to select the numbers in such an order that the result is maximized.

For the above list, if the order of selction is 4->2->3->5 then the sum obtained is 53 which is the maximum.

I am including a program which lets you pass as input the set of elements and gives all possible sums and also indicates the max sum.

Here's a link.

import itertools

l = [int(i) for i in input().split()]
p = itertools.permutations(l) 

c, cs = 1, -1
mm = -1
for i in p:
    var, s = l[:], 0
    print(c, ':', i)
    c += 1
    
    for j in i:
        print(' removing: ', j)
        pos = var.index(j)
        if pos == 0 or pos == len(var) - 1:
            if pos == 0 and len(var) != 1:
                s += var[pos] * var[pos + 1]
                var.remove(j)
            elif pos == 0 and len(var) == 1:
                s += var[pos]
                var.remove(j)
            if pos == len(var) - 1 and pos != 0:
                s += var[pos] * var[pos - 1]
                var.remove(j)
        else:
            mx = max(var[pos - 1], var[pos + 1])
            mn = min(var[pos - 1], var[pos + 1])
            
            s += var[pos] * mx + mn
            var.remove(j)
        
        if s > mm:
            mm = s
            cs = c - 1
        print(' modified list: ', var, '\n  sum:', s)

print('MAX SUM was', mm, ' at', cs)
Piyush Ranjan
  • 329
  • 1
  • 4
  • 9
  • 1
    With [4, 2, 3, 5] if you remove the elements in order 3, 2, 4, 5 you get 17 + 14 + 20 + 5 = 56. 3 has neighbours 2, 5 so you get 3*5+2=17. 2 has neighbours 4 and 5 (since you removed 3) and 2*5+4=14. 4 has neighbour 5 so you get 4*5=20. 5 remains. Is there a mistake in the question? – Paul Hankin Aug 09 '20 at 15:49
  • 1
    the given set is [4 5 2 3]. From it if we remove elements according to the order 4-> 2 -> 3 -> 5, we get sum 53. You have wrongly assumed the set [4 2 3 5], which is different from my question – Piyush Ranjan Aug 09 '20 at 15:57
  • 1
    Yes, sorry. I misread the optimal removal sequence you gave as the original sequence. I agree that [4 5 2 3] gives 53. – Paul Hankin Aug 09 '20 at 17:55

1 Answers1

0

Consider 4 variants of the problem: those where every element gets consumed, and those where either the left, the right, or both the right and left elements are not consumed.

In each case, you can consider the last element to be removed, and this breaks the problem down into 1 or 2 subproblems.

This solves the problem in O(n^3) time. Here's a python program that solves the problem. The 4 variants of solve_ correspond to none, one or the other, or both of the endpoints being fixed. No doubt this program can be reduced (there's a lot of duplication).

def solve_00(seq, n, m, cache):
    key = ('00', n, m)
    if key in cache:
        return cache[key]
    assert m >= n
    if n == m:
        return seq[n]
    best = -1e9
    for i in range(n, m+1):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[i])
    cache[key] = best
    return best


def solve_01(seq, n, m, cache):
    key = ('01', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n, m):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[m])
    cache[key] = best
    return best

def solve_10(seq, n, m, cache):
    key = ('10', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n+1, m+1):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[n] * seq[i])
    cache[key] = best
    return best

def solve_11(seq, n, m, cache):
    key = ('11', n, m)
    if key in cache:
        return cache[key]   
    assert m >= n + 2
    if m == n + 2:
        return max(seq[n] * seq[n+1] + seq[n+2], seq[n] + seq[n+1] * seq[n+2])
    best = -1e9
    for i in range(n + 1, m):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[n] + seq[m], left + right + seq[i] * seq[m] + seq[n])
    cache[key] = best
    return best

for c in [[1, 1, 1], [4, 2, 3, 5], [1, 2], [1, 2, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]:
    print(c, solve_00(c, 0, len(c)-1, dict()))
Paul Hankin
  • 54,811
  • 11
  • 92
  • 118