Coverage for orcanet/lib/losses.py: 100%
21 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1"""
2OrcaNet custom loss functions.
3"""
4import numpy as np
5import tensorflow as tf
6import tensorflow_probability as tfp
7from orcanet.misc import get_register
9# fuzz factor for numerical stability
10EPS = tf.constant(1e-7, dtype="float32")
11# for loading via toml and orcanet custom objects
12loss_functions, _register = get_register()
15@_register
16def lkl_normal_tfp(y_true, y_pred):
17 """Normal distribution using tfp. See lkl_normal."""
18 mu_true = y_true[:, 0]
19 mu_pred, sigma_pred = y_pred[:, 0], y_pred[:, 1]
21 return (
22 -1
23 * tfp.distributions.Normal(
24 loc=mu_pred,
25 scale=tf.math.maximum(sigma_pred, EPS),
26 ).log_prob(mu_true)
27 )
30@_register
31def lkl_normal(y_true, y_pred):
32 """
33 Negative normal log-likelihood function for n regression output neurons
34 with clipping for increased stability.
36 For stability in the case of outliers, the loss l_i is capped
37 at a maximum of 10 * |pred_i - true_i| for each sample.
39 Parameters
40 ----------
41 y_true : tf.Tensor
42 Shape (bs, 2, n) or (bs, 2).
43 y_true[:, 0] is the label of shape (bs, n) (true), and y_true[:, 1]
44 is not used (necessary as tf 2.1 requires y_true and y_pred to
45 have same shape).
46 y_pred : tf.Tensor
47 Shape (bs, 2, n) or (bs, 2).
48 The output of the network.
49 y_pred[:, 0] is mu, and y_pred[:, 1] is sigma.
51 """
52 mu_true = y_true[:, 0]
53 mu_pred, sigma_pred = y_pred[:, 0], y_pred[:, 1]
55 return _normal_lkl(
56 mu_pred=mu_pred, mu_true=mu_true, sigma_pred=sigma_pred, clip=True
57 )
60def _normal_lkl(mu_pred, mu_true, sigma_pred, clip=False, clip_thresh=10):
61 delta = mu_pred - mu_true
62 std_sq = sigma_pred ** 2
63 loglike = tf.math.log(std_sq + EPS) + delta ** 2 / (std_sq + EPS)
64 if clip:
65 loglike = tf.minimum(loglike, clip_thresh * tf.abs(delta))
66 return 0.5 * (tf.constant(np.log(2 * np.pi), dtype="float32") + loglike)