Saturday, January 31, 2015

Codeforces 441E - Valera and Number

Problem Statement:
441E - Valera and Number

Solution:
Another tedious DP problem, but it has some interesting ideas. The official editorial to this problem is pretty easy to understand.

There are two types of operation we can have on a given bit string: left shift and increment by one. Notice that since there are at most 200 increment operations, we can show that it suffices to keep track of the last 8 bit of the bit string. This is because in each increment operation, we can have either:

1. the last 8 bit is less than 11111111, hence whatever changes the increment operation does, it is contained inside the last 8 bit.
2. the last 8 bit is exactly 11111111 (all 8 bits are set). In this case our increment operation will have a cascading effect from bit 9 onwards. This is the only case we need to consider. Hence it makes sense to also keep track of the value on 9th bit (either 0 or 1) (call last). Furthermore, we also need to keep track on the length of consecutive bits following 9th bit that has value equal to the 9th bit (call cnt).

Hence the DP state that is needed are (k, bitmask, last, cnt), where
1. k is the kth step,
2. bitmask is the bit mask of the last 8 bit of the current bit string,
3. last is the value of the 9th bit, and
4. cnt is the length of consecutive bits starting from 10th bit, such that their values are equal to the 9th bit.
5. for each state, we store the probability of reaching that particular state.

From here, we "simply" need to keep track of the possible transitions, which can be a tedious exercise.


#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;

int X,K,P;
double dp[203][260][2][260];
int main(){
    scanf("%d%d%d",&X,&K,&P);
    double prob = 0.01*P;
    int bm = X & 255;
    int last = (X>>8) & 1;
    int tmp = X>>8;
    int cnt = 0;
    while(tmp && (tmp&1) == last){
        ++cnt;
        tmp >>= 1;
    }
    dp[0][bm][last][cnt] = 1.0;

    for(int k=0;k<K;++k){
        for(int i=0;i<256;++i){
            for(int j=0;j<2;++j){
                for(int cnt=0;cnt<240;++cnt){
                    //left shift
                    int new_bm = i << 1;
                    int new_cnt, new_last;
                    if(new_bm & 256) {
                        new_last = 1;
                        if(j==1) {
                            new_cnt = cnt+1;
                        } else {
                            new_cnt = 1;
                        }
                    } else {
                        new_last = 0;
                        if(j==1) {
                            new_cnt = 1;
                        } else {
                            new_cnt = (cnt==0 ? 0 : cnt+1);
                        }
                    }
                    new_bm &= 255;
                    dp[k+1][new_bm][new_last][new_cnt] += prob * dp[k][i][j][cnt];
                    //add 1
                    new_bm = i+1;
                    if(new_bm&256) {
                        if(j == 1) {
                            new_last = 0;
                            new_cnt = cnt;
                        } else {
                            new_last = 1;
                            new_cnt = 1;
                        }
                    } else {
                        new_last = j;
                        new_cnt = cnt;
                    }
                    new_bm &= 255;
                    dp[k+1][new_bm][new_last][new_cnt] += (1.0 - prob) * dp[k][i][j][cnt];
                }
            }
        }
    }
    double ans = 0;
    for(int i=0;i<256;++i){
        for(int j=0;j<2;++j){
            for(int cnt=0;cnt<240;++cnt){
                int val = 0;
                if(j==1 || i&255) {
                    if(i&255) {
                        int lsb = i & ((~i)+1);
                        while(lsb>1){
                            ++val;
                            lsb>>=1;
                        }
                    } else val = 8;
                } else {
                    val = (cnt > 0 ? cnt + 8 : 0);
                }
                ans += dp[K][i][j][cnt] * val;
            }
        }
    }
    printf("%.12lf\n",ans);
    return 0;
}