Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# (c) Stefan Countryman, 2019 

2 

3""" 

4Tools for analyzing large-scale batches of LLAMA data, e.g. to run simulations 

5off of injected or randmoized data, with inputs and results stored locally or 

6on a cloud storage service. 

7""" 

8 

9import os 

10import json 

11import shutil 

12import logging 

13from textwrap import dedent 

14from socket import gethostname 

15from tempfile import TemporaryDirectory 

16from random import randint 

17from itertools import product 

18from argparse import Action, Namespace 

19from collections import OrderedDict, namedtuple 

20from llama.cli import CliParser 

21from llama.com.s3 import PrivateFileCacher, upload_file 

22from llama.utils import RemoteFileCacher 

23from llama.com.slack import alert_maintainers 

24 

25LOGGER = logging.getLogger(__name__) 

26 

27 

28class DictAction(Action): 

29 """ 

30 Split a list of command line arguments into an ``OrderedDict`` along the 

31 first occurence of "=" in each one. 

32 """ 

33 

34 def __call__(self, parser, namespace, values, option_string=None): 

35 if not (isinstance(values, list) or isinstance(values, tuple)): 

36 values = [values] 

37 invalid_args = [v for v in values if '=' not in v] 

38 if invalid_args: 

39 msg = (f"When specifying ``{option_string}`` values, each argument" 

40 " must be of the form KEY=VALUE. The following arguments " 

41 "are missing an '=' sign:\n\t" + "\n\t".join(invalid_args)) 

42 parser.error(msg) 

43 setattr(namespace, self.dest, 

44 OrderedDict(v.split('=', 1) for v in values)) 

45 

46 

47def batch_error_dump(errdump, event): 

48 """ 

49 Get an error callback function like 

50 ``llama.cli.traceback_alert_maintainers`` to pass to 

51 ``llama.cli.log_exceptions_and_recover`` that dumps the results of the 

52 failing event to the directories listed in ``errdump``. 

53 

54 Parameters 

55 ---------- 

56 errdump : array-like 

57 A list of locations in which dumped data will go. These can be S3 

58 "directory" paths compatible with ``put_file``. 

59 event : llama.event.Event 

60 The event which errored. Files will be dumped from here. 

61 """ 

62 

63 def dump_callback(func, err, tb, self, *args, **kwargs): 

64 """Callback to dump event data in the event of an uncaught 

65 exception.""" 

66 LOGGER.error("Dumping errored event files to %s", errdump) 

67 # same logic as ``Event.cruft_files`` but leave out 

68 # ``Event.auxiliary_paths`` since this contains event history. 

69 non_cruft = {f for fh in event.files.values() 

70 for f in {fh.FILENAME}.union(fh.auxiliary_paths)} 

71 upload_names = non_cruft.intersection(os.listdir(event.eventdir)) 

72 LOGGER.error("Dumped file names: %s", upload_names) 

73 for dest in errdump: 

74 LOGGER.error("Dumping to %s", dest) 

75 for fname in upload_names: 

76 put_file(os.path.join(event.eventdir, fname), 

77 os.path.join(dest, fname), public=False) 

78 LOGGER.error("Dumping traceback to errdump file ERRDUMP") 

79 tbpath = os.path.join(event.eventdir, 'ERRDUMP.log') 

80 with open(tbpath, 'w') as tbfile: 

81 tbfile.write(tb) 

82 put_file(tbpath, os.path.join(dest, 'ERRDUMP.log')) 

83 

84 return dump_callback 

85 

86 

87def batch_error_alert_maintainers(errdump, event, params): 

88 """ 

89 Get an error callback function like 

90 ``llama.cli.traceback_alert_maintainers`` to pass to 

91 ``llama.cli.log_exceptions_and_recover`` that provides batch-specific 

92 information about the error. 

93 

94 Parameters 

95 ---------- 

96 errdump : array-like 

97 A list of locations in which dumped data will go. 

98 event : llama.event.Event 

99 The event which errored. Will be noted in the Slack alert message. 

100 params : dict 

101 The dictionary that was used to format ``errdump`` containing values 

102 from ``--params``. Contains information on the current event being 

103 processed. 

104 """ 

105 

106 def alert_callback(func, err, tb, self, *args, **kwargs): 

107 """Callback to run in the event of an uncaught exception.""" 

108 LOGGER.error("Attempting to `alert_maintainers` with traceback...") 

109 fmt = dedent(""" 

110 BATCH PROCESS ERROR ON {hostname}: {err}. 

111 

112 Traceback: 

113 

114 ``` 

115 {tb} 

116 ``` 

117 

118 - `event`: `{event}` 

119 - `args`: `{args}` 

120 - `kwargs`: `{kwargs}` 

121 - Batch Params: `{params}` 

122 

123 Files saved in: 

124 

125 ``` 

126 {errdump} 

127 ``` 

128 """).strip() 

129 msg = fmt.format(hostname=gethostname(), err=err, tb=tb, args=args, 

130 kwargs=kwargs, event=event, errdump=errdump, 

131 params=params) 

132 res = alert_maintainers(msg, func.__name__, recover=True) 

133 if res.get('ok', False): 

134 LOGGER.error("Alerted maintainers. Traceback sent: %s", tb) 

135 

136 return alert_callback 

137 

138 

139def put_file(source: str, dest: str, public: bool = False): 

140 """ 

141 Save a file somewhere. Same as copying the file, but with S3 compatibility. 

142 The destination directory and all parent directories will be created if 

143 they don't exist. 

144 

145 Parameters 

146 ---------- 

147 source: str 

148 The path to the file to be uploaded. 

149 dest: str 

150 Destination path. Works for local paths or for S3 keys (by prepending 

151 ``s3://{bucket}/`` where by ``{bucket}`` is the name of the S3 bucket 

152 to which the file should be uploaded). Unlike ``read_file``, does not 

153 work for uploading via HTTP/HTTPS. Intermediate directories will be 

154 made for local paths. 

155 public: bool, optional 

156 Whether files uploaded to S3 should be publicly available. This will 

157 have no effect on local file destinations. 

158 """ 

159 if dest.startswith('s3://'): 

160 bucket, key = dest[5:].split('/', 1) 

161 upload_file(source, key, bucket=bucket, public=public) 

162 else: 

163 os.makedirs(os.path.dirname(dest)) 

164 shutil.copy2(source, dest) 

165 

166 

167def read_file(path: str): 

168 """ 

169 Load data from a file path, either local, via HTTP/HTTPS, or from an S3 

170 bucket. 

171 

172 Parameters 

173 ---------- 

174 path : str 

175 Path to a JSON or (newline-delimited) text file containing a list of 

176 values to load. Path can also be an S3 object, denoted with "s3://" as 

177 the prefix, or a URL, denoted with "http://" or "https://" as the 

178 prefix. In these cases, the file is not saved locally. 

179 

180 Returns 

181 ------- 

182 data : bytes 

183 The contents of the file. 

184 """ 

185 if path.startswith('https://') or path.startswith('http://'): 

186 with TemporaryDirectory() as tmp: 

187 tmppath = os.path.join(tmp, os.path.basename(path)) 

188 cacher = RemoteFileCacher(path, tmppath) 

189 cacher.get() 

190 with open(cacher.localpath, 'rb') as infile: 

191 return infile.read() 

192 elif path.startswith('s3://'): 

193 with TemporaryDirectory() as tmp: 

194 tmppath = os.path.join(tmp, os.path.basename(path)) 

195 bucket, key = path[5:].split('/', 1) 

196 cacher = PrivateFileCacher(key, bucket, tmppath) 

197 cacher.get() 

198 with open(cacher.localpath, 'rb') as infile: 

199 return infile.read() 

200 else: 

201 with open(path, 'rb') as infile: 

202 return infile.read() 

203 

204 

205def load_params(path: str): 

206 """ 

207 Use ``read_file`` to read a local or remote path. Decode the file contents 

208 as either a JSON list or newline-delimited text file, and return the 

209 decoded list. 

210 """ 

211 dat = read_file(path).decode() 

212 try: 

213 values = json.loads(dat) 

214 if not isinstance(values, list): 

215 raise ValueError(f"Could not load a list from {path}. Instead, " 

216 f"got: {values}") 

217 return values 

218 except json.JSONDecodeError: 

219 return dat.strip().split('\n') 

220 

221 

222def postprocess_param_iterator(self: CliParser, namespace: Namespace): 

223 """ 

224 Overrides the ``namespace.params`` value with an iterator that captures all 

225 desired combinations of the specified ``params``. Loads from file, HTTP, 

226 HTTPS, or S3 for each param list. If the value for ``namespace.random`` is 

227 not None, will instead become an infinitely long generator serving random 

228 combinations of the parameters. 

229 

230 Parameters 

231 ---------- 

232 self : CliParser 

233 The ``CliParser`` applying this post-processing. 

234 namespace : Namespace 

235 The processed arguments parsed by ``self``. 

236 

237 Returns 

238 ------- 

239 combinations : GeneratorType 

240 A generator returning a ``namedtuple`` whose names are the keys of 

241 ``parameters`` and whose values are drawn from the corresponding 

242 value of ``parameters``. 

243 """ 

244 params = OrderedDict((k, load_params(v)) 

245 for k, v in namespace.params.items()) 

246 if namespace.random is None: 

247 combos = (OrderedDict(zip(params, v)) 

248 for v in product(*params.values())) 

249 else: 

250 def rand_vecs(): 

251 i = 0 

252 # keep going forever if a negative number is given 

253 while i != namespace.random: 

254 i += 1 

255 vec = [v[randint(0, len(v)-1)] for v in params.values()] 

256 yield OrderedDict(zip(params, vec)) 

257 combos = rand_vecs() 

258 setattr(namespace, 'params', combos)