Monday, October 19, 2015

Count the number of set bits

Found something cool today. The problem is simple: Count the number of set bits in a bit representation of an integer. Of course the simplest solution will be to go through the bits one by one and count the number of 1s. The time complexity is O(n), where n is the length of the bit representation. Turns out, we can do better.


Here is the function as described in a stackoverflow thread (assuming 32 bits representation):

function count_bits(i):
    i = i - ((i>>1) & 0x555555)
    i = i & 0x333333 + ((i>>2) & 0x333333)
    return (((i + (i>>4)) & 0x0F0F0F0F) * 0x01010101) >> 24

How cryptic.. and magical! Sometimes I wonder how something as simple as 0s and 1s could be so magical. But then again, the whole modern technological era are built on top of 0s and 1s..

How does it work? It is actually based on a bottom-up divide and conquer idea. Let's focus on an example.

Say we have the following integer:
i = 01100101 11010010 11010011 11110100

What does i - ((i>>1) & 0x55555555) does? It counts the number of set bits on each group of 2, but in parallel for each group of 2.

The goal is to produce the following mapping for each i made up of 2 bits:


   i   |   no. of set bits
   ----+-------------------
   00  |   0   [00]
   01  |   1   [01]
   10  |   1   [01]
   11  |   2   [11]

This is achieved as follows: let p = 01. First we shift i 1 bit to the right (i.e. i>>1) and then apply AND with p, resulting in i>>1 & p. Then we subtract it from initial i. Turns out, it is exactly the mapping we want, as shown in the following table:


   i    | (i>>1) & 01    | i - [(i>>1) & 01]
   -----+----------------+-------------------
   00   |    00          |    00  [0]
   01   |    00          |    01  [1]
   10   |    01          |    01  [1]
   11   |    01          |    10  [2]

So, for every 2 bits integer, subtracting its most significant bit with itself gives us the number of set bits in the integer. This is actually the most magical idea in the whole function, and it serves as the base case of our divide and conquer strategy.

Now, 0x55555555 is actually 01010101 01010101 01010101 01010101. That means i - [(i>>1) & 0x55555555] computes the number of set bits for every consecutive group of 2 bits, more or less in parallel! So for our example, we can update our i to the computed value:


i  | 01 10 01 01  11 01 00 10  11 01 00 11  11 11 01 00
---+-----------------------------------------------------
i' |  1  1  1  1   2  1  0  1   2  1  0  2   2  2  1  0   [number of set bits]
i' | 01 01 01 01  10 01 00 01  10 01 00 10  10 10 01 00   [bit representation]

So i' is like an array containing information regarding the number of set bits for every group of twos. Our goal is to aggregate the entry of this array into a single sum.

That brings us the the second line of the code, (i & 0x33333333) + (i>>2 & 0x33333333). To understand what it does, let's see what happens to an integer with 4 bits.


Take i = 0110

    i    & 0011  | 0010
    i>>2 & 0011  | 0001
    -------------+------ +
    sum          | 0011

What it does, as demonstrated above, is to take the first two bits of i and add it to the last two bits. So in effect, it is accumulating (i.e. summing) the values and storing it on a 4 bits integer.

Realise that 0x33333333 is 00110011 00110011 00110011 00110011, hence  (i & 0x33333333) + (i>>2 & 0x33333333) adds every two group of twos, and as a result we get the number of set bits in every group of 4 bits, almost in parallel! So, from our previous i', after performing this operation, we get:


  i'   0101 0101  1001 0001  1001 0010  1010 0100 [bit rep]
  i'    1 1  1 1   2 1  0 1   2 1  0 2   2 2  1 0 [decimal array rep]
  ---------------[after operation]------------------------
  i''     2    2     3    1     3    2     4    1 [sum for each group of two]
  i''  0010 0010  0011 0001  0011 0010  0100 0001 [bits rep]

Something to note, since for each entry in i' is at most 2,  every entry in i'' is at most 4, hence the 4 bits representation will not overflow.

Lastly, what does the last line do? It has three distinct part, the first part is (i + i>>4) & 0x0F0F0F0F, second part is multiplication with 0x01010101, and the third part is a right shift by 24 bits.

We can see that the first part is actually similar to what we have been doing on the previous line: adding up the number of set bits in every two group of 4 bits, and storing the result to an 8 bit representation. That is because 0x0F is actually 0000 1111, hence (i + i>>4) is an operation that adds up the first 4 bits of an 8 bit integer to its last 4 bits, and an AND operation with 0x0F retrieves the last 4 bits of the result. The following table demonstrates the result:


  i''   00100010  00110001  00110010  01000001 [bits rep]
  i''      2   2     3   1     3   2     4   1 [decimal array rep]
  -------------------------------------------------------------------
  tmp          4         4         5         5
  tmp   00000100  00000100  00000101  00000101

Since each entry in i'' is at most 4, each entry in tmp is at most 8, so it fits 8 bit representation.

Next, what does the multiplication with 0x01010101 does? It actually accumulates each 8 bits in tmp to the highest 8 bits of tmp! Why, see that 0x01010101 = 1<<24 + 1<<16 + 1<<8 + 1. Hence i * 0x01010101 = i<<24 + i<<16 + i<<8 + i, which indeed accumulates all 8 bits partition in i to the highest bits.


 tmp   00000100  00000100  00000101  00000101
 <<8   00000100  00000101  00000101  00000000
 <<16  00000101  00000101  00000000  00000000
 <<24  00000101  00000000  00000000  00000000
 ---------------------------------------------+
 res   00010010  00001110  00001010  00000101
         [18]    ^---------irrelevant-------^

Lastly, we only want the first 8 bits of res, so we return res >> 24, which is the total number of set bits in the initial integer! (As a side note, observe that 8*4 = 32 fits into 8 bit representation, hence the last operation would not overflow)

The running time complexity of this function is definitely less than O(n), and the exact speed-up will depend on the hardware implementation of ALU and any parallelisation involved. Conservatively I would say it takes around 15 operations, roughly less than half operation needed by the linear approach. Pretty awesome.