import orcanet.misc as misc
# for loading via toml
dmods, register = misc.get_register()
@register
[docs]def as_array(info_blob):
"""
Save network output as ndarrays to h5. This is the default dataset modifier.
Every output layer will get one dataset each for both the label and
the prediction. E.g. if the model has an output layer called "energy",
the datasets "label_energy" and "pred_energy" will be made.
"""
datasets = dict()
y_pred = info_blob["y_pred"]
for out_layer_name in y_pred:
datasets["pred_" + out_layer_name] = y_pred[out_layer_name]
ys = info_blob.get("ys")
if ys is not None:
for out_layer_name in ys:
datasets["label_" + out_layer_name] = ys[out_layer_name]
y_values = info_blob.get("y_values")
if y_values is not None:
datasets["y_values"] = y_values
return datasets
@register
[docs]def as_recarray(info_blob):
"""
Save network output as recarray to h5. Intended for when network
outputs are 2D, i.e. (batchsize, X).
Output from network:
Dict with arrays, shapes (batchsize, x_i).
E.g. {"foo": ndarray, "bar": ndarray}
dtypes that will get saved to h5:
(foo_1, foo_2, ..., bar_1, bar_2, ... )
"""
datasets = dict()
datasets["pred"] = misc.dict_to_recarray(info_blob.get("y_pred"))
ys = info_blob.get("ys")
if ys is not None:
datasets["true"] = misc.dict_to_recarray(ys)
y_values = info_blob.get("y_values")
if y_values is not None:
datasets["y_values"] = y_values # is already a structured array
return datasets
@register
[docs]def as_recarray_dist(info_blob):
"""
Save network output as recarray to h5. Intended for when network
outputs are distributions and thus 3D (for example when using
OutputRegNormal as output layer block).
I.e. (batchsize, 2, X), with [:, 0] being mu and [:, 1] being std.
Example output from network:
shape {"A": (bs, 2), "B": (bs, 2, 3)}
[:, 0] is reco, [:, 1] is err
dtypes that will get saved to h5:
A_1, A_err_1, B_1, B_2, B_3, B_err_1, B_err_2, B_err_3
"""
y_pred = info_blob["y_pred"]
datas = {}
for output_name, array in y_pred.items():
# [:, 0] is mu and [:, 1] is err
datas[output_name] = array[:, 0]
datas[f"{output_name}_err"] = array[:, 1]
info_blob["y_pred"] = datas
ys = info_blob.get("ys")
if ys is not None:
# errs for the trues are just padded, so skip
datas = {}
for output_name, array in ys.items():
datas[output_name] = array[:, 0]
info_blob["ys"] = datas
return as_recarray(info_blob)
@register
[docs]def as_recarray_dist_split(info_blob):
"""
Save network output as recarray to h5. Intended for networks that
output recos and errs in seperate towers (for example when using
OutputRegNormalSplit as output layer block).
Example output from network:
shape {"A": (bs, 1), "A_err": (bs, 2, 1),
"B": (bs, 3), "B_err": (bs, 2, 3)}
In "A_err": [:, 0] is mu, [:, 1] is sigma
dtypes that will get saved to h5:
A_1, A_err_1, B_1, B_1_err, B_2, B_err_2, ...
"""
def transform(network_output):
"""Skip A and rename A_err to A."""
transformed = {}
for output_name, output_value in network_output.items():
if output_name.endswith("_err"):
transformed[output_name[:-4]] = output_value
return transformed
info_blob["y_pred"] = transform(info_blob["y_pred"])
if info_blob.get("ys") is not None:
info_blob["ys"] = transform(info_blob["ys"])
return as_recarray_dist(info_blob)