1

I have a ganeral question about how Neural ODE Nets are trained in Julia. Are data points being sampled from the tspan on which the Nural ODE is defined and on them the parameter updates computed? In other words is there some shuffling and batching happening during training or is the loss computed over all data points in the tspan?

SimonAda
  • 167
  • 1
  • 9

2 Answers2

0

parameters are optimized according to minimizing the loss function. So it's up to you to define how the sampling occurs in the loss function. Typically one may be comparing the output to discrete data points, in whcih case those become your discrete points.

But neuralODE isn't handling this--- you are. It's the loss function

  • I would not agree here. The point of the loss function is only to measure the performance of your network w.r.t. to the current parameter setting, that will be used to compute the gradients and updates for those parameters. Flux.train! is where the logic of the training should be definded and this at the moment is handled by Julia. – SimonAda Feb 20 '20 at 08:33
  • I found an answer about what Julia is doing here: https://github.com/FluxML/Flux.jl/blob/master/src/optimise/train.jl . I think that to do batching one needs to sample datapoints and then run Flux.train on them in a loop. – SimonAda Feb 20 '20 at 08:36
  • @SimonAda, I think maybe you are confusing terms here. A neuralODE can have batch data training inputs, typically a list of different intial conditions , but the output will be a time series for each initial condition. But Flux itself doesn't get this return value. that is handled by the loss function. And how the loss function chooses to compute the loss (e.g. sampling the NeuralODE return object at discrete time points or a different way) is entirely up to the loss function. All that Flux gets is the scalar value of the loss summed over the whole batch. A batch is your discrete inputs. – Alexander Hamilton Feb 21 '20 at 14:18
0

I found an answer about what Julia is doing here: https://github.com/JuliaDiffEq/DiffEqFlux.jl/blob/master/src/train.jl .

    "Optimizes the `loss(θ,curdata...)` function with respect to the parameter vector
`θ` iterating over the `data`. By default the data iterator is empty, i.e.
`loss(θ)` is used. The first output of the loss function is considered the loss.
Extra outputs are passed to the callback."

I think that to do batching one needs to sample datapoints and then run Flux.train on them in a loop, giving as input the batch data points.

SimonAda
  • 167
  • 1
  • 9