0

I'd like to compute the part of multivariate normal distribution density that is a quadratic form

(X - mu)^T * S * (X - mu)

Assume the data

mu = np.array([[1,2,3], [4,5,6]])
S = np.array([np.eye(3)*3, np.eye(3)*5])
X = np.array([np.random.random(3*10)]).reshape(10, 3)

Now, an iterative process would be to calculate

(X[0] - mu[0]) @ S[0] @ (X[0] - mu[0]).T, (X[0] - mu[1]) @ S[1] @ (X[0] - mu[1]).T

(I don't need to vectorize with respect to X). However, I guess that's not the fastest approach. What I tried is

np.squeeze((X[0] - mu)[:, None] @ S) @ ((X[0] - mu)).T

But the values that I want are placed on the main diagonal of matrix above. I could use np.diagonal(), but is there a better way to perform the calculations?

thesecond
  • 362
  • 2
  • 9

2 Answers2

2

I think you were almost there. You have matrix A = np.squeeze((X[0] - mu)[:, None] @ S) which you matrix multiply with B = ((X[0] - mu)).T) but you only want the diagonal elements.

As pointed out here C = N.diag(A.dot(B)) is equivalent to C = (A * B.T).sum(-1) which leads to the following solution:

import numpy as np

mu = np.array([[1,2,3], [4,5,6]])
S = np.array([np.eye(3)*3, np.eye(3)*5])
X = np.array([np.random.random(3*10)]).reshape(10, 3)

res1 = (X[0] - mu[0]) @ S[0] @ (X[0] - mu[0]).T, (X[0] - mu[1]) @ S[1] @ (X[0] - mu[1]).T

res2 = (np.squeeze((X[0] - mu)[:, None] @ S) * (X[0] - mu)).sum(-1)
print(res1)
print(res2)
Sandro
  • 286
  • 1
  • 13
1

This can also be expressed using np.einsum allowing you to broadcast over X as well:

import numpy as np
mu = np.array([[1,2,3], [4,5,6]])
S = np.array([np.eye(3)*3, np.eye(3)*5])
X = np.random.random((10, 3))

resOP = np.array([(X[0] - mu[0]) @ S[0] @ (X[0] - mu[0]).T, (X[0] - mu[1]) @ S[1] @ (X[0] - mu[1]).T])
resNin17 = np.einsum("...ij, ...ij->...i", np.einsum("...j, ...ij", (X[:, None] - mu), S), (X[:, None] - mu))

assert np.allclose(resOP, resNin17[0])

Or to just calculate one row:

assert np.array_equal(resNin17[2], np.einsum("...ij, ...ij->...i", np.einsum("...j, ...ij", (X[2] - mu), S), (X[2] - mu)))
Nin17
  • 2,821
  • 2
  • 4
  • 14