Tuesday, March 17, 2015

UVa 10364 - Square

Problem Statement :

Solution:
This is an innocent looking problem, but actually it is a variant of set partitioning, an NP complete problem. To solve the problem, I used a bitmasking technique coupled with a dynamic programming technique to come up with an $$O(N2^N)$$ solution. While I like the problem itself, to pass the time limit on UVa, you may require some optimisations to the plain DP implementation, which makes the experience a little bit unappealing to me.

First of all, let a[i] be the length of i-th stick, and let's define $$b$$ as a bit mask containing N bits which indicates that i-th stick is included in b if b[i] equals to 1, and 0 otherwise. We can first precompute val[b], the sum of all a[i] for which b[i] is set to 1. Furthermore, let's define X be the sum of all a[i] divided by 4. As you can guess, X is the length of each sides of the resulting square in the end. Let D[b] be the number of times we passed a multiple of X. D[b] = k will mean that we can divide N sticks into k groups of total length X each, and possibly one more group with total length less than X. Therefore we have the following relationship:
let K = max { D[b with j-th bit set to 0] }, then
D[b] = K + { 1 if val[b] equals to (K+1)X, and 0 otherwise }

In the end, we will have D[$$2^N$$-1] = 4 if and only if there exists a possible partitioning of the sticks into 4 groups of total length X.

To pass the UVa time limit, you need to make use of a symmetry: since every stick must belong to a group, we can without loss of generality set the first stick to be in the first group and proceed as per usual. This optimisation alone will reduce the search time by half.

Implementation:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <ctime>
using namespace std;

int dp[1<<20];
int val[1<<20];
int a[23];
int N;

void solve() {
for(int b=0, sz=(1<<N);b<sz;++b){
val[b] = -1;
dp[b] = 0;
}
int sum = 0;
val[0] = 0;
for(int i=0;i<N;++i){
sum += a[i];
val[1<<i] = a[i];
}
int X = sum/4;
if(sum % 4 != 0) {
printf("no\n");
return;
}
for(int b=0, sz=(1<<N);b<sz;++b){
if(val[b] != -1) continue;
int msb = b & (-b);
val[b] = val[msb] + val[b ^ msb];
}
for(int b=0, sz=(1<<(N-1));b<sz;++b){
int bm = (b << 1) | 1;
for(int i=0;i<N-1;++i){
if(b & (1<<i)) {
dp[b] = max(dp[b], dp[b ^ (1<<i)]);
}
}
if(val[bm] == (dp[b]+1) * X) dp[b]++;
}
if(dp[(1<<(N-1))-1] == 4) {
printf("yes\n");
} else {
printf("no\n");
}
}

int main(){
int TC;
scanf("%d", &TC);
while(TC--){
scanf("%d",&N);
for(int i=0;i<N;++i) scanf("%d",&a[i]);
solve();
}
return 0;
}