Here is the function as described in a stackoverflow thread (assuming 32 bits representation):
function count_bits(i): i = i - ((i>>1) & 0x555555) i = i & 0x333333 + ((i>>2) & 0x333333) return (((i + (i>>4)) & 0x0F0F0F0F) * 0x01010101) >> 24
How cryptic.. and magical! Sometimes I wonder how something as simple as 0s and 1s could be so magical. But then again, the whole modern technological era are built on top of 0s and 1s..
How does it work? It is actually based on a bottom-up divide and conquer idea. Let's focus on an example.
Say we have the following integer:
i = 01100101 11010010 11010011 11110100
What does i - ((i>>1) & 0x55555555) does? It counts the number of set bits on each group of 2, but in parallel for each group of 2.
The goal is to produce the following mapping for each i made up of 2 bits:
i | no. of set bits ----+------------------- 00 | 0 [00] 01 | 1 [01] 10 | 1 [01] 11 | 2 [11]
This is achieved as follows: let p = 01. First we shift i 1 bit to the right (i.e. i>>1) and then apply AND with p, resulting in i>>1 & p. Then we subtract it from initial i. Turns out, it is exactly the mapping we want, as shown in the following table:
i | (i>>1) & 01 | i - [(i>>1) & 01] -----+----------------+------------------- 00 | 00 | 00 [0] 01 | 00 | 01 [1] 10 | 01 | 01 [1] 11 | 01 | 10 [2]
So, for every 2 bits integer, subtracting its most significant bit with itself gives us the number of set bits in the integer. This is actually the most magical idea in the whole function, and it serves as the base case of our divide and conquer strategy.
Now, 0x55555555 is actually 01010101 01010101 01010101 01010101. That means i - [(i>>1) & 0x55555555] computes the number of set bits for every consecutive group of 2 bits, more or less in parallel! So for our example, we can update our i to the computed value:
i | 01 10 01 01 11 01 00 10 11 01 00 11 11 11 01 00 ---+----------------------------------------------------- i' | 1 1 1 1 2 1 0 1 2 1 0 2 2 2 1 0 [number of set bits] i' | 01 01 01 01 10 01 00 01 10 01 00 10 10 10 01 00 [bit representation]
So i' is like an array containing information regarding the number of set bits for every group of twos. Our goal is to aggregate the entry of this array into a single sum.
That brings us the the second line of the code, (i & 0x33333333) + (i>>2 & 0x33333333). To understand what it does, let's see what happens to an integer with 4 bits.
Take i = 0110 i & 0011 | 0010 i>>2 & 0011 | 0001 -------------+------ + sum | 0011
What it does, as demonstrated above, is to take the first two bits of i and add it to the last two bits. So in effect, it is accumulating (i.e. summing) the values and storing it on a 4 bits integer.
Realise that 0x33333333 is 00110011 00110011 00110011 00110011, hence (i & 0x33333333) + (i>>2 & 0x33333333) adds every two group of twos, and as a result we get the number of set bits in every group of 4 bits, almost in parallel! So, from our previous i', after performing this operation, we get:
i' 0101 0101 1001 0001 1001 0010 1010 0100 [bit rep] i' 1 1 1 1 2 1 0 1 2 1 0 2 2 2 1 0 [decimal array rep] ---------------[after operation]------------------------ i'' 2 2 3 1 3 2 4 1 [sum for each group of two] i'' 0010 0010 0011 0001 0011 0010 0100 0001 [bits rep]
Something to note, since for each entry in i' is at most 2, every entry in i'' is at most 4, hence the 4 bits representation will not overflow.
Lastly, what does the last line do? It has three distinct part, the first part is (i + i>>4) & 0x0F0F0F0F, second part is multiplication with 0x01010101, and the third part is a right shift by 24 bits.
We can see that the first part is actually similar to what we have been doing on the previous line: adding up the number of set bits in every two group of 4 bits, and storing the result to an 8 bit representation. That is because 0x0F is actually 0000 1111, hence (i + i>>4) is an operation that adds up the first 4 bits of an 8 bit integer to its last 4 bits, and an AND operation with 0x0F retrieves the last 4 bits of the result. The following table demonstrates the result:
i'' 00100010 00110001 00110010 01000001 [bits rep] i'' 2 2 3 1 3 2 4 1 [decimal array rep] ------------------------------------------------------------------- tmp 4 4 5 5 tmp 00000100 00000100 00000101 00000101
Since each entry in i'' is at most 4, each entry in tmp is at most 8, so it fits 8 bit representation.
Next, what does the multiplication with 0x01010101 does? It actually accumulates each 8 bits in tmp to the highest 8 bits of tmp! Why, see that 0x01010101 = 1<<24 + 1<<16 + 1<<8 + 1. Hence i * 0x01010101 = i<<24 + i<<16 + i<<8 + i, which indeed accumulates all 8 bits partition in i to the highest bits.
tmp 00000100 00000100 00000101 00000101 <<8 00000100 00000101 00000101 00000000 <<16 00000101 00000101 00000000 00000000 <<24 00000101 00000000 00000000 00000000 ---------------------------------------------+ res 00010010 00001110 00001010 00000101 [18] ^---------irrelevant-------^
Lastly, we only want the first 8 bits of res, so we return res >> 24, which is the total number of set bits in the initial integer! (As a side note, observe that 8*4 = 32 fits into 8 bit representation, hence the last operation would not overflow)
The running time complexity of this function is definitely less than O(n), and the exact speed-up will depend on the hardware implementation of ALU and any parallelisation involved. Conservatively I would say it takes around 15 operations, roughly less than half operation needed by the linear approach. Pretty awesome.
No comments:
Post a Comment