Rounding integer division by multiply-and-shift - bit-manipulation

I learned about this trick the other day by inspecting the machine code generated by gcc. Dividing an integer a by a constant b can be optimized as follows:
x = a / b
x = a * (1 / b)
x = (a * (((1 << n) / b) + 1)) >> n
The reciprocal can be evaluated at compile-time, resulting in a multiply-and-shift operation that is more efficient than division.
c = ((1 << n) / b) + 1
x = (a * c) >> n
Now this has the same effect as plain integer division, it truncates the result (rounding towards zero). Is it possible to modify this algorithm to round towards the nearest value instead?

I came up with:
c = (1 << n) / b
d = a * c
x = ((d >> (n - 1)) & 1) + (d >> n)
Does the trick, but I wonder if there are more efficient methods.
edit: I posed the same question on reddit and got a better answer:
https://www.reddit.com/r/AskProgramming/comments/9cx9dl/rounding_integer_division_by_multiplyandshift/
c = ((1 << n) + b - 1) / b;
x = (((a * c) >> (n - 1)) + 1) >> 1
One more shift can be removed by defining another constant:
e = 1 << (n - 1)
x = (a * c + e) >> 1

Related

Why the function is failing in case of numbers greater than 48 digits?

I am trying to find
(a^b) % mod
where b and mod is upto 10^9, while l can be really large i have tested upto 48 digits with success
using this relation
(a^b) % mod = (a%mod)^b % mod
#define ll long long int
ll powerLL(ll x, ll n,ll MOD)
{
ll result = 1;
while (n) {
if (n & 1)
result = result * x % MOD;
n = n / 2;
x = x * x % MOD;
}
return result;
}
ll powerStrings(string sa, string sb,ll MOD)
{
ll a = 0, b = 0;
for (size_t i = 0; i < sa.length(); i++)
a = (a * 10 + (sa[i] - '0')) % MOD;
for (size_t i = 0; i < sb.length(); i++)
b = (b * 10 + (sb[i] - '0')) % (MOD - 1);
return powerLL(a, b,MOD);
}
powerStrings("5109109785634228366587086207094636370893763284000","362323789",354252525) returns 208624800 but it should return 323419500. In this case a is 49 digits
powerStrings("300510498717329829809207642824818434714870652000","362323489",354255221) returns 282740484 , which is correct. In this case a is 48 digits
Is something wrong with the code or I will have to use other method of doing the same??
It does not work because it is not mathematically correct.
In general, we have that pow(a, n, m) = pow(a, n % λ(m), m) (with a coprime to m) where λ is the Carmichael function. As a special case, when m is a prime number, then λ(m) = m - 1. That situation is also covered by Fermat's little theorem. That's only a special case, it does not always work.
λ(354252525) = 2146980, if I hack that in then the right result comes out. (the base is not actually coprime to the modulus though)
In general you would need to compute the Carmichael function for the modulus, which is non-trivial, but feasible for small moduli.

Branch free saturation

I have the following calculation:
unsigned int a;
unsigned b = (a < 4) ? a : 4;
Is it possible to convert the second line to a branch free format?
Thanks!
Try this:
b = (a >= 4) * 4 + (a < 4) * ((a >> 1) & 1) * 2 + (a < 4) * (a & 1);
Explanation: we are returning 4 by "zeroing" the 2 least significant bits if a >= 4. If a < 4, we use these 2 least significant bits.
You could use a conditionally applied mask:
unsigned int a, b, t, m;
t = a - 4;
m = 0 - ((int)t < 0); // mask of all 0s or all 1s
b = (t & m) + 4; // mask all 1s: b=a-4+4; mask all 0s: b=4

Code or Algo for Modulo Division

I need to find Res = (A / B) % P. (P is prime).
I have Num = A % P, and Den = B % P.
Is there any way to find Res just by using Num, Den and P?
I came across this :
(a / b) mod p = ((a mod p) * (b^(-1) mod p)) mod p
i.e. Res = (Num * b ^ (p - 2) % p) % p
Now how can I find b^(p-2) % p using Den?
If you can provide me with a C++/C code, I would be more than happy, as I can then directly use it in my game, otherwise, please help me in finding a formula, so that I can obtain the Res on my own.
You can get the result of b^(p-2) % p using quick exponentiation.
int qexp (int b, int e) {
if (e == 0) return 1;
long long h = qexp(b, e/2); //long long to prevent overflow in next step
h *= h;
h %= p;
if (e % 2 == 0) return h;
else return (h*b)%p;
}
call using qexp(b, p-2);

overflow possibilities in modular exponentiation by squaring

I am looking to implement the fermat's little theorem for prime testing. Here's the code I have written:
lld expo(lld n, lld p) //2^p mod n
{
if(p==0)
return 1;
lld exp=expo(n,p/2);
if(p%2==0)
return (exp*exp)%n;
else
return (((exp*exp)%n)*2)%n;
}
bool ifPseudoPrime(lld n)
{
if(expo(n,n)==2)
return true;
else
return false;
}
NOTE: I took the value of a(<=n-1) as 2.
Now, the number n can go as large as 10^18. This means that variable exp can reach values near 10^18. Which further implies that the expression (exp*exp) can reach as high as 10^36 hence causing overflow. How do I avoid this.
I tested this and it ran fine till 10^9. I am using C++
If the modulus is close to the limit of the largest integer type you can use, things get somewhat complicated. If you can't use a library that implements biginteger arithmetic, you can roll a modular multiplication yourself by splitting the factors in low-order and high-order parts.
If the modulus m is so large that 2*(m-1) overflows, things get really fussy, but if 2*(m-1) doesn't overflow, it's bearable.
Let us suppose you have and use a 64-bit unsigned integer type.
You can calculate the modular product by splitting the factors into low and high 32 bits, the product then splits into
a = a1 + (a2 << 32) // 0 <= a1, a2 < (1 << 32)
b = b1 + (b2 << 32) // 0 <= b1, b2 < (1 << 32)
a*b = a1*b1 + (a1*b2 << 32) + (a2*b1 << 32) + (a2*b2 << 64)
To calculate a*b (mod m) with m <= (1 << 63), reduce each of the four products modulo m,
p1 = (a1*b1) % m;
p2 = (a1*b2) % m;
p3 = (a2*b1) % m;
p4 = (a2*b2) % m;
and the simplest way to incorporate the shifts is
for(i = 0; i < 32; ++i) {
p2 *= 2;
if (p2 >= m) p2 -= m;
}
the same for p3 and with 64 iterations for p4. Then
s = p1+p2;
if (s >= m) s -= m;
s += p3;
if (s >= m) s -= m;
s += p4;
if (s >= m) s -= m;
return s;
That way is not very fast, but for the few multiplications needed here, it may be fast enough. A small speedup should be obtained by reducing the number of shifts; first calculate (p4 << 32) % m,
for(i = 0; i < 32; ++i) {
p4 *= 2;
if (p4 >= m) p4 -= m;
}
then all of p2, p3 and the current value of p4 need to be multiplied with 232 modulo m,
p4 += p3;
if (p4 >= m) p4 -= m;
p4 += p2;
if (p4 >= m) p4 -= m;
for(i = 0; i < 32; ++i) {
p4 *= 2;
if (p4 >= m) p4 -= m;
}
s = p4+p1;
if (s >= m) s -= m;
return s;
You can perform your multiplications in several stages. For example, say you want to compute X*Y mod n. Take X and Y and write them as X = 10^9*X_1 + X_0, Y = 10^9*Y_1 + Y_0. Then compute all four products X_i*Y_j mod n, and finally compute X = 10^18*(X_1*Y_1 mod n) + 10^9*( X_0*Y_1 + X_1*Y_0 mod n) + X_0*Y_0. Note that in this case, you are operating with numbers half the size of the maximum allowed.
If splitting in two parts do not suffice (I suspect this is the case), split in three parts using the same schema. Splitting in three should work.
A simpler approach is just to multiply the school way. It corresponds to the previous approach, but writing one number in as many parts as digits it has.
Good luck!

How get smallest n, that 2 ^ n >= x for given integer x in O(1)?

How for given unsigned integer x find the smallest n, that 2 ^ n ≥ x in O(1)? in other words I want to find the index of higher set bit in binary format of x (plus 1 if x is not power of 2) in O(1) (not depended on size of integer and size of byte).
If you have no memory constraints, then you can use a lookup table (one entry for each possible value of x) to achieve O(1) time.
If you want a practical solution, most processors will have some kind of "find highest bit set" opcode. On x86, for instance, it's BSR. Most compilers will have a mechanism to write raw assembler.
Ok, since so far nobody has posted a compile-time solution, here's mine. The precondition is that your input value is a compile-time constant. If you have that, it's all done at compile-time.
#include <iostream>
#include <iomanip>
// This should really come from a template meta lib, no need to reinvent it here,
// but I wanted this to compile as is.
namespace templ_meta {
// A run-of-the-mill compile-time if.
template<bool Cond, typename T, typename E> struct if_;
template< typename T, typename E> struct if_<true , T, E> {typedef T result_t;};
template< typename T, typename E> struct if_<false, T, E> {typedef E result_t;};
// This so we can use a compile-time if tailored for types, rather than integers.
template<int I>
struct int2type {
static const int result = I;
};
}
// This does the actual work.
template< int I, unsigned int Idx = 0>
struct index_of_high_bit {
static const unsigned int result =
templ_meta::if_< I==0
, templ_meta::int2type<Idx>
, index_of_high_bit<(I>>1),Idx+1>
>::result_t::result;
};
// just some testing
namespace {
template< int I >
void test()
{
const unsigned int result = index_of_high_bit<I>::result;
std::cout << std::setfill('0')
<< std::hex << std::setw(2) << std::uppercase << I << ": "
<< std::dec << std::setw(2) << result
<< '\n';
}
}
int main()
{
test<0>();
test<1>();
test<2>();
test<3>();
test<4>();
test<5>();
test<7>();
test<8>();
test<9>();
test<14>();
test<15>();
test<16>();
test<42>();
return 0;
}
'twas fun to do that.
In <cmath> there are logarithm functions that will perform this computation for you.
ceil(log(x) / log(2));
Some math to transform the expression:
int n = ceil(log(x)/log(2));
This is obviously O(1).
It's a question about finding the highest bit set (as lshtar and Oli Charlesworth pointed out). Bit Twiddling Hacks gives a solution which takes about 7 operations for 32 Bit Integers and about 9 operations for 64 Bit Integers.
You can use precalculated tables.
If your number is in [0,255] interval, simple table look up will work.
If it's bigger, then you may split it by bytes and check them from high to low.
Perhaps this link will help.
Warning : the code is not exactly straightforward and seems rather unmaintainable.
uint64_t v; // Input value to find position with rank r.
unsigned int r; // Input: bit's desired rank [1-64].
unsigned int s; // Output: Resulting position of bit with rank r [1-64]
uint64_t a, b, c, d; // Intermediate temporaries for bit count.
unsigned int t; // Bit count temporary.
// Do a normal parallel bit count for a 64-bit integer,
// but store all intermediate steps.
// a = (v & 0x5555...) + ((v >> 1) & 0x5555...);
a = v - ((v >> 1) & ~0UL/3);
// b = (a & 0x3333...) + ((a >> 2) & 0x3333...);
b = (a & ~0UL/5) + ((a >> 2) & ~0UL/5);
// c = (b & 0x0f0f...) + ((b >> 4) & 0x0f0f...);
c = (b + (b >> 4)) & ~0UL/0x11;
// d = (c & 0x00ff...) + ((c >> 8) & 0x00ff...);
d = (c + (c >> 8)) & ~0UL/0x101;
t = (d >> 32) + (d >> 48);
// Now do branchless select!
s = 64;
// if (r > t) {s -= 32; r -= t;}
s -= ((t - r) & 256) >> 3; r -= (t & ((t - r) >> 8));
t = (d >> (s - 16)) & 0xff;
// if (r > t) {s -= 16; r -= t;}
s -= ((t - r) & 256) >> 4; r -= (t & ((t - r) >> 8));
t = (c >> (s - 8)) & 0xf;
// if (r > t) {s -= 8; r -= t;}
s -= ((t - r) & 256) >> 5; r -= (t & ((t - r) >> 8));
t = (b >> (s - 4)) & 0x7;
// if (r > t) {s -= 4; r -= t;}
s -= ((t - r) & 256) >> 6; r -= (t & ((t - r) >> 8));
t = (a >> (s - 2)) & 0x3;
// if (r > t) {s -= 2; r -= t;}
s -= ((t - r) & 256) >> 7; r -= (t & ((t - r) >> 8));
t = (v >> (s - 1)) & 0x1;
// if (r > t) s--;
s -= ((t - r) & 256) >> 8;
s = 65 - s;
As has been mentioned, the length of the binary representation of x + 1 is the n you're looking for (unless x is in itself a power of two, meaning 10.....0 in a binary representation).
I seriously doubt there exists a true solution in O(1), unless you consider translations to binary representation to be O(1).
For a 32 bit int, the following pseudocode will be O(1).
highestBit(x)
bit = 1
highest = 0
for i 1 to 32
if x & bit == 1
highest = i
bit = bit * 2
return highest + 1
It doesn't matter how big x is, it always checks all 32 bits. Thus constant time.
If the input can be any integer size, say the input is n digits long. Then any solution reading the input, will read n digits and must be at least O(n). Unless someone comes up solution without reading the input, it is impossible to find a O(1) solution.
After some search in internet I found this 2 versions for 32 bit unsigned integer number. I have tested them and they work. It is clear for me why second one works, but still now I'm thinking about first one...
1.
unsigned int RoundUpToNextPowOf2(unsigned int v)
{
unsigned int r = 1;
if (v > 1)
{
float f = (float)v;
unsigned int const t = 1U << ((*(unsigned int *)&f >> 23) - 0x7f);
r = t << (t < v);
}
return r;
}
2.
unsigned int RoundUpToNextPowOf2(unsigned int v)
{
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
edit: First one in clear as well.
An interesting question. What do you mean by not depending on the size
of int or the number of bits in a byte? To encounter a different number
of bits in a byte, you'll have to use a different machine, with
a different set of machine instructions, which may or may not affect the
answer.
Anyway, based sort of vaguely on the first solution proposed by Mihran,
I get:
int
topBit( unsigned x )
{
int r = 1;
if ( x > 1 ) {
if ( frexp( static_cast<double>( x ), &r ) != 0.5 ) {
++ r;
}
}
return r - 1;
}
This works within the constraint that the input value must be exactly
representable in a double; if the input is unsigned long long, this
might not be the case, and on some of the more exotic platforms, it
might not even be the case for unsigned.
The only other constant time (with respect to the number of bits) I can
think of is:
int
topBit( unsigned x )
{
return x == 0 ? 0.0 : ceil( log2( static_cast<double>( x ) ) );
}
, which has the same constraint with regards to x being exactly
representable in a double, and may also suffer from rounding errors
inherent in the floating point operations (although if log2 is
implemented correctly, I don't think that this should be the case). If
your compiler doesn't support log2 (a C++11 feature, but also present
in C90, so I would expect most compilers to already have implemented
it), then of course, log( x ) / log( 2 ) could be used, but I suspect
that this will increase the risk of a rounding error resulting in
a wrong result.
FWIW, I find the O(1) on the number of bits a bit illogical, for the
reasons I specified above: the number of bits is just one of the many
"constant factors" which depend on the machine on which you run.
Anyway, I came up with the following purely integer solution, which is
O(lg 1) for the number of bits, and O(1) for everything else:
template< int k >
struct TopBitImpl
{
static int const k2 = k / 2;
static unsigned const m = ~0U << k2;
int operator()( unsigned x ) const
{
unsigned r = ((x & m) != 0) ? k2 : 0;
return r + TopBitImpl<k2>()(r == 0 ? x : x >> k2);
}
};
template<>
struct TopBitImpl<1>
{
int operator()( unsigned x ) const
{
return 0;
}
};
int
topBit( unsigned x )
{
return TopBitImpl<std::numeric_limits<unsigned>::digits>()(x)
+ (((x & (x - 1)) != 0) ? 1 : 0);
}
A good compiler should be able to inline the recursive calls, resulting
in close to optimal code.