SPOJ - Another Tree Problem (MTREE)

Summary:

Define weight of a path as the product of weights of edges in a path between node A and B in a tree, and define weight of a tree as the sum of the weight of path of all paths (inception!).

(Number of vertices \(N \leq 10^5\))

Solution:

I think this is a rather nice problem! Firstly notice that \(N\) is quite large, so we are forced to find a solution in \(O(N)\) complexity. Seems intimidating at first, but here are several observations that can make the link to the solution clearer:

1. Suppose we have a tree with root node r. To find P(r), the weight of paths starting from r to any node in the tree, we can recursively define \(P(r) = \sum_{(r,u) \in E} (P(u) + 1) (w(r,u))\), where w(r,u) is the weight of edge r-u.

2. How about the weight of tree T(r)? For each u of children of r, they define a tree with weight T(u) and weight of all paths to u P(u). Each of individual weights of these trees contributes to T(r). Furthermore, any paths starting from one tree to another tree also contributes to T(r). So if u and v are both children of r, sum of weights of all paths from tree defined by vertex u to tree defined by vertex v is given by \((P(u)+1)(w(u,r))(w(r,v))(P(v)+1)\). Hence the total contribution of such paths from all different possible trees are the sum of such terms, which after some mental work will simplify to \(\frac{1}{2}((\sum_{(r,u) \in E} (P(u)+1))^2 - \sum_{(r,u) \in E} (P(u)+1)^2)\). Sum this up with all T(u), we have recursive definition of T(r).

The implementation of the tree traversal itself can be done using DFS algorithm that runs in \(O(V+E)\) time, and since each computation of T(r) and P(r) is done in \(O(1)\) time, we have a linear running time overall :D

#include <iostream> #include <cstdio> #include <algorithm> #include <vector> #include <utility> using namespace std; long long MOD = (long long) 1e9 + 7; long long DIV = (long long) 5e8 + 4; vector<vector<pair<int,int> > > adj; int vis[100003]; int par[100003]; long long path[100003]; long long weight[100003]; int N; void dfs(int u){ if(vis[u]) return; long long ret = 0; long long sum = 0; long long square = 0; for(int i=0;i<adj[u].size();++i){ int v = adj[u][i].first; long long w = adj[u][i].second; if(par[u] != v) { par[v] = u; dfs(v); long long tmp = ((path[v] + 1LL) * w) % MOD; sum += tmp; sum %= MOD; square += (tmp * tmp) % MOD; square %= MOD; ret += weight[v]; ret %= MOD; } } ret += sum; ret %= MOD; ret += (((sum * sum - square) % MOD) * DIV) % MOD; ret %= MOD; vis[u] = 1; path[u] = sum; weight[u] = ret; } void solve(){ for(int i=1;i<=N;++i){ vis[i] = 0; par[i] = -1; } dfs(1); printf("%lld\n", weight[1]); } int main(){ int u, v, w; scanf("%d", &N); adj = vector<vector<pair<int,int> > >(N+3); for(int i=0;i<N-1;++i){ scanf("%d %d %d", &u, &v, &w); adj[u].push_back(make_pair(v,w)); adj[v].push_back(make_pair(u,w)); } solve(); return 0; }