Wednesday, December 24, 2014

Codeforces Round 271 Div. 2 Problem E - Pillars

Problem Statement:
474E - Pillars

Solution:
This problem has a very interesting use of segment tree in combination with DP. The DP idea is very standard: let f[i] be the maximum length of jumps using pillars from [1..i], ending at i. Then f[i] = max of f[j] + 1, for all j such that \(|h_i-h_j| \geq d\). Then what's left is to find an efficient way to find those j that satisfies the requirement, and at the same time, find the maximum f[j] amongst all such j. This is where the segment tree comes in.



Before that, we need to sort the pillars according to their heights, and build a segment tree on that sorted pillars. The segment tree will answer this query: in [L, R], return the index of the pillar which has the maximum f[i]. Initially, all f[i] = 0.

Next, we iterate for i = 1 .. N:
1 Initially, set f[i] = 1, which is the case if we cannot find any \(h_j\) such that \(|h_j-h_i| \geq d\).
2. We have \(h_i\). Using binary search, we can find the pillar j closest to i which has \(h_j \leq h_i-d\), and similarly pillar k closest to i which has \(h_k \geq h_i + d\).
3. Then we query the segment tree to find the maximum of f in [1..j] and [k..N]
4. Update f[i] accordingly.

I have an accepted implementation, but it's quite messy... I'll put it here more for my own self reference.

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <utility>
#include <vector>
#include <cstring>
#include <cassert>
using namespace std;

int segtree[400005];
int par[100005];
vector<long long> val;
int tot[100005];
int idx[100005];
vector<pair<long long,int> > st;
int n, d;
int N;

int lowerbound(long long u){
 int lo = 0, hi = N-1, mid;
 while(lo <= hi){
  mid = (lo+hi)/2;
  if(val[mid] < u){
   lo = mid+1;
  } else {
   hi = mid-1;
  }
 }
 return lo;
}

int upperbound(long long u){
 int lo=0, hi=N-1, mid;
 while(lo<=hi){
  mid = (lo+hi)/2;
  if(val[mid] > u){
   hi = mid-1;
  } else {
   lo = mid+1;
  }
 }
 return hi;
}

void build(int p, int L, int R){
 if(L==R){
  segtree[p] = L;
  return;
 }
 int M = (L+R)/2;
 build(2*p, L, M);
 build(2*p+1, M+1, R);
 segtree[p] = segtree[2*p];
}

int update(int p, int L, int R, int S, int T, long long v) {

 if(R < S || T < L) {
  return segtree[p];
 }
 if(S <= L && R <= T){
  tot[segtree[p]] = v;
  return segtree[p];
 }
 int M = (L+R)/2;
 int left = update(2*p, L, M, S, T, v);
 int right = update(2*p+1, M+1, R, S, T, v);

 if(left == -1) segtree[p] = right;
 else if(right == -1) segtree[p] = left;
 else segtree[p] = (tot[left] > tot[right] ? left : right);
 return segtree[p];
}

int rmq(int p, int L, int R, int S, int T){
 if(R < S || T < L) {
  return -1;
 }
 if(S <= L && R <= T){
  return segtree[p];
 }
 int M = (L+R)/2;
 int left = rmq(2*p, L, M, S, T);
 int right = rmq(2*p+1, M+1, R, S, T);
 if(left == -1) return right;
 else if(right == -1) return left;
 return (tot[left] > tot[right] ? left : right);
}

void printout(int u){
 if(par[u] == -1){
  printf("%d", u+1);
  return;
 }
 printout(par[u]);
 printf(" %d", u+1);
}

int main(){
 scanf("%d%d", &n,&d);
 long long u;
 for(int i=0;i<n;++i){
  scanf("%I64d", &u);
  val.push_back(u);
  st.push_back(make_pair(u,i));
 }
 if(d==0){
  printf("%d\n", n);
  for(int i=1;i<=n;++i){
   if(i!=1)printf(" ");
   printf("%d",i);
  }
  printf("\n");
  return 0;
 }
 sort(val.begin(), val.end());
 int it = 0;
 while(it+1 < val.size()){
  if(val[it] == val[it+1]) val.erase(val.begin()+it);
  else ++it;
 }
 memset(par, -1, sizeof par);
 N = val.size();
 int ans = -1;
 int maxans = 0;
 build(1, 0, N-1);
 for(int i=0;i<n;++i){
  int cur = lowerbound(st[i].first);
  assert(st[i].first == val[cur]);
  int tmp = 1;
  idx[cur] = st[i].second;
  int k = upperbound(st[i].first-d);
  if(k >= 0){
   int p = rmq(1,0,N-1,0,k);
   if(tot[p]+1 > tmp){
    par[idx[cur]] = idx[p];
    tmp = tot[p] + 1;
   }
  }
  k = lowerbound(st[i].first+d);
  if(k < N){
   int p = rmq(1,0,N-1,k,N-1);
   if(tot[p]+1 > tmp){
    par[idx[cur]] = idx[p];
    tmp = tot[p] + 1;
   }
  }
  if(maxans < tmp){
   ans = idx[cur];
   maxans = tmp;
  }
  update(1,0,N-1,cur,cur,tmp);

 }
 printf("%d\n", maxans);
 printout(ans);
 printf("\n");
 return 0;
}

No comments:

Post a Comment