Thursday, January 28, 2016

[Square Root Decomposition] Codeforces 617E. XOR and Favorite Number

Problem Statement:
617E. XOR and Favorite Number

Solution:
Spent quite a long time trying to solve this problem using segment tree and suffix/prefix ideas... So it turns out that this problem is solvable using a new technique (to me)!

Square Root Decomposition
The type of problem that can be solved using this technique: Given an array A of numbers N, and a set of M queries in the form (i, j), find the maximum, minimum, sum, mean, or mode (or some other aggregate function), of A[i], A[i+1], ..., A[j].



For example, let's say that in each query we want to find the sum of A[i..j].
The naive solution will give us O(MN) complexity.

// O(N) for each of the M queries
for (int k = i; k <= j; ++k) {
    sum += A[k];
}

With Sqrt Decomposition, we get \(O((M+N)\sqrt{N})\).

The core of this technique is to order the queries into K = \( \sqrt{N} \) boxes, where each box has size K. The steps are:
1. First, place each of the query into the corresponding boxes, following this rule: query (i, j) is placed into the i/K th box. Hence for each query, i determines in which box it will be placed.
2. Next, in each box, we sort the queries in that box with respect to j in increasing order.

Let's say we have sorted the queries as described above. Modify the above code as follows:

// M queries have been sorted with sqrt decomposition
// for each query, from the first box till the last box:
int L = 0, R = 0;
sum = A[0];
for (int qid=0; qid < M; ++qid) {
    int i = query[qid].i;
    int j = query[qid].j;
    while (R < j) {
        R++;
        sum += A[R];
    }
    while (R > j) {
        sum -= A[R];
        R--;
    }
    while (L < i) {
        sum -= A[L];
        L++;
    }
    while (L > i) {
        L--;
        sum += A[L];
    }
    cout << sum << endl;
}
    

We did not do much on this transformation. The idea is to perform incremental changes on our variable of interest sum.

L and R are the Sqrt Decomposition window, which means that we are currently interested in A[L..R]. To answer the query, we "slide" the window using the four while loops until we cover (i, j) exactly.

The code above itself is still O(NM) if we did not sort those M queries. But if we have sorted the queries following the Sqrt Decomposition recipe, magically, it now runs in \(O(N \sqrt{N})\)!

How come? Here is the reason. Observe that the four while loops are the source of running time complexity. The first two are used to adjust the value of R by one step at a time, while the last two are used to adjust L.

For each box, the queries inside it are sorted by its j, so as we move from one query to the next, the value of R will only increase monotonically. Hence the total adjustment of R for each box will be O(N), and since there are \(\sqrt{N}\) boxes, the total adjustment will be \(O(N\sqrt{N})\).

On the other hand, in each box, from query to query, the value of i might be in random order. Hence the adjustment may need to move back and forth. However, from one query to another the adjustment of L cannot exceed \(O(\sqrt{N})\), because that is the size of the box. Since there are M queries, adjusting L takes \(O(M\sqrt{N} + N)\) in total, where the additional N is to account from the movement required from one box to another.

Hence overall we have \(O((M+N)\sqrt{N})\). That is a cool and neat technique!


Back to Problem
With that, we are now able to solve the problem. We store the prefix array P where P[i] = a[1]^a[2]^a[3]^...^a[i], where ^ is the XOR operator. Also, we make use of the fact that the values of a[i] (and also P[i]) is less than \(2^{20}\). This means that we can have an array cnt[p] which count the number of prefixes in window [L..R] with value equals to p. Also, we have another variable cur which keeps track of the number of substring in [L..R] which sum of XOR equals to k.

From [L..R], define a set of incremental update rules on cnt[p] and cur, depending on the adjustment on L and R. You should be able to come up with this yourself by thinking pretty hard.

Finally sort the queries and run the above rules on each query. Following is the implementation:


#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cmath>
#include <utility>
using namespace std;
int n, m, k;
int cnt[1<<21], a[100010];
int N;
vector<pair<pair<int,int>,int> > query;
long long ans[100010];

bool sqdec(const pair<pair<int,int>,int>& L, const pair<pair<int,int>,int>& R) {
    if (L.first.first/N == R.first.first/N) {
        return L.first.second < R.first.second;
    } 
    return L.first.first < R.first.first;
}

void updleft(int& L, int val, long long& cur) {
    if (val == 1) {
        int p = (L>0?a[L-1]:0);
        cur -= cnt[p^k];
        cnt[a[L]]--;
        L++;
    } else {
        L--;
        cnt[a[L]]++;
        int p = (L>0?a[L-1]:0);
        cur += cnt[p^k];
    }
}

void updRight(int& R, int val, long long& cur, int L) {
    if (val == 1) {
        R++;
        cur += cnt[a[R]^k]+(L>0?(a[L-1]==(a[R]^k)?1:0):(k==a[R]?1:0));
        cnt[a[R]]++;
    } else {
        cnt[a[R]]--;
        cur -= cnt[a[R]^k]+(L>0?(a[L-1]==(a[R]^k)?1:0):(k==a[R]?1:0));
        R--;
    }
}

int main(){
    scanf("%d%d%d",&n,&m,&k);
    N = sqrt(n);
    for(int i=0;i<n;++i){
        scanf("%d",&a[i]);
        if (i > 0) a[i] ^= a[i-1];
    }
    int l,r;
    for(int i=0;i<m;++i){
        scanf("%d%d",&l,&r);
        query.push_back(make_pair(make_pair(l-1,r-1),i));
    }
    sort(query.begin(), query.end(), sqdec);

    int L = 0, R = 0;
    cnt[a[0]]=1;
    long long cur = cnt[k];
    for(int i=0;i<m;++i){
        int l = query[i].first.first;
        int r = query[i].first.second;
        int idx = query[i].second;
        while (R < r) {
            updRight(R, 1, cur, L);
        }
        while (R > r) {
            updRight(R, -1, cur, L);
        }
        while (L < l) {
            updleft(L, 1, cur);
        }
        while (L > l) {
            updleft(L, -1, cur);
        }
        ans[idx] = cur;
    }
    for (int i=0;i<m;++i){
        cout << ans[i] << endl;
    }
    return 0;
}