I have implemented an approximate natural log function based on a Padé Approximation of a truncated Taylor Series. The accuracy is acceptable (±0.000025) but despite several rounds of optimizations, its exec time is still about 2.5x that of the standard library ln
function! If it isn't faster and it isn't as accurate, it is worthless! Nevertheless, I am using this as a way to learn how to optimize my Rust code. (My timings come from using the criterion
crate. I used blackbox, summed the values in the loop and created a string from the result to defeat the optimizer.)
On Rust Playground, my code is:
Algorithm
An overview of my algorithm, which works on a ratio of unsigned integers:
- Range reduction to the interval [1, 2] by dividing by the largest power of two not exceeding the value:
- Change representation of numerator →
2ⁿ·N where 1 ≤ N ≤ 2
- Change representation of denominator →
2ᵈ·D where 1 ≤ D ≤ 2
- Change representation of numerator →
- This makes the result
log(numerator/denominator) = log(2ⁿ·N / 2ᵈ·D) = (n-d)·log(2) + log(N) - log(D)
- To perform log(N), Taylor series does not converge in the neighborhood of zero, but it does near one...
- ... since N is near one, substitute x = N - 1 so that we now need to evaluate log(1 + x)
- Perform a substitution of
y = x/(2+x)
- Consider the related function
f(y) = Log((1+y)/(1-y))
= Log((1 + x/(2+x)) / (1 - x/(2+x)))
= Log( (2+2x) / 2)
= Log(1 + x)
- f(y) has a Taylor Expansion which converges must faster than the expansion for Log(1+x) ...
- For Log(1+x) →
x - x²/2 + x³/3 - y⁴/4 + ...
- For Log((1+y)/(1-y)) →
y + y³/3 + y⁵/5 + ...
- For Log(1+x) →
- Use the Padé Approximation for the truncated series
y + y³/3 + y⁵/5 ...
- ... Which is
2y·(15 - 4y²)/(15 - 9y²)
- Repeat for the denominator and combine the results.
Padé Approximation
Here is the Padé Approximation part of the code:
/// Approximate the natural logarithm of one plus a number in the range (0..1).
///
/// Use a Padé Approximation for the truncated Taylor series for Log((1+y)/(1-y)).
///
/// - x - must be a value between zero and one, inclusive.
#[inline]
fn log_1_plus_x(x : f64) -> f64 {
// This is private and its caller already checks for negatives, so no need to check again here.
// Also, though ln(1 + 0) == 0 is an easy case, it is not so much more likely to be the argument
// than other values, so no need for a special test.
let y = x / (2.0 + x);
let y_squared = y * y;
// Original Formula is this: 2y·(15 - 4y²)/(15 - 9y²)
// 2.0 * y * (15.0 - 4.0 * y_squared) / (15.0 - 9.0 * y_squared)
// Reduce multiplications: (8/9)y·(3.75 - y²)/((5/3) - y²)
0.8888888888888889 * y * (3.75 - y_squared) / (1.6666666666666667 - y_squared)
}
Clearly, not much more to speed up there!
Most Significant Bit
The change that has had the most impact so far was optmizing my calculation that gets the position of the most significant bit. I need that for the range reduction.
Here is my msb
function:
/// Provide `msb` method for numeric types to obtain the zero-based
/// position of the most significant bit set.
///
/// Algorithms used based on this article:
/// https://prismoskills.appspot.com/lessons/Bitwise_Operators/Find_position_of_MSB.jsp
pub trait MostSignificantBit {
/// Get the zero-based position of the most significant bit of an integer type.
/// If the number is zero, return zero.
///
/// ## Examples:
///
/// ```
/// use clusterphobia::clustering::msb::MostSignificantBit;
///
/// assert!(0_u64.msb() == 0);
/// assert!(1_u64.msb() == 0);
/// assert!(2_u64.msb() == 1);
/// assert!(3_u64.msb() == 1);
/// assert!(4_u64.msb() == 2);
/// assert!(255_u64.msb() == 7);
/// assert!(1023_u64.msb() == 9);
/// ```
fn msb(self) -> usize;
}
#[inline]
/// Return whether floor(log2(x))!=floor(log2(y))
/// with zero for false and 1 for true, because this came from C!
fn ld_neq(x : u64, y : u64) -> u64 {
let neq = (x^y) > (x&y);
if neq { 1 } else { 0 }
}
impl MostSignificantBit for u64 {
#[inline]
fn msb(self) -> usize {
/*
// SLOWER CODE THAT I REPLACED:
// Bisection guarantees performance of O(Log B) where B is number of bits in integer.
let mut high = 63_usize;
let mut low = 0_usize;
while (high - low) > 1
{
let mid = (high+low)/2;
let mask_high = (1 << high) - (1 << mid);
if (mask_high & self) != 0 { low = mid; }
else { high = mid; }
}
low
*/
// This algorithm found on pg 16 of "Matters Computational" at https://www.jjj.de/fxt/fxtbook.pdf
// It avoids most if-branches and has no looping.
// Using this instead of Bisection and looping shaved off 1/3 of the time.
const MU0 : u64 = 0x5555555555555555; // MU0 == ((-1UL)/3UL) == ...01010101_2
const MU1 : u64 = 0x3333333333333333; // MU1 == ((-1UL)/5UL) == ...00110011_2
const MU2 : u64 = 0x0f0f0f0f0f0f0f0f; // MU2 == ((-1UL)/17UL) == ...00001111_2
const MU3 : u64 = 0x00ff00ff00ff00ff; // MU3 == ((-1UL)/257UL) == (8 ones)
const MU4 : u64 = 0x0000ffff0000ffff; // MU4 == ((-1UL)/65537UL) == (16 ones)
const MU5 : u64 = 0x00000000ffffffff; // MU5 == ((-1UL)/4294967297UL) == (32 ones)
let r : u64 = ld_neq(self, self & MU0)
+ (ld_neq(self, self & MU1) << 1)
+ (ld_neq(self, self & MU2) << 2)
+ (ld_neq(self, self & MU3) << 3)
+ (ld_neq(self, self & MU4) << 4)
+ (ld_neq(self, self & MU5) << 5);
r as usize
}
}
Rust u64::next_power_of_two, unsafe code and intrinsics
Now I know that Rust has a fast method for finding the lowest power of two greater than or equal to a number. I need this, but I also need the bit position, because that is the equivalent of the log base 2 of my numbers. (For example: next_power_of_two(255) yields 256, but I want 8, because it has the 8th bit set.) Looking at the source code for next_power_of_two
, I see this line inside a private helper method called fn one_less_than_next_power_of_two
:
let z = unsafe { intrinsics::ctlz_nonzero(p) };
So is there an intrinsic that I can use to get the bit position the same way? Is it used in a public method that I have access to? Or is there a way to write unsafe code to call some intrinsic I don't know about (which is most of them)?
If there is such a methd or intrinsic I can call, I suspect that will greatly speed up my program, but maybe there are other things that will also help.
UPDATE:
Head smack! I can use 63 - x.leading_zeros()
to find the position of the most significant bit! I just didn't think of coming from the other end. I will try this and see if it speeds things up...