5

I made the following implementation of the median in C++ and and used it in R via Rcpp:

// [[Rcpp::export]]
double median2(std::vector<double> x){
  double median;
  size_t size = x.size();
  sort(x.begin(), x.end());
  if (size  % 2 == 0){
      median = (x[size / 2 - 1] + x[size / 2]) / 2.0;
  }
  else {
      median = x[size / 2];
  }
  return median;
}

If I subsequently compare the performance with the standard built-in R median function, I get the following results via microbenchmark

> x = rnorm(100)
> microbenchmark(median(x),median2(x))
Unit: microseconds
       expr    min     lq     mean median     uq     max neval
  median(x) 25.469 26.990 34.96888 28.130 29.081 518.126   100
 median2(x)  1.140  1.521  2.47486  1.901  2.281  47.897   100

Why is the standard median function so much slower? This isn't what I would expect...

Ruben
  • 304
  • 3
  • 15
  • 3
    For starters, look at all the stuff that `median.default` actually does and then try testing against something more fair. – joran Jan 13 '16 at 15:55
  • Ok, so I guess that it is because of all the things around, but actually computing the median takes no time at all. – Ruben Jan 13 '16 at 16:11
  • 3
    As an aside, sorting the vector is overkill. You don't care about the ordering of the first n/2 elements -- you just care about what the n/2th element is. The algorithm `std::nth_element` will do this faster than sorting. It can be implemented reasonably easily and efficiently using recursive median-of-median-of-5 and partition, with a short-length alternative algorithm, if you want it in r. Second, use explicit `std::sort` on `std::vector` iterators (there is no guarantee they are defined in `namespace std`, which your code relies upon). – Yakk - Adam Nevraumont Jan 13 '16 at 16:31
  • 1
    @Yakk Indeed, note that `median.default` uses the `partial` argument for `sort`, which does something similar to what you're describing I think. – joran Jan 13 '16 at 16:37
  • @joran it actually does more -- it sorts the first half of the vector. `nth_element` only partitions it so that the nth element is at position n, and all elements before are less, and all elements after are more. You can do this .. faster than a half-sort. You do [median of medians](https://en.wikipedia.org/wiki/Median_of_medians) to find an almost-median, partition to find out where it is. Repeat until you find nth. – Yakk - Adam Nevraumont Jan 13 '16 at 18:00

4 Answers4

16

As noted by @joran, your code is very specialized, and generally speaking, less generalized functions, algorithms, etc... are often more performant. Take a look at median.default:

median.default
# function (x, na.rm = FALSE) 
# {
#   if (is.factor(x) || is.data.frame(x)) 
#     stop("need numeric data")
#   if (length(names(x))) 
#     names(x) <- NULL
#   if (na.rm) 
#     x <- x[!is.na(x)]
#   else if (any(is.na(x))) 
#     return(x[FALSE][NA])
#   n <- length(x)
#   if (n == 0L) 
#     return(x[FALSE][NA])
#   half <- (n + 1L)%/%2L
#   if (n%%2L == 1L) 
#     sort(x, partial = half)[half]
#   else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L])
# }

There are several operations in place to accommodate the possibility of missing values, and these will definitely impact the overall execution time of the function. Since your function does not replicate this behavior it can eliminate a bunch of calculations, but consequently will not provide the same result for vectors with missing values:

median(c(1, 2, NA))
#[1] NA

median2(c(1, 2, NA))
#[1] 2

A couple of other factors which probably don't have as much of an effect as the handling of NAs, but are worth pointing out:

  • median, along with a handful of the functions it uses, are S3 generics, so there is a small amount of time spent on method dispatch
  • median will work with more than just integer and numeric vectors; it will also handle Date, POSIXt, and probably a bunch of other classes, and preserve attributes correctly:

median(Sys.Date() + 0:4)
#[1] "2016-01-15"

median(Sys.time() + (0:4) * 3600 * 24)
#[1] "2016-01-15 11:14:31 EST"

Edit: I should mention that the function below will cause the original vector to be sorted since NumericVectors are proxy objects. If you want to avoid this, you can either Rcpp::clone the input vector and operate on the clone, or use your original signature (with a std::vector<double>), which implicitly requires a copy in the conversion from SEXP to std::vector.

Also note that you can shave off a little more time by using a NumericVector instead of a std::vector<double>:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med(Rcpp::NumericVector x){
  std::size_t size = x.size();
  std::sort(x.begin(), x.end());
  if (size  % 2 == 0) return (x[size / 2 - 1] + x[size / 2]) / 2.0;
  return x[size / 2];
}

microbenchmark::microbenchmark(
  median(x),
  median2(x),
  cpp_med(x),
  times = 200L
)
# Unit: microseconds
#       expr    min      lq      mean  median      uq     max neval
#  median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810   200
# median2(x)  6.474  7.9665  13.90126 11.0570  14.844 151.817   200
# cpp_med(x)  5.737  7.4285  11.25318  9.0270  13.405  52.184   200

Yakk brought up a great point in the comments above - also elaborated on by Jerry Coffin - about the inefficiency of doing a complete sort. Here's a rewrite using std::nth_element, benchmarked on a much larger vector:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med2(Rcpp::NumericVector xx) {
  Rcpp::NumericVector x = Rcpp::clone(xx);
  std::size_t n = x.size() / 2;
  std::nth_element(x.begin(), x.begin() + n, x.end());

  if (x.size() % 2) return x[n]; 
  return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}

set.seed(123)
xx <- rnorm(10e5)

all.equal(cpp_med2(xx), median(xx))
all.equal(median2(xx), median(xx))

microbenchmark::microbenchmark(
  cpp_med2(xx), median2(xx), 
  median(xx), times = 200L
)
# Unit: milliseconds
#         expr      min       lq     mean   median       uq       max neval
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161  33.92103   200
#  median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301   200
#   median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939   200
nrussell
  • 18,382
  • 4
  • 47
  • 60
  • 4
    I'd be curious what happens if one used only the last four lines of `median.default` and replaced `mean()` by `.Internal(mean())`. I would guess that would be pretty close to `median2`, maybe even faster. – joran Jan 13 '16 at 16:14
  • 2
    ...so having tested that it's definitely not as fast as `median2`, but it is much closer. – joran Jan 13 '16 at 16:18
  • 1
    @joran It's probably worth testing on a larger vector as well; when I compared `median.default` to the two C++ versions on a `rnorm(1e5)` vector the timings were much closer. – nrussell Jan 13 '16 at 16:20
  • And I'm also wondering why the author of `median.default` chose to use `any(is.na(x))` instead of `anyNA(x)`, as the latter is much faster... – nrussell Jan 13 '16 at 16:29
  • One call to `nth_element` isn't enough for an even length list. Call it once, then call `min_element` on the right hand side interval (or max element on the left hand side, depending how your math works). Also you have to deal with 0 sized lists with special case code. – Yakk - Adam Nevraumont Jan 13 '16 at 18:41
  • Second, everything done outside of the median algorithm should be really cheap stuff even in R. But it is slow enough that on small lists, it dominates the search time. Note that my version above is sub-optimal, because you can do lots of work to get the nth-1 element while getting the nth element -- calling `max_element` afterwards ignores the structure we have generated in the lhs list. There isn't, as far as I know a `nsth_elements` function that takes a collection of elements to find in a way similar to `nth_element` finds one; you'd have to roll your own. It might be ~1.2x faster tho – Yakk - Adam Nevraumont Jan 13 '16 at 18:56
  • @Yakk Whoops, good catch - thank you. I left out the zero-size check in the previous versions and in your version just because the OP wasn't performing these in their function. And although it's definitely a good idea to check for this in C++, empty vectors are *much* less likely to exist in R just due to the nature of the language, i.e. you pretty much have to explicitly create one with `numeric(0)`. Thanks again for your input. – nrussell Jan 13 '16 at 19:34
  • You should also update the benchmarks (I cannot believe that the max element change had no timing impact) – Yakk - Adam Nevraumont Jan 13 '16 at 19:55
  • @Yakk One would think that the extra call to `std::max_element` would have a noticeable impact, but I just reran the benchmark with the updated version of `cpp_med2` and did not see a measurable increase in the run time; e.g. the median times were all within the range of about 12.4 - 12.8 milliseconds. – nrussell Jan 13 '16 at 20:09
  • I don't think the part `Rcpp::NumericVector x = Rcpp::clone(xx);` in `cpp_med2` is necessary, because we did a pass by value ? – Ruben Jan 14 '16 at 09:42
  • It is necessary, try it both ways and see : ). – nrussell Jan 14 '16 at 10:13
  • No, let me show you: `// [[Rcpp::export]] double median3(NumericVector x) { std::size_t n = x.size() / 2; std::nth_element(x.begin(), x.begin() + n, x.end()); if (x.size() % 2) return x(n); return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.; }` also works. – Ruben Jan 15 '16 at 10:22
2

[This is more of an extended comment than an answer to the question you actually asked.]

Even your code may be open to significant improvement. In particular, you're sorting the entire input even though you only care about one or two elements.

You can change this from O(n log n) to O(n) by using std::nth_element instead of std::sort. In case of an even number of elements, you'd typically want to use std::nth_element to find the element just before the middle, then use std::min_element to find the immediately succeeding element--but std::nth_element also partitions the input items, so the std::min_element only has to run on the items above the middle after the nth_element, not the entire input array. That is, after nth_element, you get a situation like this:

enter image description here

The complexity of std::nth_element is "linear on average", and (of course) std::min_element is linear as well, so the overall complexity is linear.

So, for the simple case (odd number of elements), you get something like:

auto pos = x.begin() + x.size()/2;

std::nth_element(x.begin(), pos, x.end());
return *pos;

...and for the more complex case (even number of elements):

std::nth_element(x.begin(), pos, x.end());
auto pos2 = std::min_element(pos+1, x.end());
return (*pos + *pos2) / 2.0;
Jerry Coffin
  • 476,176
  • 80
  • 629
  • 1,111
-1

I'm not sure what "standard" implementation you would be referring to.

Anyway: If there were one, it would, being part of a standard library, certainly not be allowed to change the order of elements in the vector (as your implementation does), so it would definitely have to work on a copy.

Creating this copy would take time and CPU (and significant memory), which would affect the run time.

tofro
  • 5,640
  • 14
  • 31
-1

From here one can expect that max_element( ForwardIt first, ForwardIt last ) provides the max from first to last, but by doing: return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2. the x.begin() + n element seems to be excluded from the calculation. Why is that discrepancy?

E.g. cpp_med2({6, 2, 1, 5, 3, 4}) produces x={2, 1, 3, 4, 5, 6} where:

n = 3
*x[n] = 4
*x.begin() = 2
*(x.begin() + n) = 4
*std::max_element(x.begin(), x.begin() + n) = 3

So that the cpp_med2({6, 2, 1, 5, 3, 4}) returns (4+3)/2=3.5 which is the correct median. But why is *std::max_element(x.begin(), x.begin() + n) equal to 3 instead of 4? the function actually seems to exclude the last element (4) in the max calculation.

SOLVED (I think): in:

Finds the greatest element in the range [first, last)

the ) closing means last is excluded from the calculation. Is that correct?

Best regards

Guillermo Luijk
  • 199
  • 1
  • 7