Yes, you can find the sum distance of the whole tree between every two node by DP in O(n).
Briefly, you should know 3 things:
cnt[i] is the node count of the ith-node's sub-tree
dis[i] is the sum distance of every ith-node subtree's node to i-th node
ret[i] is the sum distance of the ith-node subtree between every two node
notice that ret[root]
is answer of the problem, so just calculate ret[i]
right and the problem will be done...
How to calculate ret[i]
? Need the help of cnt[i]
and dis[i]
and solve it recursively.
The key problem is:
Given ret[left] ret[right] dis[left] dis[right] cnt[left] cnt[right] to cal ret[node] dis[node] cnt[node].
(node)
/ \
(left-subtree) (right subtree)
/ \
...(node x_i) ... ...(node y_i)...
important:x_i is the any node in left-subtree(not leaf!)
and y_i is the any node in right-subtree(not leaf either!).
cnt[node]
is easy,just equals cnt[left] + cnt[right] + 1
dis[node]
is not so hard, equals dis[left] + dis[right] + cnt[left] + cnt[right]
. reason: sigma(xi->left) is dis[left]
, so sigma(xi->node) is dis[left] + cnt[left]
.
ret[node]
equal three part:
- xi -> xj and yi -> yj, equals
ret[left] + ret[right]
.
- xi -> node and yi -> node, equals
dis[node]
.
- xi -> yj:
equals sigma(xi -> node -> yj), fixed xi, then we get cnt[left]*distance(xi,node) + sigma(node->yj), then cnt[left]*distance(xi,node) + sigma(node->left->yj),
and it is cnt[left]*distance(x_i,node) + cnt[left] + dis[left]
.
Sum up xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left])
, then it is 2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left]
.
Sum these three parts and we get ret[i]
. Do it recursively, we will get ret[root]
.
My code:
import java.util.Arrays;
public class BSTDistance {
int[] left;
int[] right;
int[] cnt;
int[] ret;
int[] dis;
int nNode;
public BSTDistance(int n) {// n is the number of node
left = new int[n];
right = new int[n];
cnt = new int[n];
ret = new int[n];
dis = new int[n];
Arrays.fill(left,-1);
Arrays.fill(right,-1);
nNode = n;
}
void add(int a, int b)
{
if (left[b] == -1)
{
left[b] = a;
}
else
{
right[b] = a;
}
}
int cal()
{
_cal(0);//assume root's idx is 0
return ret[0];
}
void _cal(int idx)
{
if (left[idx] == -1 && right[idx] == -1)
{
cnt[idx] = 1;
dis[idx] = 0;
ret[idx] = 0;
}
else if (left[idx] != -1 && right[idx] == -1)
{
_cal(left[idx]);
cnt[idx] = cnt[left[idx]] + 1;
dis[idx] = dis[left[idx]] + cnt[left[idx]];
ret[idx] = ret[left[idx]] + dis[idx];
}//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)
else
{
_cal(left[idx]);
_cal(right[idx]);
cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
}
}
public static void main(String[] args)
{
BSTDistance bst1 = new BSTDistance(3);
bst1.add(1, 0);
bst1.add(2, 0);
// (0)
// / \
//(1) (2)
System.out.println(bst1.cal());
BSTDistance bst2 = new BSTDistance(5);
bst2.add(1, 0);
bst2.add(2, 0);
bst2.add(3, 1);
bst2.add(4, 1);
// (0)
// / \
// (1) (2)
// / \
// (3) (4)
//0 -> 1:1
//0 -> 2:1
//0 -> 3:2
//0 -> 4:2
//1 -> 2:2
//1 -> 3:1
//1 -> 4:1
//2 -> 3:3
//2 -> 4:3
//3 -> 4:2
//2*4+3*2+1*4=18
System.out.println(bst2.cal());
}
}
output:
4
18
For the convenience(of readers to understand my solution), I paste the value of cnt[],dis[] and ret[]
after bst2.cal()
is called:
cnt[] 5 3 1 1 1
dis[] 6 2 0 0 0
ret[] 18 4 0 0 0
PS:
It's the solution from UESTC_elfness, it's a simple problem for him , and I'm sayakiss, the problem is not so hard for me..
So you can trust us...