Sunday, November 2, 2014

SPOJ - Another Tree Problem (MTREE)

Problem Statement:
SPOJ - Another Tree Problem (MTREE)

Summary:
Define weight of a path as the product of weights of edges in a path between node A and B in a tree, and define weight of a tree as the sum of the weight of path of all paths (inception!).
(Number of vertices  \(N \leq 10^5\))

Solution:
I think this is a rather nice problem! Firstly notice that \(N\) is quite large, so we are forced to find a solution in \(O(N)\) complexity. Seems intimidating at first, but here are several observations that can make the link to the solution clearer:
1. Suppose we have a tree with root node r. To find P(r), the weight of paths starting from r to any node in the tree, we can recursively define \(P(r) = \sum_{(r,u) \in E} (P(u) + 1) (w(r,u))\), where w(r,u) is the weight of edge r-u.
2. How about the weight of tree T(r)? For each u of children of r, they define a tree with weight T(u) and weight of all paths to u P(u). Each of individual weights of these trees contributes to T(r). Furthermore, any paths starting from one tree to another tree also contributes to T(r). So if u and v are both children of r, sum of weights of all paths from tree defined by vertex u to tree defined by vertex v is given by \((P(u)+1)(w(u,r))(w(r,v))(P(v)+1)\). Hence the total contribution of such paths from all different possible trees are the sum of such terms, which after some mental work will simplify to \(\frac{1}{2}((\sum_{(r,u) \in E} (P(u)+1))^2 - \sum_{(r,u) \in E} (P(u)+1)^2)\). Sum this up with all T(u), we have recursive definition of T(r).

The implementation of the tree traversal itself can be done using DFS algorithm that runs in \(O(V+E)\) time, and since each computation of T(r) and P(r) is done in \(O(1)\) time, we have a linear running time overall :D

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <utility>
using namespace std;
long long MOD = (long long) 1e9 + 7;
long long DIV = (long long) 5e8 + 4;
vector<vector<pair<int,int> > > adj;
int vis[100003];
int par[100003];
long long path[100003];
long long weight[100003];
int N;


void dfs(int u){
    if(vis[u]) return;
    long long ret = 0;
    long long sum = 0;
    long long square = 0;
    for(int i=0;i<adj[u].size();++i){
        int v = adj[u][i].first;
        long long w = adj[u][i].second;
        if(par[u] != v) {
            par[v] = u;
            dfs(v);
            long long tmp = ((path[v] + 1LL) * w) % MOD;
            sum += tmp;
            sum %= MOD;
            square += (tmp * tmp) % MOD;
            square %= MOD;
            ret += weight[v];
            ret %= MOD;
        }
    }
    ret += sum;
    ret %= MOD;
    ret += (((sum * sum - square) % MOD) * DIV) % MOD;
    ret %= MOD;
    vis[u] = 1;
    path[u] = sum;
    weight[u] = ret;
}

void solve(){
    for(int i=1;i<=N;++i){
        vis[i] = 0;
        par[i] = -1;
    }
    dfs(1);
    printf("%lld\n", weight[1]);
}

int main(){
    int u, v, w;
    scanf("%d", &N);
    adj = vector<vector<pair<int,int> > >(N+3);
    for(int i=0;i<N-1;++i){
        scanf("%d %d %d", &u, &v, &w);
        adj[u].push_back(make_pair(v,w));
        adj[v].push_back(make_pair(u,w));
    }
    solve();
    return 0;
}