import numpy as np
from . import fftw
def _Xfftn_plan_pyfftw(shape, axes, dtype, transforms, options):
"""Plan serial transforms using pyfftw
Parameters
----------
shape : list of ints
shape of input array planned for
axes : list of ints
axes to transform over
dtype : np.dtype
type of input array
options : dict
arguments for planning serial transforms
"""
import pyfftw
opts = dict(
avoid_copy=True,
overwrite_input=True,
auto_align_input=True,
auto_contiguous=True,
threads=1,
)
opts.update(options)
transforms = {} if transforms is None else transforms
if tuple(axes) in transforms:
plan_fwd, plan_bck = transforms[tuple(axes)]
else:
if np.issubdtype(dtype, np.floating):
plan_fwd = pyfftw.builders.rfftn
plan_bck = pyfftw.builders.irfftn
else:
plan_fwd = pyfftw.builders.fftn
plan_bck = pyfftw.builders.ifftn
s = tuple(np.take(shape, axes))
U = pyfftw.empty_aligned(shape, dtype=dtype)
xfftn_fwd = plan_fwd(U, s=s, axes=axes, **opts)
U.fill(0)
if np.issubdtype(dtype, np.floating):
del opts['overwrite_input']
V = xfftn_fwd.output_array
xfftn_bck = plan_bck(V, s=s, axes=axes, **opts)
V.fill(0)
xfftn_fwd.update_arrays(U, V)
xfftn_bck.update_arrays(V, U)
return (xfftn_fwd, xfftn_bck)
def _Xfftn_plan_mpi4py(shape, axes, dtype, transforms, options):
"""Plan serial transforms using local wrapper
Parameters
----------
shape : list of ints
shape of input array planned for
axes : list of ints
axes to transform over
dtype : np.dtype
type of input array
options : dict
arguments for planning serial transforms
"""
opts = dict(
overwrite_input='FFTW_DESTROY_INPUT',
planner_effort='FFTW_MEASURE',
threads=1,
)
opts.update(options)
flags = (fftw.flag_dict[opts['planner_effort']],
fftw.flag_dict[opts['overwrite_input']])
threads = opts['threads']
transforms = {} if transforms is None else transforms
if tuple(axes) in transforms:
plan_fwd, plan_bck = transforms[tuple(axes)]
else:
if np.issubdtype(dtype, np.floating):
plan_fwd = fftw.rfftn
plan_bck = fftw.irfftn
else:
plan_fwd = fftw.fftn
plan_bck = fftw.ifftn
s = tuple(np.take(shape, axes))
U = fftw.aligned(shape, dtype=dtype)
xfftn_fwd = plan_fwd(U, s=s, axes=axes, threads=threads, flags=flags)
U.fill(0)
V = xfftn_fwd.output_array
if np.issubdtype(dtype, np.floating):
flags = (fftw.flag_dict[opts['planner_effort']],)
xfftn_bck = plan_bck(V, s=s, axes=axes, threads=threads, flags=flags, output_array=U)
return (xfftn_fwd, xfftn_bck)
class _Xfftn_wrap(object):
# pylint: disable=too-few-public-methods
__slots__ = ('_xfftn', '__doc__', '_input_array', '_output_array')
def __init__(self, xfftn_obj, input_array, output_array):
object.__setattr__(self, '_xfftn', xfftn_obj)
object.__setattr__(self, '_input_array', input_array)
object.__setattr__(self, '_output_array', output_array)
object.__setattr__(self, '__doc__', xfftn_obj.__doc__)
@property
def input_array(self):
return object.__getattribute__(self, '_input_array')
@property
def output_array(self):
return object.__getattribute__(self, '_output_array')
@property
def xfftn(self):
return object.__getattribute__(self, '_xfftn')
def __call__(self, input_array=None, output_array=None, **options):
if input_array is not None:
self.input_array[...] = input_array
self.xfftn(**options)
if output_array is not None:
output_array[...] = self.output_array
return output_array
else:
return self.output_array
[docs]class FFTBase(object):
"""Base class for serial FFT transforms
Parameters
----------
shape : list or tuple of ints
shape of input array planned for
axes : None, int or tuple of ints, optional
axes to transform over. If None transform over all axes
dtype : np.dtype, optional
Type of input array
padding : bool, number or list of numbers
If False, then no padding. If number, then apply this number as padding
factor for all axes. If list of numbers, then each number gives the
padding for each axis. Must be same length as axes.
"""
def __init__(self, shape, axes=None, dtype=float, padding=False):
shape = list(shape) if np.ndim(shape) else [shape]
assert len(shape) > 0
assert min(shape) > 0
if axes is not None:
axes = list(axes) if np.ndim(axes) else [axes]
for i, axis in enumerate(axes):
if axis < 0:
axes[i] = axis + len(shape)
else:
axes = list(range(len(shape)))
assert min(axes) >= 0
assert max(axes) < len(shape)
assert 0 < len(axes) <= len(shape)
assert sorted(axes) == sorted(set(axes))
dtype = np.dtype(dtype)
assert dtype.char in 'fdgFDG'
self.shape = shape
self.axes = axes
self.dtype = dtype
self.padding = padding
self.real_transform = np.issubdtype(dtype, np.floating)
self.padding_factor = 1
def _truncation_forward(self, padded_array, trunc_array):
axis = self.axes[-1]
if self.padding_factor > 1.0+1e-8:
trunc_array.fill(0)
N0 = self.forward.output_array.shape[axis]
if self.real_transform:
N = trunc_array.shape[axis]
s = [slice(None)]*trunc_array.ndim
s[axis] = slice(0, N)
trunc_array[:] = padded_array[tuple(s)]
if N0 % 2 == 0:
s[axis] = N-1
s = tuple(s)
trunc_array[s] = trunc_array[s].real
trunc_array[s] *= 2
else:
N = trunc_array.shape[axis]
su = [slice(None)]*trunc_array.ndim
su[axis] = slice(0, N//2+1)
trunc_array[tuple(su)] = padded_array[tuple(su)]
su[axis] = slice(-(N//2), None)
trunc_array[tuple(su)] += padded_array[tuple(su)]
def _padding_backward(self, trunc_array, padded_array):
axis = self.axes[-1]
if self.padding_factor > 1.0+1e-8:
padded_array.fill(0)
N0 = self.forward.output_array.shape[axis]
if self.real_transform:
s = [slice(0, n) for n in trunc_array.shape]
padded_array[tuple(s)] = trunc_array[:]
N = trunc_array.shape[axis]
if N0 % 2 == 0: # Symmetric Fourier interpolator
s[axis] = N-1
s = tuple(s)
padded_array[s] = padded_array[s].real
padded_array[s] *= 0.5
else:
N = trunc_array.shape[axis]
su = [slice(None)]*trunc_array.ndim
su[axis] = slice(0, N//2+1)
padded_array[tuple(su)] = trunc_array[tuple(su)]
su[axis] = slice(-(N//2), None)
padded_array[tuple(su)] = trunc_array[tuple(su)]
if N0 % 2 == 0: # Use symmetric Fourier interpolator
su[axis] = N//2
padded_array[tuple(su)] *= 0.5
su[axis] = -(N//2)
padded_array[tuple(su)] *= 0.5
[docs]class FFT(FFTBase):
"""Class for serial FFT transforms
Parameters
----------
shape : list or tuple of ints
shape of input array planned for
axes : None, int or tuple of ints, optional
axes to transform over. If None transform over all axes
dtype : np.dtype, optional
Type of input array
padding : bool, number or list of numbers
If False, then no padding. If number, then apply this number as padding
factor for all axes. If list of numbers, then each number gives the
padding for each axis. Must be same length as axes.
kw : dict
Parameters passed to serial transform object
Methods
-------
forward(input_array=None, output_array=None, **options)
Generic serial forward transform
Parameters
----------
input_array : array, optional
output_array : array, optional
options : dict
parameters to serial transforms
Returns
-------
output_array : array
backward(input_array=None, output_array=None, **options)
Generic serial backward transform
Parameters
----------
input_array : array, optional
output_array : array, optional
options : dict
parameters to serial transforms
Returns
-------
output_array : array
"""
def __init__(self, shape, axes=None, dtype=float, padding=False,
use_pyfftw=False, transforms=None, **kw):
FFTBase.__init__(self, shape, axes, dtype, padding)
plan = _Xfftn_plan_pyfftw if use_pyfftw is True else _Xfftn_plan_mpi4py
self.fwd, self.bck = plan(self.shape, self.axes, self.dtype, transforms, kw)
U, V = self.fwd.input_array, self.fwd.output_array
if use_pyfftw:
self.M = 1./np.prod(np.take(self.shape, self.axes))
else:
self.M = self.fwd.get_normalization()
self.padding_factor = 1.0
if padding is not False:
self.padding_factor = padding[self.axes[-1]] if np.ndim(padding) else padding
if abs(self.padding_factor-1.0) > 1e-8:
assert len(self.axes) == 1
trunc_array = self._get_truncarray(shape, V.dtype)
self.forward = _Xfftn_wrap(self._forward, U, trunc_array)
self.backward = _Xfftn_wrap(self._backward, trunc_array, U)
else:
self.forward = _Xfftn_wrap(self._forward, U, V)
self.backward = _Xfftn_wrap(self._backward, V, U)
def _forward(self, **kw):
self.fwd(None, None, **kw)
self._truncation_forward(self.fwd.output_array, self.forward.output_array)
self.forward._output_array *= self.M
return self.forward.output_array
def _backward(self, **kw):
self._padding_backward(self.backward.input_array, self.bck.input_array)
self.bck(None, None, normalise_idft=False, **kw)
return self.backward.output_array
def _get_truncarray(self, shape, dtype):
axis = self.axes[-1]
if not self.real_transform:
shape = list(shape)
shape[axis] = int(np.round(shape[axis] / self.padding_factor))
return fftw.aligned(shape, dtype=dtype)
shape = list(shape)
shape[axis] = int(np.round(shape[axis] / self.padding_factor))
shape[axis] = shape[axis]//2 + 1
return fftw.aligned(shape, dtype=dtype)
[docs]class FFTNumPy(FFTBase): #pragma: no cover
"""Class for serial FFT transforms using Numpy FFT
Parameters
----------
shape : list or tuple of ints
shape of input array planned for
axes : None, int or tuple of ints, optional
axes to transform over. If None transform over all axes
dtype : np.dtype, optional
Type of input array
padding : bool, number or list of numbers
If False, then no padding. If number, then apply this number as padding
factor for all axes. If list of numbers, then each number gives the
padding for each axis. Must be same length as axes.
kw : dict
Parameters passed to serial transform object
forward(input_array=None, output_array=None, **options)
Generic serial forward transform
Parameters
----------
input_array : array, optional
output_array : array, optional
options : dict
parameters to serial transforms
Returns
-------
output_array : array
backward(input_array=None, output_array=None, **options)
Generic serial backward transform
Parameters
----------
input_array : array, optional
output_array : array, optional
options : dict
parameters to serial transforms
Returns
-------
output_array : array
"""
def __init__(self, shape, axes=None, dtype=float, padding=False, **kw):
FFTBase.__init__(self, shape, axes, dtype, padding)
typecode = self.dtype.char
self.sizes = list(np.take(self.shape, self.axes))
arrayA = np.zeros(self.shape, self.dtype)
if self.real_transform:
axis = self.axes[-1]
self.shape[axis] = self.shape[axis]//2 + 1
arrayB = np.zeros(self.shape, typecode.upper())
fwd = np.fft.rfftn
bck = np.fft.irfftn
else:
arrayB = np.zeros(self.shape, typecode)
fwd = np.fft.fftn
bck = np.fft.ifftn
fwd.input_array = arrayA
fwd.output_array = arrayB
bck.input_array = arrayB
bck.output_array = arrayA
self.fwd, self.bck = fwd, bck
self.padding_factor = 1
if padding is not False:
assert len(self.axes) == 1
self.axis = self.axes[-1]
self.padding_factor = padding[axes[-1]] if np.ndim(padding) else padding
trunc_array = self._get_truncarray(shape, arrayB.dtype)
self.forward = _Xfftn_wrap(self._forward, arrayA, trunc_array)
self.backward = _Xfftn_wrap(self._backward, trunc_array, arrayA)
else:
self.forward = _Xfftn_wrap(self._forward, arrayA, arrayB)
self.backward = _Xfftn_wrap(self._backward, arrayB, arrayA)
def _forward(self, **kw):
self.fwd.output_array[:] = self.fwd(self.fwd.input_array, s=self.sizes,
axes=self.axes, **kw)
self._truncation_forward(self.fwd.output_array, self.forward.output_array)
return self.forward.output_array
def _backward(self, **kw):
self._padding_backward(self.backward.input_array, self.bck.input_array)
self.backward.output_array[:] = self.bck(self.bck.input_array, s=self.sizes,
axes=self.axes, **kw)
return self.backward.output_array
def _get_truncarray(self, shape, dtype):
axis = self.axes[-1]
if not self.real_transform:
shape = list(shape)
shape[axis] = int(np.round(shape[axis] / self.padding_factor))
return np.zeros(shape, dtype=dtype)
shape = list(shape)
shape[axis] = int(np.round(shape[axis] / self.padding_factor))
shape[axis] = shape[axis]//2 + 1
return np.zeros(shape, dtype=dtype)
#FFT = FFTNumPy