Wednesday, July 23, 2014

a bit of dp : Longest Increasing Subsequence

Given a sequence of number, find a subsequence (may not be consecutive) in the sequence such that the elements are strictly increasing.

Eg:
S = 2, -1, 3, -3, 1, 8, 9

Longest Subsequence: 2, 3, 8, 9
(2, -1, 3, -3, 1, 8, 9)

There are a lot of ways to do this. For small search space, we can do a complete search, but the complexity grows exponentially with incremental increase in search space.

Dynamic Programming can solve this problem in \(O(N^2)\) complexity, in which both top-down and bottom-up approach can work. Suppose that we are given \(|s_1|, |s_2|, \ldots, |s_k|\) where \(|s_k|\) is the length of the longest subsequence of the sequence \(a_1, a_2, \ldots, a_k\) which includes \(a_k\) as the last element of \(s_k\). Then to find \(s_{k+1}\), we go through all the list \(s_1, s_2, \ldots, s_k\), and we append \(a_{k+1}\) to \(s_i\) if and only if \(a_i < a_{k+1}\). Then \(s_{k+1}\) is the one with the maximum length.

Another approach is by using Binary Search (wow!). As we go through \(i = 1,2,3,\ldots, k\), we maintain an array of maximum lengths of subsequences we have seen so far, and we check if we build a longer subsequence using \(a_i\). If so, we add \(a_i\) to the array and continue. Otherwise, \(a_i\) is either equal to an element (say \(a_j\) in our array or smaller. Either case, we update the the array by replacing \(a_j\) with \(a_i\) since we can build a subsequence of equal length, but with a smaller last element. Using binary search, we can check for the above cases (whether \(a_i\) is bigger than all elements in our array or if there is such \(a_j\)) in \(O(\log{N})\) time. Overall, we can solve the problem in \(O(M\log{N})\) time.

As an example, UVa 481 - What Goes Up can be solved using binary search:



#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
vector<int> par;
vector<int> dp;
vector<int> num;

void printout(int idx){
 if(idx == -1) return;
 printout(par[idx]);
 printf("%d\n", num[idx]);
}

int main(){
 int N,cur=0;
 while(scanf("%d", &N) != EOF){
  num.push_back(N);
  if(dp.empty()){
   dp.push_back(cur);
   par.push_back(-1);
   ++cur;
   continue;
  }
  
  int lo = 0, hi = dp.size()-1, mid;
  while(lo <= hi){
   mid = (hi+lo)/2;
   if(num[dp[mid]] < N){
    lo = mid + 1;
   } else {
    hi = mid - 1;
   }
  }
  int it = lo;
  if(it == dp.size()){
   par.push_back(dp[it-1]);
   dp.push_back(cur);
  } else {
   if(it > 0) par.push_back(dp[it-1]);
   else par.push_back(-1);
   dp[it] = cur;
  }

  ++cur;
 }
 cout << dp.size() << endl;
 printf("-\n");
 printout(dp[dp.size()-1]);
 return 0;
}