3

I am reading Functional Programming in Scala and am having trouble understanding a piece of code. I have checked the errata for the book and the passage in question does not have a misprint. (Actually, it does have a misprint, but the misprint does not affect the code that I have a question about.)

The code in question calculates a pseudo-random, non-negative integer that is less than some upper bound. The function that does this is called nonNegativeLessThan.

trait RNG {
  def nextInt: (Int, RNG) // Should generate a random `Int`. 
}

case class Simple(seed: Long) extends RNG {
  def nextInt: (Int, RNG) = {
    val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL // `&` is bitwise AND. We use the current seed to generate a new seed.
    val nextRNG = Simple(newSeed) // The next state, which is an `RNG` instance created from the new seed.
    val n = (newSeed >>> 16).toInt // `>>>` is right binary shift with zero fill. The value `n` is our new pseudo-random integer.
    (n, nextRNG) // The return value is a tuple containing both a pseudo-random integer and the next `RNG` state.
  }
}

type Rand[+A] = RNG => (A, RNG)

def nonNegativeInt(rng: RNG): (Int, RNG) = {
  val (i, r) = rng.nextInt
  (if (i < 0) -(i + 1) else i, r)
}

def nonNegativeLessThan(n: Int): Rand[Int] = { rng =>
  val (i, rng2) = nonNegativeInt(rng)
  val mod = i % n
  if (i + (n-1) - mod >= 0) (mod, rng2)
  else nonNegativeLessThan(n)(rng2)
}

I have trouble understanding the following code in nonNegativeLessThan that looks like this: if (i + (n-1) - mod >= 0) (mod, rng2), etc.

The book explains that this entire if-else expression is necessary because a naive implementation that simply takes the mod of the result of nonNegativeInt would be slightly skewed toward lower values since Int.MaxValue is not guaranteed to be a multiple of n. Therefore, this code is meant to check if the generated output of nonNegativeInt would be larger than the largest multiple of n that fits inside a 32 bit value. If the generated number is larger than the largest multiple of n that fits inside a 32 bit value, the function recalculates the pseudo-random number.

To elaborate, the naive implementation would look like this:

def naiveNonNegativeLessThan(n: Int): Rand[Int] = map(nonNegativeInt){_ % n}

where map is defined as follows

def map[A,B](s: Rand[A])(f: A => B): Rand[B] = {
  rng => 
    val (a, rng2) = s(rng)
    (f(a), rng2)
}

To repeat, this naive implementation is not desirable because of a slight skew towards lower values when Int.MaxValue is not a perfect multiple of n.

So, to reiterate the question: what does the following code do, and how does it help us determine whether a number is smaller that the largest multiple of n that fits inside a 32 bit integer? I am talking about this code inside nonNegativeLessThan:

if (i + (n-1) - mod >= 0) (mod, rng2)
else nonNegativeLessThan(n)(rng2)
Allen Han
  • 1,163
  • 7
  • 16
  • 3
    As far as I can tell, it doesn't do anything because the condition is never **false**. As `i` is guaranteed to be non-negative, and `n` is _supposed_ to be positive, that puts `mod` in the range of zero to `n-1`, inclusive. So `(n-1) - mod` is also non-negative. Thus the `if` condition is always **true**. – jwvh Oct 31 '19 at 07:39

2 Answers2

0

I have exactly the same confusion about this passage from the Functional Programming in Scala. And I absolutely agree with jwvh's analysis - the statement if (i + (n-1) - mod >= 0) will be always true.

0

In fact, if one tries the same example in Rust, the compiler warns about this (just an interesting comparison of how much static checking is being done). Of course the pencil and paper approach of jwvh is absolutely the right approach.

We first define some type aliases to make the code match closer to the Scala code (forgive my Rust if its not quite idiomatic).

pub type RNGType = Box<dyn RNG>;
pub type Rand<A> = Box<dyn Fn(RNGType) -> (A, RNGType)>;
pub fn non_negative_less_than_(n: u32) -> Rand<u32> {
    let t = move |rng: RNGType| {
        let (i, rng2) = non_negative_int(rng);
        let rem = i % n;
        if i + (n - 1) - rem >= 0 {
            (rem, rng2)
        } else {
            non_negative_less_than(n)(rng2)
        }
    };

    Box::new(t)
}

The compiler warning regarding if nn + (n - 1) - rem >= 0 is:

warning: comparison is useless due to type limits
Ahmed Riza
  • 13
  • 1
  • 5