Coverage for orcanet/lib/dataset_modifiers.py: 98%

50 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1import orcanet.misc as misc 

2 

3# for loading via toml 

4dmods, register = misc.get_register() 

5 

6 

7@register 

8def as_array(info_blob): 

9 """ 

10 Save network output as ndarrays to h5. This is the default dataset modifier. 

11 

12 Every output layer will get one dataset each for both the label and 

13 the prediction. E.g. if the model has an output layer called "energy", 

14 the datasets "label_energy" and "pred_energy" will be made. 

15 

16 """ 

17 datasets = dict() 

18 

19 y_pred = info_blob["y_pred"] 

20 for out_layer_name in y_pred: 

21 datasets["pred_" + out_layer_name] = y_pred[out_layer_name] 

22 

23 ys = info_blob.get("ys") 

24 if ys is not None: 

25 for out_layer_name in ys: 

26 datasets["label_" + out_layer_name] = ys[out_layer_name] 

27 

28 y_values = info_blob.get("y_values") 

29 if y_values is not None: 

30 datasets["y_values"] = y_values 

31 

32 return datasets 

33 

34 

35@register 

36def as_recarray(info_blob): 

37 """ 

38 Save network output as recarray to h5. Intended for when network 

39 outputs are 2D, i.e. (batchsize, X). 

40 

41 Output from network: 

42 Dict with arrays, shapes (batchsize, x_i). 

43 E.g. {"foo": ndarray, "bar": ndarray} 

44 

45 dtypes that will get saved to h5: 

46 (foo_1, foo_2, ..., bar_1, bar_2, ... ) 

47 

48 """ 

49 datasets = dict() 

50 datasets["pred"] = misc.dict_to_recarray(info_blob.get("y_pred")) 

51 

52 ys = info_blob.get("ys") 

53 if ys is not None: 

54 datasets["true"] = misc.dict_to_recarray(ys) 

55 

56 y_values = info_blob.get("y_values") 

57 if y_values is not None: 

58 datasets["y_values"] = y_values # is already a structured array 

59 

60 return datasets 

61 

62 

63@register 

64def as_recarray_dist(info_blob): 

65 """ 

66 Save network output as recarray to h5. Intended for when network 

67 outputs are distributions and thus 3D (for example when using 

68 OutputRegNormal as output layer block). 

69 I.e. (batchsize, 2, X), with [:, 0] being mu and [:, 1] being std. 

70 

71 Example output from network: 

72 shape {"A": (bs, 2), "B": (bs, 2, 3)} 

73 [:, 0] is reco, [:, 1] is err 

74 

75 dtypes that will get saved to h5: 

76 A_1, A_err_1, B_1, B_2, B_3, B_err_1, B_err_2, B_err_3 

77 

78 """ 

79 y_pred = info_blob["y_pred"] 

80 datas = {} 

81 for output_name, array in y_pred.items(): 

82 # [:, 0] is mu and [:, 1] is err 

83 datas[output_name] = array[:, 0] 

84 datas[f"{output_name}_err"] = array[:, 1] 

85 info_blob["y_pred"] = datas 

86 

87 ys = info_blob.get("ys") 

88 if ys is not None: 

89 # errs for the trues are just padded, so skip 

90 datas = {} 

91 for output_name, array in ys.items(): 

92 datas[output_name] = array[:, 0] 

93 info_blob["ys"] = datas 

94 

95 return as_recarray(info_blob) 

96 

97 

98@register 

99def as_recarray_dist_split(info_blob): 

100 """ 

101 Save network output as recarray to h5. Intended for networks that 

102 output recos and errs in seperate towers (for example when using 

103 OutputRegNormalSplit as output layer block). 

104 

105 Example output from network: 

106 shape {"A": (bs, 1), "A_err": (bs, 2, 1), 

107 "B": (bs, 3), "B_err": (bs, 2, 3)} 

108 In "A_err": [:, 0] is mu, [:, 1] is sigma 

109 

110 dtypes that will get saved to h5: 

111 A_1, A_err_1, B_1, B_1_err, B_2, B_err_2, ... 

112 

113 """ 

114 

115 def transform(network_output): 

116 """Skip A and rename A_err to A.""" 

117 transformed = {} 

118 for output_name, output_value in network_output.items(): 

119 if output_name.endswith("_err"): 

120 transformed[output_name[:-4]] = output_value 

121 return transformed 

122 

123 info_blob["y_pred"] = transform(info_blob["y_pred"]) 

124 if info_blob.get("ys") is not None: 

125 info_blob["ys"] = transform(info_blob["ys"]) 

126 

127 return as_recarray_dist(info_blob)