Coverage for orcanet/misc.py: 98%
59 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1""" Odds and ends. """
2import os
3import inspect
4import numpy as np
7def get_register():
8 """E.g. for storing orcanet layer blocks as custom objects."""
9 saved = {}
11 def register(obj):
12 saved[obj.__name__] = obj
13 return obj
15 return saved, register
18def from_register(toml_entry, register):
19 """
20 Get an initilized object via a toml entry.
21 Used for loading orcanet built-in sample modifiers etc.
23 Parameters
24 ----------
25 toml_entry : str or dict or list
26 The 'sample_modifier' given in the config toml.
27 E.g., to initialize "obj_name" from register, these are possible formats:
28 "obj_name"
29 ["obj_name", True]
30 ["obj_name", {"setting_1": True}]
31 {"name": "obj_name", "setting_1": True}
32 register : dict
33 Maps class names to class references.
35 """
36 args, kwargs = [], {}
37 if isinstance(toml_entry, str):
38 name = toml_entry
39 elif isinstance(toml_entry, dict):
40 if "name" not in toml_entry:
41 raise KeyError(f"missing entry in dict: 'name', given: {toml_entry}")
42 name = toml_entry["name"]
43 kwargs = {k: v for k, v in toml_entry.items() if k != "name"}
44 else:
45 name = toml_entry[0]
46 if len(toml_entry) == 2 and isinstance(toml_entry[1], dict):
47 kwargs = toml_entry[1]
48 else:
49 args = toml_entry[1:]
51 obj = register[name]
52 try:
53 if inspect.isfunction(obj):
54 if args or kwargs:
55 raise TypeError(
56 f"Can not pass arguments to function ({args}, {kwargs})"
57 )
58 return obj
59 else:
60 return obj(*args, **kwargs)
61 except TypeError:
62 raise TypeError(f"Error initializing {obj}")
65def dict_to_recarray(array_dict):
66 """
67 Convert a dict with np arrays to a 2d recarray.
68 Column names are derived from the dict keys.
70 Parameters
71 ----------
72 array_dict : dict
73 Keys: string
74 Values: ND arrays, same length and number of dimensions.
75 All dimensions expect first will get flattened.
77 Returns
78 -------
79 The recarray.
81 """
82 column_names, arrays = [], []
83 for key, array in array_dict.items():
84 if len(array.shape) == 1:
85 array = np.expand_dims(array, -1)
86 elif len(array.shape) > 2:
87 array = np.reshape(array, (len(array), -1))
88 for i in range(array.shape[-1]):
89 arrays.append(array[:, i])
90 column_names.append(f"{key}_{i+1}")
91 return np.core.records.fromarrays(arrays, names=column_names)
94def to_ndarray(x, dtype="float32"):
95 """Turn recarray to ndarray."""
96 new_dtype = [(name, dtype) for name in x.dtype.names]
97 new_shape = (len(x), len(x.dtype.names))
98 return np.ascontiguousarray(x).astype(new_dtype).view(dtype).reshape(new_shape)
101def find_file(directory, filename):
102 """Look for file in given directoy. Error if there are multiple."""
103 found = []
104 for root, dirs, files in os.walk(directory):
105 for file in files:
106 if file == filename:
107 found.append(os.path.join(root, file))
108 if len(found) >= 2:
109 raise ValueError(f"Can not find {filename}: More than one file found ({found})")
110 elif len(found) == 0:
111 return None
112 else:
113 fpath = found[0]
114 print(f"Found {fpath}")
115 return fpath