Coverage for pythia/misc.py: 91%
127 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-08 17:13 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-08 17:13 +0000
1"""
2File: pythia/misc.py
3Author: Nando Hegemann
4Gitlab: https://gitlab.com/Nando-Hegemann
5Description: Miscellaneous functions to support PyThia core functionality.
6SPDX-License-Identifier: LGPL-3.0-or-later OR Hippocratic-3.0-ECO-MEDIA-MIL
7"""
8from typing import Sequence, Iterator
9import os
10import datetime
11import shutil
12import numpy as np
15def shift_coord(
16 x: float | np.ndarray, S: np.ndarray | list, T: np.ndarray | list
17) -> np.ndarray:
18 """Shift `x` in interval `S` to interval `T`.
20 Use an affine transformation to shift points :math:`x` from the source
21 interval :math:`S = [t_0, t_1]` to the target interval :math:`T = [a, b]`.
23 Parameters
24 ----------
25 x : array_like
26 Points in interval :math:`S`.
27 S : array_like
28 Source interval.
29 T : array_like
30 Target interval.
32 Returns
33 -------
34 :
35 Shifted values for `x`.
36 """
37 return ((T[1] - T[0]) * x + T[0] * S[1] - T[1] * S[0]) / (S[1] - S[0])
40def cart_prod(array_list: list[np.ndarray] | np.ndarray) -> np.ndarray:
41 """Compute the outer product of two or more arrays.
43 Assemble an array containing all possible combinations of the elements
44 of the input vectors :math:`v_1,\\dots,v_n`.
46 Parameters
47 ----------
48 array_list : list of array_like
49 List of vectors :math:`v_1,\\dots,v_n`.
51 Returns
52 -------
53 :
54 Cartesian product array.
55 """
56 dim = len(array_list)
57 if dim == 1:
58 return np.array(array_list).T
59 x = np.hstack((np.meshgrid(*array_list))).swapaxes(0, 1).reshape(dim, -1).T
60 return x
63def is_contained(val: float | Sequence | np.ndarray, domain: list | np.ndarray) -> bool:
64 """Check if a given value (vector) is contained in a domain.
66 Checks if each component of the vector lies in the one dimensional
67 interval of the corresponding component of the domain.
69 Parameters
70 ----------
71 val : array_like
72 Vector to check containment in domain
73 domain : array_like
74 Product domain of one dimensional intervals.
76 Returns
77 -------
78 :
79 Bool stating if value is contained in domain.
80 """
81 if not isinstance(val, np.ndarray):
82 val = np.array(val)
83 if not isinstance(domain, np.ndarray):
84 domain = np.array(domain)
85 if val.ndim < 2:
86 val.shape = 1, -1
87 if domain.ndim < 2:
88 domain.shape = 1, -1
89 assert val.ndim == 2
90 assert val.shape[0] == 1
91 assert domain.ndim == 2 and domain.shape[1] == 2
92 if np.all(domain[:, 0] <= val) and np.all(val <= domain[:, 1]):
93 return True
94 return False
97def format_time(dt: float) -> str:
98 """Converts time (seconds) to time format string.
100 Parameters
101 ----------
102 dt : float
103 Time in seconds.
105 Returns
106 -------
107 :
108 Formatted time string.
109 """
110 assert dt >= 0
111 dct = {}
112 dct["d"], rem = divmod(int(dt), 86400)
113 dct["h"], rem = divmod(int(rem), 3600)
114 dct["min"], seconds = divmod(int(rem), 60)
115 dct["sec"] = seconds + 1 # rounding seconds up
116 fmt = ""
117 if dct["d"] != 0:
118 fmt += "{d} days "
119 if dct["h"] != 0:
120 fmt += "{h}h "
121 if dct["min"] != 0:
122 fmt += "{min}min "
123 if dct["sec"] > 1:
124 fmt += "{sec}s "
125 if dt < 1.0:
126 fmt += "{:2.2g}s".format(dt)
127 fmt = fmt.strip()
128 return fmt.format(**dct)
131def now() -> str:
132 """Get string of current machine date and time.
134 Returns
135 -------
136 :
137 Formatted date and time string.
138 """
139 dt = datetime.datetime.now()
140 today = "{:04}-{:02}-{:02} ".format(dt.year, dt.month, dt.day)
141 now = "{:02}:{:02}:{:02}".format(dt.hour, dt.minute, dt.second)
142 return today + now
145def line(indicator: str, message: str = None) -> str:
146 """Print a line of 80 characters by repeating indicator.
148 An additional message can be given.
150 Parameters
151 ----------
152 indicator : string
153 Indicator the line consists of, e.g. '-', '+' or '+-'.
154 message : string, optional
155 Message integrated in the line.
157 Returns
158 -------
159 :
160 String of 80 characters length.
161 """
162 assert len(indicator) > 0
163 text = ""
164 if message is not None:
165 text = 2 * indicator
166 text = text[:2] + " " + message + " "
167 while len(text) < 80:
168 text += indicator
169 return text[:80]
172def save(filename: str, data: np.ndarray, path: str = "./") -> None:
173 """Wrapper for numpy save.
175 Assures path directory is created if necessary and backup old data if
176 existent.
178 Parameters
179 ----------
180 name : str
181 Filename to save data to.
182 data : array_like
183 Data to save as .npy file.
184 path : str, default='./'
185 Path under which the file should be created.
186 """
187 if not os.path.isdir(path):
188 os.makedirs(path)
189 if os.path.isfile(path + filename):
190 shutil.copyfile(path + filename, path + filename + ".backup")
191 np.save(path + filename, data)
194def load(filename: str) -> np.ndarray:
195 """Alias for numpy.load()."""
196 return np.load(filename)
199def str2iter(string: str, iterType: type = list, dataType: type = int) -> Sequence:
200 """Cast `str(iterable)` to `iterType` of `dataType`.
202 Cast a string of lists, tuples, etc to the specified iterable and data
203 type, i.e., for `iterType=tuple` and `dataType=float` cast
204 ``str([1,2,3]) -> (1.0, 2.0, 3.0)``.
206 Parameters
207 ----------
208 string : str
209 String representation of iterable.
210 iterType : iterable, default=list
211 Iterable type the string is converted to.
212 dataType : type, default=int
213 Data type of entries of iterable, e.g. `int` or `float`.
214 """
215 items = [s.strip() for s in string[1:-1].split(",")]
216 if items[-1] == "":
217 items = items[:-1]
218 return iterType([dataType(item) for item in items])
221def batch(iterable: Sequence, n: int = 1) -> Iterator:
222 """Split iterable into different batches of batchsize n.
224 Parameters
225 ----------
226 iterable : array_like
227 Iterable to split.
228 n : int, default=1
229 Batch size.
231 Yields
232 ------
233 :
234 Iterable for different batches.
235 """
236 for ndx in range(0, len(iterable), n):
237 yield iterable[ndx : min(ndx + n, len(iterable))]
240def wls_sampling_bound(m: int, c: float = 4) -> int:
241 """Compute the weighted Least-Squares sampling bound.
243 The number of samples :math:`n` is chosen such that
245 .. math::
246 \\frac{n}{\\log(n)} \\geq cm,
248 where :math:`m` is the dimension of the Gramian matrix (number of PC
249 expansion terms) and :math:`c` is an arbitrary constant. In
250 Cohen & Migliorati 2017 the authors observed that the coice :math:`c=4`
251 yields a well conditioned Gramian with high probability.
253 Parameters
254 ----------
255 m : int
256 Dimension of Gramian matrix.
257 c : float, default=4
258 Scaling constant.
260 Returns
261 -------
262 :
263 Number of required wLS samples.
264 """
265 assert m > 0 and c > 0
266 jj = max(int(np.ceil(c * m * np.log(c * m))), 2)
267 while True:
268 if jj / np.log(jj) >= c * m:
269 n = jj
270 break
271 jj += 1
272 return n
275def gelman_rubin_condition(chains: np.ndarray) -> np.ndarray:
276 """Compute Gelman-Rubin criterion.
278 Implementation of the Gelman-Rubin convergence criterion for multiple
279 parameters. A Markov chain is said to be in its convergence, if the final
280 ration is close to one.
282 Parameters
283 ----------
284 chains : array_like, ndim=3
285 Array containing the Markov chains of each parameter. All chains are
286 equal in length, the assumed shape is
287 ``(#chains, chain length, #params)``.
289 Returns
290 -------
291 :
292 Values computed by Gelman-Rubin criterion for each parameter.
293 """
294 assert chains.ndim == 3
295 M, N, DIM = chains.shape # chains shape is (#chains, len(chains), #params)
297 # Mean and var of chains.
298 chain_means = np.mean(chains, axis=1) # shape is (#chains, #params)
299 chain_vars = np.var(chains, axis=1) # shape is (#chains, #params)
301 # Mean across all chains.
302 mean = np.mean(chain_means, axis=0).reshape(1, -1) # shape = (1, #params)
304 # Between chain variance.
305 B = N / (M - 1) * np.sum((chain_means - mean) ** 2, axis=0)
306 # Within chain variance.
307 W = 1 / M * np.sum(chain_vars, axis=0)
308 # pooled variance
309 V = (N - 1) / N * W + (M + 1) / (M * N) * B
311 return np.array([np.sqrt(v / w) if w > 0 else np.inf for v, w in zip(V, W)])
314def confidence_interval(
315 samples: np.ndarray, rate: float = 0.95, resolution: int = 500
316) -> np.ndarray:
317 """Compute confidence intervals of samples.
319 Compute the confidence intervals of the 1D marginals of the samples
320 (slices). The confidence interval of a given rate is the interval around
321 the median (not mean) of the samples containing roughly `rate` percent of
322 the total mass. This is computed for the left and right side of the median
323 independently.
325 Parameters
326 ----------
327 samples : array_like, ndim < 3
328 Array containing the (multidimensional) samples.
329 rate : float, default=0.95
330 Fraction of the total mass the interval should contain.
331 resolution : int, default=500
332 Number of bins used in histogramming the samples.
334 Returns
335 -------
336 :
337 Confidence intervals for each component.
338 """
339 if samples.ndim < 2:
340 samples.shape = -1, 1
341 assert samples.ndim == 2
342 assert 0 <= rate <= 1
343 conf_intervals = np.empty((samples.shape[1], 2))
344 for j, s in enumerate(samples.T):
345 median = np.median(s)
346 hist, bdry = np.histogram(s, bins=resolution, density=True)
347 median_bin = np.argmax(bdry > median) - 1
348 cumsum_left = (
349 np.cumsum(np.flip(hist[: median_bin + 1]))
350 - 0.5 * hist[median_bin] # do not count the median bin twice
351 )
352 cumsum_right = (
353 np.cumsum(hist[median_bin:])
354 - 0.5 * hist[median_bin] # do not count the median bin twice
355 )
356 lc = median_bin - np.argmax(cumsum_left >= rate * cumsum_left[-1])
357 rc = np.argmax(cumsum_right >= rate * cumsum_right[-1]) + median_bin
358 conf_intervals[j] = np.array([bdry[lc], bdry[rc + 1]])
359 return conf_intervals
362def doerfler_marking(
363 values: np.ndarray | list,
364 idx: np.ndarray | None = None,
365 threshold: float = 0.9,
366) -> tuple[np.ndarray, np.ndarray, int]:
367 """Dörfler marking for arbitrary values.
369 Parameters
370 ----------
371 values : array_like
372 Values for the Dörfler marking.
373 idx : list of int, optional
374 List of indices associated with the entries of `values`.
375 If `None`, this is set to ``range(len(values))``.
376 threshold : float, default=0.9
377 Threshold paramter for Dörfler marking.
379 Returns
380 -------
381 idx_reordered :
382 Reordered indices given by `idx`.
383 Ordered from largest to smallest value.
384 ordered_values :
385 Reordered values. Ordered from largest to smallest.
386 marker :
387 Threshold marker such that
388 ``sum(values[:marker]) > threshold * sum(values)``.
389 """
390 if isinstance(values, list):
391 values = np.array(values)
392 assert values.size > 0
393 if values.ndim < 2:
394 values.shape = -1, 1
395 assert values.ndim == 2
396 if idx is None:
397 idx = np.arange(values.shape[0], dtype=int)
398 # index list of largest to smalles (absolute) coeff values
399 sort = np.flip(np.argsort(np.abs(values), axis=None), axis=0)
400 idx_reordered = idx[sort]
401 ordered_values = values[sort]
402 marker = int(
403 np.argmax(np.cumsum(ordered_values, axis=0) > threshold * np.sum(values))
404 )
405 return idx_reordered, ordered_values, marker + 1
408if __name__ == "__main__":
409 pass