3

Consider two arrays A and B. The element at index i in array A is associated with element at index i in array B. We can think of them as a pair. We have some queries q in form of (a, b). We need to find the count of all such elements for which A[i] > a and B[i] > b.

Constraints - 
n (size of array) <= 10^5
q (count of queries) <= 10^5
 

Example - 
A = [1,  3, 6, 7, 2]
B = [10, 7, 2, 6, 4]
q = [(2, 6), (3, 9), (0, 1)]

Output - 
[1, 0, 5]

Explanation-

For query (2, 6) there is only one entity such that A[i] > 2 and B[i] > 6. For the first condition A[i] > 2 we have three candidates - 3, 6, 7 but based on second condition B[i] > 6 for these candidates there is only one answer that is candidate with value 3 in first array (3, 7).

I have tried the brute force approach of linear search but that leads to TLE.

גלעד ברקן
  • 23,602
  • 3
  • 25
  • 61

2 Answers2

1

If the queries are provided offline, there's no need for a quad tree and we can solve this in O(n log n). Insert all query-pairs and array-pairs into one list, sorted by a or A[i] (if an a is equal to an A[i], place the query pair after the array pair). Process the pairs in the list by descending order (of A[i] or a). If it's an array pair, insert it into an order-statistic tree ordered by B[i]. If it's a query, look up in the tree the count of tree nodes (these are array pairs) that have B[i] > b (we already know all pairs in the tree have A[i] > a).

Python code:

# Order statistic treap

import random

class Treap:
  def __init__(self, val=None):
    self.val = val
    self.size = 1
    self.key = random.random()
    self.left = None
    self.right = None

  def __repr__(self):
    return str({"val": self.val, "size": self.size, "key": self.key, "left": self.left, "right": self.right})

def size(t):
  return t.size if t else 0

def update(t):
  if t:
    t.size = 1 + size(t.left) + size(t.right)
  return t

def insert(t, node):
  if not t:
    return node

  # t above
  if node.key > t.key:
    if node.val > t.val:
      t.right = insert(t.right, node)
      return update(t)
    else:
      t.left = insert(t.left, node)
      return update(t)
  # node above
  else:
    if node.val > t.val:
      node.left = insert(node.left, t)
      return update(node)
    else:
      node.right = insert(node.right, t)
      return update(node)

def query(t, val):
  if not t:
    return 0

  if val < t.val:
    return 1 + size(t.right) + query(t.left, val)
  else:
    return query(t.right, val)


def merge_queries(lst, Q):
  result = [None] * (len(lst) + len(Q))

  i = 0
  j = 0

  for k in range(len(result)):
    if i < len(lst) and (j == len(Q) or lst[i][0] <= Q[j][0]):
      result[k] = lst[i]
      i += 1
    else:
      result[k] = Q[j]
      j += 1

  return result


def f(A, B, Q):
  sorted_zip = sorted(zip(A, B))
  sorted_queries = sorted([(a, b, i) for i, (a, b) in enumerate(Q)])
  merged = merge_queries(sorted_zip, sorted_queries)

  result = [None] * len(Q)
  tree = None

  for tpl in reversed(merged):
    if len(tpl) == 3:
      result[tpl[2]] = query(tree, tpl[1])
    else:
      tree = insert(tree, Treap(tpl[1]))

  return result


A = [1,  3, 6, 7, 2]
B = [10, 7, 2, 6, 4]
Q = [(2, 6), (3, 9), (0, 1)]

print(f(A, B, Q))
גלעד ברקן
  • 23,602
  • 3
  • 25
  • 61
0

Use a quadtree (https://en.m.wikipedia.org/wiki/Quadtree) that gives you the number of conditions of q met for a value x in log(len(q)) time Then you can solve you problem in len(A)*log(len(q)) time

Jean Valj
  • 119
  • 4
  • Could you elaborate on this? – TYeung Sep 04 '21 at 14:47
  • Start with an easier problem : L=(x0,x1,...xn] is a list of numbers E=((a0,b0),(a1,b1),...;(an,bn)] is a set of intervals Find the number of xi intersection with Ei using https://en.wikipedia.org/wiki/Interval_tree Then you problem is the 2 dimension version of this problem. – Jean Valj Sep 04 '21 at 15:02