Coverage for orcanet/lib/losses.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1""" 

2OrcaNet custom loss functions. 

3""" 

4import numpy as np 

5import tensorflow as tf 

6import tensorflow_probability as tfp 

7from orcanet.misc import get_register 

8 

9# fuzz factor for numerical stability 

10EPS = tf.constant(1e-7, dtype="float32") 

11# for loading via toml and orcanet custom objects 

12loss_functions, _register = get_register() 

13 

14 

15@_register 

16def lkl_normal_tfp(y_true, y_pred): 

17 """Normal distribution using tfp. See lkl_normal.""" 

18 mu_true = y_true[:, 0] 

19 mu_pred, sigma_pred = y_pred[:, 0], y_pred[:, 1] 

20 

21 return ( 

22 -1 

23 * tfp.distributions.Normal( 

24 loc=mu_pred, 

25 scale=tf.math.maximum(sigma_pred, EPS), 

26 ).log_prob(mu_true) 

27 ) 

28 

29 

30@_register 

31def lkl_normal(y_true, y_pred): 

32 """ 

33 Negative normal log-likelihood function for n regression output neurons 

34 with clipping for increased stability. 

35 

36 For stability in the case of outliers, the loss l_i is capped 

37 at a maximum of 10 * |pred_i - true_i| for each sample. 

38 

39 Parameters 

40 ---------- 

41 y_true : tf.Tensor 

42 Shape (bs, 2, n) or (bs, 2). 

43 y_true[:, 0] is the label of shape (bs, n) (true), and y_true[:, 1] 

44 is not used (necessary as tf 2.1 requires y_true and y_pred to 

45 have same shape). 

46 y_pred : tf.Tensor 

47 Shape (bs, 2, n) or (bs, 2). 

48 The output of the network. 

49 y_pred[:, 0] is mu, and y_pred[:, 1] is sigma. 

50 

51 """ 

52 mu_true = y_true[:, 0] 

53 mu_pred, sigma_pred = y_pred[:, 0], y_pred[:, 1] 

54 

55 return _normal_lkl( 

56 mu_pred=mu_pred, mu_true=mu_true, sigma_pred=sigma_pred, clip=True 

57 ) 

58 

59 

60def _normal_lkl(mu_pred, mu_true, sigma_pred, clip=False, clip_thresh=10): 

61 delta = mu_pred - mu_true 

62 std_sq = sigma_pred ** 2 

63 loglike = tf.math.log(std_sq + EPS) + delta ** 2 / (std_sq + EPS) 

64 if clip: 

65 loglike = tf.minimum(loglike, clip_thresh * tf.abs(delta)) 

66 return 0.5 * (tf.constant(np.log(2 * np.pi), dtype="float32") + loglike)