I'm confused about the meaning of evaluating vector-Jacobian-products when the vector used for the VJP is a non-identity row vector. My question pertains to vector-valued functions, not scalar functions like loss. I will show a concrete example using Python and JAX but this is a very general question about reverse-mode automatic differentiation.
Consider this simple vector-valued function for which the Jacobian is trivial to write down analytically:
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import vjp, jacrev
# Define a vector-valued function (3 inputs --> 2 outputs)
def vector_func(args):
x,y,z = args
a = 2*x**2 + 3*y**2 + 4*z**2
b = 4*x*y*z
return jnp.array([a, b])
# Define the inputs
x = 2.0
y = 3.0
z = 4.0
# Compute the vector-Jacobian product at the fiducial input point (x,y,z)
val, func_vjp = vjp(vector_func, (x, y, z))
print(val)
# [99,96]
# now evaluate the function returned by vjp along with basis row vectors to pull out gradient of 1st and 2nd output components
v1 = jnp.array([1.0, 0.0]) # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., first row of Jacobian
v2 = jnp.array([0.0, 1.0]) # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., second row of Jacobian
gradient1 = func_vjp(v1)
print(gradient1)
# [8, 18, 32]
gradient2 = func_vjp(v2)
print(gradient2)
# [48,32,24]
That much makes sense to me -- we're separately feeding [1,0] and [0,1] to vjp_func to respectively get the first and second rows of the Jacobian evaluated at our fiducial point (x,y,z)=(2,3,4).
But now what if we fed vjp_func a non-identity row vector like [2,0]? Is this asking how the fiducial (x,y,z) would need to be perturbed to double the first component of the output? If so, is there a way to see this by evaluating vector_func at the perturbed parameter values?
I tried but I'm not sure:
# suppose I want to know what perturbations in (x,y,z) cause a doubling of the first output and no change in second output component
print(func_vjp(jnp.array([2.0,0.0])))
# [16,36,64]
### Attempts to use the output of vjp_func to verify that val becomes [99*2, 96]
### none of these work
print(vector_func([16,36,64]))
# [20784, 147456]
print(vector_func([x*16,y*36,z*64])
# [299184., 3538944.]
What am I doing wrong in using the output of func_vjp to modify the fiducial parameters (x,y,z) and feed those back into vector_func to verify indeed that those parameter perturbations double the first component of the original output and leave the second component unchanged?