Purpose of custom BitSet implementation in this Leetcode answer - c++

Last night I was working on the "Longest Palindromic Subsequence" problem on leetcode. After completing it I took a look at the fastest answer, and to my surprise it was a giant custom bitset implementation. I decided to try and reverse engineer it a bit and see if I could implement it using std::bitset, but I've run into some issues.
Here's the code:
#if __cplusplus>199711L //c++11
#include<unordered_map>
#endif
const int N=1005;
template<int S>
struct BitSet{
#define W 6
#define mask 63
#define get_size(n) ((n)<1?0:((n)+mask)>>W)
typedef unsigned long long uint; //typedef unsigned int uint;
uint a[get_size(S)];int size;
void reset(){memset(a,0,sizeof(uint)*size);}
BitSet():size(get_size(S)){reset();}
BitSet(uint x):size(get_size(S)){reset();a[0]=x;}
BitSet(const BitSet<S> &x):size(get_size(S)){*this=x;}
BitSet& set(int x,int y){
//if (y<0||y>1){printf("error!\n");return *this;}
int X=x>>W,Y=x&mask;
if (y)a[X]|=1ull<<Y;else a[X]&=~(1ull<<Y);
return *this;
}
int find(int x){int X=x>>W,Y=x&mask;return (a[X]>>Y)&1ull;}
int operator [](int x){return find(x);}
BitSet& operator =(const BitSet &y){
memcpy(a,y.a,sizeof(uint)*size);
return *this;
}
BitSet<S> operator |(const BitSet<S> &y)const{return BitSet<S>(*this)|=y;}
BitSet<S> operator &(const BitSet<S> &y)const{return BitSet<S>(*this)&=y;}
BitSet<S> operator ^(const BitSet<S> &y)const{return BitSet<S>(*this)^=y;}
BitSet<S> operator +(const BitSet<S> &y)const{return BitSet<S>(*this)+=y;}
BitSet<S> operator -(const BitSet<S> &y)const{return BitSet<S>(*this)-=y;}
BitSet<S> operator <<(int x)const{return BitSet<S>(*this)<<=x;}
BitSet<S> operator >>(int x)const{return BitSet<S>(*this)>>=x;}
BitSet<S> operator ~()const{return BitSet<S>(*this).flip();}
BitSet<S>& operator =(const char *s){
memset(a,0,sizeof(uint)*size);
for (int i=0;i<S;++i){
if (s[i]!='0'&&s[i]!='1')break;
int X=i>>W,Y=i&mask;
if (s[i]=='1')a[X]|=1ull<<Y;
}
return *this;
}
BitSet<S>& operator =(const int *s){
memset(a,0,sizeof(uint)*size);
for (int i=0;i<S;++i){
if (s[i]!=0&&s[i]!=1)break;
int X=i>>W,Y=i&mask;
if (s[i]==1)a[X]|=1ull<<Y;
}
return *this;
}
BitSet<S>& operator <<=(int x){
int shift=x>>W; int delta=x&mask,delta1=mask+1-delta;
if (!x)return *this;
if (delta==0)for (uint *p=a+size-1,*q=p-shift,*end=a+shift-1;p!=end;--p,--q)*p=*q;
else {
for (uint *p=a+size-1,*q1=p-shift,*q2=p-shift-1,*end=a+shift;p!=end;--p,--q1,--q2)*p=(*q1<<delta)|(*q2>>delta1);
a[shift]=a[0]<<delta;
}
memset(a,0,sizeof(uint)*shift); //for (uint *p=a,*end=a+shift;p!=end;++p)*p=0;
return *this;
}
BitSet<S>& operator >>=(int x){
int shift=x>>W; int delta=x&mask,delta1=mask+1-delta;
if (!x)return *this;
correction();
if (delta==0)for (uint *p=a,*q=p+shift,*end=a+size-shift;p!=end;++p,++q)*p=*q;
else {
for (uint *p=a,*q1=p+shift,*q2=p+shift+1,*end=a+size-shift-1;p!=end;++p,++q1,++q2)*p=(*q1>>delta)|(*q2<<delta1);
a[size-shift-1]=a[size-1]>>delta;
}
memset(a+size-shift,0,sizeof(uint)*shift);
return *this;
}
BitSet<S>& operator |=(const BitSet<S> &y){
uint *startA=a;const uint *startB=y.a,*endA=a+size;
while (startA!=endA){*startA|=*startB;++startA;++startB;}
//for (int i=0;i<size;++i)a[i]|=y.a[i];
return *this;
}
/*BitSet<S>& operator |=(const BitSet<S> &y){
uint *p0=a,*p1=p0+1,*p2=p0+2,*p3=p0+3;const uint *q0=y.a,*q1=q0+1,*q2=q0+2,*q3=q0+3,*pend=a+((size>>2)<<2);
while (p0!=pend){
*p0|=*q0; p0+=4; q0+=4;
*p1|=*q1; p1+=4; q1+=4;
*p2|=*q2; p2+=4; q2+=4;
*p3|=*q3; p3+=4; q3+=4;
}
for (int i=0;i<(size&3);++i)*p0++|=*q0++;
return *this;
}*/
BitSet<S>& operator &=(const BitSet<S> &y){
uint *startA=a;const uint *startB=y.a,*endA=a+size;
while (startA!=endA){*startA&=*startB;++startA;++startB;}
return *this;
}
BitSet<S>& operator ^=(const BitSet<S> &y){
uint *startA=a;const uint *startB=y.a,*endA=a+size;
while (startA!=endA){*startA^=*startB;++startA;++startB;}
return *this;
}
BitSet<S>& operator +=(const BitSet<S> &y){
uint t=0,*p=a,*end=a+size; const uint *q=y.a;
while (p!=end){
uint p1=*p; *p=p1+*q+t;
t=(*p<p1)||(p1+t<t);
++p; ++q;
}
return *this;
}
BitSet<S>& operator -=(const BitSet<S> &y){
uint t=0,*p=a,*end=a+size; const uint *q=y.a;
while (p!=end){
uint p1=*p; *p=p1-*q-t;
t=(*p>p1)||(p1+t<t);
++p; ++q;
}
return *this;
}
operator bool(){return count()>0;}
BitSet<S>& flip(){
//for (uint *start=a,*end=a+size;start!=end;*start=~*start,++start);
uint *p0=a,*p1=p0+1,*p2=p0+2,*p3=p0+3,*pend=a+((size>>2)<<2);
while (p0!=pend){
*p0=~*p0; p0+=4;
*p1=~*p1; p1+=4;
*p2=~*p2; p2+=4;
*p3=~*p3; p3+=4;
}
for (int i=0;i<(size&3);++i,++p0)*p0=~*p0;
return *this;
}
//void flip(){*this=~*this;}
void flip(int x){a[x>>W]^=1ull<<(x&mask);}
int popcount(uint x)const{
x-=(x&0xaaaaaaaaaaaaaaaaull)>>1;
x=((x&0xccccccccccccccccull)>>2)+(x&0x3333333333333333ull);
x=((x>>4)+x)&0x0f0f0f0f0f0f0f0full;
return (x*0x0101010101010101ull)>>56;
}
int count(){
int res=0;
correction();
for (int i=0;i<size;++i)res+=__builtin_popcountll(a[i]); //popcount
return res;
}
int clz(){
correction();
int res=0;
if (a[size-1])res=__builtin_clzll(a[size-1])-(mask+1-(S&mask));
else {
res+=S&mask;
for (int i=size-2;i>=0;--i)
if (a[i]){res+=__builtin_clzll(a[i]); break;}
else res+=mask+1;
}
return res;
}
int ctz(){
correction();
int res=0;
for (int i=0;i<size;++i)
if (a[i]){res+=__builtin_ctzll(a[i]); break;}
else res+=mask+1;
return min(res,S);
}
int ffs(){
int res=ctz()+1;
if (res==S+1)res=0;
return res;
}
uint to_uint(){
correction();
return a[0];
}
void print(){
for (int i=0;i<size;++i)
for (int j=0;j<=mask&&(i<<W)+j+1<=S;++j)printf("%I64d",(a[i]>>j)&1ull);
printf("\n");
}
void correction(){if (S&mask)a[size-1]&=(1ull<<(S&mask))-1;}
#undef mask
#undef W
#undef get_size
};
int a[N],b[N];
BitSet<N> row[2],X,Y;
unordered_map<int,vector<int> > S;
unordered_map<int,BitSet<N> > match;
class Solution {
public:
int longestPalindromeSubseq(string s) {
int n=s.size(),m=n;
S.clear();match.clear();row[1].reset();
for (int i=0;i<n;++i)a[i]=int(s[i]),S[a[i]].push_back(i);
for (int i=0;i<m;++i)b[i]=int(s[n-1-i]);
for (int i=0;i<m;++i)if (match.find(b[i])==match.end()){
unordered_map<int,BitSet<N> >::iterator x=match.insert(make_pair(b[i],BitSet<N>())).first;
for (vector<int>::iterator j=S[b[i]].begin();j!=S[b[i]].end();++j)x->second.set(*j,1);
}
for (int i=0,now=0;i<m;++i,now^=1)
X=(row[now^1]|match[b[i]]).set(n,1),row[now]=(X&((X-(row[now^1]<<1).set(0,1))^X)).set(n,0);
return row[(m-1)&1].count();
}
};
And here's my attempt at cleaning it up/understanding it:
Bitset.h:
#pragma once
//#if __cplusplus>199711L //c++11
#include<unordered_map>
//#endif
#include <intrin.h>
const int N = 1005; //this size doesn't even matter?? It's just enigmatically assigning a type of int, I guess??
template<int S>
struct BitSet {
#define W 6
#define mask 63
#define get_size(n) ((n)<1?0:((n)+mask)>>W)
//members
typedef unsigned long long uint; //typedef unsigned int uint; //not sure why they didn't just use size_t here...
uint a[get_size(S)]; //represents the "BitSet" I think...
int size;
//constructors
BitSet() :size(get_size(S)) { reset(); }
BitSet(uint x) :size(get_size(S)) { reset(); a[0] = x; }
BitSet(const BitSet<S>& x) :size(get_size(S)) { *this = x; }
//utility functions
BitSet& set(int x, int y) {
//if (y<0||y>1){printf("error!\n");return *this;}
int X = x >> W, Y = x & mask; //bit shift with the magic numbers??? //What is the significance of 6, 63, and 1005?
if (y)a[X] |= 1ull << Y; else a[X] &= ~(1ull << Y);
return *this;
}
void reset() { memset(a, 0, sizeof(uint) * size); }
int find(int x) { int X = x >> W, Y = x & mask; return (a[X] >> Y) & 1ull; }
//void flip(){*this=~*this;}
void flip(int x) { a[x >> W] ^= 1ull << (x & mask); }
int popcount(uint x)const {
x -= (x & 0xaaaaaaaaaaaaaaaaull) >> 1;
x = ((x & 0xccccccccccccccccull) >> 2) + (x & 0x3333333333333333ull);
x = ((x >> 4) + x) & 0x0f0f0f0f0f0f0f0full;
return (x * 0x0101010101010101ull) >> 56;
}
int count() {
int res = 0;
correction();
for (int i = 0; i < size; ++i)res += __popcnt(a[i]); //popcount // __builtin_popcountll is only available on GCC apparently. Substituted with library version.
return res;
}
int clz() {
correction();
int res = 0;
if (a[size - 1])res = __builtin_clzll(a[size - 1]) - (mask + 1 - (S & mask));
else {
res += S & mask;
for (int i = size - 2; i >= 0; --i)
if (a[i]) { res += __builtin_clzll(a[i]); break; }
else res += mask + 1;
}
return res;
}
int ctz() {
correction();
int res = 0;
for (int i = 0; i < size; ++i)
if (a[i]) { res += __builtin_ctzll(a[i]); break; }
else res += mask + 1;
return min(res, S);
}
int ffs() {
int res = ctz() + 1;
if (res == S + 1)res = 0;
return res;
}
uint to_uint() {
correction();
return a[0];
}
void print() {
for (int i = 0; i < size; ++i)
for (int j = 0; j <= mask && (i << W) + j + 1 <= S; ++j)printf("%I64d", (a[i] >> j) & 1ull);
printf("\n");
}
void correction() { if (S & mask)a[size - 1] &= (1ull << (S & mask)) - 1; }
BitSet<S>& flip() {
//for (uint *start=a,*end=a+size;start!=end;*start=~*start,++start);
uint* p0 = a, * p1 = p0 + 1, * p2 = p0 + 2, * p3 = p0 + 3, * pend = a + ((size >> 2) << 2);
while (p0 != pend) {
*p0 = ~*p0; p0 += 4;
*p1 = ~*p1; p1 += 4;
*p2 = ~*p2; p2 += 4;
*p3 = ~*p3; p3 += 4;
}
for (int i = 0; i < (size & 3); ++i, ++p0)*p0 = ~*p0;
return *this;
}
//operators
int operator [](int x) { return find(x); }
BitSet& operator =(const BitSet& y) {
memcpy(a, y.a, sizeof(uint) * size);
return *this;
}
BitSet<S>& operator =(const char* s) {
memset(a, 0, sizeof(uint) * size);
for (int i = 0; i < S; ++i) {
if (s[i] != '0' && s[i] != '1')break;
int X = i >> W, Y = i & mask;
if (s[i] == '1')a[X] |= 1ull << Y;
}
return *this;
}
BitSet<S>& operator =(const int* s) {
memset(a, 0, sizeof(uint) * size);
for (int i = 0; i < S; ++i) {
if (s[i] != 0 && s[i] != 1)break;
int X = i >> W, Y = i & mask;
if (s[i] == 1)a[X] |= 1ull << Y;
}
return *this;
}
BitSet<S> operator |(const BitSet<S>& y)const { return BitSet<S>(*this) |= y; }
BitSet<S> operator &(const BitSet<S>& y)const { return BitSet<S>(*this) &= y; }
BitSet<S> operator ^(const BitSet<S>& y)const { return BitSet<S>(*this) ^= y; }
BitSet<S> operator +(const BitSet<S>& y)const { return BitSet<S>(*this) += y; }
BitSet<S> operator -(const BitSet<S>& y)const { return BitSet<S>(*this) -= y; }
BitSet<S> operator <<(int x)const { return BitSet<S>(*this) <<= x; }
BitSet<S> operator >>(int x)const { return BitSet<S>(*this) >>= x; }
BitSet<S> operator ~()const { return BitSet<S>(*this).flip(); }
BitSet<S>& operator <<=(int x) {
int shift = x >> W; int delta = x & mask, delta1 = mask + 1 - delta;
if (!x)return *this;
if (delta == 0)for (uint* p = a + size - 1, *q = p - shift, *end = a + shift - 1; p != end; --p, --q)*p = *q;
else {
for (uint* p = a + size - 1, *q1 = p - shift, *q2 = p - shift - 1, *end = a + shift; p != end; --p, --q1, --q2)*p = (*q1 << delta) | (*q2 >> delta1);
a[shift] = a[0] << delta;
}
memset(a, 0, sizeof(uint) * shift); //for (uint *p=a,*end=a+shift;p!=end;++p)*p=0; //if this is a left shift... why is it setting the beginning of the "a" array to 0???... unless the damn array is "backwards". Sheesh.
//wait... this is a right shift according to https://orthallelous.wordpress.com/2019/10/24/magic-numbers-encoding-truth-tables-into-giant-single-values/ ... fuck.
return *this;
}
BitSet<S>& operator >>=(int x) {
int shift = x >> W; int delta = x & mask, delta1 = mask + 1 - delta;
if (!x)return *this;
correction();
if (delta == 0)for (uint* p = a, *q = p + shift, *end = a + size - shift; p != end; ++p, ++q)*p = *q;
else {
for (uint* p = a, *q1 = p + shift, *q2 = p + shift + 1, *end = a + size - shift - 1; p != end; ++p, ++q1, ++q2)*p = (*q1 >> delta) | (*q2 << delta1);
a[size - shift - 1] = a[size - 1] >> delta;
}
memset(a + size - shift, 0, sizeof(uint) * shift);
return *this;
}
BitSet<S>& operator |=(const BitSet<S>& y) {
uint* startA = a; const uint* startB = y.a, * endA = a + size;
while (startA != endA) { *startA |= *startB; ++startA; ++startB; }
//for (int i=0;i<size;++i)a[i]|=y.a[i];
return *this;
}
/*BitSet<S>& operator |=(const BitSet<S> &y){
uint *p0=a,*p1=p0+1,*p2=p0+2,*p3=p0+3;const uint *q0=y.a,*q1=q0+1,*q2=q0+2,*q3=q0+3,*pend=a+((size>>2)<<2);
while (p0!=pend){
*p0|=*q0; p0+=4; q0+=4;
*p1|=*q1; p1+=4; q1+=4;
*p2|=*q2; p2+=4; q2+=4;
*p3|=*q3; p3+=4; q3+=4;
}
for (int i=0;i<(size&3);++i)*p0++|=*q0++;
return *this;
}*/
BitSet<S>& operator &=(const BitSet<S>& y) {
uint* startA = a; const uint* startB = y.a, * endA = a + size;
while (startA != endA) { *startA &= *startB; ++startA; ++startB; }
return *this;
}
BitSet<S>& operator ^=(const BitSet<S>& y) {
uint* startA = a; const uint* startB = y.a, * endA = a + size;
while (startA != endA) { *startA ^= *startB; ++startA; ++startB; }
return *this;
}
BitSet<S>& operator +=(const BitSet<S>& y) {
uint t = 0, * p = a, * end = a + size; const uint* q = y.a;
while (p != end) {
uint p1 = *p; *p = p1 + *q + t;
t = (*p < p1) || (p1 + t < t);
++p; ++q;
}
return *this;
}
BitSet<S>& operator -=(const BitSet<S>& y) {
uint t = 0, * p = a, * end = a + size; const uint* q = y.a;
while (p != end) {
uint p1 = *p; *p = p1 - *q - t;
t = (*p > p1) || (p1 + t < t);
++p; ++q;
}
return *this;
}
operator bool() { return count() > 0; }
#undef mask
#undef W
#undef get_size
};
LeetCode516.cpp
// LeetCode516.cpp : This file contains the 'main' function. Program execution begins and ends there.
//
#include <iostream>
#include "BitSet.h"
#include <bitset>
// If parameter is not true, test fails
// This check function would be provided by the test framework
#define IS_TRUE(x) { if (!x) std::cout << __FUNCTION__ << " failed on line " << __LINE__ << std::endl; else std::cout << __FUNCTION__ << " passed" << std:: endl;}
int longestPalindromeSubseq(std::string s) {
int a[N] = { 0 }, b[N] = { 0 }; //2 integer arrays - initialzing to 0 isn't necessary, but seems to make the debug output more readable?
BitSet<N> row[2], X, Y; //3 bitsets... one of them is actually 2 (lol)
std::unordered_map<int, std::vector<int>> S; //This map tracks the number of occurrences of each letter?
std::unordered_map<int, BitSet<N>> match;
int n = s.size(), m = n;
S.clear();
match.clear();
//row[1].reset(); //WTF is this? Garbage.
//For each piece of the string, cast it from a char to an int and shove it in the first array, then push back a copy of each array into the UOMap S.
for (int i = 0; i < n; ++i) {
a[i] = static_cast<int>(s[i]); //changed C-style cast to static_cast
S[a[i]].push_back(i);
}
//Set the second array to a backwards copy of A I guess? (while converting the characters to integers ofc)
for (int i = 0; i < m; ++i) {
b[i] = int(s[n - 1 - i]);
}
for (int i = 0; i < m; ++i) {
//so if it's the first loop iteration or the find (working on the "backwards" array) matches the end, do this:
if (match.find(b[i]) == match.end()) { //how does this even run the first time? Match is getting cleared and never set? I guess it's working because "find(b[i])" and "end()" are both 0? - Kind of - turns out that if the find function fails, it returns "end"...
//Insert a new bitset with a "key" of "b[i]" into the match map.
auto x = match.insert(std::make_pair(b[i], BitSet<N>())).first; // replaced std::unordered_map<int, BitSet<N> >::iterator with auto... that's nice.
//Then using the iterator we just created above, loop through all entries in UOMap S under that key and ???
//Since we found it in "match", lets check for it(?) in UOMap S... and then set some bit in match based on how many entries there are in S?
for (auto j = S[b[i]].begin(); j != S[b[i]].end(); ++j) {
x->second.set(*j, 1); //this is setting some bit value in "match" to 1 I guess.
}
}
}
for (int i = 0, now = 0; i < m; ++i, now ^= 1) { //Is "*=" being intentionally obfuscated here to "^=" or is it actually different?? Ugh... nasty.
X = (row[now ^ 1] | match[b[i]]).set(n, 1); // if the caret is supposed to represent XOR, why does this still work whenever I replace it with a *... very funky.
row[now] = (X & ((X - (row[now ^ 1] << 1).set(0, 1)) ^ X)).set(n, 0);
}
return row[(m - 1) & 1].count();
}
template<int S>
std::bitset<S>& operator-(const std::bitset<S>&y) { return std::bitset<S>(*this) -= y; }
#define W 6
#define mask 63
#define get_size(n) ((n)<1?0:((n)+mask)>>W)
template<size_t S>
std::bitset<S>& operator-(std::bitset<S>& y, const std::bitset<S>& z) {
//return std::bitset<S>(*this) -= y;
size_t thing[get_size(S)];
size_t t = 0, * p = thing, * end = thing + get_size(y);
const size_t* q = z[0]; //z.a
while (p != end) {
size_t p1 = *p;
*p = p1 - *q - t;
t = (*p > p1) || (p1 + t < t);
++p;
++q;
}
return y;
}
template<size_t S>
std::bitset<S>& operator+(std::bitset<S>& y, const std::bitset<S>& z) {
size_t thing[get_size(S)];
size_t t = 0, * p = thing, * end = thing + get_size(y);
const size_t* q = z[0];
while (p != end) {
size_t p1 = *p; *p = p1 + *q + t;
t = (*p < p1) || (p1 + t < t);
++p;
++q;
}
return y;
}
template<size_t S>
std::bitset<S>& operator -=(const std::bitset<S>& y) {
uint t = 0, * p = a, * end = a + size;
const uint* q = y.a;
while (p != end) {
uint p1 = *p; *p = p1 - *q - t;
t = (*p > p1) || (p1 + t < t);
++p; ++q;
}
return *this;
}
int LPS_STL(std::string s) {
const int M = 1005;
int c[M], d[M];
std::bitset<M> row[2], X, Y;
std::unordered_map<int, std::vector<int> > S;
std::unordered_map<int, std::bitset<M> > match;
int n = s.size(), m = n;
S.clear();
match.clear();
row[1].reset();
for (int i = 0; i < n; ++i)c[i] = int(s[i]), S[c[i]].push_back(i);
for (int i = 0; i < m; ++i)d[i] = int(s[n - 1 - i]);
for (int i = 0; i < m; ++i)if (match.find(d[i]) == match.end()) {
std::unordered_map<int, std::bitset<M> >::iterator x = match.insert(std::make_pair(d[i], std::bitset<M>())).first;
for (std::vector<int>::iterator j = S[d[i]].begin(); j != S[d[i]].end(); ++j)x->second.set(*j, 1);
}
for (int i = 0, now = 0; i < m; ++i, now ^= 1)
X = (row[now ^ 1] | match[d[i]]).set(n, 1), row[now] = (X & (( X - (row[now ^ 1] << 1).set(0, 1)) ^ X)).set(n, 0);
return row[(m - 1) & 1].count();
}
void test1() {
int test = longestPalindromeSubseq("");
IS_TRUE(test == 0);
}
void test2() {
int test = longestPalindromeSubseq("bbbab");
IS_TRUE(test == 4);
}
void test3() {
int test = longestPalindromeSubseq("cbbd");
IS_TRUE(test == 2);
}
void test4() {
int test = longestPalindromeSubseq("cbbd");
IS_TRUE(test == 0); //purposefully fail
}
void test5() {
int test = longestPalindromeSubseq("cacbcbba");
IS_TRUE(test == 5);
}
void test6() {
int test = longestPalindromeSubseq("eeeecdeabfbeeb");
IS_TRUE(test == 7);
}
//Std::bitset tests
//void test7() {
// int test = LPS_STL("");
// IS_TRUE(test == 0);
//}
//void test8() {
// int test = LPS_STL("bbbab")
// IS_TRUE(test == 4);
//}
int main()
{
std::cout << "Longest Palindrome Subsequence\n";
test1();
test2();
test3();
test4();
test5();
test6();
test7();
test8();
}
Main question: Is there a core reason this person chose to implement a custom bitset instead of the std library?
Extraneous questions:
Is some of this code intentionally obfuscated?
It feels a bit naive to ask this, but do some of you intentionally code this way?
Why are they using the value of N (an int) as the type for their bitsets?
Some of their formulas are incredibly long.
Attempting to implement the necessary operators for the std::bitset implementation didn't go well. Can someone elucidate the reason these don't exist? I feel like there's a good reason for it regarding expected output, but I can't articulate it.
Is this question more appropriate for another SO site, such as Code Review?

Is some of this code intentionally obfuscated?
It is just faster to type
#define W 6
than
constexpr size_t bitcoverage { std::bitwidth(sizeof(int64_t)*CHAR_BIT)-1 }; // bits needed to represent 8*8=64 2^7 so 7-1=6 which 0-63 bits.
(if I got it right ...)
It feels a bit naive to ask this, but do some of you intentionally code this way?
When I started programming I did, now I have realized that code is read more than written and it could confuse people, including me, who read it later.
Why are they using the value of N (an int) as the type for their bitsets?
N is just how many entries there is max, see the link you provided.
It actually uses an int64_t to store the bits. Which for many operaton on most larger CPU's is far more efficient than the same doing the same operation on an byte as you can do or/and/etc on a lot of bits with just one instuction.
Some of their formulas are incredibly long.
I am not sure which your are speculating about.
Attempting to implement the necessary operators for the std::bitset implementation didn't go well. Can someone elucidate the reason these don't exist? I feel like there's a good reason for it regarding expected output, but I can't articulate it.
Again, what exact are you referring to?

Related

C++ Bitset algorithm

I am given a nxn grid with filled with 1 or 0. I want to count the number of subgrids where the corner tiles are all 1s. My solution goes through all pairs of rows and counts the number of matching 1s then it uses the formula numOf1s * (numOf1s-1)/2 and adds to the result. However, when I submit my solution on https://cses.fi/problemset/task/2137, there is no output on inputs with n = 3000 (probably caused by some error). What could the error be?
int main()
{
int n; cin>> n;
vector<bitset<3000>> grid(n);
for(int i=0;i<n;i++){
cin >> grid[i];
}
long result = 0;
for(int i=0;i<n-1;i++){
for(int j=i+1;j<n;j++){
int count = (grid[i]&grid[j]).count();
result += (count*(count-1))/2;
}
}
cout << result;
}
This solution will cause a time limit exceeded. bitset::count() is O(n) in worst case. The total complexity of your code is O(n^3). In the worst-case the number of operations would be 3000^3 > 10^10 which is too large.
I'm not sure this solution is the best you can come up with, but it is based on the original solution, with a homebrew alternative for the bitset. This allows me to work with 64 bits blocks, and using a fast popcnt(). An hardware version would be even better, as it would be to work with AVX registers, but this should be more portable and it works on cses.fi. Basically instead of generating a long intersection bitset and later count the number of ones, the function count_common() makes a piece of the intersection and immediately uses it just to count the ones.
The stream extractor could be probably improved, saving some more time.
#include <iostream>
#include <array>
#include <cstdint>
#include <climits>
uint64_t popcnt(uint64_t v) {
v = v - ((v >> 1) & (uint64_t)~(uint64_t)0 / 3);
v = (v & (uint64_t)~(uint64_t)0 / 15 * 3) + ((v >> 2) & (uint64_t)~(uint64_t)0 / 15 * 3);
v = (v + (v >> 4)) & (uint64_t)~(uint64_t)0 / 255 * 15;
uint64_t c = (uint64_t)(v * ((uint64_t)~(uint64_t)0 / 255)) >> (sizeof(uint64_t) - 1) * CHAR_BIT;
return c;
}
struct line {
uint64_t cells_[47] = { 0 }; // 3000/64 = 47
uint64_t& operator[](int pos) { return cells_[pos]; }
const uint64_t& operator[](int pos) const { return cells_[pos]; }
};
uint64_t count_common(const line& a, const line& b) {
uint64_t u = 0;
for (int i = 0; i < 47; ++i) {
u += popcnt(a[i] & b[i]);
}
return u;
}
std::istream& operator>>(std::istream& is, line& ln) {
is >> std::ws;
int pos = 0;
uint64_t val = 0;
while (true) {
char ch = is.get();
if (is && ch == '\n') {
break;
}
if (ch == '1') {
val |= 1LL << (63 - pos % 64);
}
if ((pos + 1) % 64 == 0) {
ln[pos / 64] = val;
val = 0;
}
++pos;
}
if (pos % 64 != 0) {
ln[pos / 64] = val;
}
return is;
}
struct grid {
int n_;
std::array<line, 3000> data_;
line& operator[](int r) {
return data_[r];
}
};
std::istream& operator>>(std::istream& is, grid& g) {
is >> g.n_;
for (int r = 0; r < g.n_; ++r) {
is >> g[r];
}
return is;
}
int main()
{
grid g;
std::cin >> g;
uint64_t count = 0;
for (int r1 = 0; r1 < g.n_; ++r1) {
for (int r2 = r1 + 1; r2 < g.n_; ++r2) {
uint64_t n = count_common(g[r1], g[r2]);
count += n * (n - 1) / 2;
}
}
std::cout << count << '\n';
return 0;
}

BVH Tree Construction - Compiling gives Random mistakes

Much thanks for the help in additionally.
I'm trying to build a BVH Tree with Surface Area Heuristic, but everytime I compile my code it gives me random errors like:
"Access violation reading location"
"Run-Time Check Failure #2 - Stack around the variable 'x' was
corrupted."
"Stack overflow "
The errors happen in the BVH::buildSAH() function.
And I have tried to find a solution for the whole day, meaningless. Could it be something from the std::partition function or from sending variables with pointers to a recursion?
I'm reading from the book "Physically Based Rendering: From Theory to Implementation
By Matt Pharr, Greg Humphreys"
It works for 2 primitives in the area, but thats trivial...
If you would like to clone: https://github.com/vkaytsanov/MortonCode-BVH-KD
My BVH.hpp:
#include <vector>
#include <cassert>
#include <algorithm>
#include "memory.hpp"
#include "Screen.hpp"
#include "Point3D.hpp"
#include "BoundBox.hpp"
#pragma once
enum Axis{
X, Y, Z
};
struct MortonPrimitive{
int primitiveIndex;
uint32_t mortonCode;
};
struct BVHPrimitiveInfo {
BVHPrimitiveInfo() {}
BVHPrimitiveInfo(int primitiveNumber, const BoundBox& box) : primitiveNumber(primitiveNumber), box(box),
centroid(Point3D(box.pMin.x* 0.5f + box.pMax.x * 0.5f, box.pMin.y* 0.5f + box.pMax.y * 0.5f, box.pMin.z* 0.5f + box.pMax.z * 0.5f)) {}
int primitiveNumber;
BoundBox box;
Point3D centroid;
};
struct BVHNode {
void InitLeaf(int first, int n, const BoundBox& b) {
firstPrimOffset = first;
nPrimitives = n;
box = b;
children[0] = children[1] = nullptr;
}
void InitInterior(int axis, BVHNode* c0, BVHNode* c1) {
assert(c0 != NULL || c1 != NULL);
children[0] = c0;
children[1] = c1;
this->box = Union(c0->box, c1->box);
splitAxis = axis;
nPrimitives = 0;
}
BoundBox box;
BVHNode* children[2];
int splitAxis, firstPrimOffset, nPrimitives;
};
struct LinearBVHNode {
BoundBox bounds;
union {
int primitivesOffset; // leaf
int secondChildOffset; // interior
};
uint16_t nPrimitives; // 0 -> interior node
uint8_t axis; // interior node: xyz
uint8_t pad[1]; // ensure 32 byte total size
};
struct BVHLittleTree {
int startIndex;
int numPrimitives;
BVHNode* nodes;
};
struct BVH {
BVH(std::vector<std::shared_ptr<Primitive>> p) : primitives(std::move(p)) {
std::vector<BVHPrimitiveInfo> BVHPrimitives;
BVHPrimitives.reserve(primitives.size());
for (int i = 0; i < primitives.size(); i++) {
BVHPrimitives.push_back({ i, primitives[i]->box });
}
MemoryArena arena(1024 * 1024);
int totalNodes = 0;
std::vector<std::shared_ptr<Primitive>> orderedPrimitives;
orderedPrimitives.reserve(primitives.size());
BVHNode* root;
root = HLBVHBuild(arena, BVHPrimitives, &totalNodes, orderedPrimitives);
primitives.swap(orderedPrimitives);
BVHPrimitives.resize(0);
printf("BVH created with %d nodes for %d "
"primitives (%.4f MB), arena allocated %.2f MB\n",
(int)totalNodes, (int)primitives.size(),
float(totalNodes * sizeof(LinearBVHNode)) /
(1024.f * 1024.f),
float(arena.TotalAllocated()) /
(1024.f * 1024.f));
assert(root != NULL);
nodes = AllocAligned<LinearBVHNode>(totalNodes);
int offset = 0;
flattenBVHTree(root, &offset);
}
~BVH() { FreeAligned(nodes); }
BVHNode* build(std::vector<MortonPrimitive>&, std::vector<Primitive>&);
BVHNode* HLBVHBuild(MemoryArena& arena, const std::vector<BVHPrimitiveInfo>& BVHPrimitives, int* totalNodes, std::vector<std::shared_ptr<Primitive>>& orderedPrims);
BVHNode* emit(BVHNode*& nodes, const std::vector<BVHPrimitiveInfo>& BVHPrimitives, MortonPrimitive* mortonPrimitives, std::vector<std::shared_ptr<Primitive>>&, int, int*, int*, int);
BVHNode* buildSAH(MemoryArena& arena, std::vector<BVHNode*>& treeRoots, int start, int end, int* total) const;
int flattenBVHTree(BVHNode*, int*);
std::vector<std::shared_ptr<Primitive>> primitives;
LinearBVHNode* nodes = nullptr;
int maxPrimsInNode = 1;
};
inline uint32_t LeftShift3(uint32_t x) {
if (x == (1 << 10)) --x;
x = (x | (x << 16)) & 0b00000011000000000000000011111111;
x = (x | (x << 8)) & 0b00000011000000001111000000001111;
x = (x | (x << 4)) & 0b00000011000011000011000011000011;
x = (x | (x << 2)) & 0b00001001001001001001001001001001;
return x;
}
uint32_t EncodeMorton3(const Point3D& p) {
return (LeftShift3(p.z) << 2) |
(LeftShift3(p.y) << 1) |
(LeftShift3(p.x) << 0);
}
short bitValue(uint32_t& number, uint32_t& mask) {
return number & mask ? 1 : 0;
}
static void radixSort(std::vector<MortonPrimitive>* v)
{
std::vector<MortonPrimitive> tempVector(v->size());
const int bitsPerPass = 6;
const int nBits = 30;
static_assert((nBits % bitsPerPass) == 0,
"Radix sort bitsPerPass must evenly divide nBits");
const int nPasses = nBits / bitsPerPass;
for (int pass = 0; pass < nPasses; ++pass) {
// Perform one pass of radix sort, sorting _bitsPerPass_ bits
int lowBit = pass * bitsPerPass;
// Set in and out vector pointers for radix sort pass
std::vector<MortonPrimitive>& in = (pass & 1) ? tempVector : *v;
std::vector<MortonPrimitive>& out = (pass & 1) ? *v : tempVector;
// Count number of zero bits in array for current radix sort bit
const int nBuckets = 1 << bitsPerPass;
int bucketCount[nBuckets] = { 0 };
const int bitMask = (1 << bitsPerPass) - 1;
for (const MortonPrimitive& mp : in) {
int bucket = (mp.mortonCode >> lowBit) & bitMask;
++bucketCount[bucket];
}
// Compute starting index in output array for each bucket
int outIndex[nBuckets];
outIndex[0] = 0;
for (int i = 1; i < nBuckets; ++i)
outIndex[i] = outIndex[i - 1] + bucketCount[i - 1];
// Store sorted values in output array
for (const MortonPrimitive& mp : in) {
int bucket = (mp.mortonCode >> lowBit) & bitMask;
out[outIndex[bucket]++] = mp;
}
}
// Copy final result from _tempVector_, if needed
if (nPasses & 1) std::swap(*v, tempVector);
}
//BVHNode* BVH::build(std::vector<MortonPrimitive>& mortonPrimitives, std::vector<Primitive>& prims) {
//
//
//}
struct BucketInfo {
int count = 0;
BoundBox bounds;
};
BVHNode* BVH::HLBVHBuild(MemoryArena& arena, const std::vector<BVHPrimitiveInfo>& BVHPrimitives, int* totalNodes, std::vector<std::shared_ptr<Primitive>>& orderedPrims) {
BoundBox box;
for (const BVHPrimitiveInfo& pi : BVHPrimitives) {
box = box.Union(box, pi.centroid); // maybe it should be UNION #TODO
}
std::vector<MortonPrimitive> mortonPrims(BVHPrimitives.size());
for (int i = 0; i < BVHPrimitives.size(); i++) {
const int mortonBits = 10;
const int mortonScale = 1 << mortonBits;
mortonPrims[i].primitiveIndex = BVHPrimitives[i].primitiveNumber;
Point3D p = box.offset(BVHPrimitives[i].centroid);
p.x = p.x * mortonScale;
p.y = p.y * mortonScale;
p.z = p.z * mortonScale;
mortonPrims[i].mortonCode = EncodeMorton3(p);
}
radixSort(&mortonPrims);
//for (MortonPrimitive mp : mortonPrims) {
// std::cout << mp.primitiveIndex << " " << mp.mortonCode << std::endl;
//}
std::vector<BVHLittleTree> treesToBuild;
uint32_t mask = 0b00111111111111000000000000000000; // first 12 bits describe the position of the primitive
for (int start = 0, end = 1; end <= (int)mortonPrims.size(); end++) {
if (end == mortonPrims.size() || ((mortonPrims[start].mortonCode & mask) != (mortonPrims[end].mortonCode & mask))) {
int n = end - start;
int maxNodes = 2 * n;
BVHNode* nodes = arena.Alloc<BVHNode>(maxNodes, false);
treesToBuild.push_back({ start, n, nodes });
start = end;
}
}
int orderedPrimsOffset = 0;
orderedPrims.resize(primitives.size());
int nodesCreated = 0;
int firstBitIndex = 29 - 12;
for (int i = 0; i < treesToBuild.size(); i++) {
treesToBuild[i].nodes = BVH::emit(treesToBuild[i].nodes, BVHPrimitives, &mortonPrims[treesToBuild[i].startIndex], orderedPrims, treesToBuild[i].numPrimitives, &nodesCreated, &orderedPrimsOffset, firstBitIndex);
*totalNodes += nodesCreated;
}
totalNodes += nodesCreated;
std::vector<BVHNode*> finishedTrees;
finishedTrees.reserve(treesToBuild.size());
for (BVHLittleTree& tr : treesToBuild) {
finishedTrees.emplace_back(tr.nodes);
}
return buildSAH(arena, finishedTrees, 0, finishedTrees.size(), totalNodes);
}
BVHNode* BVH::emit(BVHNode*& nodes, const std::vector<BVHPrimitiveInfo>& BVHPrimitive, MortonPrimitive* mortonPrimitives, std::vector<std::shared_ptr<Primitive>>& orderedPrimitives, int primitivesCount, int* totalNodes, int* orderedPrimsOffset, int bitIndex) {
if (bitIndex == -1 || primitivesCount < maxPrimsInNode) {
(*totalNodes)++;
BVHNode* tmp = nodes++;
BoundBox box;
int firstPrimOffset = *orderedPrimsOffset;
for (int i = 0; i < primitivesCount; i++) {
int index = mortonPrimitives[i].primitiveIndex;
orderedPrimitives[firstPrimOffset + i] = primitives[index];
box = box.Union(box, BVHPrimitive[index].box);
}
tmp->InitLeaf(0, primitivesCount, box);
return tmp;
}
else {
int mask = 1 << bitIndex;
if ((mortonPrimitives[0].mortonCode & mask) == (mortonPrimitives[primitivesCount - 1].mortonCode & mask)){ // Next tree if nothing to split for this bit
return emit(nodes, BVHPrimitive, mortonPrimitives, orderedPrimitives, primitivesCount, totalNodes, orderedPrimsOffset, bitIndex - 1);
}
int start = 0;
int end = primitivesCount - 1;
while (start + 1 != end) {
int mid = (end - start) / 2 + start; // (start-end)/2
if ((mortonPrimitives[start].mortonCode & mask) == (mortonPrimitives[mid].mortonCode & mask)) {
start = mid;
}
else {
end = mid;
}
}
int split = end;
(*totalNodes)++;
BVHNode* tmp = nodes++;
BVHNode* lbvh[2];
lbvh[0] = emit(nodes, BVHPrimitive, mortonPrimitives, orderedPrimitives, split, totalNodes, orderedPrimsOffset, bitIndex-1);
lbvh[1] = emit(nodes, BVHPrimitive, &mortonPrimitives[split], orderedPrimitives, primitivesCount - split, totalNodes, orderedPrimsOffset, bitIndex - 1);
int axis = bitIndex % 3;
tmp->InitInterior(axis, lbvh[0], lbvh[1]);
return tmp;
}
}
BVHNode* BVH::buildSAH(MemoryArena& arena, std::vector<BVHNode*>& treeRoots, int start, int end, int* total) const {
int nodesCount = end - start;
if (nodesCount == 1) {
return treeRoots[start];
}
assert(nodesCount > 1);
(*total)++;
BVHNode* node = arena.Alloc<BVHNode>();
BoundBox box;
for (int i = start; i < end; i++) {
box = Union(box, treeRoots[i]->box);
}
BoundBox centroidBox;
for (int i = start; i < end; i++) {
Point3D centroid = Point3D((treeRoots[i]->box.pMin.x + treeRoots[i]->box.pMax.x) * 0.5f, (treeRoots[i]->box.pMin.y + treeRoots[i]->box.pMax.y) * 0.5f, (treeRoots[i]->box.pMin.z + treeRoots[i]->box.pMax.z) * 0.5f);
centroidBox = Union(centroidBox, centroid);
}
const int dimension = centroidBox.MaximumExtent();
const int nBuckets = 12;
struct Buckets {
int count = 0;
BoundBox box;
};
Buckets buckets[nBuckets];
for (int i = start; i < end; i++) {
float centroid = (treeRoots[i]->box.pMin[dimension] * 0.5f + treeRoots[i]->box.pMax[dimension] * 0.5f) ;
int b = nBuckets * ((centroid - centroidBox.pMin[dimension]) / (centroidBox.pMax[dimension] - centroidBox.pMin[dimension]));
if (b == nBuckets) b = nBuckets - 1;
//assert(b < nBuckets);
buckets[b].count++;
buckets[b].box = Union(buckets[b].box, treeRoots[i]->box);
}
float cost[nBuckets - 1];
for (int i = 0; i < nBuckets - 1; i++) {
BoundBox b0, b1;
int count0 = 0, count1 = 0;
for (int j = 0; j <= i; j++) {
b0 = Union(b0, buckets[j].box);
count0 += buckets[j].count;
}
for (int j = i+1; j < nBuckets; j++) {
b1 = Union(b1, buckets[j].box);
count1 += buckets[j].count;
}
cost[i] = (.125f + (count0 * b0.surfaceArea() + count1 * b1.surfaceArea())) / box.surfaceArea();
}
double minCost = cost[0];
int minCostSplitBucket = 0;
for (int i = 1; i < nBuckets - 1; ++i) {
if (cost[i] < minCost) {
minCost = cost[i];
minCostSplitBucket = i;
}
}
BVHNode** pmid = std::partition(&treeRoots[start], &treeRoots[end - 1] + 1, [=](const BVHNode* node) {
double centroid = (node->box.pMin[dimension]*0.5f + node->box.pMax[dimension] * 0.5f) ;
int b = nBuckets * ((centroid - centroidBox.pMin[dimension]) / (centroidBox.pMax[dimension] - centroidBox.pMin[dimension]));
if (b == nBuckets) b = nBuckets - 1;
return b <= minCostSplitBucket;
});
assert(pmid != NULL);
//std::cout << pmid << " " << &treeRoots[0];
int mid = pmid - &treeRoots[0];
//std::cout << start << " " << mid << std::endl;
//std::cout << mid << " " << end << std::endl;
std::cout << dimension << std::endl;
//assert(dimension < 3);
node->InitInterior(dimension, this->buildSAH(arena, treeRoots, start, mid, total), this->buildSAH(arena, treeRoots, mid, end, total));
return node;
}
int BVH::flattenBVHTree(BVHNode* node, int* offset) {
LinearBVHNode* linearNode = &nodes[*offset];
linearNode->bounds = node->box;
int myOffset = (*offset)++;
if (node->nPrimitives > 0) {
linearNode->primitivesOffset = node->firstPrimOffset;
linearNode->nPrimitives = node->nPrimitives;
}
else {
// Create interior flattened BVH node
linearNode->axis = node->splitAxis;
linearNode->nPrimitives = 0;
flattenBVHTree(node->children[0], offset);
linearNode->secondChildOffset = flattenBVHTree(node->children[1], offset);
}
return myOffset;
}
My Point3D.hpp
#include <cstdint>
#pragma once
struct Point3D {
float x;
float y;
float z;
Point3D(uint32_t, uint32_t, uint32_t);
Point3D();
int operator[](int);
int operator[](int) const;
Point3D operator+(int);
Point3D operator-(int);
Point3D operator-(Point3D&);
};
Point3D::Point3D() {
x = 0;
y = 0;
z = 0;
}
Point3D::Point3D(uint32_t x, uint32_t y, uint32_t z) {
this->x = x;
this->y = y;
this->z = z;
}
bool operator<(Point3D a, Point3D b) {
uint32_t xSquare = a.x * a.x;
uint32_t ySquare = a.y * a.y;
uint32_t zSquare = a.z * a.z;
uint32_t x2Square = b.x * b.x;
uint32_t y2Square = b.y * b.y;
uint32_t z2Square = b.z * b.z;
int64_t sum = std::sqrt(xSquare + ySquare + z2Square) - std::sqrt(x2Square + y2Square + z2Square);
return sum < 0 ||
sum == 0 && xSquare < x2Square ||
sum == 0 && xSquare == x2Square && ySquare < y2Square ||
sum == 0 && xSquare == x2Square && ySquare == y2Square && zSquare < z2Square;
}
bool operator>(Point3D a, Point3D b) {
uint32_t xSquare = a.x * a.x;
uint32_t ySquare = a.y * a.y;
uint32_t zSquare = a.z * a.z;
uint32_t x2Square = b.x * b.x;
uint32_t y2Square = b.y * b.y;
uint32_t z2Square = b.z * b.z;
int32_t sum = std::sqrt(xSquare + ySquare + z2Square) - std::sqrt(x2Square + y2Square + z2Square);
return sum > 0 ||
sum == 0 && xSquare > x2Square ||
sum == 0 && xSquare == x2Square && ySquare > y2Square ||
sum == 0 && xSquare == x2Square && ySquare == y2Square && zSquare > z2Square;
}
int Point3D::operator[](int i) {
if (i == 0) return x;
if (i == 1) return y;
return z;
}
Point3D Point3D::operator+(int i) {
this->x += i;
this->y += i;
this->z += i;
return *this;
}
Point3D Point3D::operator-(const int i) {
this->x -= i;
this->y -= i;
this->z -= i;
return *this;
}
Point3D Point3D::operator-(Point3D& p) {
this->x -= p.x;
this->y -= p.y;
this->z -= p.z;
return *this;
}
int Point3D::operator[](const int i) const {
if (i == 0) return x;
if (i == 1) return y;
return z;
}
My BoundBox.hpp
#include "Point3D.hpp"
#include "Vector3D.hpp"
#pragma once
struct BoundBox {
Point3D pMin;
Point3D pMax;
BoundBox(Point3D);
BoundBox(Point3D, Point3D);
BoundBox();
void setBounds(BoundBox);
void Union(BoundBox);
BoundBox Union(BoundBox&, Point3D&);
BoundBox Union(BoundBox, BoundBox);
BoundBox unite(BoundBox, BoundBox);
BoundBox unite(BoundBox);
const Point3D offset(const Point3D&);
Point3D diagonal();
const int MaximumExtent();
float surfaceArea();
};
BoundBox::BoundBox() {
float minNum = 0;
pMin = Point3D(800, 600, 300);
pMax = Point3D(minNum, minNum, minNum);
}
BoundBox::BoundBox(Point3D p){
pMin = p;
pMax = p;
}
BoundBox::BoundBox(Point3D p1, Point3D p2) {
pMin = Point3D(std::min(p1.x, p2.x), std::min(p1.y, p2.y), std::min(p1.z, p2.z));
pMax = Point3D(std::max(p1.x, p2.x), std::max(p1.y, p2.y), std::max(p1.z, p2.z));
}
BoundBox BoundBox::Union(BoundBox& box, Point3D& p) {
BoundBox newBox;
newBox.pMin = Point3D(std::min(box.pMin.x, p.x), std::min(box.pMin.y, p.y), std::min(box.pMin.z, p.z));
newBox.pMax = Point3D(std::max(box.pMax.x, p.x), std::max(box.pMax.y, p.y), std::max(box.pMax.z, p.z));
return newBox;
}
BoundBox BoundBox::Union(BoundBox box1, BoundBox box2) {
BoundBox newBox;
newBox.pMin = std::min(box1.pMin, box2.pMin);
newBox.pMax = std::max(box1.pMax, box2.pMax);
return newBox;
}
BoundBox Union(BoundBox box1, BoundBox box2) {
BoundBox newBox;
newBox.pMin = std::min(box1.pMin, box2.pMin);
newBox.pMax = std::max(box1.pMax, box2.pMax);
return newBox;
}
BoundBox BoundBox::unite(BoundBox b1, BoundBox b2) {
bool x = (b1.pMax.x >= b2.pMin.x) && (b1.pMin.x <= b2.pMax.x);
bool y = (b1.pMax.y >= b2.pMin.y) && (b1.pMin.y <= b2.pMax.y);
bool z = (b1.pMax.z >= b2.pMin.z) && (b1.pMin.z <= b2.pMax.z);
if (x && y && z) {
return Union(b1, b2);
}
}
BoundBox BoundBox::unite(BoundBox b2) {
bool x = (this->pMax.x >= b2.pMin.x) && (this->pMin.x <= b2.pMax.x);
bool y = (this->pMax.y >= b2.pMin.y) && (this->pMin.y <= b2.pMax.y);
bool z = (this->pMax.z >= b2.pMin.z) && (this->pMin.z <= b2.pMax.z);
if (x && y && z) {
return Union(*this, b2);
}
else return *this;
}
const int BoundBox::MaximumExtent() {
Point3D d = Point3D(this->pMax.x - this->pMin.x, this->pMax.y - this->pMin.y, this->pMax.z - this->pMin.z); // diagonal
if (d.x > d.y && d.x > d.z) {
return 0;
}
else if (d.y > d.z) {
return 1;
}
else {
return 2;
}
}
float BoundBox::surfaceArea() {
Point3D d = Point3D(this->pMax.x - this->pMin.x, this->pMax.y - this->pMin.y, this->pMax.z - this->pMin.z); // diagonal
return 2 * (d.x * d.y + d.x * d.z + d.y * d.z);
}
const Point3D BoundBox::offset(const Point3D& p) {
Point3D o = Point3D(p.x - pMin.x, p.y - pMin.y, p.z - pMin.z);
if (pMax.x > pMin.x) o.x /= pMax.x - pMin.x;
if (pMax.y > pMin.y) o.y /= pMax.y - pMin.y;
if (pMax.z > pMin.z) o.z /= pMax.z - pMin.z;
return o;
}
My memory.hpp
#include <list>
#include <cstddef>
#include <algorithm>
#include <malloc.h>
#include <stdlib.h>
#pragma once
#define ARENA_ALLOC(arena, Type) new ((arena).Alloc(sizeof(Type))) Type
void* AllocAligned(size_t size);
template <typename T>
T* AllocAligned(size_t count) {
return (T*)AllocAligned(count * sizeof(T));
}
void FreeAligned(void*);
class
#ifdef PBRT_HAVE_ALIGNAS
alignas(PBRT_L1_CACHE_LINE_SIZE)
#endif // PBRT_HAVE_ALIGNAS
MemoryArena {
public:
// MemoryArena Public Methods
MemoryArena(size_t blockSize = 262144) : blockSize(blockSize) {}
~MemoryArena() {
FreeAligned(currentBlock);
for (auto& block : usedBlocks) FreeAligned(block.second);
for (auto& block : availableBlocks) FreeAligned(block.second);
}
void* Alloc(size_t nBytes) {
// Round up _nBytes_ to minimum machine alignment
#if __GNUC__ == 4 && __GNUC_MINOR__ < 9
// gcc bug: max_align_t wasn't in std:: until 4.9.0
const int align = alignof(::max_align_t);
#elif !defined(PBRT_HAVE_ALIGNOF)
const int align = 16;
#else
const int align = alignof(std::max_align_t);
#endif
#ifdef PBRT_HAVE_CONSTEXPR
static_assert(IsPowerOf2(align), "Minimum alignment not a power of two");
#endif
nBytes = (nBytes + align - 1) & ~(align - 1);
if (currentBlockPos + nBytes > currentAllocSize) {
// Add current block to _usedBlocks_ list
if (currentBlock) {
usedBlocks.push_back(
std::make_pair(currentAllocSize, currentBlock));
currentBlock = nullptr;
currentAllocSize = 0;
}
// Get new block of memory for _MemoryArena_
// Try to get memory block from _availableBlocks_
for (auto iter = availableBlocks.begin();
iter != availableBlocks.end(); ++iter) {
if (iter->first >= nBytes) {
currentAllocSize = iter->first;
currentBlock = iter->second;
availableBlocks.erase(iter);
break;
}
}
if (!currentBlock) {
currentAllocSize = std::max(nBytes, blockSize);
currentBlock = AllocAligned<uint8_t>(currentAllocSize);
}
currentBlockPos = 0;
}
void* ret = currentBlock + currentBlockPos;
currentBlockPos += nBytes;
return ret;
}
template <typename T>
T* Alloc(size_t n = 1, bool runConstructor = true) {
T* ret = (T*)Alloc(n * sizeof(T));
if (runConstructor)
for (size_t i = 0; i < n; ++i) new (&ret[i]) T();
return ret;
}
void Reset() {
currentBlockPos = 0;
availableBlocks.splice(availableBlocks.begin(), usedBlocks);
}
size_t TotalAllocated() const {
size_t total = currentAllocSize;
for (const auto& alloc : usedBlocks) total += alloc.first;
for (const auto& alloc : availableBlocks) total += alloc.first;
return total;
}
private:
MemoryArena(const MemoryArena&) = delete;
MemoryArena & operator=(const MemoryArena&) = delete;
// MemoryArena Private Data
const size_t blockSize;
size_t currentBlockPos = 0, currentAllocSize = 0;
uint8_t * currentBlock = nullptr;
std::list<std::pair<size_t, uint8_t*>> usedBlocks, availableBlocks;
};
template <typename T, int logBlockSize>
class BlockedArray {
public:
// BlockedArray Public Methods
BlockedArray(int uRes, int vRes, const T* d = nullptr)
: uRes(uRes), vRes(vRes), uBlocks(RoundUp(uRes) >> logBlockSize) {
int nAlloc = RoundUp(uRes) * RoundUp(vRes);
data = AllocAligned<T>(nAlloc);
for (int i = 0; i < nAlloc; ++i) new (&data[i]) T();
if (d)
for (int v = 0; v < vRes; ++v)
for (int u = 0; u < uRes; ++u) (*this)(u, v) = d[v * uRes + u];
}
const int BlockSize() const { return 1 << logBlockSize; }
int RoundUp(int x) const {
return (x + BlockSize() - 1) & ~(BlockSize() - 1);
}
int uSize() const { return uRes; }
int vSize() const { return vRes; }
~BlockedArray() {
for (int i = 0; i < uRes * vRes; ++i) data[i].~T();
FreeAligned(data);
}
int Block(int a) const { return a >> logBlockSize; }
int Offset(int a) const { return (a & (BlockSize() - 1)); }
T& operator()(int u, int v) {
int bu = Block(u), bv = Block(v);
int ou = Offset(u), ov = Offset(v);
int offset = BlockSize() * BlockSize() * (uBlocks * bv + bu);
offset += BlockSize() * ov + ou;
return data[offset];
}
const T & operator()(int u, int v) const {
int bu = Block(u), bv = Block(v);
int ou = Offset(u), ov = Offset(v);
int offset = BlockSize() * BlockSize() * (uBlocks * bv + bu);
offset += BlockSize() * ov + ou;
return data[offset];
}
void GetLinearArray(T * a) const {
for (int v = 0; v < vRes; ++v)
for (int u = 0; u < uRes; ++u) * a++ = (*this)(u, v);
}
private:
// BlockedArray Private Data
T * data;
const int uRes, vRes, uBlocks;
};
void* AllocAligned(size_t size) {
return _aligned_malloc(size, 32);
}
void FreeAligned(void* ptr) {
if (!ptr) return;
_aligned_free(ptr);
}
and My Source.cpp
#include <iostream>
#include <vector>
#include <chrono>
#include "Point3D.hpp"
#include "Screen.hpp"
#include "BVH.hpp"
#define N 150
int main(){
auto startTime = std::chrono::high_resolution_clock::now();
Screen* screen = new Screen(800, 600, 300);
screen->generatePoints(N);
//for (MortonPrimitive m : mortonPrims) {
// std::cout << m.mortonCode << std::endl;
//}
std::vector<std::shared_ptr<Primitive>> primitives;
primitives.reserve(N);
for (int i = 0; i < N; i++) {
primitives.emplace_back(screen->castPointToPrimitive(i));
}
BVH test(primitives);
auto endTime = std::chrono::high_resolution_clock::now();
std::cout << "Time spent: " << std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count() << "ms\n";
getchar();
delete screen;
}
Probably it would be wise to first cleanup your github. This mean update stuff to the recent c++ standard. It seems that you can use c++17 so use it. Also please look at some names. For example 'nodes' is used as member variable as well as parameter name, this is confusion. Please also initialize relevant (all) member variables.
Now it seems that the code in buildSAH override memory. It seems that it it can write over the end of buckets array.

Divide a large number represented in string by 3

I have a very large number represented by a string. Say String n = "64772890123784224827" . I want to divide the number by 3 efficiently. How can I do it? Some implementations are given below which can find out remainder. But how to get the quotient efficiently?
In Java, the number can be represented with BigInteger and the division operation can be done on BigInteger. But that takes too much time. Please help me find out the efficient way to divide this large number by 3.
Well following is a very basic implementation to find out the remainder:
#include <bits/stdc++.h>
using namespace std;
int divideByN(string, int);
int main()
{
string str = "-64772890123784224827";
//string str = "21";
int N = 3;
int remainder = divideByN(str, N);
cout << "\nThe remainder = " << remainder << endl;
return 0;
}
int divideByN(string s, int n)
{
int carry = 0;
int remainder = 0;
for(int i = 0; i < s.size(); i++)
{
if(i == 0 && s.at(i) == '-')
{
cout << "-";
continue;
}
//Check for any illegal characters here. If any, throw exception.
int tmp = (s.at(i) - '0') + remainder * carry;
cout << (tmp / n);
if(tmp % n == 0)
{
carry = 0;
remainder = 0;
}
else
{
remainder = tmp % n;
carry = 10;
}
}
return remainder;
}
Based on some good answers, here is a minimal implementation using lookup table to find out the remainder:
#include <bits/stdc++.h>
using namespace std;
int divideByN_Lookup(string, int);
int lookup[] = { 0, 1, 2, 0, 1, 2, 0, 1, 2, 0 }; //lookup considering 3 as divisor.
int main() {
string str = "64772890123784224827";
int N = 3;
int remaninder_lookup = divideByN_Lookup(str, N);
cout << "Look up implementation of remainder = " << remaninder_lookup
<< endl;
return 0;
}
int divideByN_Lookup(string s, int n) {
int rem = 0;
int start = 0;
if (s.at(start) == '-')
start = 1;
for (unsigned int i = start; i < s.size(); i++)
rem = (rem + lookup[s.at(i) - '0']) % n;
return rem;
}
What about quotient? I know I can process all characters one by one and add the quotient to a char array or string. But what is the most efficient way to find out the quotient?
If all you need is the remainder after dividing by 3, make a look up table or function that converts each string character digit to an int, which is the remainder when you divide the digit by 3, and add up the ints across all digits in the string, and then there is a fact that the remainder when you divide your original number by 3 is the same as the remainder when you divide the sum of digits by 3. It would be virtually impossible to not be able to fit the sum of 0,1,2 values into a 32 or 64 byte integer. The input would simply have to be too large. And if it does start to become almost too large when you're summing the digits, then just take the remainder when you divide by 3 when you start getting close to the maximum value for an int. Then you can process any length number, using very few division remainder (modulus) operations (which is important because they are much slower than addition).
The reason why the sum-of-digits fact is true is that the remainder when you divide any power of 10 by 3 is always 1.
This is actually very simple. Since every power of 10 is equivalent to 1 modulo 3, all you have to do is add the digits together. The resulting number will have the same remainder when divided by 3 as the original large number.
For example:
3141592654 % 3 = 1
3+1+4+1+5+9+2+6+5+4 = 40
40 % 3 = 1
I think you can start processing from the left, dividing each digit by 3, and adding the remainder to the next one.
In your example you divide the 6, write 2, then divide the 4, write 1 and add the remainder of 1 to the 7 to get 17... Divide the 17... and so on.
EDIT:
I've just verified my solution works using this code. Note you may get a leading zero:
int main( int argc, char* argv[] )
{
int x = 0;
for( char* p = argv[1]; *p; p++ ) {
int d = x*10 + *p-'0';
printf("%d", d/3);
x = d % 3;
}
printf("\n");
return 0;
}
It's not optimal using so many divs and muls, but CS-wise it's O(N) ;-)
I wrote this a while ago.. Doesn't seem slow :S
I've only included the necessary parts for "division"..
#include <string>
#include <cstring>
#include <algorithm>
#include <stdexcept>
#include <iostream>
class BigInteger
{
public:
char sign;
std::string digits;
const std::size_t base = 10;
short toDigit(std::size_t index) const {return index >= 0 && index < digits.size() ? digits[index] - '0' : 0;}
protected:
void Normalise();
BigInteger& divide(const BigInteger &Divisor, BigInteger* Remainder);
public:
BigInteger();
BigInteger(const std::string &value);
inline bool isNegative() const {return sign == '-';}
inline bool isPositive() const {return sign == '+';}
inline bool isNeutral() const {return sign == '~';}
inline std::string toString() const
{
std::string digits = this->digits;
std::reverse(digits.begin(), digits.end());
if (!isNeutral())
{
std::string sign;
sign += this->sign;
return sign + digits;
}
return digits;
}
bool operator < (const BigInteger &other) const;
bool operator > (const BigInteger &other) const;
bool operator <= (const BigInteger &other) const;
bool operator >= (const BigInteger &other) const;
bool operator == (const BigInteger &other) const;
bool operator != (const BigInteger &other) const;
BigInteger& operator /= (const BigInteger &other);
BigInteger operator / (const BigInteger &other) const;
BigInteger Remainder(const BigInteger &other) const;
};
BigInteger::BigInteger() : sign('~'), digits(1, '0') {}
BigInteger::BigInteger(const std::string &value) : sign('~'), digits(value)
{
sign = digits.empty() ? '~' : digits[0] == '-' ? '-' : '+';
if (digits[0] == '+' || digits[0] == '-') digits.erase(0, 1);
std::reverse(digits.begin(), digits.end());
Normalise();
for (std::size_t I = 0; I < digits.size(); ++I)
{
if (!isdigit(digits[I]))
{
sign = '~';
digits = "0";
break;
}
}
}
void BigInteger::Normalise()
{
for (int I = digits.size() - 1; I >= 0; --I)
{
if (digits[I] != '0') break;
digits.erase(I, 1);
}
if (digits.empty())
{
digits = "0";
sign = '~';
}
}
bool BigInteger::operator < (const BigInteger &other) const
{
if (isNeutral() || other.isNeutral())
{
return isNeutral() ? other.isPositive() : isNegative();
}
if (sign != other.sign)
{
return isNegative();
}
if (digits.size() != other.digits.size())
{
return (digits.size() < other.digits.size() && isPositive()) || (digits.size() > other.digits.size() && isNegative());
}
for (int I = digits.size() - 1; I >= 0; --I)
{
if (toDigit(I) < other.toDigit(I))
return isPositive();
if (toDigit(I) > other.toDigit(I))
return isNegative();
}
return false;
}
bool BigInteger::operator > (const BigInteger &other) const
{
if (isNeutral() || other.isNeutral())
{
return isNeutral() ? other.isNegative() : isPositive();
}
if ((sign != other.sign) && !(isNeutral() || other.isNeutral()))
{
return isPositive();
}
if (digits.size() != other.digits.size())
{
return (digits.size() > other.digits.size() && isPositive()) || (digits.size() < other.digits.size() && isNegative());
}
for (int I = digits.size() - 1; I >= 0; --I)
{
if (toDigit(I) > other.toDigit(I))
return isPositive();
if (toDigit(I) < other.toDigit(I))
return isNegative();
}
return false;
}
bool BigInteger::operator <= (const BigInteger &other) const
{
return (*this < other) || (*this == other);
}
bool BigInteger::operator >= (const BigInteger &other) const
{
return (*this > other) || (*this == other);
}
bool BigInteger::operator == (const BigInteger &other) const
{
if (sign != other.sign || digits.size() != other.digits.size())
return false;
for (int I = digits.size() - 1; I >= 0; --I)
{
if (toDigit(I) != other.toDigit(I))
return false;
}
return true;
}
bool BigInteger::operator != (const BigInteger &other) const
{
return !(*this == other);
}
BigInteger& BigInteger::divide(const BigInteger &Divisor, BigInteger* Remainder)
{
if (Divisor.isNeutral())
{
throw std::overflow_error("Division By Zero Exception.");
}
char rem_sign = sign;
bool neg_res = sign != Divisor.sign;
if (!isNeutral()) sign = '+';
if (*this < Divisor)
{
if (Remainder)
{
Remainder->sign = this->sign;
Remainder->digits = this->digits;
}
sign = '~';
digits = "0";
return *this;
}
if (this == &Divisor)
{
if (Remainder)
{
Remainder->sign = this->sign;
Remainder->digits = this->digits;
}
sign = '+';
digits = "1";
return *this;
}
BigInteger Dvd(*this);
BigInteger Dvr(Divisor);
BigInteger Quotient("0");
Dvr.sign = '+';
std::size_t len = std::max(Dvd.digits.size(), Dvr.digits.size());
std::size_t diff = std::max(Dvd.digits.size(), Dvr.digits.size()) - std::min(Dvd.digits.size(), Dvr.digits.size());
std::size_t offset = len - diff - 1;
Dvd.digits.resize(len, '0');
Dvr.digits.resize(len, '0');
Quotient.digits.resize(len, '0');
memmove(&Dvr.digits[diff], &Dvr.digits[0], len - diff);
memset(&Dvr.digits[0], '0', diff);
while(offset < len)
{
while (Dvd >= Dvr)
{
int borrow = 0, total = 0;
for (std::size_t I = 0; I < len; ++I)
{
total = Dvd.toDigit(I) - Dvr.toDigit(I) - borrow;
borrow = 0;
if (total < 0)
{
borrow = 1;
total += 10;
}
Dvd.digits[I] = total + '0';
}
Quotient.digits[len - offset - 1]++;
}
if (Remainder && offset == len - 1)
{
Remainder->digits = Dvd.digits;
Remainder->sign = rem_sign;
Remainder->Normalise();
if (Remainder == this)
{
return *this;
}
}
memmove(&Dvr.digits[0], &Dvr.digits[1], len - 1);
memset(&Dvr.digits[len - 1], '0', 1);
++offset;
}
Quotient.sign = neg_res ? '-' : '+';
Quotient.Normalise();
this->sign = Quotient.sign;
this->digits = Quotient.digits;
return *this;
}
BigInteger& BigInteger::operator /= (const BigInteger &other)
{
return divide(other, nullptr);
}
BigInteger BigInteger::operator / (const BigInteger &other) const
{
return BigInteger(*this) /= other;
}
BigInteger BigInteger::Remainder(const BigInteger &other) const
{
BigInteger remainder;
BigInteger(*this).divide(other, &remainder);
return remainder;
}
int main()
{
BigInteger a{"-64772890123784224827"};
BigInteger b{"3"};
BigInteger result = a/b;
std::cout<<result.toString();
}

Storing a Big Number in a Variable and Looping

How can i store a big number in a variable and use a for loop?
I have a very big number 75472202764752234070123900087933251 and i need to loop from 0 to this number!
Is it even possible to do this? how much time will it take to end?
EDIT: i am trying to solve a hard problem by brute force. its a combination problem.the bruteforcing cases may reach 470C450.
so i guess i should use a different algorithm...
This might take
0.23 x 10^23 years if C++ processed 100,000 loops per second :|
http://www.wolframalpha.com/input/?i=75472202764752234070123900087933251%2F%28100000*1*3600*24*365%29
It looks that this number fits into 128 bit. So you could use a modern system and a modern compiler that implements such numbers. This would e.g be the case for a 64bit linux system with gcc as a compiler. This has something like __uint128_t that you could use.
Obviously you can't use such a variable as a for-loop variable, others have give you the calculations. But you could use it to store some of your calculations.
Well, you would need an implementation that can handle at least a subset of the initialization, boolean, and arithmetic functions on very large integers. Something like: https://mattmccutchen.net/bigint/.
For something that would give a bit better performance than a general large integer math library, you could use specialized operations specifically to allow use of a large integer as a counter. For an example of this, see dewtell's updated answer to this question.
As for it being possible for you to loop from 0 to that number: well, yes, it is possible to write the code for it with one of the above solutions, but I think the answer is no, you personally will not be able to do it because you will not be alive to see it finish.
[edit: Yes, I would definitely recommend you find a different algorithm. :D]
If you need to loop a certain number of times, and that number is greater than 2^64, just use while(1) because your computer will break before it counts up to 2^64 anyway.
There's no need for a complete bignum package - if all you need is a loop counter, here's a simple byte counter that uses an array of bytes as a counter. It stops when the byte array wraps around to all zeros again. If you wanted to count to some other value than 2^(bytesUsed*CHAR_BITS), you could just compute the two's complement value of the negative of the number of iterations you wanted, and let it count up to 0, keeping in mind that bytes[0] is the low-order byte (or use the positive value and count down instead of up).
#include <stdio.h>
#define MAXBYTES 20
/* Simple byte counter - note it uses argc as # of bytes to use for convenience */
int main(int argc, char **argv) {
unsigned char bytes[MAXBYTES];
const int bytesUsed = argc < MAXBYTES? argc : MAXBYTES;
int i;
unsigned long counter = (unsigned long)-1; /* to give loop something to do */
for (i = 0; i < bytesUsed; i++) bytes[i] = 0; /* Initialize bytes */
do {
for (i = 0; i < bytesUsed && !++bytes[i]; i++) ; /* NULL BODY - this is the byte counter */
counter++;
} while (i < bytesUsed);
printf("With %d bytes used, final counter value = %lu\n", bytesUsed, counter);
}
Run times for the first 4 values (under Cygwin, on a Lenovo T61):
$ time ./bytecounter
With 1 bytes used, final counter value = 255
real 0m0.078s
user 0m0.031s
sys 0m0.046s
$ time ./bytecounter a
With 2 bytes used, final counter value = 65535
real 0m0.063s
user 0m0.031s
sys 0m0.031s
$ time ./bytecounter a a
With 3 bytes used, final counter value = 16777215
real 0m0.125s
user 0m0.015s
sys 0m0.046s
$ time ./bytecounter a a a
With 4 bytes used, final counter value = 4294967295
real 0m6.578s
user 0m0.015s
sys 0m0.047s
At this rate, five bytes should take around half an hour, and six bytes should take the better part of a week. Of course the counter value will be inaccurate for those - it's mostly just there to verify the number of iterations for the smaller byte values and give the loop something to do.
Edit: And here's the time for five bytes, around half an hour as I predicted:
$ time ./bytecounter a a a a
With 5 bytes used, final counter value = 4294967295
real 27m22.184s
user 0m0.015s
sys 0m0.062s
Ok, here's code to take an arbitrary decimal number passed as the first arg and count down from it to zero. I set it up to allow the counter to use different size elements (just change the typedef for COUNTER_BASE), but it turns out that bytes are actually somewhat faster than either short or long on my system.
#include <stdio.h>
#include <limits.h> // defines CHAR_BIT
#include <ctype.h>
#include <vector>
using std::vector;
typedef unsigned char COUNTER_BASE;
typedef vector<COUNTER_BASE> COUNTER;
typedef vector<unsigned char> BYTEVEC;
const unsigned long byteMask = (~0ul) << CHAR_BIT;
const size_t MAXBYTES=20;
void mult10(BYTEVEC &val) {
// Multiply value by 10
unsigned int carry = 0;
int i;
for (i = 0; i < val.size(); i++) {
unsigned long value = val[i]*10ul+carry;
carry = (value & byteMask) >> CHAR_BIT;
val[i] = value & ~byteMask;
}
if (carry > 0) val.push_back(carry);
}
void addDigit(BYTEVEC &val, const char digit) {
// Add digit to the number in BYTEVEC.
unsigned int carry = digit - '0'; // Assumes ASCII char set
int i;
for (i = 0; i < val.size() && carry; i++) {
unsigned long value = static_cast<unsigned long>(val[i])+carry;
carry = (value & byteMask) >> CHAR_BIT;
val[i] = value & ~byteMask;
}
if (carry > 0) val.push_back(carry);
}
BYTEVEC Cstr2Bytevec(const char *str) {
// Turn a C-style string into a BYTEVEC. Only the digits in str apply,
// so that one can use commas, underscores, or other non-digits to separate
// digit groups.
BYTEVEC result;
result.reserve(MAXBYTES);
result[0]=0;
unsigned char *res=&result[0]; // For debugging
while (*str) {
if (isdigit(static_cast<int>(*str))) {
mult10(result);
addDigit(result, *str);
}
str++;
}
return result;
}
void packCounter(COUNTER &ctr, const BYTEVEC &val) {
// Pack the bytes from val into the (possibly larger) datatype of COUNTER
int i;
ctr.erase(ctr.begin(), ctr.end());
COUNTER_BASE value = 0;
for (i = 0; i < val.size(); i++) {
int pos = i%sizeof(COUNTER_BASE); // position of this byte in the value
if (i > 0 && pos == 0) {
ctr.push_back(value);
value = val[i];
} else {
value |= static_cast<COUNTER_BASE>(val[i]) << pos*CHAR_BIT;
}
}
ctr.push_back(value);
}
inline bool decrementAndTest(COUNTER &ctr) {
// decrement value in ctr and return true if old value was not all zeros
int i;
for (i = 0; i < ctr.size() && !(ctr[i]--); i++) ; // EMPTY BODY
return i < ctr.size();
}
inline bool decrementAndTest2(COUNTER_BASE *ctr, const size_t size) {
// decrement value in ctr and return true if old value was not all zeros
int i;
for (i = 0; i < size && !(ctr[i]--); i++) ; // EMPTY BODY
return i < size;
}
/* Vector counter - uses first arg (if supplied) as the count */
int main(int argc, const char *argv[]) {
BYTEVEC limit = Cstr2Bytevec(argc > 1? argv[1] : "0");
COUNTER ctr;
packCounter(ctr, limit);
COUNTER_BASE *ctr_vals = ctr.size() > 0 ? &ctr[0] : NULL;
size_t ctr_size = ctr.size();
unsigned long ul_counter = 0ul; /* to give loop something to do */
while(decrementAndTest2(ctr_vals, ctr_size)) {
ul_counter++;
};
printf("With %d bytes used, final ul_counter value = %lu\n", limit.size(), ul_counter);
return 0;
}
Examples of use:
$ time ./bigcounter 5
With 1 bytes used, final ul_counter value = 5
real 0m0.094s
user 0m0.031s
sys 0m0.047s
$ time ./bigcounter 5,000
With 2 bytes used, final ul_counter value = 5000
real 0m0.062s
user 0m0.015s
sys 0m0.062s
$ time ./bigcounter 5,000,000
With 3 bytes used, final ul_counter value = 5000000
real 0m0.093s
user 0m0.015s
sys 0m0.046s
$ time ./bigcounter 1,000,000,000
With 4 bytes used, final ul_counter value = 1000000000
real 0m2.688s
user 0m0.015s
sys 0m0.015s
$ time ./bigcounter 2,000,000,000
With 4 bytes used, final ul_counter value = 2000000000
real 0m5.125s
user 0m0.015s
sys 0m0.046s
$ time ./bigcounter 3,000,000,000
With 4 bytes used, final ul_counter value = 3000000000
real 0m7.485s
user 0m0.031s
sys 0m0.047s
$ time ./bigcounter 4,000,000,000
With 4 bytes used, final ul_counter value = 4000000000
real 0m9.875s
user 0m0.015s
sys 0m0.046s
$ time ./bigcounter 5,000,000,000
With 5 bytes used, final ul_counter value = 705032704
real 0m12.594s
user 0m0.046s
sys 0m0.015s
$ time ./bigcounter 6,000,000,000
With 5 bytes used, final ul_counter value = 1705032704
real 0m14.813s
user 0m0.015s
sys 0m0.062s
Unwrapping the counter vector into C-style data structures (i.e., using decrementAndTest2 instead of decrementAndTest) sped things up by around 20-25%, but the code is still about twice as slow as my previous C program for similar-sized examples (around 4 billion). This is with MS Visual C++ 6.0 as the compiler in release mode, optimizing for speed, on a 2GHz dual-core system, for both programs. Inlining the decrementAndTest2 function definitely makes a big difference (around 12 sec. vs. 30 for the 5 billion loop), but I'll have to see whether physically inlining the code as I did in the C program can get similar performance.
the variable in main function can Store even 100 factorial
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include <map>
#include <functional>
#include <algorithm>
#include <cstdlib>
#include <iomanip>
#include <stack>
#include <queue>
#include <deque>
#include <limits>
#include <cmath>
#include <numeric>
#include <set>
using namespace std;
//template for BIGINIT
// base and base_digits must be consistent
const int base = 10;
const int base_digits = 1;
struct bigint {
vector<int> a;
int sign;
bigint() :
sign(1) {
}
bigint(long long v) {
*this = v;
}
bigint(const string &s) {
read(s);
}
void operator=(const bigint &v) {
sign = v.sign;
a = v.a;
}
void operator=(long long v) {
sign = 1;
if (v < 0)
sign = -1, v = -v;
for (; v > 0; v = v / base)
a.push_back(v % base);
}
bigint operator+(const bigint &v) const {
if (sign == v.sign) {
bigint res = v;
for (int i = 0, carry = 0; i < (int) max(a.size(), v.a.size()) || carry; ++i) {
if (i == (int) res.a.size())
res.a.push_back(0);
res.a[i] += carry + (i < (int) a.size() ? a[i] : 0);
carry = res.a[i] >= base;
if (carry)
res.a[i] -= base;
}
return res;
}
return *this - (-v);
}
bigint operator-(const bigint &v) const {
if (sign == v.sign) {
if (abs() >= v.abs()) {
bigint res = *this;
for (int i = 0, carry = 0; i < (int) v.a.size() || carry; ++i) {
res.a[i] -= carry + (i < (int) v.a.size() ? v.a[i] : 0);
carry = res.a[i] < 0;
if (carry)
res.a[i] += base;
}
res.trim();
return res;
}
return -(v - *this);
}
return *this + (-v);
}
void operator*=(int v) {
if (v < 0)
sign = -sign, v = -v;
for (int i = 0, carry = 0; i < (int) a.size() || carry; ++i) {
if (i == (int) a.size())
a.push_back(0);
long long cur = a[i] * (long long) v + carry;
carry = (int) (cur / base);
a[i] = (int) (cur % base);
//asm("divl %%ecx" : "=a"(carry), "=d"(a[i]) : "A"(cur), "c"(base));
}
trim();
}
bigint operator*(int v) const {
bigint res = *this;
res *= v;
return res;
}
friend pair<bigint, bigint> divmod(const bigint &a1, const bigint &b1) {
int norm = base / (b1.a.back() + 1);
bigint a = a1.abs() * norm;
bigint b = b1.abs() * norm;
bigint q, r;
q.a.resize(a.a.size());
for (int i = a.a.size() - 1; i >= 0; i--) {
r *= base;
r += a.a[i];
int s1 = r.a.size() <= b.a.size() ? 0 : r.a[b.a.size()];
int s2 = r.a.size() <= b.a.size() - 1 ? 0 : r.a[b.a.size() - 1];
int d = ((long long) base * s1 + s2) / b.a.back();
r -= b * d;
while (r < 0)
r += b, --d;
q.a[i] = d;
}
q.sign = a1.sign * b1.sign;
r.sign = a1.sign;
q.trim();
r.trim();
return make_pair(q, r / norm);
}
bigint operator/(const bigint &v) const {
return divmod(*this, v).first;
}
bigint operator%(const bigint &v) const {
return divmod(*this, v).second;
}
void operator/=(int v) {
if (v < 0)
sign = -sign, v = -v;
for (int i = (int) a.size() - 1, rem = 0; i >= 0; --i) {
long long cur = a[i] + rem * (long long) base;
a[i] = (int) (cur / v);
rem = (int) (cur % v);
}
trim();
}
bigint operator/(int v) const {
bigint res = *this;
res /= v;
return res;
}
int operator%(int v) const {
if (v < 0)
v = -v;
int m = 0;
for (int i = a.size() - 1; i >= 0; --i)
m = (a[i] + m * (long long) base) % v;
return m * sign;
}
void operator+=(const bigint &v) {
*this = *this + v;
}
void operator-=(const bigint &v) {
*this = *this - v;
}
void operator*=(const bigint &v) {
*this = *this * v;
}
void operator/=(const bigint &v) {
*this = *this / v;
}
bool operator<(const bigint &v) const {
if (sign != v.sign)
return sign < v.sign;
if (a.size() != v.a.size())
return a.size() * sign < v.a.size() * v.sign;
for (int i = a.size() - 1; i >= 0; i--)
if (a[i] != v.a[i])
return a[i] * sign < v.a[i] * sign;
return false;
}
bool operator>(const bigint &v) const {
return v < *this;
}
bool operator<=(const bigint &v) const {
return !(v < *this);
}
bool operator>=(const bigint &v) const {
return !(*this < v);
}
bool operator==(const bigint &v) const {
return !(*this < v) && !(v < *this);
}
bool operator!=(const bigint &v) const {
return *this < v || v < *this;
}
void trim() {
while (!a.empty() && !a.back())
a.pop_back();
if (a.empty())
sign = 1;
}
bool isZero() const {
return a.empty() || (a.size() == 1 && !a[0]);
}
bigint operator-() const {
bigint res = *this;
res.sign = -sign;
return res;
}
bigint abs() const {
bigint res = *this;
res.sign *= res.sign;
return res;
}
long long longValue() const {
long long res = 0;
for (int i = a.size() - 1; i >= 0; i--)
res = res * base + a[i];
return res * sign;
}
friend bigint gcd(const bigint &a, const bigint &b) {
return b.isZero() ? a : gcd(b, a % b);
}
friend bigint lcm(const bigint &a, const bigint &b) {
return a / gcd(a, b) * b;
}
void read(const string &s) {
sign = 1;
a.clear();
int pos = 0;
while (pos < (int) s.size() && (s[pos] == '-' || s[pos] == '+')) {
if (s[pos] == '-')
sign = -sign;
++pos;
}
for (int i = s.size() - 1; i >= pos; i -= base_digits) {
int x = 0;
for (int j = max(pos, i - base_digits + 1); j <= i; j++)
x = x * 10 + s[j] - '0';
a.push_back(x);
}
trim();
}
friend istream& operator>>(istream &stream, bigint &v) {
string s;
stream >> s;
v.read(s);
return stream;
}
friend ostream& operator<<(ostream &stream, const bigint &v) {
if (v.sign == -1)
stream << '-';
stream << (v.a.empty() ? 0 : v.a.back());
for (int i = (int) v.a.size() - 2; i >= 0; --i)
stream << setw(base_digits) << setfill('0') << v.a[i];
return stream;
}
static vector<int> convert_base(const vector<int> &a, int old_digits, int new_digits) {
vector<long long> p(max(old_digits, new_digits) + 1);
p[0] = 1;
for (int i = 1; i < (int) p.size(); i++)
p[i] = p[i - 1] * 10;
vector<int> res;
long long cur = 0;
int cur_digits = 0;
for (int i = 0; i < (int) a.size(); i++) {
cur += a[i] * p[cur_digits];
cur_digits += old_digits;
while (cur_digits >= new_digits) {
res.push_back(int(cur % p[new_digits]));
cur /= p[new_digits];
cur_digits -= new_digits;
}
}
res.push_back((int) cur);
while (!res.empty() && !res.back())
res.pop_back();
return res;
}
typedef vector<long long> vll;
static vll karatsubaMultiply(const vll &a, const vll &b) {
int n = a.size();
vll res(n + n);
if (n <= 32) {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
res[i + j] += a[i] * b[j];
return res;
}
int k = n >> 1;
vll a1(a.begin(), a.begin() + k);
vll a2(a.begin() + k, a.end());
vll b1(b.begin(), b.begin() + k);
vll b2(b.begin() + k, b.end());
vll a1b1 = karatsubaMultiply(a1, b1);
vll a2b2 = karatsubaMultiply(a2, b2);
for (int i = 0; i < k; i++)
a2[i] += a1[i];
for (int i = 0; i < k; i++)
b2[i] += b1[i];
vll r = karatsubaMultiply(a2, b2);
for (int i = 0; i < (int) a1b1.size(); i++)
r[i] -= a1b1[i];
for (int i = 0; i < (int) a2b2.size(); i++)
r[i] -= a2b2[i];
for (int i = 0; i < (int) r.size(); i++)
res[i + k] += r[i];
for (int i = 0; i < (int) a1b1.size(); i++)
res[i] += a1b1[i];
for (int i = 0; i < (int) a2b2.size(); i++)
res[i + n] += a2b2[i];
return res;
}
bigint operator*(const bigint &v) const {
vector<int> a6 = convert_base(this->a, base_digits, 6);
vector<int> b6 = convert_base(v.a, base_digits, 6);
vll a(a6.begin(), a6.end());
vll b(b6.begin(), b6.end());
while (a.size() < b.size())
a.push_back(0);
while (b.size() < a.size())
b.push_back(0);
while (a.size() & (a.size() - 1))
a.push_back(0), b.push_back(0);
vll c = karatsubaMultiply(a, b);
bigint res;
res.sign = sign * v.sign;
for (int i = 0, carry = 0; i < (int) c.size(); i++) {
long long cur = c[i] + carry;
res.a.push_back((int) (cur % 1000000));
carry = (int) (cur / 1000000);
}
res.a = convert_base(res.a, 6, base_digits);
res.trim();
return res;
}
};
//use : bigint var;
//template for biginit over
int main()
{
bigint var=10909000890789;
cout<<var;
return 0;
}

Modular Exponentiation for high numbers in C++

So I've been working recently on an implementation of the Miller-Rabin primality test. I am limiting it to a scope of all 32-bit numbers, because this is a just-for-fun project that I am doing to familiarize myself with c++, and I don't want to have to work with anything 64-bits for awhile. An added bonus is that the algorithm is deterministic for all 32-bit numbers, so I can significantly increase efficiency because I know exactly what witnesses to test for.
So for low numbers, the algorithm works exceptionally well. However, part of the process relies upon modular exponentiation, that is (num ^ pow) % mod. so, for example,
3 ^ 2 % 5 =
9 % 5 =
4
here is the code I have been using for this modular exponentiation:
unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
unsigned test;
for(test = 1; pow; pow >>= 1)
{
if (pow & 1)
test = (test * num) % mod;
num = (num * num) % mod;
}
return test;
}
As you might have already guessed, problems arise when the arguments are all exceptionally large numbers. For example, if I want to test the number 673109 for primality, I will at one point have to find:
(2 ^ 168277) % 673109
now 2 ^ 168277 is an exceptionally large number, and somewhere in the process it overflows test, which results in an incorrect evaluation.
on the reverse side, arguments such as
4000111222 ^ 3 % 1608
also evaluate incorrectly, for much the same reason.
Does anyone have suggestions for modular exponentiation in a way that can prevent this overflow and/or manipulate it to produce the correct result? (the way I see it, overflow is just another form of modulo, that is num % (UINT_MAX+1))
Exponentiation by squaring still "works" for modulo exponentiation. Your problem isn't that 2 ^ 168277 is an exceptionally large number, it's that one of your intermediate results is a fairly large number (bigger than 2^32), because 673109 is bigger than 2^16.
So I think the following will do. It's possible I've missed a detail, but the basic idea works, and this is how "real" crypto code might do large mod-exponentiation (although not with 32 and 64 bit numbers, rather with bignums that never have to get bigger than 2 * log (modulus)):
Start with exponentiation by squaring, as you have.
Perform the actual squaring in a 64-bit unsigned integer.
Reduce modulo 673109 at each step to get back within the 32-bit range, as you do.
Obviously that's a bit awkward if your C++ implementation doesn't have a 64 bit integer, although you can always fake one.
There's an example on slide 22 here: http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf, although it uses very small numbers (less than 2^16), so it may not illustrate anything you don't already know.
Your other example, 4000111222 ^ 3 % 1608 would work in your current code if you just reduce 4000111222 modulo 1608 before you start. 1608 is small enough that you can safely multiply any two mod-1608 numbers in a 32 bit int.
I wrote something for this recently for RSA in C++, bit messy though.
#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>
BigInteger::BigInteger() {
digits.push_back(0);
negative = false;
}
BigInteger::~BigInteger() {
}
void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
int sum_n_carry = 0;
int n = (int)a.digits.size();
if (n < (int)b.digits.size()) {
n = b.digits.size();
}
c.digits.resize(n);
for (int i = 0; i < n; ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size()) {
a_digit = a.digits[i];
}
if (i < (int)b.digits.size()) {
b_digit = b.digits[i];
}
sum_n_carry += a_digit + b_digit;
c.digits[i] = (sum_n_carry & 0xFFFF);
sum_n_carry >>= 16;
}
if (sum_n_carry != 0) {
putCarryInfront(c, sum_n_carry);
}
while (c.digits.size() > 1 && c.digits.back() == 0) {
c.digits.pop_back();
}
//std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
}
void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
int sub_n_borrow = 0;
int n = a.digits.size();
if (n < (int)b.digits.size())
n = (int)b.digits.size();
c.digits.resize(n);
for (int i = 0; i < n; ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
if (i < (int)b.digits.size())
b_digit = b.digits[i];
sub_n_borrow += a_digit - b_digit;
if (sub_n_borrow >= 0) {
c.digits[i] = sub_n_borrow;
sub_n_borrow = 0;
} else {
c.digits[i] = 0x10000 + sub_n_borrow;
sub_n_borrow = -1;
}
}
while (c.digits.size() > 1 && c.digits.back() == 0) {
c.digits.pop_back();
}
//std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
}
int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
int n = (int)a.digits.size();
if (n < (int)b.digits.size())
n = (int)b.digits.size();
//std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
for (int i = n-1; i >= 0; --i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
if (i < (int)b.digits.size())
b_digit = b.digits[i];
if (a_digit < b_digit) {
//std::cout << "-1" << std::endl;
return -1;
} else if (a_digit > b_digit) {
//std::cout << "+1" << std::endl;
return +1;
}
}
//std::cout << "0" << std::endl;
return 0;
}
void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
unsigned int mult_n_carry = 0;
c.digits.clear();
c.digits.resize(a.digits.size());
for (int i = 0; i < (int)a.digits.size(); ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = b;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
mult_n_carry += a_digit * b_digit;
c.digits[i] = (mult_n_carry & 0xFFFF);
mult_n_carry >>= 16;
}
if (mult_n_carry != 0) {
putCarryInfront(c, mult_n_carry);
}
//std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
}
void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
b.digits.resize(a.digits.size() + times);
for (int i = 0; i < times; ++i) {
b.digits[i] = 0;
}
for (int i = 0; i < (int)a.digits.size(); ++i) {
b.digits[i + times] = a.digits[i];
}
}
void BigInteger::shiftRight(BigInteger& a) {
//std::cout << "shr " << a.toString() << " == ";
for (int i = 0; i < (int)a.digits.size(); ++i) {
a.digits[i] >>= 1;
if (i+1 < (int)a.digits.size()) {
if ((a.digits[i+1] & 0x1) != 0) {
a.digits[i] |= 0x8000;
}
}
}
//std::cout << a.toString() << std::endl;
}
void BigInteger::shiftLeft(BigInteger& a) {
bool lastBit = false;
for (int i = 0; i < (int)a.digits.size(); ++i) {
bool bit = (a.digits[i] & 0x8000) != 0;
a.digits[i] <<= 1;
if (lastBit)
a.digits[i] |= 1;
lastBit = bit;
}
if (lastBit) {
a.digits.push_back(1);
}
}
void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
BigInteger b;
b.negative = a.negative;
b.digits.resize(a.digits.size() + 1);
b.digits[a.digits.size()] = carry;
for (int i = 0; i < (int)a.digits.size(); ++i) {
b.digits[i] = a.digits[i];
}
a.digits.swap(b.digits);
}
void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
c.digits.clear();
c.digits.push_back(0);
BigInteger two("2");
BigInteger e = b;
BigInteger f("1");
BigInteger g = a;
BigInteger one("1");
while (cmpWithoutSign(g, e) >= 0) {
shiftLeft(e);
shiftLeft(f);
}
shiftRight(e);
shiftRight(f);
while (cmpWithoutSign(g, b) >= 0) {
g -= e;
c += f;
while (cmpWithoutSign(g, e) < 0) {
shiftRight(e);
shiftRight(f);
}
}
e = c;
e *= b;
f = a;
f -= e;
d = f;
}
BigInteger::BigInteger(const BigInteger& other) {
digits = other.digits;
negative = other.negative;
}
BigInteger::BigInteger(const char* other) {
digits.push_back(0);
negative = false;
BigInteger ten;
ten.digits[0] = 10;
const char* c = other;
bool make_negative = false;
if (*c == '-') {
make_negative = true;
++c;
}
while (*c != 0) {
BigInteger digit;
digit.digits[0] = *c - '0';
*this *= ten;
*this += digit;
++c;
}
negative = make_negative;
}
bool BigInteger::isOdd() const {
return (digits[0] & 0x1) != 0;
}
BigInteger& BigInteger::operator=(const BigInteger& other) {
if (this == &other) // handle self assignment
return *this;
digits = other.digits;
negative = other.negative;
return *this;
}
BigInteger& BigInteger::operator+=(const BigInteger& other) {
BigInteger result;
if (negative) {
if (other.negative) {
result.negative = true;
addWithoutSign(result, *this, other);
} else {
int a = cmpWithoutSign(*this, other);
if (a < 0) {
result.negative = false;
subWithoutSign(result, other, *this);
} else if (a > 0) {
result.negative = true;
subWithoutSign(result, *this, other);
} else {
result.negative = false;
result.digits.clear();
result.digits.push_back(0);
}
}
} else {
if (other.negative) {
int a = cmpWithoutSign(*this, other);
if (a < 0) {
result.negative = true;
subWithoutSign(result, other, *this);
} else if (a > 0) {
result.negative = false;
subWithoutSign(result, *this, other);
} else {
result.negative = false;
result.digits.clear();
result.digits.push_back(0);
}
} else {
result.negative = false;
addWithoutSign(result, *this, other);
}
}
negative = result.negative;
digits.swap(result.digits);
return *this;
}
BigInteger& BigInteger::operator-=(const BigInteger& other) {
BigInteger neg_other = other;
neg_other.negative = !neg_other.negative;
return *this += neg_other;
}
BigInteger& BigInteger::operator*=(const BigInteger& other) {
BigInteger result;
for (int i = 0; i < (int)digits.size(); ++i) {
BigInteger mult;
multByDigitWithoutSign(mult, other, digits[i]);
BigInteger shift;
shiftLeftByBase(shift, mult, i);
BigInteger add;
addWithoutSign(add, result, shift);
result = add;
}
if (negative != other.negative) {
result.negative = true;
} else {
result.negative = false;
}
//std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
negative = result.negative;
digits.swap(result.digits);
return *this;
}
BigInteger& BigInteger::operator/=(const BigInteger& other) {
BigInteger result, tmp;
divideWithoutSign(result, tmp, *this, other);
result.negative = (negative != other.negative);
negative = result.negative;
digits.swap(result.digits);
return *this;
}
BigInteger& BigInteger::operator%=(const BigInteger& other) {
BigInteger c, d;
divideWithoutSign(c, d, *this, other);
*this = d;
return *this;
}
bool BigInteger::operator>(const BigInteger& other) const {
if (negative) {
if (other.negative) {
return cmpWithoutSign(*this, other) < 0;
} else {
return false;
}
} else {
if (other.negative) {
return true;
} else {
return cmpWithoutSign(*this, other) > 0;
}
}
}
BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
BigInteger zero("0");
BigInteger one("1");
BigInteger e = exponent;
BigInteger base = *this;
*this = one;
while (cmpWithoutSign(e, zero) != 0) {
//std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
if (e.isOdd()) {
*this *= base;
*this %= modulus;
}
shiftRight(e);
base *= BigInteger(base);
base %= modulus;
}
return *this;
}
std::string BigInteger::toString() const {
std::ostringstream os;
if (negative)
os << "-";
BigInteger tmp = *this;
BigInteger zero("0");
BigInteger ten("10");
tmp.negative = false;
std::stack<char> s;
while (cmpWithoutSign(tmp, zero) != 0) {
BigInteger tmp2, tmp3;
divideWithoutSign(tmp2, tmp3, tmp, ten);
s.push((char)(tmp3.digits[0] + '0'));
tmp = tmp2;
}
while (!s.empty()) {
os << s.top();
s.pop();
}
/*
for (int i = digits.size()-1; i >= 0; --i) {
os << digits[i];
if (i != 0) {
os << ",";
}
}
*/
return os.str();
And an example usage.
BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")
// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);
Its fast too, and has unlimited number of digits.
Two things:
Are you using the appropriate data type? In other words, does UINT_MAX allow you to have 673109 as an argument?
No, it does not, since at one point you have Your code does not work because at one point you have num = 2^16 and the num = ... causes overflow. Use a bigger data type to hold this intermediate value.
How about taking modulo at every possible overflow oppertunity such as:
test = ((test % mod) * (num % mod)) % mod;
Edit:
unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
unsigned long long test;
unsigned long long n = num;
for(test = 1; pow; pow >>= 1)
{
if (pow & 1)
test = ((test % mod) * (n % mod)) % mod;
n = ((n % mod) * (n % mod)) % mod;
}
return test; /* note this is potentially lossy */
}
int main(int argc, char* argv[])
{
/* (2 ^ 168277) % 673109 */
printf("%u\n", mod_pow(2, 168277, 673109));
return 0;
}
package playTime;
public class play {
public static long count = 0;
public static long binSlots = 10;
public static long y = 645;
public static long finalValue = 1;
public static long x = 11;
public static void main(String[] args){
int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1};
x = BME(x, count, binArray);
System.out.print("\nfinal value:"+finalValue);
}
public static long BME(long x, long count, int[] binArray){
if(count == binSlots){
return finalValue;
}
if(binArray[(int) count] == 1){
finalValue = finalValue*x%y;
}
x = (x*x)%y;
System.out.print("Array("+binArray[(int) count]+") "
+"x("+x+")" +" finalVal("+ finalValue + ")\n");
count++;
return BME(x, count,binArray);
}
}
LL is for long long int
LL power_mod(LL a, LL k) {
if (k == 0)
return 1;
LL temp = power(a, k/2);
LL res;
res = ( ( temp % P ) * (temp % P) ) % P;
if (k % 2 == 1)
res = ((a % P) * (res % P)) % P;
return res;
}
Use the above recursive function for finding the mod exp of the number. This will not result in overflow because it calculates in a bottom up manner.
Sample test run for :
a = 2 and k = 168277 shows output to be 518358 which is correct and the function runs in O(log(k)) time;
You could use following identity:
(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)
Try using it straightforward way and incrementally improve.
if (pow & 1)
test = ((test % mod) * (num % mod)) % mod;
num = ((num % mod) * (num % mod)) % mod;