Coverage for orcanet/lib/dataset_modifiers.py: 98%
50 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1import orcanet.misc as misc
3# for loading via toml
4dmods, register = misc.get_register()
7@register
8def as_array(info_blob):
9 """
10 Save network output as ndarrays to h5. This is the default dataset modifier.
12 Every output layer will get one dataset each for both the label and
13 the prediction. E.g. if the model has an output layer called "energy",
14 the datasets "label_energy" and "pred_energy" will be made.
16 """
17 datasets = dict()
19 y_pred = info_blob["y_pred"]
20 for out_layer_name in y_pred:
21 datasets["pred_" + out_layer_name] = y_pred[out_layer_name]
23 ys = info_blob.get("ys")
24 if ys is not None:
25 for out_layer_name in ys:
26 datasets["label_" + out_layer_name] = ys[out_layer_name]
28 y_values = info_blob.get("y_values")
29 if y_values is not None:
30 datasets["y_values"] = y_values
32 return datasets
35@register
36def as_recarray(info_blob):
37 """
38 Save network output as recarray to h5. Intended for when network
39 outputs are 2D, i.e. (batchsize, X).
41 Output from network:
42 Dict with arrays, shapes (batchsize, x_i).
43 E.g. {"foo": ndarray, "bar": ndarray}
45 dtypes that will get saved to h5:
46 (foo_1, foo_2, ..., bar_1, bar_2, ... )
48 """
49 datasets = dict()
50 datasets["pred"] = misc.dict_to_recarray(info_blob.get("y_pred"))
52 ys = info_blob.get("ys")
53 if ys is not None:
54 datasets["true"] = misc.dict_to_recarray(ys)
56 y_values = info_blob.get("y_values")
57 if y_values is not None:
58 datasets["y_values"] = y_values # is already a structured array
60 return datasets
63@register
64def as_recarray_dist(info_blob):
65 """
66 Save network output as recarray to h5. Intended for when network
67 outputs are distributions and thus 3D (for example when using
68 OutputRegNormal as output layer block).
69 I.e. (batchsize, 2, X), with [:, 0] being mu and [:, 1] being std.
71 Example output from network:
72 shape {"A": (bs, 2), "B": (bs, 2, 3)}
73 [:, 0] is reco, [:, 1] is err
75 dtypes that will get saved to h5:
76 A_1, A_err_1, B_1, B_2, B_3, B_err_1, B_err_2, B_err_3
78 """
79 y_pred = info_blob["y_pred"]
80 datas = {}
81 for output_name, array in y_pred.items():
82 # [:, 0] is mu and [:, 1] is err
83 datas[output_name] = array[:, 0]
84 datas[f"{output_name}_err"] = array[:, 1]
85 info_blob["y_pred"] = datas
87 ys = info_blob.get("ys")
88 if ys is not None:
89 # errs for the trues are just padded, so skip
90 datas = {}
91 for output_name, array in ys.items():
92 datas[output_name] = array[:, 0]
93 info_blob["ys"] = datas
95 return as_recarray(info_blob)
98@register
99def as_recarray_dist_split(info_blob):
100 """
101 Save network output as recarray to h5. Intended for networks that
102 output recos and errs in seperate towers (for example when using
103 OutputRegNormalSplit as output layer block).
105 Example output from network:
106 shape {"A": (bs, 1), "A_err": (bs, 2, 1),
107 "B": (bs, 3), "B_err": (bs, 2, 3)}
108 In "A_err": [:, 0] is mu, [:, 1] is sigma
110 dtypes that will get saved to h5:
111 A_1, A_err_1, B_1, B_1_err, B_2, B_err_2, ...
113 """
115 def transform(network_output):
116 """Skip A and rename A_err to A."""
117 transformed = {}
118 for output_name, output_value in network_output.items():
119 if output_name.endswith("_err"):
120 transformed[output_name[:-4]] = output_value
121 return transformed
123 info_blob["y_pred"] = transform(info_blob["y_pred"])
124 if info_blob.get("ys") is not None:
125 info_blob["ys"] = transform(info_blob["ys"])
127 return as_recarray_dist(info_blob)