4

It appears that for larger nnet::multinom multinomial regression models (with a few thousand coefficients), calculating the Hessian (the matrix of second derivatives of the negative log likelihood, also known as the observed Fisher information matrix) becomes super slow, which then prevents me from calculating the variance-covariance matrix & allowing me to calculate confidence intervals on model predictions.

It seems the culprit is the following pure R function - it seems it uses some code to calculate the Fisher information matrix analytically using code contributed by David Firth : https://github.com/cran/nnet/blob/master/R/vcovmultinom.R

multinomHess = function (object, Z = model.matrix(object)) 
{
    probs <- object$fitted
    coefs <- coef(object)
    if (is.vector(coefs)) {
        coefs <- t(as.matrix(coefs))
        probs <- cbind(1 - probs, probs)
    }
    coefdim <- dim(coefs)
    p <- coefdim[2L]
    k <- coefdim[1L]
    ncoefs <- k * p
    kpees <- rep(p, k)
    n <- dim(Z)[1L]

##  Now compute the observed (= expected, in this case) information,
##  e.g. as in T Amemiya "Advanced Econometrics" (1985) pp 295-6.
##  Here i and j are as in Amemiya, and x, xbar are vectors
##  specific to (i,j) and to i respectively.

    info <- matrix(0, ncoefs, ncoefs)
    Names <- dimnames(coefs)
    if (is.null(Names[[1L]])) 
        Names <- Names[[2L]]
    else Names <- as.vector(outer(Names[[2L]], Names[[1L]], function(name2, 
        name1) paste(name1, name2, sep = ":")))
    dimnames(info) <- list(Names, Names)
    x0 <- matrix(0, p, k + 1L)
    row.totals <- object$weights
    for (i in seq_len(n)) {
        Zi <- Z[i, ]
        xbar <- rep(Zi, times=k) * rep(probs[i, -1, drop=FALSE], times=kpees)
        for (j in seq_len(k + 1)) {
            x <- x0
            x[, j] <- Zi
            x <- x[, -1, drop = FALSE]
            x <- x - xbar
            dim(x) <- c(1, ncoefs)
            info <- info + (row.totals[i] * probs[i, j] * crossprod(x))
        }
    }
    info
}

The info in the Advanced Econometrics book that is referenced states enter image description here enter image description here

From this explanation, we can see that the Hessian indeed is just given by the sum of a bunch of crossproducts. I also saw this and this in terms of derivation of how to calculate the Hessian matrix of a multinomial regression model, which may be even more elegant and efficient, as the Hessian is there calculated based on a sum of Kronecker products.

For a smallish nnet::multinom model (in which I am modelling the frequency of different SARS-CoV2 lineages through time) the provided function runs quickly :

library(nnet)
library(splines)
download.file("https://www.dropbox.com/s/gt0yennn2gkg3rd/smallmodel.RData?dl=1",
              "smallmodel.RData", 
              method = "auto", mode="wb")
load("smallmodel.RData")
length(fit_multinom_small$lev) # k=12 outcome levels
dim(coef(fit_multinom_small)) # 11 x 3 = (k-1) x p = 33 coefs
system.time(hess <- nnet:::multinomHess(fit_multinom_small)) # 0.11s
dim(hess) # 33 33

but doing this for a large model takes more than 2 hours (even though the model itself fits in ca. 1 minute) (again modelling the frequency of different SARS-CoV2 lineages through time, but now across different continents / countries) :

download.file("https://www.dropbox.com/s/mpz08jj7fmubd68/bigmodel.RData?dl=1",
              "bigmodel.RData", 
              method = "auto", mode="wb")
load("bigmodel.RData")
length(fit_global_multi_last3m$lev) # k=20 outcome levels
dim(coef(fit_global_multi_last3m)) # 19 x 229 = (k-1) x p = 4351 coefficients
system.time(hess <- nnet:::multinomHess(fit_global_multi_last3m)) # takes forever

I was now looking for ways to speed up the above function.

The obvious attempt could be to port it to Rcpp, but unfortunately I am not so experienced in this. Anybody any thoughts?

EDIT: From the info here and here, it appears that calculating the Hessian for a multinomial fit should just come down to calculating a sum of Kronecker products, which we can just do from R using efficient matrix algebra, but right now I am unsure how to include my total row counts fit$weights. Anybody any idea?

download.file("https://www.dropbox.com/s/gt0yennn2gkg3rd/smallmodel.RData?dl=1",
                      "smallmodel.RData", 
                      method = "auto", mode="wb")

load("smallmodel.RData")
library(nnet)
length(fit_multinom_small$lev) # k=12 outcome levels
dim(coef(fit_multinom_small)) # 11 x 3 = (k-1) x p = 33 coefs

fit = fit_multinom_small

Z = model.matrix(fit)
P = fitted(fit)[, -1, drop=F]
k = ncol(P) # nr of outcome categories-1
p = ncol(Z) # nr of parameters
n = nrow(Z) # nr of observations
ncoefs = k*p
library(fastmatrix)

# Fisher information matrix
info <- matrix(0, ncoefs, ncoefs)
for (i in 1:n) { # sum over observations
info = info + kronecker.prod(diag(P[i,]) - tcrossprod(P[i,]), tcrossprod(Z[i,]))
}
Tom Wenseleers
  • 7,535
  • 7
  • 63
  • 103
  • 1
    If you are using matrix multiplications then RcppArmadillo is probably a better choice – user20650 Sep 22 '22 at 14:01
  • @user20650 Also tried an RcppArmadillo version (see EDIT above), but with my limited Rcpp coding skills unfortunately also still a few bugs there... Any input wellcome! Disadvantage of RcppArmadillo is that this just falls on whatever BLAS R is compiled against, and in Windows it's difficult to get hold of one compiled against IntelMKL or OpenBLAS, that would take advantage of multithreading etc (and Microsoft Open R, that was compiled against IntelMKL will be phased out). RcppEigen could be a good alternative - that works well no matter what BLAS is installed... – Tom Wenseleers Sep 22 '22 at 14:23
  • 1
    Thanks for update. re the first two errors , to simplify, I'd just add the names to the object returned by the cpp function later using R as they are not used in the function. That leaves repmat! – user20650 Sep 22 '22 at 15:07
  • @user20650 Thanks! And any thoughts where the "no matching function for call to 'repmat(arma::subview_row, int, arma::vec&)'" is coming from & how to fix that? Does repmat not allow for mixed type arguments or something? – Tom Wenseleers Sep 22 '22 at 15:23
  • From a quick look at http://arma.sourceforge.net/docs.html#repmat i'd think `repmat` takes scalar nrow and ncol arguments but `kpees` is a vector. Just using `kpees[0]` instead lets the function compile but as the R code has `rep` I doubt that we can expect t to be a scalar all the time. So likely need another way to reshape your vector. (sorry I don't have time to look just now) – user20650 Sep 22 '22 at 16:28
  • @Tom sadly we cannot drop R functions names as-is into C++ code and hope for the best. Some exist, some don't so generally this is more of struggle line-by-line to compose your bigger picture function from the ground up step by (validated) step. – Dirk Eddelbuettel Sep 22 '22 at 23:42
  • @DirkEddelbuettel Yes, I know, and I know my quick attempt with my limited Rcpp coding experience was probably quite naive, though I also feel it might not be too far off either. Find it quite hard to debug this though. Should I try to put the bits that don't work in separate Rcpp functions to get them working? And do the same for the other parts to verify that each bit is doing what it is supposed to? – Tom Wenseleers Sep 23 '22 at 07:34
  • @DirkEddelbuettel One would somehow hope to be able to really step through your C++ code like with dedicated C++ debuggers. What's the next best thing? Once it compiles, Rf_PrintValue and Rprintf calls I suppose? – Tom Wenseleers Sep 23 '22 at 07:50
  • @user20650 Seems I may not even have to use Rcpp; from https://stats.stackexchange.com/questions/525042/derivation-of-hessian-for-multinomial-logistic-regression-in-b%C3%B6hning-1992 seems I could do it just by calculating a Kroncker product. Though I still don't see how their lambda phat is defined, but if I can figure that out this should be the most convenient & fastest way to do it... (see my EDIT2) – Tom Wenseleers Sep 23 '22 at 12:23
  • Yes, low-tech debugging of individual functions and building up is the way to go at the C++ level. No magic, no shortcuts. So if a Kronecker product helps, great! – Dirk Eddelbuettel Sep 23 '22 at 13:04
  • @microhaus Would you happen to know an answer to this question, given your related earlier answer https://stats.stackexchange.com/questions/525042/derivation-of-hessian-for-multinomial-logistic-regression-in-b%C3%B6hning-1992 ? – Tom Wenseleers Sep 23 '22 at 13:53
  • @DirkEddelbuettel Managed in the end, see below - with a little help from https://beta.openai.com/playground?model=code-davinci-002 , which beautifully translated a small pure R function I made (making use of Kronecker products) to Rcpp using Armadillo classes... Quite handy for an Rcpp beginner like myself! And it even ran the first time without any tweaking! AI is getting better every day! For large model resulting function is now 80 times faster than the original... Cool! – Tom Wenseleers Sep 24 '22 at 21:09
  • @Tom the link doesn't work, sadly. Comes up with an 'empty' code-davinci-002 page for me. Impressive, though. Glad you're covered now. – Dirk Eddelbuettel Sep 24 '22 at 21:14
  • @DirkEddelbuettel Ha maybe that's because I got access to the beta - you can sign up to the waiting list here: https://openai.com/blog/openai-codex/. I just put "/* Convert the following R function to Rcpp and use Armadillo classes followed by my small pure R function */" as input & it gave me the Rcpp code back... It's free as long as it's in beta... – Tom Wenseleers Sep 24 '22 at 21:37
  • No I obviously also signed up and am "in" but it greats me with an empty Javascript template. – Dirk Eddelbuettel Sep 24 '22 at 21:56
  • Ha that's just an example line. You can put as a prompt whatever you like, usually starting with /* and ending with */, best to increase maximum length to near max allowed & then press submit. I put as input a bit of R code and asked it to convert it to Rcpp. Mainly works well for small bits of code... – Tom Wenseleers Sep 24 '22 at 22:00
  • We must indeed be living in the (unevenly distributed) future! – Dirk Eddelbuettel Sep 24 '22 at 22:09
  • 1
    This was what it gave me when I asked it to program a Mandelbrot fractal using prompt "/* Show me how to calculate the Mandelbrot fractal in the fastest way possible in R using Rcpp and OpenMP and display it using the base R image() graphics function */": https://twitter.com/TWenseleers/status/1559001987148038144 – Tom Wenseleers Sep 24 '22 at 22:25

1 Answers1

4

Figured it out in the end & was able to calculate the observed Fisher information matrix using Kronecker products, as well as port that bit to Rcpp, using Armadillo classes (full disclosure: I made that Rcpp port just using OpenAI's code-davinci / Codex, https://openai.com/blog/openai-codex/, and surprisingly it worked straight out of the box - AI is getting better every day; parallelReduce could still be used to parallelize the accumulation I presume; the function was faster than an equivalent RcppEigen implementation I tried). The mistake I made was that the formula above was the observed Fisher information for a single observation, so I had to accumulate over observations & I also had to take into account my total row counts.

Rcpp function:

// RcppArmadillo utility function to calculate observed Fisher 
// information matrix of multinomial fit, with 
// probs=fitted probabilities (with 1st category/column dropped)
// Z = model matrix
// row_totals = row totals
// We do this using Kronecker products, as in
// https://ieeexplore.ieee.org/abstract/document/1424458
// B. Krishnapuram; L. Carin; M.A.T. Figueiredo; A.J. Hartemink
// Sparse multinomial logistic regression: fast algorithms and
// generalization bounds
// IEEE Transactions on Pattern Analysis and Machine
// Intelligence ( Volume: 27, Issue: 6, June 2005)

#include <RcppArmadillo.h>

using namespace arma;

// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::mat calc_infmatrix_RcppArma(arma::mat probs, arma::mat Z, arma::vec row_totals) {
  int n = Z.n_rows;
  int p = Z.n_cols;
  int k = probs.n_cols;
  int ncoefs = k * p;
  arma::mat info = arma::zeros<arma::mat>(ncoefs, ncoefs);
  arma::mat diag_probs;
  arma::mat tcrossprod_probs;
  arma::mat tcrossprod_Z;
  arma::mat kronecker_prod;
  for (int i = 0; i < n; i++) {
    diag_probs = arma::diagmat(probs.row(i));
    tcrossprod_probs = arma::trans(probs.row(i)) * probs.row(i);
    tcrossprod_Z = (arma::trans(Z.row(i)) * Z.row(i)) * row_totals(i);
    kronecker_prod = arma::kron(diag_probs - tcrossprod_probs, tcrossprod_Z);
    info += kronecker_prod;
  }
  return info;
}

saved as "calc_infmatrix_arma.cpp".

library(Rcpp)
library(RcppArmadillo)
sourceCpp("calc_infmatrix_arma.cpp")

R wrapper function :

# Function to calculate Hessian / observed Fisher information
# matrix of nnet::multinom multinomial fit object
fastmultinomHess <- function(object, Z = model.matrix(object)) {
  
  probs <- object$fitted # predicted probabilities, avoid napredict from fitted.default

  coefs <- coef(object)
  if (is.vector(coefs)){ # ie there are only 2 response categories
    coefs <- t(as.matrix(coefs))
    probs <- cbind(1 - probs, probs)
  }
  coefdim <- dim(coefs)
  p <- coefdim[2L] # nr of parameters
  k <- coefdim[1L] # nr out outcome categories-1
  ncoefs <- k * p # nr of coefficients
  n <- dim(Z)[1L] # nr of observations
  
  #  Now compute the Hessian = the observed 
  #  (= expected, in this case) 
  #  Fisher information matrix
    
  info <- calc_infmatrix_RcppArma(probs = probs[, -1, drop=F], 
                                  Z = Z, 
                                  row_totals = object$weights)

  Names <- dimnames(coefs)
  if (is.null(Names[[1L]])) Names <- Names[[2L]] else Names <- as.vector(outer(Names[[2L]], Names[[1L]],
                                function(name2, name1)
                                  paste(name1, name2, sep = ":")))
  dimnames(info) <- list(Names, Names)

  return(info)
}

For my larger model this now calculates in 100s instead of >2 hours, so almost 80 times faster :

download.file("https://www.dropbox.com/s/mpz08jj7fmubd68/bigmodel.RData?dl=1",
              "bigmodel.RData", 
              method = "auto", mode="wb")
load("bigmodel.RData")
object = fit_global_multi_last3m # large nnet::multinom fit
system.time(info <- fastmultinomHess(object, Z = model.matrix(object))) # 103s
system.time(info <- nnet:::multinomHess(object, Z = model.matrix(object))) # 8127s = 2.25h

A pure R version of the calc_infmatrix function (ca. 5x slower than the Rcpp function above) would be

# Utility function to calculate observed Fisher information matrix
# of multinomial fit, with 
# probs=fitted probabilities (with 1st category/column dropped)
# Z = model matrix
# row_totals = row totals
    calc_infmatrix = function(probs, Z, row_totals) {
      require(fastmatrix) # for kronecker.prod Kronecker product function
      
      n <- nrow(Z)
      p <- ncol(Z)
      k <- ncol(probs)
      ncoefs <- k * p
      info <- matrix(0, ncoefs, ncoefs)
      for (i in 1:n) {
        info <- info + kronecker.prod((diag(probs[i,]) - tcrossprod(probs[i,])), tcrossprod(Z[i,])*row_totals[i] )
      }
      return(info)  
    }
Tom Wenseleers
  • 7,535
  • 7
  • 63
  • 103