3

I'm trying to train a UNet in Julia with the help of Flux.

Flux.train!(loss, Flux.params(model), train_data_loader, opt)
            batch_loss = loss(train_data, train_targets)

where the loss is

logitcrossentropy

and train_data_loader is

train_data_loader = DataLoader((train_data |> device, train_targets |> device), batchsize=batch_size, shuffle=true)

I dont understand how to take the loss from Flux.train out for printing loss (is that validation loss?). Evalcb will also trigger a call to calculate loss, so its not different. I was to skip extra calculation. So What I did is call the loss function again and store it in a variable then print it per batch. Is there a way to print loss from Flux.train() instead of calling loss again?

logankilpatrick
  • 13,148
  • 7
  • 44
  • 125
h612
  • 544
  • 2
  • 11

3 Answers3

3

Instead of altering train! like @Tomas suggested, the loss function can be instrumented to log the return value. Printing stuff during calculation sounds like a bad idea for decent performance, so I've made an example where the loss is logged into a global vector:

using ChainRulesCore

# returns another loss function which is the same as the function
# in parameter, but push!es the return value into global variable
# `loss_log_vec`
function logged_loss(lossfn, history) 
    return function _loss(args...) 
        err = lossfn(args...) 
        ignore_derivatives() do 
            push!(history, err)
        end
        return err 
     end 
end 

# initialize log vector
log_vec = Float32[]

# use function above to create logging loss function
newloss = logged_loss(loss, log_vec)

# run the training
Flux.train!(newloss, Flux.params(W, b), train_data, opt)

At this point, log_vec should include a record of return values from loss function. This is a rough solution, which uses annoying global variables. Interpreting the loss return values depends also on the nature of the optimizer. For my test, there was one call per epoch and it returned a decreasing loss until convergence. [This answer incorporates suggestions from @darsnack]

Note, since the log_vec is incorporated into the loss function, to clear the log, it must not be reassigned but clear!ed with clear!(log_vec).

Dan Getz
  • 17,002
  • 2
  • 23
  • 41
  • 1
    If you are using this answer, then I would recommend avoiding having a specially named global, since this could easily introduce a bug during code changes. The answer can be slightly augmented to close over the logging vector: ``` function logged_loss(lossfn, history) return function _loss(args...) err = lossfn(args...) Zygote.ignore_derivatives() do push!(history, err) end return err end end log_vec = Float32[] newloss = logged_loss(loss, log_vec) ``` – darsnack Sep 16 '22 at 13:54
  • Thanks @darsnack. I've incorporated your seggestions into the answer – Dan Getz Sep 16 '22 at 14:25
3

Adding to @Dan's answer, you can also augment your loss function with logging on the fly using the do syntax:

using ChainRules

loss_history = Float32[]
Flux.train!(Flux.params(model), train_data_loader, opt) do x, y
    err = loss(x, y)
    ChainRules.ignore_derivatives() do
        push!(loss_history, err)
    end
    return err
end
darsnack
  • 915
  • 5
  • 10
2

You would need to write your own version of Flux.train! using withgradient instead of gradient function. withgradient gives you the output of the loss (or a function which you are differentiating to be more precise). Flux.train! (https://github.com/FluxML/Flux.jl/blob/8bc0c35932c4a871ac73b42e39146cd9bbb1d446/src/optimise/train.jl#L123) is literaly few lines of code, therefore updating it to your version is very easy.