## Sunday, November 2, 2014

### SPOJ - Another Tree Problem (MTREE)

Problem Statement:
SPOJ - Another Tree Problem (MTREE)

Summary:
Define weight of a path as the product of weights of edges in a path between node A and B in a tree, and define weight of a tree as the sum of the weight of path of all paths (inception!).
(Number of vertices  $$N \leq 10^5$$)

Solution:
I think this is a rather nice problem! Firstly notice that $$N$$ is quite large, so we are forced to find a solution in $$O(N)$$ complexity. Seems intimidating at first, but here are several observations that can make the link to the solution clearer:
1. Suppose we have a tree with root node r. To find P(r), the weight of paths starting from r to any node in the tree, we can recursively define $$P(r) = \sum_{(r,u) \in E} (P(u) + 1) (w(r,u))$$, where w(r,u) is the weight of edge r-u.
2. How about the weight of tree T(r)? For each u of children of r, they define a tree with weight T(u) and weight of all paths to u P(u). Each of individual weights of these trees contributes to T(r). Furthermore, any paths starting from one tree to another tree also contributes to T(r). So if u and v are both children of r, sum of weights of all paths from tree defined by vertex u to tree defined by vertex v is given by $$(P(u)+1)(w(u,r))(w(r,v))(P(v)+1)$$. Hence the total contribution of such paths from all different possible trees are the sum of such terms, which after some mental work will simplify to $$\frac{1}{2}((\sum_{(r,u) \in E} (P(u)+1))^2 - \sum_{(r,u) \in E} (P(u)+1)^2)$$. Sum this up with all T(u), we have recursive definition of T(r).

The implementation of the tree traversal itself can be done using DFS algorithm that runs in $$O(V+E)$$ time, and since each computation of T(r) and P(r) is done in $$O(1)$$ time, we have a linear running time overall :D

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <utility>
using namespace std;
long long MOD = (long long) 1e9 + 7;
long long DIV = (long long) 5e8 + 4;
int vis[100003];
int par[100003];
long long path[100003];
long long weight[100003];
int N;

void dfs(int u){
if(vis[u]) return;
long long ret = 0;
long long sum = 0;
long long square = 0;
if(par[u] != v) {
par[v] = u;
dfs(v);
long long tmp = ((path[v] + 1LL) * w) % MOD;
sum += tmp;
sum %= MOD;
square += (tmp * tmp) % MOD;
square %= MOD;
ret += weight[v];
ret %= MOD;
}
}
ret += sum;
ret %= MOD;
ret += (((sum * sum - square) % MOD) * DIV) % MOD;
ret %= MOD;
vis[u] = 1;
path[u] = sum;
weight[u] = ret;
}

void solve(){
for(int i=1;i<=N;++i){
vis[i] = 0;
par[i] = -1;
}
dfs(1);
printf("%lld\n", weight[1]);
}

int main(){
int u, v, w;
scanf("%d", &N);