Coverage for orcanet/parser.py: 0%
83 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1"""
2Run OrcaNet functionalities from command line.
4"""
5import warnings
6import argparse
7from orcanet import version
9# imports involving tf moved inside functions for speed up
12def train(directory, list_file=None, config_file=None, model_file=None, to_epoch=None):
13 from orcanet.core import Organizer
14 from orcanet.model_builder import ModelBuilder
15 from orcanet.misc import find_file
17 orga = Organizer(directory, list_file, config_file, tf_log_level=1)
19 if orga.io.get_latest_epoch() is None:
20 # Start of training
21 print("Building new model")
22 if model_file is None:
23 model_file = find_file(directory, "model.toml")
24 model = ModelBuilder(model_file).build(orga, verbose=False)
25 else:
26 model = None
28 return orga.train_and_validate(model=model, to_epoch=to_epoch)
31def _add_parser_train(subparsers):
32 parser = subparsers.add_parser(
33 "train",
34 description="Train and validate a model.",
35 )
36 _add_common_args(parser)
37 parser.add_argument(
38 "--model_file",
39 type=str,
40 help="Path to toml model file. Will be used to build a model at "
41 "the start of the training. Not needed to resume training. "
42 "Default: Look for a file called 'model.toml' in the "
43 "given OrcaNet directory.",
44 default=None,
45 )
46 parser.add_argument(
47 "--to_epoch",
48 type=int,
49 help="Train up to and including this epoch. Default: Train forever.",
50 default=None,
51 )
52 parser.set_defaults(func=train)
55def predict(directory, list_file=None, config_file=None, epoch=None, fileno=None):
56 from orcanet.core import Organizer
58 orga = Organizer(directory, list_file, config_file, tf_log_level=1)
59 return orga.predict(epoch=epoch, fileno=fileno)[0]
62def _add_paser_predict(subparsers):
63 parser = subparsers.add_parser(
64 "predict",
65 description="Load a trained model and save its prediction on "
66 "the predictions files to h5.",
67 )
68 _add_common_args(parser)
69 parser.add_argument(
70 "--epoch", type=int, help="Epoch of model to load. Default: best", default=None
71 )
72 parser.add_argument(
73 "--fileno",
74 type=int,
75 help="Fileno of model to load. Default: best",
76 default=None,
77 )
78 parser.set_defaults(func=predict)
81def inference(directory, list_file=None, config_file=None, epoch=None, fileno=None):
82 from orcanet.core import Organizer
84 orga = Organizer(directory, list_file, config_file, tf_log_level=1)
85 return orga.inference(epoch=epoch, fileno=fileno)
88def _add_parser_inference(subparsers):
89 parser = subparsers.add_parser(
90 "inference",
91 description="Load a trained model and save its prediction on the "
92 "inference files to h5.",
93 )
94 _add_common_args(parser)
95 parser.add_argument(
96 "--epoch",
97 type=int,
98 help="Epoch of model to load. Default: best",
99 default=None,
100 )
101 parser.add_argument(
102 "--fileno",
103 type=int,
104 help="Fileno of model to load. Default: best",
105 default=None,
106 )
107 parser.set_defaults(func=inference)
110def inference_on_file(
111 input_file,
112 output_file=None,
113 config_file=None,
114 saved_model=None,
115 directory=None,
116 epoch=None,
117 fileno=None,
118):
119 from orcanet.core import Organizer
121 if directory is None and saved_model is None:
122 raise ValueError("Either directory or saved_model is required!")
123 elif directory is not None and saved_model is not None:
124 warnings.warn("Warning: Ignoring given directory since saved_model was given.")
125 directory = None
127 if directory is not None:
128 orga = Organizer(directory, config_file=config_file)
129 else:
130 orga = Organizer(".", config_file=config_file, discover_tomls=False)
132 return orga.inference_on_file(
133 input_file,
134 output_file=output_file,
135 saved_model=saved_model,
136 epoch=epoch,
137 fileno=fileno,
138 )
141def _add_parser_inf_on_file(subparsers):
142 parser = subparsers.add_parser(
143 "inference_on_file",
144 description="Load a trained model and save its prediction on the given input "
145 "file to the given output file.\n"
146 "Useful for sharing a fully trained model, since the usual "
147 "orcanet directory structure is not necessarily required.\n Can either load "
148 "a saved model from a given path, or use the usual orcanet "
149 "directory method of loading the best model of a training. ",
150 )
151 parser.add_argument(
152 "input_file",
153 type=str,
154 help="Path to a DL file (i.e. output of OrcaSong) on which the inference should be done on.",
155 )
156 parser.add_argument(
157 "--output_file",
158 type=str,
159 help="Save output to an h5 file with this name. Default: auto-generate "
160 "name and save in same directory as the input file.",
161 default=None,
162 )
163 parser.add_argument(
164 "--config_file",
165 type=str,
166 help="Path to toml config file. Default: None.",
167 default=None,
168 )
169 parser.add_argument(
170 "--saved_model",
171 type=str,
172 help="Optional path to a saved model, which will be used instead of "
173 "loading the one with the given epoch/fileno. ",
174 default=None,
175 )
176 parser.add_argument(
177 "--directory",
178 type=str,
179 help="Path to an OrcaNet directory. Only relevant if saved_model is not given.",
180 default=None,
181 )
182 parser.add_argument(
183 "--epoch",
184 type=int,
185 help="Epoch of a model to load from the directory. Only relevant if saved_model is not given. "
186 "Default: lowest val loss.",
187 default=None,
188 )
189 parser.add_argument(
190 "--fileno",
191 type=int,
192 help="File number of a model to load from the directory. Only relevant if saved_model is not given. "
193 "Default: lowest val loss.",
194 default=None,
195 )
196 parser.set_defaults(func=inference_on_file)
199def _add_common_args(prsr):
200 prsr.add_argument(
201 "directory",
202 help="Path to OrcaNet directory.",
203 )
204 prsr.add_argument(
205 "--list_file",
206 type=str,
207 help="Path to toml list file. Default: Look for a file called "
208 "'list.toml' in the given OrcaNet directory.",
209 default=None,
210 )
211 prsr.add_argument(
212 "--config_file",
213 type=str,
214 help="Path to toml config file. Default: Look for a file called "
215 "'config.toml' in the given OrcaNet directory.",
216 default=None,
217 )
220def _add_parser_summarize(subparsers):
221 import orcanet.utilities.summarize_training as summarize_training
223 parent_parser = summarize_training.get_parser()
224 parser = subparsers.add_parser(
225 "summarize",
226 description=parent_parser.description,
227 formatter_class=argparse.RawTextHelpFormatter,
228 parents=[parent_parser],
229 add_help=False,
230 )
231 parser.set_defaults(func=summarize_training.summarize)
234def main():
235 parser = argparse.ArgumentParser(
236 prog="orcanet",
237 description=__doc__,
238 formatter_class=argparse.RawTextHelpFormatter,
239 )
240 parser.add_argument("--version", action="version", version=version)
242 subparsers = parser.add_subparsers()
243 _add_parser_train(subparsers)
244 _add_paser_predict(subparsers)
245 _add_parser_inference(subparsers)
246 _add_parser_inf_on_file(subparsers)
247 _add_parser_summarize(subparsers)
249 kwargs = vars(parser.parse_args())
250 func = kwargs.pop("func")
251 func(**kwargs)