Coverage for orcanet/misc.py: 98%

59 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1""" Odds and ends. """ 

2import os 

3import inspect 

4import numpy as np 

5 

6 

7def get_register(): 

8 """E.g. for storing orcanet layer blocks as custom objects.""" 

9 saved = {} 

10 

11 def register(obj): 

12 saved[obj.__name__] = obj 

13 return obj 

14 

15 return saved, register 

16 

17 

18def from_register(toml_entry, register): 

19 """ 

20 Get an initilized object via a toml entry. 

21 Used for loading orcanet built-in sample modifiers etc. 

22 

23 Parameters 

24 ---------- 

25 toml_entry : str or dict or list 

26 The 'sample_modifier' given in the config toml. 

27 E.g., to initialize "obj_name" from register, these are possible formats: 

28 "obj_name" 

29 ["obj_name", True] 

30 ["obj_name", {"setting_1": True}] 

31 {"name": "obj_name", "setting_1": True} 

32 register : dict 

33 Maps class names to class references. 

34 

35 """ 

36 args, kwargs = [], {} 

37 if isinstance(toml_entry, str): 

38 name = toml_entry 

39 elif isinstance(toml_entry, dict): 

40 if "name" not in toml_entry: 

41 raise KeyError(f"missing entry in dict: 'name', given: {toml_entry}") 

42 name = toml_entry["name"] 

43 kwargs = {k: v for k, v in toml_entry.items() if k != "name"} 

44 else: 

45 name = toml_entry[0] 

46 if len(toml_entry) == 2 and isinstance(toml_entry[1], dict): 

47 kwargs = toml_entry[1] 

48 else: 

49 args = toml_entry[1:] 

50 

51 obj = register[name] 

52 try: 

53 if inspect.isfunction(obj): 

54 if args or kwargs: 

55 raise TypeError( 

56 f"Can not pass arguments to function ({args}, {kwargs})" 

57 ) 

58 return obj 

59 else: 

60 return obj(*args, **kwargs) 

61 except TypeError: 

62 raise TypeError(f"Error initializing {obj}") 

63 

64 

65def dict_to_recarray(array_dict): 

66 """ 

67 Convert a dict with np arrays to a 2d recarray. 

68 Column names are derived from the dict keys. 

69 

70 Parameters 

71 ---------- 

72 array_dict : dict 

73 Keys: string 

74 Values: ND arrays, same length and number of dimensions. 

75 All dimensions expect first will get flattened. 

76 

77 Returns 

78 ------- 

79 The recarray. 

80 

81 """ 

82 column_names, arrays = [], [] 

83 for key, array in array_dict.items(): 

84 if len(array.shape) == 1: 

85 array = np.expand_dims(array, -1) 

86 elif len(array.shape) > 2: 

87 array = np.reshape(array, (len(array), -1)) 

88 for i in range(array.shape[-1]): 

89 arrays.append(array[:, i]) 

90 column_names.append(f"{key}_{i+1}") 

91 return np.core.records.fromarrays(arrays, names=column_names) 

92 

93 

94def to_ndarray(x, dtype="float32"): 

95 """Turn recarray to ndarray.""" 

96 new_dtype = [(name, dtype) for name in x.dtype.names] 

97 new_shape = (len(x), len(x.dtype.names)) 

98 return np.ascontiguousarray(x).astype(new_dtype).view(dtype).reshape(new_shape) 

99 

100 

101def find_file(directory, filename): 

102 """Look for file in given directoy. Error if there are multiple.""" 

103 found = [] 

104 for root, dirs, files in os.walk(directory): 

105 for file in files: 

106 if file == filename: 

107 found.append(os.path.join(root, file)) 

108 if len(found) >= 2: 

109 raise ValueError(f"Can not find {filename}: More than one file found ({found})") 

110 elif len(found) == 0: 

111 return None 

112 else: 

113 fpath = found[0] 

114 print(f"Found {fpath}") 

115 return fpath