You can do this in O(n^3)
time (for any number of bracket types) with dynamic programming. This is not a lower bound on the runtime, but it appears that a greedy approach doesn't work for this problem.
First, it's helpful to realize that the 'minimum additions to balance' is the same as the 'minimum deletions to balance' for any bracket string, since the deletion framework is easier to work with. To see why this is true, consider a minimum set of additions: for every bracket that is now matched but was unmatched before, we could have also deleted that bracket, and vice versa.
The idea is to compute all possible bracket pairs: create a list of all indices [i, j], 0 <= i < j < n
, where s[i]
and s[j]
are an open and closed bracket pair of the same type. Then, we find the maximum number of intervals [i, j]
we can have, such that any two intervals are either nested or disjoint. This is exactly the requirements to be balanced, and, if you're curious, means that we're looking for the maximum size trivially perfect subgraph of the intersection graph formed by our intervals.
There are O(n^2)
intervals, so any modification of this approach has an O(n^2)
lower bound. We sort these intervals (by start, then by end if tied), and use dynamic programming (DP) to find the maximum number of nested or disjoint intervals we can have.
Our DP equation has 3 parameters: left, right, and min_index. [left, right]
is an inclusive range of indices of s
we are allowed to use, and min_index
is the smallest index (in our interval list) interval we are allowed to use. If we know the leftmost interval that we can feasibly use, say, [start, end]
, the answer will come from either using or not using this interval. If we don't use it, we get dp(left, right, min_index+1)
. If we do use the interval, we add the maximum number of intervals we can nest inside (start, end)
, plus the maximum number of intervals starting strictly after end
. This is 1 + dp(start+1, end-1, min_index+1) + dp(end+1, right, min_index+1)
.
For a fuller definition:
dp(left, right, min_index) :=
maximum number of intervals from interval_list[min_index:]
that are contained in [left, right] and all pairwise nested or disjoint.
Also, let
first_index := max(smallest index of an interval starting at or after left,
min_index)
so that interval_list[first_index] = (first_start, first_end).
dp(left, right, min_index) = 0 if (left > right or first_index >= length(interval_list)),
max(dp(left, right, first_index+1),
1
+ dp(first_start+1, first_end-1, first_index+1)
+ dp(first_end+1, right, first_index+1))
otherwise.
Here's a Python implementation of the algorithm:
def balance_multi_string(s: str) -> int:
"""Given a multi-paren string s, return minimum deletions to balance
it. 'Balanced' means all parentheses are matched, and
all pairs from different types are either nested or disjoint
Runs in O(n^3) time.
"""
open_brackets = {'{', '[', '('}
closed_brackets = {'}', ']', ')'}
bracket_partners = {'{': '}', '[': ']', '(': ')',
'}': '{', ']': '[', ')': '('}
n = len(s)
bracket_type_to_open_locations = collections.defaultdict(list)
intervals = []
for i, x in enumerate(s):
if x in closed_brackets:
for j in bracket_type_to_open_locations[bracket_partners[x]]:
intervals.append((j, i))
else:
bracket_type_to_open_locations[x].append(i)
if len(intervals) == 0:
return n
intervals.sort()
num_intervals = len(intervals)
@functools.lru_cache(None)
def point_to_first_interval_strictly_after(point: int) -> int:
"""Given a point, return index of first interval starting
strictly after, or num_intervals if there is none."""
if point > intervals[-1][0]:
return num_intervals
if point < intervals[0][0]:
return 0
return bisect.bisect_right(intervals, (point, n + 2))
@functools.lru_cache(None)
def dp(left: int, right: int, min_index: int) -> int:
"""Given inclusive range [left,right], and minimum interval index,
return the maximum number of intervals we can add
within this range so that all added intervals
are either nested or disjoint."""
if left >= right or min_index >= num_intervals:
return 0
starting_idx = max(point_to_first_interval_strictly_after(left - 1), min_index)
if starting_idx == num_intervals or intervals[starting_idx][0] >= right:
return 0
first_start, first_end = intervals[starting_idx]
best_answer = dp(first_start, right, starting_idx + 1) # Without first interval
if first_end <= right: # If we include the first interval
best_answer = max(best_answer,
1
+ dp(first_start + 1, first_end - 1, starting_idx + 1)
+ dp(first_end + 1, right, starting_idx + 1))
return best_answer
return n - 2 * dp(0, n - 1, 0)
Examples:
( [ ( [ ) } ] --> 3
} ( } [ ) [ { } --> 4
} ( } } ) ] ) { --> 6
{ ) { ) { [ } } --> 4
) ] } { } [ ( { --> 6
] ) } } ( [ } { --> 8