Sunday, March 1, 2015

Codeforces Round #294 Div. 2 Problem E - A and B and Lecture Rooms

Problem Statement:
519E - A and B and Lecture Rooms

Solution:
A very innocent looking problem, but really we shouldn't judge a book by its cover :P. Very difficult to solve as it requires a lot of data structures to implement, and several problems are layered very nicely together. Let's see what I mean.

First of all, we know that we are dealing with a tree. Between to rooms x and y, we know that there will only exist one and unique simple path between them (otherwise if there exist two paths, then there exist a cycle in the tree, a contradiction).



What we want is to find the number of rooms in between x and y which is equidistant to both rooms. To do this let's assume we can find the room p located in the simple path between x -> y which is equidistant to x and y. If there is such p, then we can compute the total number of desired rooms as follows:
1. image p as the root of the whole tree. Let cnt be the number of desired rooms.
2. for each u a child of p, we check if the subtree rooted at u contains x or y. If it does, we exclude this subtree from consideration. Otherwise, we add the size of this subtree to cnt.

[PS: We can also (and actually must) improve the above procedure to O(1). The idea can be deciphered from my implementation below, but for now just keep this in mind.]

Now we arrive at the question of how to find p efficiently given x and y. I think there might be a lot of ways to do this. My solution that I will describe here will use an implementation of heavy light decomposition (HLD), and an implementation of answering lowest common ancestor (LCA) using segment tree. Both of these techniques are described in much details in the following blogs and articles:
Heavy Light Decompositionhttp://blog.anudeep2011.com/heavy-light-decomposition/ (Very very clear, concise, and intuitive. The author did such a good job.)
LCAhttp://www.topcoder.com/tc?d1=tutorials&d2=lowestCommonAncestor&module=Static (Pretty well done and thorough)

If you already know those data structures, the problem merely reduces to a careful implementation of the above two techniques.

Implementation:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;

vector<vector<int> > adj;
vector<vector<int> > chains;
int sz[100005], vis[100005];
int sc[100005], cmp[100005], pos[100005], par_chain[100005], par[100005];
int N;
vector<int> euler;
vector<int> segtree;
int depth[100005];
int a[100005];

int lca;
bool parent_used;
int right_below[2];
int state;

int mabs(int x) {return (x > 0 ? x : -x);}

int dfs(int u, int d) {
    vis[u] = 1;
    euler.push_back(u);
    a[u] = euler.size()-1;
    depth[u] = d;
    int child = 0;
    int schild[2] = {-1, 0};
    for(int i=0;i<adj[u].size();++i){
        int v=adj[u][i];
        if(vis[v])continue;
        par[v] = u;
        child += dfs(v, d+1);
        euler.push_back(u);
        if(schild[1]<sz[v]) {
            schild[0] = v;
            schild[1] = sz[v];
        }
    }
    sz[u] = child+1;
    sc[u] = schild[0];
    return sz[u];
}

void HLD(int cur, int chain) {
    cmp[cur] = chain;
    chains[chain].push_back(cur);
    pos[cur] = chains[chain].size()-1;
    if(sc[cur] == -1) return;
    HLD(sc[cur], chain);
    for(int i=0;i<adj[cur].size();++i){
        int v = adj[cur][i];
        if(sc[cur] == v) continue;
        if(cmp[v] != -1) continue;
        chains.push_back(vector<int>());
        cmp[v] = chain;
        par_chain[cmp[v]] = cmp[cur];
        HLD(v, (int) chains.size()-1);
    }
}

void build(int p, int L, int R) {
    if(L==R) {
        segtree[p] = euler[L];
        return;
    }
    int M = (L+R)/2;
    build(2*p, L, M);
    build(2*p+1, M+1, R);
    segtree[p] = (depth[segtree[2*p]] < depth[segtree[2*p+1]] ? segtree[2*p] : segtree[2*p+1]);
}

int rmq(int p, int L, int R, int S, int T) {
    if(R < S || T < L) return -1;
    if(S <= L && R <= T) {
        return segtree[p];
    }
    int M = (L+R)/2;
    int left = rmq(2*p, L, M, S, T);
    int right = rmq(2*p+1, M+1, R, S, T);
    if(left == -1) return right;
    if(right == -1) return left;
    return (depth[left] < depth[right] ? left : right);
}

int LCA(int u, int v) {
    int left = min(a[u], a[v]);
    int right = max(a[u], a[v]);
    return rmq(1, 0, euler.size()-1, left, right);
}

int dist(int u, int v) {
    return mabs(depth[lca] - depth[u]) + mabs(depth[lca] - depth[v]);
}

int search_up(int u, int v, int d) {
    if(d <= pos[u]) {
        if(d!=0){
            right_below[state] = chains[cmp[u]][pos[u]-d+1];
        }
        return chains[cmp[u]][pos[u]-d];
    }
    int head = chains[cmp[u]][0];
    right_below[state] = head;
    return search_up(par[head], v, d-pos[u]-1);
}

int find(int u, int v, int d) {
    int u_lca = dist(u, lca);
    if(u_lca == d) {
        state = 0;
        search_up(u, lca, d);
        state = 1;
        search_up(v, lca, d);
        return lca;
    } else if(u_lca < d) {
        int v_lca = dist(v, lca);
        parent_used = true;
        state = 0;
        return search_up(v, lca, v_lca + u_lca - d);
    } else {
        parent_used = true;
        state = 0;
        return search_up(u, lca, d);
    }
}

int number_of_child(int pivot) {
    int ans = 0;
    if(parent_used && pivot != 1) {
        return sz[pivot] - sz[right_below[0]];
    }
    return N - (sz[right_below[0]] + sz[right_below[1]]);
}

int main(){
    int u, v;
    scanf("%d",&N);
    adj = vector<vector<int> > (N+3);
    for(int i=0;i<N-1;++i){
        scanf("%d%d",&u,&v);
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for(int i=0;i<=N;++i){
        cmp[i] = -1;
        sc[i] = -1;
    }
    dfs(1, 0);
    segtree = vector<int>(4*(int)euler.size());
    build(1, 0, euler.size()-1);
    chains.push_back(vector<int>());
    HLD(1, 0);
    int Q;
    scanf("%d",&Q);
    sz[0] = 0;
    for(int qq=0;qq<Q;++qq){
        right_below[0] = right_below[1] = 0;
        parent_used = false;
        int x, y;
        scanf("%d%d",&x,&y);
        lca = LCA(x,y);
        int d = dist(x,y);
        if(d%2) {
            printf("0\n");
            continue;
        }
        int pivot = find(x, y, d/2);
        printf("%d\n",number_of_child(pivot));
    }
    return 0;
}