1

To visualize the gradient descent of my linear regression model, I'm trying to do a contour plot for the following mse function:

import jax.numpy as jnp
import numpy as np

def make_mse(x, t):  
  def mse(w,b): 
    return np.sum(jnp.power(x.dot(w) + b - t, 2))/2
  return mse 

where the x and y axes of the plot correspond to w and b parameters.

The x and t are non-relevant for the plot, since the values of x are just being multiplied by a single value of w each time.

I was trying to do the following:

x = np.linspace(-1.0,1.0,500)
t = 5*x + 1

xcoord = np.linspace(-10.0,10.0,50)
ycoord = np.linspace(-10.0,10.0,50)
w1,w2 = np.meshgrid(xcoord,ycoord)

Z = make_mse(x, t)(w1,w2)

However, I get to obvious error for the dot product:

/usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in dot(lhs, rhs, precision, preferred_element_type)
    634   else:
    635     raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
--> 636         lhs.shape, rhs.shape))
    637 
    638 

TypeError: Incompatible shapes for dot: got (1000, 1) and (50, 50).

Any pythonic efficient way to make a contour plot of this function?

Sandipan Dey
  • 21,482
  • 2
  • 51
  • 63
ValientProcess
  • 1,699
  • 5
  • 27
  • 43

1 Answers1

1

You don't need np.sum() since you want the MSE for each grid point individually, not their sum. Also, the dimension of x and the grid must match. The following works:

import numpy as np

def make_mse(x, t):  
  def mse(w,b): 
    return np.power(x.dot(w) + b - t, 2)
  return mse 

x = np.linspace(-1.0,1.0,500)
t = 5*x + 1

xcoord = np.linspace(-10.0,10.0,500)
ycoord = np.linspace(-10.0,10.0,500)
w1,w2 = np.meshgrid(xcoord,ycoord)

Z = make_mse(x, t)(w1,w2)
plt.contourf(w1,w2,Z)

with the following output contour

enter image description here

Sandipan Dey
  • 21,482
  • 2
  • 51
  • 63