Monday, December 21, 2015

Codeforces Education Round 1 C. Nearest Vectors

Problem Statement:
http://codeforces.com/contest/598/problem/C

Solution:
[Before you read any further, I think there is an official editorial for this round which may explain the approach to this problem better.]

The problem looks innocent enough, but actually it is quite a tedious one. Nevertheless, it links several interesting techniques and I find this problem quite interesting eventually.



First of all, we need a general strategy to find the two vectors with smallest intersection angle. Dot product will be the most natural tool for the job. However, doing pair-by-pair comparisons will take O(N^2). So a better idea is to sort the vectors first in terms of their anti-clockwise angles with a common axis, i.e. positive x-axis in this solution. After sorting, we do a linear pass and compute all pairs of consecutive vectors, and finally output the pair with smallest angle.

How to perform the sorting? For a given vector v, we can compute its cosine given by \(\cos{\left( \frac{v[0]}{length(v)} \right)}\), where length(v) = \(\sqrt{v[0]^2+v[1]^2}\). Sorting is then simply comparing the cosine values and simple checks whether v belongs to the upper or lower quadrant of the Euclidean space.

However, using square root function and performing double precision arithmetic will lead to loss of precision errors (trust me, been there, done that). So we will attempt to perform the computation precisely. Instead of using cosine function, we will develop a function \( G: \mathbb Z^2 \to \mathbb Z^2 \) as \(G(v) = (\text{sign}(v[0]) \times v[0]^2, v[0]^2+v[1]^2) \) if \(v[1] >= 0\), and \(G(v) = (-\text{sign}(v[0]) \times v[0]^2 - 2(v[0]^2+v[1]^2), v[0]^2+v[1]^2) \) otherwise. G maps vector v to (p,q) a rational number \(\frac{p}{q}\).

It is basically a transformation of the graph \(\cos(\theta)\) to \(f(\theta) = \cos^2(\theta)\), and from \(f(\theta)\) to \(g(\theta)\) where we output:
1. \(f(\theta)\) when \( \theta \in [0, \frac{\pi}{2}]\)
2. \(-f(\theta)\) when \( \theta \in [\frac{\pi}{2}, \pi]\)
3. \(f(\theta)-2\) when \( \theta \in [\pi, \frac{3\pi}{2}]\)
3. \(-f(\theta)-2\) when \( \theta \in [\frac{3\pi}{2}, 2\pi]\)
You can check that g is indeed equivalent to G and that g is monotonic decreasing, hence it can be used as comparison function for sorting.

How to compare two rational number (a,b) and (p,q) without resorting to double precision arithmetic? Easy, it is simply done by checking if the relation aq < pb holds. Since each a,b,p and q is less than \(10^8\), their multiplication fits a 64 bit long long type.

Now after we sorted the vectors, we need a way to compare whether two consecutive vectors form the smallest intersection angle. We can use dot product, but similarly we are faced with the square root function, hence we will need to get rid of it. Hence we use the squared version of the dot product relation: \(T(v,w) =  (\text{sign}(v[0]w[0]+v[1]w[1]) \times (v[0]w[0]+v[1]w[1])^2, (v[0]^2+v[1]^2)(w[0]^2+w[1]^2)) \), where T maps a pair of vectors (v,w) to a pair of integers (p, q) where \(\frac{p}{q}\) is squared dot product value of (v,w), preserving the sign of the numerator.

T preserves the monotonic relation by which larger T(v,w) means smaller intersection angle between v and w. Hence our aim is to find (v,w) with the largest T. It hence reduces to performing rational number comparison (a,b) and (p,q) where each a,b,p,q is 64 bit integer. However, the multiplication result cannot be contained in another 64 bit integer, hence we need to devise a way to do this. The way to do this can be thought of as a special case of big integer computation. We split each of the 64 bit integer into two 32bit integers, and perform polynomial multiplication (of basis \(2^32\)). The comparison of aq < bp is then equivalent to lexicographical comparison.

Implementation:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <utility>
using namespace std;

int n;
vector<pair<pair<long long,long long>,int> > a;



bool fn(pair<pair<long long,long long>,int>& lhs, 
        pair<pair<long long,long long>,int>& rhs) {
    long long x[] = {lhs.first.first, rhs.first.first};
    long long y[] = {lhs.first.second, rhs.first.second};
    long long p[2], q[2];
    for(int i=0;i<2;++i){
        p[i] = x[i]*x[i]*(x[i]<0?-1:1);
        q[i] = x[i]*x[i] + y[i]*y[i];
        if (y[i] < 0) p[i] = -p[i] - q[i]*2;
    }
    return p[0]*q[1] < p[1]*q[0];
}

void mult(long long p, long long q, long long res[3], int* sign) {
    *sign = (p >= 0 ? 1 : -1) * (q >= 0 ? 1 : -1);
    p *= (p < 0 ? -1 : 1);
    q *= (q < 0 ? -1 : 1);
    long long a[] = {p>>31, q>>31};
    long long b[] = {p&((1LL<<31)-1LL), q&((1LL<<31)-1LL)};
    res[0] = b[0]*b[1];
    res[1] = a[0]*b[1]+a[1]*b[0];
    res[2] = a[0]*a[1];
    for(int i=0;i<2;i++){
        if(res[i] >= (1LL<<31)) {
            res[i+1] += res[i] >> 31;
            res[i] &= (1LL<<31)-1LL;
        }
    }

}

bool cmp(long long a[3], long long b[3]) {
    return a[2] != b[2] ? a[2] < b[2] : (a[1] != b[1] ? a[1] < b[1] : a[0] < b[0]);
}

bool less_than(long long p0, long long q0, long long p1, long long q1) {
    int sign[2];
    long long res[2][3];
    mult(p0,q1, res[0], &sign[0]);
    mult(p1,q0, res[1], &sign[1]);
    
    if(sign[0] != sign[1]) return sign[0] < sign[1];
    if(sign[0]>0) {
        return cmp(res[0], res[1]);
    } else return cmp(res[1], res[0]);
}

int main(){
    scanf("%d",&n);
    int x,y;
    for(int i=0;i<n;++i){
        scanf("%d%d",&x,&y);
        a.push_back(make_pair(make_pair(x,y),i+1));
    }
    sort(a.begin(), a.end(), fn);
    long long mp = -2, mq = 1;
    int ai = -1, aj = -1;
    for(int i=0;i<n;++i){
        int j = (i==n-1?0:i+1);
        long long x[] = {a[i].first.first, a[j].first.first};
        long long y[] = {a[i].first.second, a[j].first.second};
        long long p = x[0]*x[1]+y[0]*y[1];
        p *= p * (p < 0 ? -1 : 1);
        long long q = (x[0]*x[0]+y[0]*y[0]) * (x[1]*x[1]+y[1]*y[1]);
        //if (mp*q < mq*p) {
        //if(1.0L*mp/mq < 1.0L*p/q){
        if(less_than(mp, mq, p, q)){
            mp = p;
            mq = q;
            ai = a[i].second;
            aj = a[j].second;
        }
    }
    printf("%d %d\n", ai, aj);
    return 0;
}

No comments:

Post a Comment