Adding a custom loss layer

Loss functions like the LogitCrossEntropyLoss are defined for users to be able to quickly prototype models on new problems. However, sometimes there is a need to write one's own customized loss function. This example will walk through this process.

To show which functions need to be implemented for your own custom loss, this example will walk through implementing a BinaryLogitCrossEntropyLoss, which acts on a model with only a single output, and binary targets.

Mathematical background for a Binary Cross Entropy Loss

Consider the following model:

\[p_\theta(X_i) = \sigma (f(X_i)),\]

where $X_i$ is the input features, $\sigma$ is the sigmoid function, given by $\sigma(x)=(1+e^{-x})^{-1}$ and $f_\theta$ is some function mapping defined by your model, which is parameterized by parameters $\theta$. The output of $f_\theta (X_i)$ is called the "logit". The loss function we want to calculate is the following:

\[L(\theta| X, Y) = -\sum_i \left [ Y_i\ln{p_\theta (X_i)} + (1-Y_i)\ln{(1-p_\theta (X_i))} \right ],\]

where $Y_i$ is the true binary label of the $i^\text{th}$ sample. In order to implement this custom loss, we have to know what the gradient of this loss function is, w.r.t the parameters:

\[ \frac{{\partial }}{{\partial } \theta} L(\theta | X, Y) = -\sum_i Y_i\frac{{\partial }}{{\partial } \theta}\ln{p_\theta (X_i)} + (1-Y_i)\frac{{\partial }}{{\partial } \theta}\ln (1-p_\theta (X_i)).\]

To simplify this calculation, we can use the fact that $1-p_\theta (x)=p_\theta (-x)$, and $\frac{{\partial }}{{\partial } \theta}\ln(p_\theta(X_i))=(1+e^{f_\theta(X_i)})^{-1} \frac{{\partial }}{{\partial } \theta} f_\theta(X_i)$. We are left with:

\[\frac{{\partial }}{{\partial } \theta} L(\theta| X, Y) = -\sum_i \left [ (2Y_i - 1){\left (1+e^{(2Y_i-1)f_\theta(X_i)} \right )}^{-1} \right ] \frac{{\partial }}{{\partial } \theta} f_\theta(X_i).\]

We have managed to write the derivative of the loss function, in terms of the derivative of the model, independently for each sample. The important part of this equation is the multiplicand of the partial derivative term; this term is the partial gradient used for back-propagation. From this point, we can begin writing the code.

Implementing a custom loss type

We start by importing SimpleChains.jl into the current namespace:

using SimpleChains

We can now define our own type, which is a subtype of SimpleChains.AbstractLoss:

struct BinaryLogitCrossEntropyLoss{T,Y<:AbstractVector{T}} <: SimpleChains.AbstractLoss{T}
    targets::Y
end

The function used to get the inner targets is called target and can be defined easily:

SimpleChains.target(loss::BinaryLogitCrossEntropyLoss) = loss.targets
(loss::BinaryLogitCrossEntropyLoss)(x::AbstractArray) = BinaryLogitCrossEntropyLoss(x)

Next, we define how to calculate the loss, given some logits:

function calculate_loss(loss::BinaryLogitCrossEntropyLoss, logits)
    y = loss.targets
    total_loss = zero(eltype(logits))
    for i in eachindex(y)
        p_i = inv(1 + exp(-logits[i]))
        y_i = y[i]
        total_loss -= y_i * log(p_i) + (1 - y_i) * (1 - log(p_i))
    end
    total_loss
end
function (loss::BinaryLogitCrossEntropyLoss)(previous_layer_output::AbstractArray{T}, p::Ptr, pu) where {T}
    total_loss = calculate_loss(loss, previous_layer_output)
    total_loss, p, pu
end

As the other loss functions do this, we should define some functions to say that we don't want any preallocated temporary arrays:

function SimpleChains.layer_output_size(::Val{T}, sl::BinaryLogitCrossEntropyLoss, s::Tuple) where {T}
    SimpleChains._layer_output_size_no_temp(Val{T}(), sl, s)
end
function SimpleChains.forward_layer_output_size(::Val{T}, sl::BinaryLogitCrossEntropyLoss, s) where {T}
    SimpleChains._layer_output_size_no_temp(Val{T}(), sl, s)
end

Finally, we define how to back-propagate the gradient from this loss function:

function SimpleChains.chain_valgrad!(
    __,
    previous_layer_output::AbstractArray{T},
    layers::Tuple{BinaryLogitCrossEntropyLoss},
    _::Ptr,
    pu::Ptr{UInt8},
) where {T}
    loss = getfield(layers, 1)
    total_loss = calculate_loss(loss, previous_layer_output)
    y = loss.targets

    # Store the backpropagated gradient in the previous_layer_output array.
    for i in eachindex(y)
        sign_arg = 2 * y[i] - 1
        # Get the value of the last logit
        logit_i = previous_layer_output[i]
        previous_layer_output[i] = -(sign_arg * inv(1 + exp(sign_arg * logit_i)))
    end

    return total_loss, previous_layer_output, pu
end

That's all! The way we can now use this loss function, just like any other:

using SimpleChains

model = SimpleChain(
    static(2),
    TurboDense(tanh, 32),
    TurboDense(tanh, 16),
    TurboDense(identity, 1)
)

batch_size = 64
X = rand(Float32, 2, batch_size)
Y = rand(Bool, batch_size)

parameters = SimpleChains.init_params(model);
gradients = SimpleChains.alloc_threaded_grad(model);

# Add the loss like any other loss type
model_loss = SimpleChains.add_loss(model, BinaryLogitCrossEntropyLoss(Y));

SimpleChains.valgrad!(gradients, model_loss, X, parameters)

Or alternatively, if you want to just train the parameters in full:

epochs = 100
SimpleChains.train_unbatched!(gradients, parameters, model_loss, X, SimpleChains.ADAM(), epochs);