Source code for mpi4py_fft.io.file_base

from mpi4py import MPI
import numpy as np

__all__ = ('FileBase',)

comm = MPI.COMM_WORLD

[docs]class FileBase(object): """Base class for reading/writing distributed arrays Parameters ---------- filename : str, optional Name of backend file used to store data domain : sequence, optional An optional spatial mesh or domain to go with the data. Sequence of either - 2-tuples, where each 2-tuple contains the (origin, length) of each dimension, e.g., (0, 2*pi). - Arrays of coordinates, e.g., np.linspace(0, 2*pi, N). One array per dimension. """ def __init__(self, filename=None, domain=None): self.f = None self.filename = filename self.domain = domain def _check_domain(self, group, field): """Check dimensions and store (if missing) self.domain""" raise NotImplementedError
[docs] def write(self, step, fields, **kw): """Write snapshot ``step`` of ``fields`` to file Parameters ---------- step : int Index of snapshot. fields : dict The fields to be dumped to file. (key, value) pairs are group name and either arrays or 2-tuples, respectively. The arrays are complete arrays to be stored, whereas 2-tuples are arrays with associated *global* slices. as_scalar : boolean, optional Whether to store rank > 0 arrays as scalars. Default is False. """ as_scalar = kw.get("as_scalar", False) def _write(group, u, sl, step, kw, k=None): if sl is None: self._write_group(group, u, step, **kw) else: self._write_slice_step(group, step, sl, u, **kw) for group, list_of_fields in fields.items(): assert isinstance(list_of_fields, (tuple, list)) assert isinstance(group, str) for field in list_of_fields: u = field[0] if isinstance(field, (tuple, list)) else field sl = field[1] if isinstance(field, (tuple, list)) else None if as_scalar is False or u.rank == 0: self._check_domain(group, u) _write(group, u, sl, step, kw) else: # as_scalar is True and u.rank > 0 if u.rank == 1: for k in range(u.shape[0]): g = group + str(k) self._check_domain(g, u[k]) _write(g, u[k], sl, step, kw) elif u.rank == 2: for k in range(u.shape[0]): for l in range(u.shape[1]): g = group + str(k) + str(l) self._check_domain(g, u[k, l]) _write(g, u[k, l], sl, step, kw)
[docs] def read(self, u, name, **kw): """Read field ``name`` into distributed array ``u`` Parameters ---------- u : array The :class:`.DistArray` to read into. name : str Name of field to be read. step : int, optional Index of field to be read. Default is 0. """ raise NotImplementedError
[docs] def close(self): """Close the self.filename file""" self.f.close()
[docs] def open(self, mode='r+'): """Open the self.filename file for reading or writing Parameters ---------- mode : str Open file in this mode. Default is 'r+'. """ raise NotImplementedError
[docs] @staticmethod def backend(): """Return which backend is used to store data""" raise NotImplementedError
def _write_slice_step(self, name, step, slices, field, **kwargs): raise NotImplementedError def _write_group(self, name, u, step, **kwargs): raise NotImplementedError @staticmethod def _get_slice_name(slices): sl = list(slices) slname = '' for ss in sl: if isinstance(ss, slice): slname += 'slice_' else: slname += str(ss)+'_' return slname[:-1] @staticmethod def _get_local_slices(slices, s): # Check if data is on this processor and make slices local inside = 1 si = np.nonzero([isinstance(x, int) and not z == slice(None) for x, z in zip(slices, s)])[0] for i in si: if slices[i] >= s[i].start and slices[i] < s[i].stop: slices[i] -= s[i].start else: inside = 0 return slices, inside