0

I have provided my code below. I am trying to update two different sets of parameters of two Neural Networks seperately. I am not able to figure out how to do that in jax optimizer. For the code given below the loss of objective function is not decreasing at all. Please provide me some solutions

I have tried to update the parameters one by one by passing different opt_state for the parameter. However it is not working. I want to update weights and biases of first neural network prior and then update the next set of weights and biases. Means in first loop first set will get updated and in the second loop second loop will get updated.

data_points = jnp.linspace(0, 1, num=100).reshape((-1, 1))
temp_arr = f(params_fwd, data_points)


def init_random_params(scale, layer_sizes, key ):
    """Build a list of (weights, biases) tuples, one for each layer."""
    return [(random.normal(key, (insize, outsize)) * scale,   # weight matrix
            random.normal(key, (outsize,)) * scale)           # bias vector
            for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])]

def swish(x):
    return x / (1.0 + jnp.exp(-x))


def f(params, inputs):
    "Neural network functions"
    for W, b in params:        
        outputs = jnp.dot(inputs, W) + b
        inputs = swish(outputs)    
    return outputs

def kf(params, inputs):
    "Neural network functions"
    for W, b in params:        
        outputs = jnp.dot(inputs, W) + b
        inputs = swish(outputs)    
    return outputs


# Here is our initial guess of params:
params1 = init_random_params(0.1, layer_sizes=[1, 5, 1], key = random.PRNGKey(0))
params2 = init_random_params(0.1, layer_sizes=[1, 5, 1], key = random.PRNGKey(1))

params = (params1, params2)

inputs = jnp.linspace(0, 1, num=50).reshape((-1, 1))

# This is the function we seek to minimize
def objective1(params, inputs):
    # These should all be zero at the solution
    params1, params2 = params
    u = lambda inputs: f(params1, inputs)

    u_x = lambda inputs: vmap(jacfwd(u,0))(inputs)

    u_xx = lambda inputs: vmap(jacfwd(jacfwd(u,0), 0))(inputs)

    eq = kf(params2, inputs).reshape(-1, 1) * u_xx(inputs).reshape(-1, 1) - jnp.sin(inputs).reshape(-1, 1)

    bc0 = f(params1, 0) - 15  # dirichlet BC
    bc1 = f(params1, 1) - 20  # neumann BC

    data_loss = temp_arr.reshape(-1,1) - f(params1, data_points).reshape(-1, 1)

    return jnp.mean(eq**2) + jnp.sum(bc0**2 + bc1**2) + jnp.sum(data_loss**2) 


# Adam optimizer 
@jit
def resnet_update(params, inputs, opt_state):
    """ Compute the gradient for a batch and update the parameters """
    value, grads = value_and_grad(objective)(params, inputs)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value

opt_init, opt_update, get_params = adam(step_size = 0.001, b1=0.9, b2=0.999, eps=1e-08)
opt_state = opt_init(params)

opt_state_all = () 

train_iters = 50000

for j in range(len(opt_state)-1):
  for i in range(train_iters):
      params, opt_state, value = resnet_update(params, inputs, opt_state_all[j])
      if i % 1000 == 0:
            print("Iteration {0:3d} objective {1}".format(i,objective(params, inputs)))

params1, params2 = params
  • Please clarify your specific problem or provide additional details to highlight exactly what you need. As it's currently written, it's hard to tell exactly what you're asking. – Community Jun 10 '23 at 13:25

0 Answers0