Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# (c) Stefan Countryman, 2019
3"""
4Tools for analyzing large-scale batches of LLAMA data, e.g. to run simulations
5off of injected or randmoized data, with inputs and results stored locally or
6on a cloud storage service.
7"""
9import os
10import json
11import shutil
12import logging
13from textwrap import dedent
14from socket import gethostname
15from tempfile import TemporaryDirectory
16from random import randint
17from itertools import product
18from argparse import Action, Namespace
19from collections import OrderedDict, namedtuple
20from llama.cli import CliParser
21from llama.com.s3 import PrivateFileCacher, upload_file
22from llama.utils import RemoteFileCacher
23from llama.com.slack import alert_maintainers
25LOGGER = logging.getLogger(__name__)
28class DictAction(Action):
29 """
30 Split a list of command line arguments into an ``OrderedDict`` along the
31 first occurence of "=" in each one.
32 """
34 def __call__(self, parser, namespace, values, option_string=None):
35 if not (isinstance(values, list) or isinstance(values, tuple)):
36 values = [values]
37 invalid_args = [v for v in values if '=' not in v]
38 if invalid_args:
39 msg = (f"When specifying ``{option_string}`` values, each argument"
40 " must be of the form KEY=VALUE. The following arguments "
41 "are missing an '=' sign:\n\t" + "\n\t".join(invalid_args))
42 parser.error(msg)
43 setattr(namespace, self.dest,
44 OrderedDict(v.split('=', 1) for v in values))
47def batch_error_dump(errdump, event):
48 """
49 Get an error callback function like
50 ``llama.cli.traceback_alert_maintainers`` to pass to
51 ``llama.cli.log_exceptions_and_recover`` that dumps the results of the
52 failing event to the directories listed in ``errdump``.
54 Parameters
55 ----------
56 errdump : array-like
57 A list of locations in which dumped data will go. These can be S3
58 "directory" paths compatible with ``put_file``.
59 event : llama.event.Event
60 The event which errored. Files will be dumped from here.
61 """
63 def dump_callback(func, err, tb, self, *args, **kwargs):
64 """Callback to dump event data in the event of an uncaught
65 exception."""
66 LOGGER.error("Dumping errored event files to %s", errdump)
67 # same logic as ``Event.cruft_files`` but leave out
68 # ``Event.auxiliary_paths`` since this contains event history.
69 non_cruft = {f for fh in event.files.values()
70 for f in {fh.FILENAME}.union(fh.auxiliary_paths)}
71 upload_names = non_cruft.intersection(os.listdir(event.eventdir))
72 LOGGER.error("Dumped file names: %s", upload_names)
73 for dest in errdump:
74 LOGGER.error("Dumping to %s", dest)
75 for fname in upload_names:
76 put_file(os.path.join(event.eventdir, fname),
77 os.path.join(dest, fname), public=False)
78 LOGGER.error("Dumping traceback to errdump file ERRDUMP")
79 tbpath = os.path.join(event.eventdir, 'ERRDUMP.log')
80 with open(tbpath, 'w') as tbfile:
81 tbfile.write(tb)
82 put_file(tbpath, os.path.join(dest, 'ERRDUMP.log'))
84 return dump_callback
87def batch_error_alert_maintainers(errdump, event, params):
88 """
89 Get an error callback function like
90 ``llama.cli.traceback_alert_maintainers`` to pass to
91 ``llama.cli.log_exceptions_and_recover`` that provides batch-specific
92 information about the error.
94 Parameters
95 ----------
96 errdump : array-like
97 A list of locations in which dumped data will go.
98 event : llama.event.Event
99 The event which errored. Will be noted in the Slack alert message.
100 params : dict
101 The dictionary that was used to format ``errdump`` containing values
102 from ``--params``. Contains information on the current event being
103 processed.
104 """
106 def alert_callback(func, err, tb, self, *args, **kwargs):
107 """Callback to run in the event of an uncaught exception."""
108 LOGGER.error("Attempting to `alert_maintainers` with traceback...")
109 fmt = dedent("""
110 BATCH PROCESS ERROR ON {hostname}: {err}.
112 Traceback:
114 ```
115 {tb}
116 ```
118 - `event`: `{event}`
119 - `args`: `{args}`
120 - `kwargs`: `{kwargs}`
121 - Batch Params: `{params}`
123 Files saved in:
125 ```
126 {errdump}
127 ```
128 """).strip()
129 msg = fmt.format(hostname=gethostname(), err=err, tb=tb, args=args,
130 kwargs=kwargs, event=event, errdump=errdump,
131 params=params)
132 res = alert_maintainers(msg, func.__name__, recover=True)
133 if res.get('ok', False):
134 LOGGER.error("Alerted maintainers. Traceback sent: %s", tb)
136 return alert_callback
139def put_file(source: str, dest: str, public: bool = False):
140 """
141 Save a file somewhere. Same as copying the file, but with S3 compatibility.
142 The destination directory and all parent directories will be created if
143 they don't exist.
145 Parameters
146 ----------
147 source: str
148 The path to the file to be uploaded.
149 dest: str
150 Destination path. Works for local paths or for S3 keys (by prepending
151 ``s3://{bucket}/`` where by ``{bucket}`` is the name of the S3 bucket
152 to which the file should be uploaded). Unlike ``read_file``, does not
153 work for uploading via HTTP/HTTPS. Intermediate directories will be
154 made for local paths.
155 public: bool, optional
156 Whether files uploaded to S3 should be publicly available. This will
157 have no effect on local file destinations.
158 """
159 if dest.startswith('s3://'):
160 bucket, key = dest[5:].split('/', 1)
161 upload_file(source, key, bucket=bucket, public=public)
162 else:
163 os.makedirs(os.path.dirname(dest))
164 shutil.copy2(source, dest)
167def read_file(path: str):
168 """
169 Load data from a file path, either local, via HTTP/HTTPS, or from an S3
170 bucket.
172 Parameters
173 ----------
174 path : str
175 Path to a JSON or (newline-delimited) text file containing a list of
176 values to load. Path can also be an S3 object, denoted with "s3://" as
177 the prefix, or a URL, denoted with "http://" or "https://" as the
178 prefix. In these cases, the file is not saved locally.
180 Returns
181 -------
182 data : bytes
183 The contents of the file.
184 """
185 if path.startswith('https://') or path.startswith('http://'):
186 with TemporaryDirectory() as tmp:
187 tmppath = os.path.join(tmp, os.path.basename(path))
188 cacher = RemoteFileCacher(path, tmppath)
189 cacher.get()
190 with open(cacher.localpath, 'rb') as infile:
191 return infile.read()
192 elif path.startswith('s3://'):
193 with TemporaryDirectory() as tmp:
194 tmppath = os.path.join(tmp, os.path.basename(path))
195 bucket, key = path[5:].split('/', 1)
196 cacher = PrivateFileCacher(key, bucket, tmppath)
197 cacher.get()
198 with open(cacher.localpath, 'rb') as infile:
199 return infile.read()
200 else:
201 with open(path, 'rb') as infile:
202 return infile.read()
205def load_params(path: str):
206 """
207 Use ``read_file`` to read a local or remote path. Decode the file contents
208 as either a JSON list or newline-delimited text file, and return the
209 decoded list.
210 """
211 dat = read_file(path).decode()
212 try:
213 values = json.loads(dat)
214 if not isinstance(values, list):
215 raise ValueError(f"Could not load a list from {path}. Instead, "
216 f"got: {values}")
217 return values
218 except json.JSONDecodeError:
219 return dat.strip().split('\n')
222def postprocess_param_iterator(self: CliParser, namespace: Namespace):
223 """
224 Overrides the ``namespace.params`` value with an iterator that captures all
225 desired combinations of the specified ``params``. Loads from file, HTTP,
226 HTTPS, or S3 for each param list. If the value for ``namespace.random`` is
227 not None, will instead become an infinitely long generator serving random
228 combinations of the parameters.
230 Parameters
231 ----------
232 self : CliParser
233 The ``CliParser`` applying this post-processing.
234 namespace : Namespace
235 The processed arguments parsed by ``self``.
237 Returns
238 -------
239 combinations : GeneratorType
240 A generator returning a ``namedtuple`` whose names are the keys of
241 ``parameters`` and whose values are drawn from the corresponding
242 value of ``parameters``.
243 """
244 params = OrderedDict((k, load_params(v))
245 for k, v in namespace.params.items())
246 if namespace.random is None:
247 combos = (OrderedDict(zip(params, v))
248 for v in product(*params.values()))
249 else:
250 def rand_vecs():
251 i = 0
252 # keep going forever if a negative number is given
253 while i != namespace.random:
254 i += 1
255 vec = [v[randint(0, len(v)-1)] for v in params.values()]
256 yield OrderedDict(zip(params, vec))
257 combos = rand_vecs()
258 setattr(namespace, 'params', combos)