Avoiding overflow working modulo p - c++

As part of a university assignment, I have to implement in C scalar multiplication on an elliptic curve modulo p = 2^255 - 19. Since all computations are made modulo p, it seems enough to work with the primitive type (unsigned long).
However, if a and b are two integers modulo p, there is a risk of overflow computing a*b. I am not sure how to avoid that. Is the following code correct ?
long a = ...;
long b = ...;
long c = (a * b) % p;
Or should I rather cast a and b first ?
long a = ...;
long b = ...;
long long a1 = (long long) a;
long long b1 = (long long) b;
long c = (long) ((a1 * b1) % p);
I was also thinking or working with long long all along.

The whole operation(multiplication) is being done keeping in mind the type of the operands. You multiplied two long variables and the result if greater than what long variable can hold, it will overflow.
((a%p)*(b%p))%p this gives one protection that it wraps around p but what is being said in earlier case would still hold - (a%p)*(b%p) still can overflow. (considering that a,b is of type long).
If you store the values of long in long long no need to cast. But yes the result will now overflow when the multiplication yields the value greater than what long long can hold.
To give you a clarification:-
long a,b;
..
long long p = (a*b)%m;
This won't help. The multiplication when done is long arithmetic. Doesn't matter where we store the end result. It depends on the type of the operands.
Now look at this
long c = (long) ((a1 * b1) % p); here the result will be two long long multiplication and will overflow based on max value long long can hold but still there is a chance of overflow when you assign it to long.
If p is 255 byte you can't realize what you want using built in types long or long long types using 32 or 64 bit system. Down the line when we have 512 bit system this would surely be possible. Also one thing to note is when p=2255-19 then there is hardly any practicality involved in doing modular arithmetic with it.
If sizeof long is equal to sizeof long long as in ILP64 and LP64 then using long and long long would give you no result as such. But if sizeof long long is greater than sizeof long it is useful in saving the operands in long long to prevent overflow of the multiplication.
Also another way around is to write your own big integer library(multiple precision integer library) or use one which is already there(maybe like this). The idea revolves around the fact that the larger types are realized using something as simple as char and then doing operation on it. This is an implementation issue and there are many implementation around this same theme.

With a 255+ bit integer requirement, standard operations and the C library are insufficient.
Follows in the general algorithm to write your own modular multiplication.
myint mod(myint a, myint m);
myint add(myint a, myint b); // this may overflow
int cmp(myint a, myint b);
int isodd(myint a);
myint halve(myint a);
// (a+b)%mod
myint addmodmax(myint a, myint b, myint m) {
myint sum = add(a,b);
if (cmp(sum,a) < 0) {
sum = add(mod(add(sum, 1),m), mod(myint_MAX,m)); // These additions do not overflow
}
return mod(sum, m);
}
// (a*b)%mod
myint mulmodmax(myint a, myint b, myint m) {
myint prod = 0;
while (cmp(b,0) > 0) {
if (isodd(b)) {
prod = addmodmax(prod, a, m);
}
b = halve(b);
a = addmodmax(a, a, m);
}
return prod;
}

I recently came to this same problem.
First of all I'm going to assume you mean 32-bit integers (after reading your comments), but I think this applies to Big Integers as well (because doing a naive multiplication means doubling the word size and is going to be slow as well).
Option 1
We use the following property:
Proposition. a*b mod m = (a - m)*(b - m) mod m
Proof.
(a - m)*(b - m) mod m =
(a*b - (a+b)*m + m^2) mod m =
(a*b mod m - ((a+b) + m)*m mod m) mod m =
(a*b mod m) mod m = a*b mod m
q.e.d.
Moreover, if a,b approx m, then (a - m)*(b - m) mod m = (a - m)*(b - m). You will need to address the case for when a,b > m, however I think the validity of (m - a)*(m - b) mod m = a*b mod m is a corollary of the above Proposition; and of course don't do this when the difference is very big (small modulus, big a or b; or vice versa) or it will overflow.
Option 2
From Wikipedia
uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t m)
{
uint64_t d = 0, mp2 = m >> 1;
int i;
if (a >= m) a %= m;
if (b >= m) b %= m;
for (i = 0; i < 64; ++i)
{
d = (d > mp2) ? (d << 1) - m : d << 1;
if (a & 0x8000000000000000ULL)
d += b;
if (d >= m) d -= m;
a <<= 1;
}
return d;
}
And also, assuming long double and 32 or 64 bit integers (not arbitrary precision) you can exploit the machine priority on most significant bits of different types:
On computer architectures where an extended precision format with at least 64 bits of mantissa is available (such as the long double type of most x86 C compilers), the following routine is faster than any algorithmic solution, by employing the trick that, by hardware, floating-point multiplication results in the most significant bits of the product kept, while integer multiplication results in the least significant bits kept
And do:
uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t m)
{
long double x;
uint64_t c;
int64_t r;
if (a >= m) a %= m;
if (b >= m) b %= m;
x = a;
c = x * b / m;
r = (int64_t)(a * b - c * m) % (int64_t)m;
return r < 0 ? r + m : r;
}
These are guaranteed to not overflow.

Related

Carry bits in incidents of overflow

/*
* isLessOrEqual - if x <= y then return 1, else return 0
* Example: isLessOrEqual(4,5) = 1.
* Legal ops: ! ~ & ^ | + << >>
* Max ops: 24
* Rating: 3
*/
int isLessOrEqual(int x, int y)
{
int msbX = x>>31;
int msbY = y>>31;
int sum_xy = (y+(~x+1));
int twoPosAndNegative = (!msbX & !msbY) & sum_xy; //isLessOrEqual is FALSE.
// if = true, twoPosAndNegative = 1; Overflow true
// twoPos = Negative means y < x which means that this
int twoNegAndPositive = (msbX & msbY) & !sum_xy;//isLessOrEqual is FALSE
//We started with two negative numbers, and subtracted X, resulting in positive. Therefore, x is bigger.
int isEqual = (!x^!y); //isLessOrEqual is TRUE
return (twoPosAndNegative | twoNegAndPositive | isEqual);
}
Currently, I am trying to work through how to carry bits in this operator.
The purpose of this function is to identify whether or not int y >= int x.
This is part of a class assignment, so there are restrictions on casting and which operators I can use.
I'm trying to account for a carried bit by applying a mask of the complement of the MSB, to try and remove the most significant bit from the equation, so that they may overflow without causing an issue.
I am under the impression that, ignoring cases of overflow, the returned operator would work.
EDIT: Here is my adjusted code, still not working. But, I think this is progress? I feel like I'm chasing my own tail.
int isLessOrEqual(int x, int y)
{
int msbX = x >> 31;
int msbY = y >> 31;
int sign_xy_sum = (y + (~x + 1)) >> 31;
return ((!msbY & msbX) | (!sign_xy_sum & (!msbY | msbX)));
}
I figured it out with the assistance of one of my peers, alongside the commentators here on StackOverflow.
The solution is as seen above.
The asker has self-answered their question (a class assignment), so providing alternative solutions seems appropriate at this time. The question clearly assumes that integers are represented as two's complement numbers.
One approach is to consider how CPUs compute predicates for conditional branching by means of a compare instruction. "signed less than" as expressed in processor condition codes is SF ≠ OF. SF is the sign flag, a copy of the sign-bit, or most significant bit (MSB) of the result. OF is the overflow flag which indicates overflow in signed integer operations. This is computed as the XOR of the carry-in and the carry-out of the sign-bit or MSB. With two's complement arithmetic, a - b = a + ~b + 1, and therefore a < b = a + ~b < 0. It remains to separate computation on the sign bit (MSB) sufficiently from the lower order bits. This leads to the following code:
int isLessOrEqual (int a, int b)
{
int nb = ~b;
int ma = a & ((1U << (sizeof(a) * CHAR_BIT - 1)) - 1);
int mb = nb & ((1U << (sizeof(b) * CHAR_BIT - 1)) - 1);
// for the following, only the MSB is of interest, other bits are don't care
int cyin = ma + mb;
int ovfl = (a ^ cyin) & (a ^ b);
int sign = (a ^ nb ^ cyin);
int lteq = sign ^ ovfl;
// desired predicate is now in the MSB (sign bit) of lteq, extract it
return (int)((unsigned int)lteq >> (sizeof(lteq) * CHAR_BIT - 1));
}
The casting to unsigned int prior to the final right shift is necessary because right-shifting of signed integers with negative value is implementation-defined, per the ISO-C++ standard, section 5.8. Asker has pointed out that casts are not allowed. When right shifting signed integers, C++ compilers will generate either a logical right shift instruction, or an arithmetic right shift instruction. As we are only interested in extracting the MSB, we can isolate ourselves from the choice by shifting then masking out all other bits besides the LSB, at the cost of one additional operation:
return (lteq >> (sizeof(lteq) * CHAR_BIT - 1)) & 1;
The above solution requires a total of eleven or twelve basic operations. A significantly more efficient solution is based on the 1972 MIT HAKMEM memo, which contains the following observation:
ITEM 23 (Schroeppel): (A AND B) + (A OR B) = A + B = (A XOR B) + 2 (A AND B).
This is straightforward, as A AND B represent the carry bits, and A XOR B represent the sum bits. In a newsgroup posting to comp.arch.arithmetic on February 11, 2000, Peter L. Montgomery provided the following extension:
If XOR is available, then this can be used to average
two unsigned variables A and B when the sum might overflow:
(A+B)/2 = (A AND B) + (A XOR B)/2
In the context of this question, this allows us to compute (a + ~b) / 2 without overflow, then inspect the sign bit to see if the result is less than zero. While Montgomery only referred to unsigned integers, the extension to signed integers is straightforward by use of an arithmetic right shift, keeping in mind that right shifting is an integer division which rounds towards negative infinity, rather than towards zero as regular integer division.
int isLessOrEqual (int a, int b)
{
int nb = ~b;
// compute avg(a,~b) without overflow, rounding towards -INF; lteq(a,b) = SF
int lteq = (a & nb) + arithmetic_right_shift (a ^ nb, 1);
return (int)((unsigned int)lteq >> (sizeof(lteq) * CHAR_BIT - 1));
}
Unfortunately, C++ itself provides no portable way to code an arithmetic right shift, but we can emulate it fairly efficiently using this answer:
int arithmetic_right_shift (int a, int s)
{
unsigned int mask_msb = 1U << (sizeof(mask_msb) * CHAR_BIT - 1);
unsigned int ua = a;
ua = ua >> s;
mask_msb = mask_msb >> s;
return (int)((ua ^ mask_msb) - mask_msb);
}
When inlined, this adds just a couple of instructions to the code when the shift count is a compile-time constant. If the compiler documentation indicates that the implementation-defined handling of signed integers of negative value is accomplished via arithmetic right shift instruction, it is safe to simplify to this six-operation solution:
int isLessOrEqual (int a, int b)
{
int nb = ~b;
// compute avg(a,~b) without overflow, rounding towards -INF; lteq(a,b) = SF
int lteq = (a & nb) + ((a ^ nb) >> 1);
return (int)((unsigned int)lteq >> (sizeof(lteq) * CHAR_BIT - 1));
}
The previously made comments regarding use of a cast when converting the sign bit into a predicate apply here as well.

Saturating subtract/add for unsigned bytes

Imagine I have two unsigned bytes b and x. I need to calculate bsub as b - x and badd as b + x. However, I don't want underflow/overflow occur during these operations. For example (pseudo-code):
b = 3; x = 5;
bsub = b - x; // bsub must be 0, not 254
and
b = 250; x = 10;
badd = b + x; // badd must be 255, not 4
The obvious way to do this includes branching:
bsub = b - min(b, x);
badd = b + min(255 - b, x);
I just wonder if there are any better ways to do this, i.e. by some hacky bit manipulations?
The article Branchfree Saturating Arithmetic provides strategies for this:
Their addition solution is as follows:
u32b sat_addu32b(u32b x, u32b y)
{
u32b res = x + y;
res |= -(res < x);
return res;
}
modified for uint8_t:
uint8_t sat_addu8b(uint8_t x, uint8_t y)
{
uint8_t res = x + y;
res |= -(res < x);
return res;
}
and their subtraction solution is:
u32b sat_subu32b(u32b x, u32b y)
{
u32b res = x - y;
res &= -(res <= x);
return res;
}
modified for uint8_t:
uint8_t sat_subu8b(uint8_t x, uint8_t y)
{
uint8_t res = x - y;
res &= -(res <= x);
return res;
}
A simple method is to detect overflow and reset the value accordingly as below
bsub = b - x;
if (bsub > b)
{
bsub = 0;
}
badd = b + x;
if (badd < b)
{
badd = 255;
}
GCC can optimize the overflow check into a conditional assignment when compiling with -O2.
I measured how much optimization comparing with other solutions. With 1000000000+ operations on my PC, this solution and that of #ShafikYaghmour averaged 4.2 seconds, and that of #chux averaged 4.8 seconds. This solution is more readable as well.
For subtraction:
diff = (a - b)*(a >= b);
Addition:
sum = (a + b) | -(a > (255 - b))
Evolution
// sum = (a + b)*(a <= (255-b)); this fails
// sum = (a + b) | -(a <= (255 - b)) falis too
Thanks to #R_Kapp
Thanks to #NathanOliver
This exercise shows the value of simply coding.
sum = b + min(255 - b, a);
If you are using a recent enough version of gcc or clang (maybe also some others) you could use built-ins to detect overflow.
if (__builtin_add_overflow(a,b,&c))
{
c = UINT_MAX;
}
For addition:
unsigned temp = a+b; // temp>>8 will be 1 if overflow else 0
unsigned char c = temp | -(temp >> 8);
For subtraction:
unsigned temp = a-b; // temp>>8 will be 0xFF if neg-overflow else 0
unsigned char c = temp & ~(temp >> 8);
No comparison operators or multiplies required.
All can be done in unsigned byte arithmetic
// Addition without overflow
return (b > 255 - a) ? 255 : a + b
// Subtraction without underflow
return (b > a) ? 0 : a - b;
If you want to do this with two bytes, use the simplest code possible.
If you want to do this with twenty billion bytes, check what vector instructions are available on your processor and whether they can be used. You might find that your processor can do 32 of these operations with a single instruction.
You could also use the safe numerics library at Boost Library Incubator. It provides drop-in replacements for int, long, etc... which guarantee that you'll never get an undetected overflow, underflow, etc.
If you are willing to use assembly or intrinsics, I think I have an optimal solution.
For subtraction:
We can use the sbb instruction
In MSVC we can use the intrinsic function _subborrow_u64 (also available in other bit sizes).
Here is how it is used:
// *c = a - (b + borrow)
// borrow_flag is set to 1 if (a < (b + borrow))
borrow_flag = _subborrow_u64(borrow_flag, a, b, c);
Here is how we could apply it to your situation
uint64_t sub_no_underflow(uint64_t a, uint64_t b){
uint64_t result;
borrow_flag = _subborrow_u64(0, a, b, &result);
return result * !borrow_flag;
}
For addition:
We can use the adcx instruction
In MSVC we can use the intrinsic function _addcarry_u64 (also available in other bit sizes).
Here is how it is used:
// *c = a + b + carry
// carry_flag is set to 1 if there is a carry bit
carry_flag = _addcarry_u64(carry_flag, a, b, c);
Here is how we could apply it to your situation
uint64_t add_no_overflow(uint64_t a, uint64_t b){
uint64_t result;
carry_flag = _addcarry_u64(0, a, b, &result);
return !carry_flag * result - carry_flag;
}
I don't like this one as much as the subtraction one, but I think it is pretty nifty.
If the add overflows, carry_flag = 1. Not-ing carry_flag yields 0, so !carry_flag * result = 0 when there is overflow. And since 0 - 1 will set the unsigned integral value to its max, the function will return the result of the addition if there is no carry and return the max of the chosen integral value if there is carry.
what about this:
bsum = a + b;
bsum = (bsum < a || bsum < b) ? 255 : bsum;
bsub = a - b;
bsub = (bsub > a || bsub > b) ? 0 : bsub;
If you will call those methods a lot, the fastest way would be not bit manipulation but probably a look-up table. Define an array of length 511 for each operation.
Example for minus (subtraction)
static unsigned char maxTable[511];
memset(maxTable, 0, 255); // If smaller, emulates cutoff at zero
maxTable[255]=0; // If equal - return zero
for (int i=0; i<256; i++)
maxTable[255+i] = i; // If greater - return the difference
The array is static and initialized only once. Now your subtraction can be defined as inline method or using pre-compiler:
#define MINUS(A,B) maxTable[A-B+255];
How it works? Well you want to pre-calculate all possible subtractions for unsigned chars. The results vary from -255 to +255, total of 511 different result. We define an array of all possible results but because in C we cannot access it from negative indices we use +255 (in [A-B+255]). You can remove this action by defining a pointer to the center of the array.
const unsigned char *result = maxTable+255;
#define MINUS(A,B) result[A-B];
use it like:
bsub = MINUS(13,15); // i.e 13-15 with zero cutoff as requested
Note that the execution is extremely fast. Only one subtraction and one pointer deference to get the result. No branching. The static arrays are very short so they will be fully loaded into CPU's cache to further speed up the calculation
Same would work for addition but with a bit different table (first 256 elements will be the indices and last 255 elements will be equal to 255 to emulate the cutoff beyond 255.
If you insist on bits operation, the answers that use (a>b) are wrong. This still might be implemented as branching. Use the sign-bit technique
// (num1>num2) ? 1 : 0
#define is_int_biggerNotEqual( num1,num2) ((((__int32)((num2)-(num1)))&0x80000000)>>31)
Now you can use it for calculation of subtraction and addition.
If you want to emulate the functions max(), min() without branching use:
inline __int32 MIN_INT(__int32 x, __int32 y){ __int32 d=x-y; return y+(d&(d>>31)); }
inline __int32 MAX_INT(__int32 x, __int32 y){ __int32 d=x-y; return x-(d&(d>>31)); }
My examples above use 32 bits integers. You can change it to 64, though I believe that 32 bits calculations run a bit faster. Up to you

How do i calculate combinations of n taken by k for large numbers in modulo 100003?

The maximum value of n is 100 000 and k can be anywhere from 0 to 100 000. The problem asks to calculate the value modulo 100 003. So I've used a function to calculate the factorial of n,n-k and k and then print fact(n)/(fact(n-k)*fact(k))% 100 003. What am I doing wrong and what would be the solution?
long long int fact (int z)
{
long long int r;
if(z<=1)return 1;
r=1LL*z*fact(z-1);
return r;
}
A long long is not big enough to hold fact(n) for interesting n, so you need a smarter algorithm.
applying the mod 100003 as you multiply is an easy way to keep things in range. But modular division is messy and in this case unnecessary.
Think about how to compute fact(n)/( fact(n-k)*fact(k) ) without ever needing to divide any big or modular numbers.
It will overflow for most z (z = 105 already overflows, for example).
Fortunately the integers modulo 100003 form a field (because 100003 is prime), so the entire calculation (even though it includes a division) can be done modulo 100003, thus preventing any overflow.
Most operations will be the same (except the extra modulo operation), but division becomes multiplication by the modular multiplicative inverse, which you can find using the extended Euclidian algorithm.
ncr=n!/((n-r)!*r!)
(a/b)%p!=((a%p)/(b%p))%p
using fermat little theorem we can compute this
Here fact() means factorial.
nCr % p = (fac[n] modInverse(fac[r]) % p modInverse(fac[n-r]) % p) % p;
Here modInverse() means modular inverse under
modulo p.
calculating ,moduloINverse if p is prime as give
long long modInverse( long long n, int p)
{
return expo(n, p - 2, p);
}
long long expo(long long a, long long b, long long mod) {
long long res = 1;
while (b > 0) {
if (b & 1)res = (res * a) % mod;
a = (a * a) % mod;
b = b >> 1;}
return res;}

(C++) Implementing Exponential Function Evaluator without Recursion?

I'm working on creating an exponential function evaluator (i.e., a function EXPO(int q, int p) that evaluates q^p) that does not use recursion, and I'm a little stuck on how to do so. Would you just multiply q by q p times or am I missing something?
Assuming that the exponent is non-negative:
long long int exp(int b, int e)
{ long long int r = 1;
long long int b_ = 1ll * b;
while(e > 0)
{ if(e & 1) r *= b_;
b_ *= b_;
e >>= 1;
}
return r;
}
This takes logarithmic time because we go through the bits of the exponent.
Unless p is negative, that's all there is to it.
Definitively would not use "Would you just multiply q by q p times" - it is unnecessarily inefficient.
On the other hand, many values will quickly overflow, even with unsigned long long.
The following runs in O(log2(b)) time.
Not much different that #saadtaame, but prefer to deal with unsigned math.
// return `a` raised to the `b` power.
unsigned long long ipower(unsigned a, unsigned b) {
unsigned long long y = 1;
unsigned long long power = a;
while (b) {
if (b % 2) y *= power;
b /= 2;
power *= power;
}
return y;
}
Note, this returns ipower(0,0) --> 1 which is a common expected result of the 0,0 special case. Mathematically an argument could be made for a result of 0, 1 or other results including an error. 1 suits many needs.
For unsigned exponents, what you have is (mostly) correct, you just have to handle the edge case of zero since n0 = 1. Pseudo-code follows:
def power(base,power):
result = 1
while power > 0:
result = result * base
power = power - 1
return result
For negative powers (if you're so inclined), you just have to realise that n-x = 1 / nx:
def power(base,power):
pneg = false
if power < 0:
power = -power
pneg = true
result = 1
while power > 0:
result = result * base
power = power - 1
if pneg:
result = 1 / result
return result

Calculating pow(a,b) mod n

I want to calculate ab mod n for use in RSA decryption. My code (below) returns incorrect answers. What is wrong with it?
unsigned long int decrypt2(int a,int b,int n)
{
unsigned long int res = 1;
for (int i = 0; i < (b / 2); i++)
{
res *= ((a * a) % n);
res %= n;
}
if (b % n == 1)
res *=a;
res %=n;
return res;
}
You can try this C++ code. I've used it with 32 and 64-bit integers. I'm sure I got this from SO.
template <typename T>
T modpow(T base, T exp, T modulus) {
base %= modulus;
T result = 1;
while (exp > 0) {
if (exp & 1) result = (result * base) % modulus;
base = (base * base) % modulus;
exp >>= 1;
}
return result;
}
You can find this algorithm and related discussion in the literature on p. 244 of
Schneier, Bruce (1996). Applied Cryptography: Protocols, Algorithms, and Source Code in C, Second Edition (2nd ed.). Wiley. ISBN 978-0-471-11709-4.
Note that the multiplications result * base and base * base are subject to overflow in this simplified version. If the modulus is more than half the width of T (i.e. more than the square root of the maximum T value), then one should use a suitable modular multiplication algorithm instead - see the answers to Ways to do modulo multiplication with primitive types.
In order to calculate pow(a,b) % n to be used for RSA decryption, the best algorithm I came across is Primality Testing 1) which is as follows:
int modulo(int a, int b, int n){
long long x=1, y=a;
while (b > 0) {
if (b%2 == 1) {
x = (x*y) % n; // multiplying with base
}
y = (y*y) % n; // squaring the base
b /= 2;
}
return x % n;
}
See below reference for more details.
1) Primality Testing : Non-deterministic Algorithms – topcoder
Usually it's something like this:
while (b)
{
if (b % 2) { res = (res * a) % n; }
a = (a * a) % n;
b /= 2;
}
return res;
The only actual logic error that I see is this line:
if (b % n == 1)
which should be this:
if (b % 2 == 1)
But your overall design is problematic: your function performs O(b) multiplications and modulus operations, but your use of b / 2 and a * a implies that you were aiming to perform O(log b) operations (which is usually how modular exponentiation is done).
Doing the raw power operation is very costly, hence you can apply the following logic to simplify the decryption.
From here,
Now say we want to encrypt the message m = 7, c = m^e mod n = 7^3 mod 33
= 343 mod 33 = 13. Hence the ciphertext c = 13.
To check decryption we compute m' = c^d mod n = 13^7 mod 33 = 7. Note
that we don't have to calculate the full value of 13 to the power 7
here. We can make use of the fact that a = bc mod n = (b mod n).(c mod
n) mod n so we can break down a potentially large number into its
components and combine the results of easier, smaller calculations to
calculate the final value.
One way of calculating m' is as follows:- Note that any number can be
expressed as a sum of powers of 2. So first compute values of 13^2,
13^4, 13^8, ... by repeatedly squaring successive values modulo 33. 13^2
= 169 ≡ 4, 13^4 = 4.4 = 16, 13^8 = 16.16 = 256 ≡ 25. Then, since 7 = 4 + 2 + 1, we have m' = 13^7 = 13^(4+2+1) = 13^4.13^2.13^1 ≡ 16 x 4 x 13 = 832
≡ 7 mod 33
Are you trying to calculate (a^b)%n, or a^(b%n) ?
If you want the first one, then your code only works when b is an even number, because of that b/2. The "if b%n==1" is incorrect because you don't care about b%n here, but rather about b%2.
If you want the second one, then the loop is wrong because you're looping b/2 times instead of (b%n)/2 times.
Either way, your function is unnecessarily complex. Why do you loop until b/2 and try to multiply in 2 a's each time? Why not just loop until b and mulitply in one a each time. That would eliminate a lot of unnecessary complexity and thus eliminate potential errors. Are you thinking that you'll make the program faster by cutting the number of times through the loop in half? Frankly, that's a bad programming practice: micro-optimization. It doesn't really help much: You still multiply by a the same number of times, all you do is cut down on the number of times testing the loop. If b is typically small (like one or two digits), it's not worth the trouble. If b is large -- if it can be in the millions -- then this is insufficient, you need a much more radical optimization.
Also, why do the %n each time through the loop? Why not just do it once at the end?
Calculating pow(a,b) mod n
A key problem with OP's code is a * a. This is int overflow (undefined behavior) when a is large enough. The type of res is irrelevant in the multiplication of a * a.
The solution is to ensure either:
the multiplication is done with 2x wide math or
with modulus n, n*n <= type_MAX + 1
There is no reason to return a wider type than the type of the modulus as the result is always represent by that type.
// unsigned long int decrypt2(int a,int b,int n)
int decrypt2(int a,int b,int n)
Using unsigned math is certainly more suitable for OP's RSA goals.
Also see Modular exponentiation without range restriction
// (a^b)%n
// n != 0
// Test if unsigned long long at least 2x values bits as unsigned
#if ULLONG_MAX/UINT_MAX - 1 > UINT_MAX
unsigned decrypt2(unsigned a, unsigned b, unsigned n) {
unsigned long long result = 1u % n; // Insure result < n, even when n==1
while (b > 0) {
if (b & 1) result = (result * a) % n;
a = (1ULL * a * a) %n;
b >>= 1;
}
return (unsigned) result;
}
#else
unsigned decrypt2(unsigned a, unsigned b, unsigned n) {
// Detect if UINT_MAX + 1 < n*n
if (UINT_MAX/n < n-1) {
return TBD_code_with_wider_math(a,b,n);
}
a %= n;
unsigned result = 1u % n;
while (b > 0) {
if (b & 1) result = (result * a) % n;
a = (a * a) % n;
b >>= 1;
}
return result;
}
#endif
int's are generally not enough for RSA (unless you are dealing with small simplified examples)
you need a data type that can store integers up to 2256 (for 256-bit RSA keys) or 2512 for 512-bit keys, etc
Here is another way. Remember that when we find modulo multiplicative inverse of a under mod m.
Then
a and m must be coprime with each other.
We can use gcd extended for calculating modulo multiplicative inverse.
For computing ab mod m when a and b can have more than 105 digits then its tricky to compute the result.
Below code will do the computing part :
#include <iostream>
#include <string>
using namespace std;
/*
* May this code live long.
*/
long pow(string,string,long long);
long pow(long long ,long long ,long long);
int main() {
string _num,_pow;
long long _mod;
cin>>_num>>_pow>>_mod;
//cout<<_num<<" "<<_pow<<" "<<_mod<<endl;
cout<<pow(_num,_pow,_mod)<<endl;
return 0;
}
long pow(string n,string p,long long mod){
long long num=0,_pow=0;
for(char c: n){
num=(num*10+c-48)%mod;
}
for(char c: p){
_pow=(_pow*10+c-48)%(mod-1);
}
return pow(num,_pow,mod);
}
long pow(long long a,long long p,long long mod){
long res=1;
if(a==0)return 0;
while(p>0){
if((p&1)==0){
p/=2;
a=(a*a)%mod;
}
else{
p--;
res=(res*a)%mod;
}
}
return res;
}
This code works because ab mod m can be written as (a mod m)b mod m-1 mod m.
Hope it helped { :)
use fast exponentiation maybe..... gives same o(log n) as that template above
int power(int base, int exp,int mod)
{
if(exp == 0)
return 1;
int p=power(base, exp/2,mod);
p=(p*p)% mod;
return (exp%2 == 0)?p:(base * p)%mod;
}
This(encryption) is more of an algorithm design problem than a programming one. The important missing part is familiarity with modern algebra. I suggest that you look for a huge optimizatin in group theory and number theory.
If n is a prime number, pow(a,n-1)%n==1 (assuming infinite digit integers).So, basically you need to calculate pow(a,b%(n-1))%n; According to group theory, you can find e such that every other number is equivalent to a power of e modulo n. Therefore the range [1..n-1] can be represented as a permutation on powers of e. Given the algorithm to find e for n and logarithm of a base e, calculations can be significantly simplified. Cryptography needs a tone of math background; I'd rather be off that ground without enough background.
For my code a^k mod n in php:
function pmod(a, k, n)
{
if (n==1) return 0;
power = 1;
for(i=1; i<=k; $i++)
{
power = (power*a) % n;
}
return power;
}
#include <cmath>
...
static_cast<int>(std::pow(a,b))%n
but my best bet is you are overflowing int (IE: the number is two large for the int) on the power I had the same problem creating the exact same function.
I'm using this function:
int CalculateMod(int base, int exp ,int mod){
int result;
result = (int) pow(base,exp);
result = result % mod;
return result;
}
I parse the variable result because pow give you back a double, and for using mod you need two variables of type int, anyway, in a RSA decryption, you should just use integer numbers.