I have implemented this solution for finding a root of a cubic function
f(x) = ax3 + bx2 + cx + d
given a, b, c, and d, ensuring it's being monotonic.
After submitting the solution to an online judge without being shown the test cases, I am being faced by a time limit error. a, b, c, and d guarantee that the function is monotonic and we know it is being continuous. The code first finds the interval [A, B] such that f(A) * f(B) < 0; then the code moves to implement the bisection search.
What I want to know is if there is some possibility to minimize the time complexity of my code so it passes the online judge. The input is a, b, c, d, and the output should be the root with an error 0.000001.
Code:
#include <iostream>
#include <algorithm>
//#include <cmath>
//#include <string>
using namespace std;
int f(double a, double b, double c, double d, double x) {
return x*(x*(a*x + b) + c) + d;
}
int main() {
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
double a, b, c, d, A, B, x = 1, res;
cin >> a >> b >> c >> d;
//determinning the interval
double f_x = f(a, b, c, d, x);
if (a > 0) { // strictly increasing
if (f_x > 0) { B = 0;
while (f(a, b, c, d, x) >= 0) { x -= x; }
A = x; }
else { A = 0;
while (f(a, b, c, d, x) <= 0) { x += x; }
B = x; }
}
else { //strictly decreasing
if (f_x > 0) { A = 0;
while (f(a, b, c, d, x) >= 0) { x += x; }
B = x; }
else { B = 0;
while (f(a, b, c, d, x) <= 0) { x -= x; }
A = x; }
}
// Bisection Search
double l = A;
while ((B - A) >= 0.000001)
{
// Find middle point
l = (A + B) / 2;
// Check if middle point is root
if (f(a, b, c, d, l) == 0.0)
break;
// Decide the side to repeat the steps
else if (f(a, b, c, d, l)*f(a, b, c, d, A) < 0)
B = l;
else
A = l;
}
res = l;
cout.precision(6);
cout << fixed << " " << res;
return 0;
}
There is no need to determine the initial interval, just take [-DBL_MAX, +DBL_MAX]. The tolerance can be chosen to be 1 ULP.
The following code implements these ideas:
// This function will be available in C++20 as std::midpoint
double midpoint(double x, double y) {
if (std::isnormal(x) && std::isnormal(y))
return x / 2 + y / 2;
else
return (x + y) / 2;
}
int main() {
...
const auto fn = [=](double x) { return x * (x * (x * a + b) + c) + d; };
auto left = -std::numeric_limits<double>::max();
auto right = std::numeric_limits<double>::max();
while (true) {
const auto mid = midpoint(left, right);
if (mid <= left || mid >= right)
break;
if (std::signbit(fn(left)) == std::signbit(fn(mid)))
left = mid;
else
right = mid;
}
const double answer = left;
...
}
Initially, fn(x) can overflow and return inf. No special handling of this case is needed.
I have the following question, which is actually from a coding test I recently took:
Question:
A function f(n) = a*n + b*n*(floor(log(n)/log(2))) + c*n*n*n exists.
At a particular value, let f(n) = k;
Given k, a, b, c, find n.
For a given value of k, if no n value exists, then return 0.
Limits:
1 <= n < 2^63-1
0 < a, b < 100
0 <= c < 100
0 < k < 2^63-1
The logic here is that since f(n) is purely increasing for a given a, b and c, I can find n by binary search.
The code I wrote was as follows:
#include<iostream>
#include<stdlib.h>
#include<math.h>
using namespace std;
unsigned long long logToBase2Floor(unsigned long long n){
return (unsigned long long)(double(log(n))/double(log(2)));
}
#define f(n, a, b, c) (a*n + b*n*(logToBase2Floor(n)) + c*n*n*n)
unsigned long long findNByBinarySearch(unsigned long long k, unsigned long long a, unsigned long long b, unsigned long long c){
unsigned long long low = 1;
unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;
unsigned long long n;
while(low<=high){
n = (low+high)/2;
cout<<"\n\n k= "<<k;
cout<<"\n f(n,a,b,c)= "<<f(n,a,b,c)<<" low = "<<low<<" mid="<<n<<" high = "<<high;
if(f(n,a,b,c) == k)
return n;
else if(f(n,a,b,c) < k)
low = n+1;
else high = n-1;
}
return 0;
}
I then tried it with a few test cases:
int main(){
unsigned long long n, a, b, c;
n = (unsigned long long)pow(2,63)-1;
a = 99;
b = 99;
c = 99;
cout<<"\nn="<<n<<" a="<<a<<" b="<<b<<" c="<<c<<" k = "<<f(n, a, b, c);
cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;
n = 1000;
cout<<"\nn="<<n<<" a="<<a<<" b="<<b<<" c="<<c<<" k = "<<f(n, a, b, c);
cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;
return 0;
}
Then something weird happened.
The code works for the test case n = (unsigned long long)pow(2,63)-1;, correctly returning that value of n. But it did not work for n=1000. I printed the output and saw the following:
n=1000 a=99 b=99 c=99 k = 99000990000
k= 99000990000
f(n,a,b,c)= 4611686018427387904 low = 1 mid=4611686018427387904 high = 9223372036854775807
...
...
k= 99000990000
f(n,a,b,c)= 172738215936 low = 1 mid=67108864 high = 134217727
k= 99000990000
f(n,a,b,c)= 86369107968 low = 1 mid=33554432 high = 67108863
k= 99000990000
f(n,a,b,c)= 129553661952 low = 33554433 mid=50331648 high = 67108863**
...
...
k= 99000990000
f(n,a,b,c)= 423215328047139441 low = 37748737 mid=37748737 high = 37748737
ANSWER: 0
Something didn't seem right mathematically. How was it that the value of f(1000) was greater than the value of f(33554432)?
So I tried the same code in Python, and got the following values:
>>> f(1000, 99, 99, 99)
99000990000L
>>> f(33554432, 99, 99, 99)
3740114254432845378355200L
So, the value is definitely larger.
Questions:
What is happening exactly?
How should I solve it?
What is happening exactly?
The problem is here:
unsigned long long low = 1;
// Side note: This is simply (2ULL << 62) - 1
unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;
unsigned long long n;
while (/* irrelevant */) {
n = (low + high) / 2;
// Some stuff that do not modify n...
f(n, a, b, c) // <-- Here!
}
In the first iteration, you have low = 1 and high = 2^63 - 1, which mean that n = 2^63 / 2 = 2^62. Now, let's look at f:
#define f(n, a, b, c) (/* I do not care about this... */ + c*n*n*n)
You have n^3 in f, so for n = 2^62, n^3 = 2^186, which is probably way too large for your unsigned long long (which is likely to be 64-bits long).
How should I solve it?
The main issue here is overflow when doing the binary search, so you should simply handle the overflowing case separatly.
Preamble: I am using ull_t because I am lazy, and you should avoid macro in C++, prefer using a function and let the compiler inline it. Also, I prefer a loop against using the log function to compute the log2 of an unsigned long long (see the bottom of this answer for the implementation of log2 and is_overflow).
using ull_t = unsigned long long;
constexpr auto f (ull_t n, ull_t a, ull_t b, ull_t c) {
if (n == 0ULL) { // Avoid log2(0)
return 0ULL;
}
if (is_overflow(n, a, b, c)) {
return 0ULL;
}
return a * n + b * n * log2(n) + c * n * n * n;
}
Here is slightly modified binary search version:
constexpr auto find_n (ull_t k, ull_t a, ull_t b, ull_t c) {
constexpr ull_t max = std::numeric_limits<ull_t>::max();
auto lb = 1ULL, ub = (1ULL << 63) - 1;
while (lb <= ub) {
if (ub > max - lb) {
// This should never happens since ub < 2^63 and lb <= ub so lb + ub < 2^64
return 0ULL;
}
// Compute middle point (no overflow guarantee).
auto tn = (lb + ub) / 2;
// If there is an overflow, then change the upper bound.
if (is_overflow(tn, a, b, c)) {
ub = tn - 1;
}
// Otherwize, do a standard binary search...
else {
auto val = f(tn, a, b, c);
if (val < k) {
lb = tn + 1;
}
else if (val > k) {
ub = tn - 1;
}
else {
return tn;
}
}
}
return 0ULL;
}
As you can see, there is only one test that is relevant here, which is is_overflow(tn, a, b, c) (the first test regarding lb + ub is irrelevant here since ub < 2^63 and lb <= ub < 2^63 so ub + lb < 2^64 which is ok for unsigned long long in our case).
Complete implementation:
#include <limits>
#include <type_traits>
using ull_t = unsigned long long;
template <typename T,
typename = std::enable_if_t<std::is_integral<T>::value>>
constexpr auto log2 (T n) {
T log = 0;
while (n >>= 1) ++log;
return log;
}
constexpr bool is_overflow (ull_t n, ull_t a, ull_t b, ull_t c) {
ull_t max = std::numeric_limits<ull_t>::max();
if (n > max / a) {
return true;
}
if (n > max / b) {
return true;
}
if (b * n > max / log2(n)) {
return true;
}
if (c != 0) {
if (n > max / c) return true;
if (c * n > max / n) return true;
if (c * n * n > max / n) return true;
}
if (a * n > max - c * n * n * n) {
return true;
}
if (a * n + c * n * n * n > max - b * n * log2(n)) {
return true;
}
return false;
}
constexpr auto f (ull_t n, ull_t a, ull_t b, ull_t c) {
if (n == 0ULL) {
return 0ULL;
}
if (is_overflow(n, a, b, c)) {
return 0ULL;
}
return a * n + b * n * log2(n) + c * n * n * n;
}
constexpr auto find_n (ull_t k, ull_t a, ull_t b, ull_t c) {
constexpr ull_t max = std::numeric_limits<ull_t>::max();
auto lb = 1ULL, ub = (1ULL << 63) - 1;
while (lb <= ub) {
if (ub > max - lb) {
return 0ULL; // Problem here
}
auto tn = (lb + ub) / 2;
if (is_overflow(tn, a, b, c)) {
ub = tn - 1;
}
else {
auto val = f(tn, a, b, c);
if (val < k) {
lb = tn + 1;
}
else if (val > k) {
ub = tn - 1;
}
else {
return tn;
}
}
}
return 0ULL;
}
Compile time check:
Below is a little piece of code that you can use to check if the above code at compile time (since everything is constexpr):
template <unsigned long long n, unsigned long long a,
unsigned long long b, unsigned long long c>
struct check: public std::true_type {
enum {
k = f(n, a, b, c)
};
static_assert(k != 0, "Value out of bound for (n, a, b, c).");
static_assert(n == find_n(k, a, b, c), "");
};
template <unsigned long long a,
unsigned long long b,
unsigned long long c>
struct check<0, a, b, c>: public std::true_type {
static_assert(a != a, "Ambiguous values for n when k = 0.");
};
template <unsigned long long n>
struct check<n, 0, 0, 0>: public std::true_type {
static_assert(n != n, "Ambiguous values for n when a = b = c = 0.");
};
#define test(n, a, b, c) static_assert(check<n, a, b, c>::value, "");
test(1000, 99, 99, 0);
test(1000, 99, 99, 99);
test(453333, 99, 99, 99);
test(495862, 99, 99, 9);
test(10000000, 1, 1, 0);
Note: The maximum value of k is about 2^63, so for a given triplet (a, b, c), the maximum value of n is the one such as f(n, a, b, c) < 2 ^ 63 and f(n + 1, a, b, c) >= 2 ^ 63. For a = b = c = 99, this maximum value is n = 453333 (empirically found), which is why I tested it above.
So i have a function
Vector getNthRoots(double a, double b, double c, int n)
{
Vector v;
int i;
v.length = 0;
double m, a2, b2, c2;
if (n % 2 == 0)
{
a2 = a;
b2 = b;
c2 = c;
if (a<0)
a2 = a*(-1);
if (b<0)
b2 = b*(-1);
if (c<0)
c2 = c*(-1);
m = floor(pow(max(a2, b2, c2),1/n));
for (i = 1; i <= m; i++)
if (pow(i, n) >= min(a2, b2, c2) && pow(i, n) <= max(a2, b2, c2))
{
v.values[v.length] = i;
v.length++;
v.values[v.length] = (-1)*i;
v.length++;
}
return v;
}
else {
for (i = ceil(pow(min(a, b, c),1/n)); i <= floor(pow(max(a, b, c),1/n)); i++)
if (pow(i, n) >= min(a, b, c) && pow(i, n) <= max(a, b, c))
{
v.values[v.length] = i;
v.length++;
}
return v;
}
}
This function is supposed to give you the numbers at power n (number^n) which are in the interval of min(a,b,c) and max(a,b,c);
Other functions/headers
double max(double a, double b, double c)
{
if (a >= b && a >= c)
return a;
if (b >= a && b >= c)
return b;
if (c >= a && c >= b)
return c;
return a;
}
double min(double a, double b, double c)
{
if (a <= b && a <= c)
return a;
if (b <= a && b <= c)
return b;
if (c <= a && c <= b)
return c;
return a;
}
#include <iostream>
#include <cmath>
using namespace std;
#define MAX_ARRAY_LENGTH 100
struct Vector
{
unsigned int length;
int values[MAX_ARRAY_LENGTH];
};
It seems i can`t receive the good answer . For example
for getNthRoots(32,15,37,5) it should return a vector [2] because 2^5 =32 which belongs to interval [15,37] but i don`t receive anything
or getNthRoots(32,1,7,5) it should return a vector [1,2] but i only receive 1 as answer
I am guessing here is the problem for (i = ceil(pow(min(a, b, c),1/n)); i <= floor(pow(max(a, b, c),1/n)); i++)but i don`t know how i could fix it
1/n evaluates to 0, because it is evaluated as an integer expression. Try replacing all the "1/n"s with "1.0/n"s.
Take care to handle the case where n is 0.
double p1::root(double (*pf)(double k), int a, int b, double e)
im not sure how to go about it, i understand that i have to loop that pinpoints the midpoint and such
double p1::root(double (*pf)(double k), int a, int b, double e) {
// void nrerror(char error_text[]);
int j;
float dx, f, fmid, xmid, rtb;
f = (*pf)(a);
fmid = (*pf)(b);
//if (f*fmid >= 0.0) nrerror("root must be bracketed for bisection in rtbis")\
;
rtb = f < 0.0 ? (dx=b-a,a) : (dx=a-b,b);
for(j = 1;j <40; j++) {
fmid = (*pf)(xmid = rtb+(dx *= .5));
if (fmid <= 0.0) rtb = xmid;
if (fabs(dx) < e || fmid == 0.0) return rtb;
}
// nrerror("too many bisections in rtbis");
return 0.0;
}
double p1::test_function(double k) {
return (pow(k, 3) -2);
}
then in main i have this
double (*pf)(double k);
pf = &p1::test_function;
//double result = p1::root(pf, a, b, e);
Maybe Numerical Recipes will give you an idea. Hint: it's recursive.