1

I'm confused about the meaning of evaluating vector-Jacobian-products when the vector used for the VJP is a non-identity row vector. My question pertains to vector-valued functions, not scalar functions like loss. I will show a concrete example using Python and JAX but this is a very general question about reverse-mode automatic differentiation.

Consider this simple vector-valued function for which the Jacobian is trivial to write down analytically:

from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import vjp, jacrev

# Define a vector-valued function (3 inputs --> 2 outputs) 
def vector_func(args):
    x,y,z = args
    a = 2*x**2 + 3*y**2 + 4*z**2
    b = 4*x*y*z
    return jnp.array([a, b])

# Define the inputs
x = 2.0
y = 3.0
z = 4.0

# Compute the vector-Jacobian product at the fiducial input point (x,y,z)
val, func_vjp = vjp(vector_func, (x, y, z))

print(val) 
# [99,96]

# now evaluate the function returned by vjp along with basis row vectors to pull out gradient of 1st and 2nd output components 
v1 = jnp.array([1.0, 0.0])  # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., first row of Jacobian
v2 = jnp.array([0.0, 1.0])  # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., second row of Jacobian 

gradient1 = func_vjp(v1)
print(gradient1)
# [8, 18, 32]

gradient2 = func_vjp(v2)
print(gradient2)
# [48,32,24]

That much makes sense to me -- we're separately feeding [1,0] and [0,1] to vjp_func to respectively get the first and second rows of the Jacobian evaluated at our fiducial point (x,y,z)=(2,3,4).

But now what if we fed vjp_func a non-identity row vector like [2,0]? Is this asking how the fiducial (x,y,z) would need to be perturbed to double the first component of the output? If so, is there a way to see this by evaluating vector_func at the perturbed parameter values?

I tried but I'm not sure:

# suppose I want to know what perturbations in (x,y,z) cause a doubling of the first output and no change in second output component 
print(func_vjp(jnp.array([2.0,0.0])))
# [16,36,64] 

### Attempts to use the output of vjp_func to verify that val becomes [99*2, 96]
### none of these work

print(vector_func([16,36,64]))
# [20784, 147456]

print(vector_func([x*16,y*36,z*64])
# [299184., 3538944.]

What am I doing wrong in using the output of func_vjp to modify the fiducial parameters (x,y,z) and feed those back into vector_func to verify indeed that those parameter perturbations double the first component of the original output and leave the second component unchanged?

Jim Raynor
  • 83
  • 6
  • I think you've answered your own question in the last paragraph: VJPs propagate output perturbations to input perturbations, and those output perturbations in general will not be a unit perturbation along a single output value in isolation. Those perturbations may be made up of any mix of values in each output dimension. – jakevdp Mar 30 '23 at 01:50
  • Okay thanks but I'm still not seeing how to use the output of vjp_func([2,0]) to modify the parameters (x,y,z) and feed them back into vector_func to verify that the first component gets doubled. I edited my question to simplify it and make it more concrete. Thank you in advance for any help you can provide! – Jim Raynor Mar 30 '23 at 14:15
  • `[2, 0]` does not mean the component will be doubled; it is a vector specifying the direction of the tangent (i.e. the gradient) that is mapped back to the corresponding tangent in the input space. – jakevdp Mar 30 '23 at 16:40

1 Answers1

1

I think in your question you are confusing primal and tangent vector spaces. The function vector_func is a non-linear function that maps a vector in an input primal vector space (represented by (x, y, z)) to a vector in an output primal vector space (represented by val in your code).

The function func_vjp is a linear function that maps a vector in an output tangent vector space (represented by array([2, 0]) in your question) to a vector in an input tangent vector space ([16,36,64] in your question).

By construction, the tangent vectors in these transformations represent the gradients of the input function at the specified primal values. That is, if you infinitesimally perturb your output primal along the direction of your output tangent, it corresponds to infinitesimally perturbing the input primal along the direction of the input tangent.

If you want to check the values, you could do something like this:

input_primal = (x, y, z)
output_primal, func_vjp = vjp(vector_func, input_primal)

epsilon = 1E-8  # note: small value so we're near the linear regime
output_tangent = epsilon * jnp.array([0.0, 1.0])
input_tangent, = func_vjp(output_tangent)

# Compute the perturbed output given the perturbed input
perturbed_input = [p + t for p, t in zip(input_primal, input_tangent)]
perturbed_output_1 = vector_func(perturbed_input)
print(perturbed_output_1)
# [99.00001728 96.00003904]

# Perturb the output directly
perturbed_output_2 = output_primal + output_tangent
print(perturbed_output_2)
# [99.         96.00000001]

Note that the results don't match exactly, because the VJP is valid in the locally linear limit, and your function is very nonlinear. But hopefully this helps clarify what these primal and tangent values mean in the context of the VJP computation. Mathematically, if we computed this in the limit where epsilon goes to zero, the results would match exactly – gradient computations are all about these kinds of infinitesimal limits.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Thank you! It's hard to compare perturbed_output_1 and perturbed_output_2 given that they are so similar. If we were using a linear function, can we get away with a larger epsilon and have perturbed_output_1 and perturbed_output_2 agree for that larger perturbation? For example ```vector_func = lambda args: jnp.array([2*args[0]+3*args[1], 4*args[2]])``` with epsilon = 0.1 and same input_primal = (2,3,4) as before, I get output_primal = [13,16] and perturbed_output_2 = [13,16.1] as expected, but perturbed_output_1 = [13, 17.6] which is wrong. – Jim Raynor Mar 30 '23 at 18:52
  • 1
    Yeah, you're right, there are some scalaing factors missing in my code. But the gist of the answer (regarding mappings between primal and tangent spaces) is how you should be thinking about these things. – jakevdp Mar 30 '23 at 19:25
  • Whew okay thanks I thought I was going nuts. Would be curious to see the missing scaling factors put in and how you derived them (if it's straightforward). So just to make sure I understand: I was under the (mistaken) impression that VJP could help one find the perturbations in inputs that reproduce a desired change in outputs. But that's not true. Instead, we merely tell the VJP in what DIRECTION we want the output to be perturbed, and VJP will give us input perturbations that will move us in that direction. What is the practical use of this? To do many small sequential iterations? – Jim Raynor Mar 30 '23 at 19:34
  • Sorry to bug you @jakevdp but if you have a chance, I'd love to hear briefly what scaling factors are missing in your code such that for a linear function, we can use a larger epsilon and still have perturbed_output_1 and perturbed_output_2 agree (e.g., the linear function in my first comment above). – Jim Raynor Apr 03 '23 at 18:34
  • 1
    I don't know - I suspect it's related to the determinant of the jacobian or something. I don't think you'd ever need to use such perturbations directly; I only intended my answer to explain generally how to think about primal and tangent spaces. – jakevdp Apr 03 '23 at 18:38