Coverage for orcanet/parser.py: 0%

83 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1""" 

2Run OrcaNet functionalities from command line. 

3 

4""" 

5import warnings 

6import argparse 

7from orcanet import version 

8 

9# imports involving tf moved inside functions for speed up 

10 

11 

12def train(directory, list_file=None, config_file=None, model_file=None, to_epoch=None): 

13 from orcanet.core import Organizer 

14 from orcanet.model_builder import ModelBuilder 

15 from orcanet.misc import find_file 

16 

17 orga = Organizer(directory, list_file, config_file, tf_log_level=1) 

18 

19 if orga.io.get_latest_epoch() is None: 

20 # Start of training 

21 print("Building new model") 

22 if model_file is None: 

23 model_file = find_file(directory, "model.toml") 

24 model = ModelBuilder(model_file).build(orga, verbose=False) 

25 else: 

26 model = None 

27 

28 return orga.train_and_validate(model=model, to_epoch=to_epoch) 

29 

30 

31def _add_parser_train(subparsers): 

32 parser = subparsers.add_parser( 

33 "train", 

34 description="Train and validate a model.", 

35 ) 

36 _add_common_args(parser) 

37 parser.add_argument( 

38 "--model_file", 

39 type=str, 

40 help="Path to toml model file. Will be used to build a model at " 

41 "the start of the training. Not needed to resume training. " 

42 "Default: Look for a file called 'model.toml' in the " 

43 "given OrcaNet directory.", 

44 default=None, 

45 ) 

46 parser.add_argument( 

47 "--to_epoch", 

48 type=int, 

49 help="Train up to and including this epoch. Default: Train forever.", 

50 default=None, 

51 ) 

52 parser.set_defaults(func=train) 

53 

54 

55def predict(directory, list_file=None, config_file=None, epoch=None, fileno=None): 

56 from orcanet.core import Organizer 

57 

58 orga = Organizer(directory, list_file, config_file, tf_log_level=1) 

59 return orga.predict(epoch=epoch, fileno=fileno)[0] 

60 

61 

62def _add_paser_predict(subparsers): 

63 parser = subparsers.add_parser( 

64 "predict", 

65 description="Load a trained model and save its prediction on " 

66 "the predictions files to h5.", 

67 ) 

68 _add_common_args(parser) 

69 parser.add_argument( 

70 "--epoch", type=int, help="Epoch of model to load. Default: best", default=None 

71 ) 

72 parser.add_argument( 

73 "--fileno", 

74 type=int, 

75 help="Fileno of model to load. Default: best", 

76 default=None, 

77 ) 

78 parser.set_defaults(func=predict) 

79 

80 

81def inference(directory, list_file=None, config_file=None, epoch=None, fileno=None): 

82 from orcanet.core import Organizer 

83 

84 orga = Organizer(directory, list_file, config_file, tf_log_level=1) 

85 return orga.inference(epoch=epoch, fileno=fileno) 

86 

87 

88def _add_parser_inference(subparsers): 

89 parser = subparsers.add_parser( 

90 "inference", 

91 description="Load a trained model and save its prediction on the " 

92 "inference files to h5.", 

93 ) 

94 _add_common_args(parser) 

95 parser.add_argument( 

96 "--epoch", 

97 type=int, 

98 help="Epoch of model to load. Default: best", 

99 default=None, 

100 ) 

101 parser.add_argument( 

102 "--fileno", 

103 type=int, 

104 help="Fileno of model to load. Default: best", 

105 default=None, 

106 ) 

107 parser.set_defaults(func=inference) 

108 

109 

110def inference_on_file( 

111 input_file, 

112 output_file=None, 

113 config_file=None, 

114 saved_model=None, 

115 directory=None, 

116 epoch=None, 

117 fileno=None, 

118): 

119 from orcanet.core import Organizer 

120 

121 if directory is None and saved_model is None: 

122 raise ValueError("Either directory or saved_model is required!") 

123 elif directory is not None and saved_model is not None: 

124 warnings.warn("Warning: Ignoring given directory since saved_model was given.") 

125 directory = None 

126 

127 if directory is not None: 

128 orga = Organizer(directory, config_file=config_file) 

129 else: 

130 orga = Organizer(".", config_file=config_file, discover_tomls=False) 

131 

132 return orga.inference_on_file( 

133 input_file, 

134 output_file=output_file, 

135 saved_model=saved_model, 

136 epoch=epoch, 

137 fileno=fileno, 

138 ) 

139 

140 

141def _add_parser_inf_on_file(subparsers): 

142 parser = subparsers.add_parser( 

143 "inference_on_file", 

144 description="Load a trained model and save its prediction on the given input " 

145 "file to the given output file.\n" 

146 "Useful for sharing a fully trained model, since the usual " 

147 "orcanet directory structure is not necessarily required.\n Can either load " 

148 "a saved model from a given path, or use the usual orcanet " 

149 "directory method of loading the best model of a training. ", 

150 ) 

151 parser.add_argument( 

152 "input_file", 

153 type=str, 

154 help="Path to a DL file (i.e. output of OrcaSong) on which the inference should be done on.", 

155 ) 

156 parser.add_argument( 

157 "--output_file", 

158 type=str, 

159 help="Save output to an h5 file with this name. Default: auto-generate " 

160 "name and save in same directory as the input file.", 

161 default=None, 

162 ) 

163 parser.add_argument( 

164 "--config_file", 

165 type=str, 

166 help="Path to toml config file. Default: None.", 

167 default=None, 

168 ) 

169 parser.add_argument( 

170 "--saved_model", 

171 type=str, 

172 help="Optional path to a saved model, which will be used instead of " 

173 "loading the one with the given epoch/fileno. ", 

174 default=None, 

175 ) 

176 parser.add_argument( 

177 "--directory", 

178 type=str, 

179 help="Path to an OrcaNet directory. Only relevant if saved_model is not given.", 

180 default=None, 

181 ) 

182 parser.add_argument( 

183 "--epoch", 

184 type=int, 

185 help="Epoch of a model to load from the directory. Only relevant if saved_model is not given. " 

186 "Default: lowest val loss.", 

187 default=None, 

188 ) 

189 parser.add_argument( 

190 "--fileno", 

191 type=int, 

192 help="File number of a model to load from the directory. Only relevant if saved_model is not given. " 

193 "Default: lowest val loss.", 

194 default=None, 

195 ) 

196 parser.set_defaults(func=inference_on_file) 

197 

198 

199def _add_common_args(prsr): 

200 prsr.add_argument( 

201 "directory", 

202 help="Path to OrcaNet directory.", 

203 ) 

204 prsr.add_argument( 

205 "--list_file", 

206 type=str, 

207 help="Path to toml list file. Default: Look for a file called " 

208 "'list.toml' in the given OrcaNet directory.", 

209 default=None, 

210 ) 

211 prsr.add_argument( 

212 "--config_file", 

213 type=str, 

214 help="Path to toml config file. Default: Look for a file called " 

215 "'config.toml' in the given OrcaNet directory.", 

216 default=None, 

217 ) 

218 

219 

220def _add_parser_summarize(subparsers): 

221 import orcanet.utilities.summarize_training as summarize_training 

222 

223 parent_parser = summarize_training.get_parser() 

224 parser = subparsers.add_parser( 

225 "summarize", 

226 description=parent_parser.description, 

227 formatter_class=argparse.RawTextHelpFormatter, 

228 parents=[parent_parser], 

229 add_help=False, 

230 ) 

231 parser.set_defaults(func=summarize_training.summarize) 

232 

233 

234def main(): 

235 parser = argparse.ArgumentParser( 

236 prog="orcanet", 

237 description=__doc__, 

238 formatter_class=argparse.RawTextHelpFormatter, 

239 ) 

240 parser.add_argument("--version", action="version", version=version) 

241 

242 subparsers = parser.add_subparsers() 

243 _add_parser_train(subparsers) 

244 _add_paser_predict(subparsers) 

245 _add_parser_inference(subparsers) 

246 _add_parser_inf_on_file(subparsers) 

247 _add_parser_summarize(subparsers) 

248 

249 kwargs = vars(parser.parse_args()) 

250 func = kwargs.pop("func") 

251 func(**kwargs)