from __future__ import print_function, absolute_import, division
import abc
import uuid
import warnings
import tempfile
from six.moves import zip
import numpy as np
from numpy.lib.stride_tricks import as_strided
import dask.array as da
from astropy.wcs import InconsistentAxisTypesError
from astropy.io import fits
from . import wcs_utils
from .utils import WCSWarning
__all__ = ['MaskBase', 'InvertedMask', 'CompositeMask', 'BooleanArrayMask',
'LazyMask', 'LazyComparisonMask', 'FunctionMask']
# Global version of the with_spectral_unit docs to avoid duplicating them
with_spectral_unit_docs = """
Parameters
----------
unit : u.Unit
Any valid spectral unit: velocity, (wave)length, or frequency.
Only vacuum units are supported.
velocity_convention : u.doppler_relativistic, u.doppler_radio, or u.doppler_optical
The velocity convention to use for the output velocity axis.
Required if the output type is velocity.
rest_value : u.Quantity
A rest wavelength or frequency with appropriate units. Required if
output type is velocity. The cube's WCS should include this
already if the *input* type is velocity, but the WCS's rest
wavelength/frequency can be overridden with this parameter.
"""
def is_broadcastable_and_smaller(shp1, shp2):
"""
Test if shape 1 can be broadcast to shape 2, not allowing the case
where shape 2 has a dimension length 1
"""
for a, b in zip(shp1[::-1], shp2[::-1]):
# b==1 is broadcastable but not desired
if a == 1 or a == b:
pass
else:
return False
return True
def dims_to_skip(shp1, shp2):
"""
For a shape `shp1` that is broadcastable to shape `shp2`, specify which
dimensions are length 1.
Parameters
----------
keep : bool
If True, return the dimensions to keep rather than those to remove
"""
if not is_broadcastable_and_smaller(shp1, shp2):
raise ValueError("Cannot broadcast {0} to {1}".format(shp1,shp2))
dims = []
for ii,(a, b) in enumerate(zip(shp1[::-1], shp2[::-1])):
# b==1 is broadcastable but not desired
if a == 1:
dims.append(len(shp2) - ii - 1)
elif a == b:
pass
else:
raise ValueError("This should not be possible")
if len(shp1) < len(shp2):
dims += list(range(len(shp2)-len(shp1)))
return dims
def view_of_subset(shp1, shp2, view):
"""
Given two shapes and a view, assuming that shape 1 can be broadcast
to shape 2, return the sub-view that applies to shape 1
"""
# if the view is 1-dimensional, we can't subset it
if not hasattr(view, '__len__'):
return view
dts = dims_to_skip(shp1, shp2)
if view:
cv_view = [x for ii,x in enumerate(view) if ii not in dts]
else:
# if no view is specified, still need to slice
cv_view = [x for ii,x in enumerate([slice(None)]*3)
if ii not in dts]
# return type matters
# array[[0,0]] = [array[0], array[0]]
# array[(0,0)] = array[0,0]
return tuple(cv_view)
[docs]
class MaskBase(object):
__metaclass__ = abc.ABCMeta
[docs]
def include(self, data=None, wcs=None, view=(), **kwargs):
"""
Return a boolean array indicating which values should be included.
If ``view`` is passed, only the sliced mask will be returned, which
avoids having to load the whole mask in memory. Otherwise, the whole
mask is returned in-memory.
kwargs are passed to _validate_wcs
"""
self._validate_wcs(data, wcs, **kwargs)
return self._include(data=data, wcs=wcs, view=view)
# Commented out, but left as a possibility, because including this didn't fix any
# of the problems we encountered with matplotlib plotting
[docs]
def view(self, view=()):
"""
Compatibility tool: if a numpy.ma.ufunc is run on the mask, it will try
to grab a view of the mask, which needs to appear to numpy as a true
array. This can be important for, e.g., plotting.
Numpy's convention is that masked=True means "masked out"
.. note::
I don't know if there are broader concerns or consequences from
including this 'view' tool here.
"""
return self.exclude(view=view)
def _validate_wcs(self, new_data=None, new_wcs=None, **kwargs):
"""
This method can be overridden in cases where the data and WCS have to
conform to some rules. This gets called automatically when
``include`` or ``exclude`` are called.
"""
pass
@abc.abstractmethod
def _include(self, data=None, wcs=None, view=()):
pass
[docs]
def exclude(self, data=None, wcs=None, view=(), **kwargs):
"""
Return a boolean array indicating which values should be excluded.
If ``view`` is passed, only the sliced mask will be returned, which
avoids having to load the whole mask in memory. Otherwise, the whole
mask is returned in-memory.
kwargs are passed to _validate_wcs
"""
self._validate_wcs(data, wcs, **kwargs)
return self._exclude(data=data, wcs=wcs, view=view)
def _exclude(self, data=None, wcs=None, view=()):
return np.logical_not(self._include(data=data, wcs=wcs, view=view))
[docs]
def any(self):
return np.any(self.exclude())
def _flattened(self, data, wcs=None, view=()):
"""
Return a flattened array of the included elements of cube
Parameters
----------
data : array-like
The data array to flatten
view : tuple, optional
Any slicing to apply to the data before flattening
Returns
-------
flat_array : `~numpy.ndarray`
A 1-D ndarray containing the flattened output
Notes
-----
This is an internal method used by :class:`SpectralCube`.
"""
mask = self.include(data=data, wcs=wcs, view=view)
# Workaround for https://github.com/dask/dask/issues/6089
if isinstance(data, da.Array) and not isinstance(mask, da.Array):
mask = da.asarray(mask, name=str(uuid.uuid4()))
# if not isinstance(data, da.Array) and isinstance(mask, da.Array):
# mask = mask.compute()
return data[view][mask]
def _filled(self, data, wcs=None, fill=np.nan, view=(), use_memmap=False,
**kwargs):
"""
Replace the excluded elements of *array* with *fill*.
Parameters
----------
data : array-like
Input array
fill : number
Replacement value
view : tuple, optional
Any slicing to apply to the data before flattening
use_memmap : bool
Use a memory map to store the output data?
Returns
-------
filled_array : `~numpy.ndarray`
A 1-D ndarray containing the filled output
Notes
-----
This is an internal method used by :class:`SpectralCube`.
Users should use the property :meth:`MaskBase.filled_data`
"""
# Must convert to floating point, but should not change from inherited
# type otherwise
dt = np.result_type(data.dtype, 0.0)
if use_memmap and data.size > 0:
ntf = tempfile.NamedTemporaryFile()
sliced_data = np.memmap(ntf, mode='w+', shape=data[view].shape,
dtype=dt)
sliced_data[:] = data[view]
else:
sliced_data = data[view].astype(dt)
ex = self.exclude(data=data, wcs=wcs, view=view, **kwargs)
return np.ma.masked_array(sliced_data, mask=ex).filled(fill)
def __and__(self, other):
return CompositeMask(self, other, operation='and')
def __or__(self, other):
return CompositeMask(self, other, operation='or')
def __xor__(self, other):
return CompositeMask(self, other, operation='xor')
def __invert__(self):
return InvertedMask(self)
@property
def shape(self):
raise NotImplementedError("{0} mask classes do not have shape attributes."
.format(self.__class__.__name__))
@property
def ndim(self):
return len(self.shape)
@property
def size(self):
return np.prod(self.shape)
@property
def dtype(self):
return np.dtype('bool')
def __getitem__(self):
raise NotImplementedError("Slicing not supported by mask class {0}"
.format(self.__class__.__name__))
[docs]
def quicklook(self, view, wcs=None, filename=None, use_aplpy=True,
aplpy_kwargs={}):
'''
View a 2D slice of the mask, specified by view.
Parameters
----------
view : tuple
Slicing to apply to the mask. Must return a 2D slice.
wcs : astropy.wcs.WCS, optional
WCS object to use in plotting the mask slice.
filename : str, optional
Filename of the output image. Enables saving of the plot.
use_aplpy : bool, optional
Try plotting with the aplpy package
aplpy_kwargs : dict, optional
kwargs passed to `~aplpy.FITSFigure`.
'''
view_twod = self.include(view=view, wcs=wcs)
if use_aplpy:
if wcs is not None:
hdu = fits.PrimaryHDU(view_twod.astype(int), wcs.to_header())
else:
hdu = fits.PrimaryHDU(view_twod.astype(int))
try:
import aplpy
FITSFigure = aplpy.FITSFigure(hdu,
**aplpy_kwargs)
FITSFigure.show_grayscale()
FITSFigure.add_colorbar()
if filename is not None:
FITSFigure.save(filename)
except (InconsistentAxisTypesError, ImportError):
use_aplpy = True
if not use_aplpy:
from matplotlib import pyplot
figure = pyplot.imshow(view_twod)
if filename is not None:
figure.savefig(filename)
def _get_new_wcs(self, unit, velocity_convention=None, rest_value=None):
"""
Returns a new WCS with a different Spectral Axis unit
"""
from .spectral_axis import convert_spectral_axis,determine_ctype_from_vconv
out_ctype = determine_ctype_from_vconv(self._wcs.wcs.ctype[self._wcs.wcs.spec],
unit,
velocity_convention=velocity_convention)
newwcs = convert_spectral_axis(self._wcs, unit, out_ctype,
rest_value=rest_value)
newwcs.wcs.set()
return newwcs
_get_new_wcs.__doc__ += with_spectral_unit_docs
[docs]
class InvertedMask(MaskBase):
def __init__(self, mask):
self._mask = mask
@property
def shape(self):
return self._mask.shape
def _include(self, data=None, wcs=None, view=()):
return np.logical_not(self._mask.include(data=data, wcs=wcs, view=view))
def __getitem__(self, view):
return InvertedMask(self._mask[view])
[docs]
def with_spectral_unit(self, unit, velocity_convention=None, rest_value=None):
"""
Get an InvertedMask copy with a WCS in the modified unit
"""
newmask = self._mask.with_spectral_unit(unit,
velocity_convention=velocity_convention,
rest_value=rest_value)
return InvertedMask(newmask)
with_spectral_unit.__doc__ += with_spectral_unit_docs
[docs]
class CompositeMask(MaskBase):
"""
A combination of several masks. The included masks are treated with the specified
operation.
Parameters
----------
mask1, mask2 : Masks
The two masks to composite
operation : str
Either 'and' or 'or'; the operation used to combine the masks
"""
def __init__(self, mask1, mask2, operation='and'):
if isinstance(mask1, np.ndarray) and isinstance(mask2, MaskBase) and hasattr(mask2, 'shape'):
if not is_broadcastable_and_smaller(mask1.shape, mask2.shape):
raise ValueError("Mask1 shape is not broadcastable to Mask2 shape: "
"%s vs %s" % (mask1.shape, mask2.shape))
mask1 = BooleanArrayMask(mask1, mask2._wcs, shape=mask2.shape)
elif isinstance(mask2, np.ndarray) and isinstance(mask1, MaskBase) and hasattr(mask1, 'shape'):
if not is_broadcastable_and_smaller(mask2.shape, mask1.shape):
raise ValueError("Mask2 shape is not broadcastable to Mask1 shape: "
"%s vs %s" % (mask2.shape, mask1.shape))
mask2 = BooleanArrayMask(mask2, mask1._wcs, shape=mask1.shape)
# both entries must have compatible, which effectively means
# equal, WCSes. Unless one is a function.
if hasattr(mask1, '_wcs') and hasattr(mask2, '_wcs'):
mask1._validate_wcs(new_data=None, wcs=mask2._wcs)
# In order to composite composites, they must have a _wcs defined.
# (maybe this should be a property?)
self._wcs = mask1._wcs
elif hasattr(mask1, '_wcs'):
# if one mask doesn't have a WCS, but the other does, the
# compositemask should have the same WCS as the one that does
self._wcs = mask1._wcs
elif hasattr(mask2, '_wcs'):
self._wcs = mask2._wcs
self._mask1 = mask1
self._mask2 = mask2
self._operation = operation
def _validate_wcs(self, new_data=None, new_wcs=None, **kwargs):
self._mask1._validate_wcs(new_data=new_data, new_wcs=new_wcs, **kwargs)
self._mask2._validate_wcs(new_data=new_data, new_wcs=new_wcs, **kwargs)
@property
def shape(self):
try:
assert self._mask1.shape == self._mask2.shape
return self._mask1.shape
except AssertionError:
raise ValueError("The composite mask does not have a well-defined "
"shape; its two components have shapes {0} and "
"{1}.".format(self._mask1.shape,
self._mask2.shape))
except NotImplementedError:
raise ValueError("The composite mask contains at least one "
"component with no defined shape.")
def _include(self, data=None, wcs=None, view=()):
result_mask_1 = self._mask1._include(data=data, wcs=wcs, view=view)
result_mask_2 = self._mask2._include(data=data, wcs=wcs, view=view)
if self._operation == 'and':
return np.bitwise_and(result_mask_1, result_mask_2)
elif self._operation == 'or':
return np.bitwise_or(result_mask_1, result_mask_2)
elif self._operation == 'xor':
return np.bitwise_xor(result_mask_1, result_mask_2)
else:
raise ValueError("Operation '{0}' not supported".format(self._operation))
def __getitem__(self, view):
return CompositeMask(self._mask1[view], self._mask2[view],
operation=self._operation)
[docs]
def with_spectral_unit(self, unit, velocity_convention=None, rest_value=None):
"""
Get a CompositeMask copy in which each component has a WCS in the
modified unit
"""
newmask1 = self._mask1.with_spectral_unit(unit,
velocity_convention=velocity_convention,
rest_value=rest_value)
newmask2 = self._mask2.with_spectral_unit(unit,
velocity_convention=velocity_convention,
rest_value=rest_value)
return CompositeMask(newmask1, newmask2, self._operation)
with_spectral_unit.__doc__ += with_spectral_unit_docs
[docs]
class BooleanArrayMask(MaskBase):
"""
A mask defined as an array on a spectral cube WCS
Parameters
----------
mask: `numpy.ndarray`
A boolean numpy ndarray
wcs: `astropy.wcs.WCS`
The WCS object
shape: tuple
The shape of the region the array is masking. This is *required* if
``mask.ndim != data.ndim`` to provide rules for how to broadcast the
mask
"""
def __init__(self, mask, wcs, shape=None, include=True):
self._mask_type = 'include' if include else 'exclude'
self._wcs = wcs
self._wcs_whitelist = set()
#if mask.ndim != 3 and (shape is None or len(shape) != 3):
# raise ValueError("When creating a BooleanArrayMask with <3 dimensions, "
# "the shape of the 3D array must be specified.")
if shape is not None and not is_broadcastable_and_smaller(mask.shape, shape):
raise ValueError("Mask cannot be broadcast to the specified shape.")
self._shape = shape or mask.shape
self._mask = mask
"""
Developer note (AG):
The logic in this following section seems overly complicated. All
of it is added to make sure that a 1D boolean array along the
spectral axis can be created. I thought this was possible
previously, but experience many errors in my latest attempt to use
one.
"""
# If a shape is given, we may need to broadcast to that shape
if shape is not None:
# these are dimensions that simply don't exist
n_empty_dims = (len(self._shape)-mask.ndim)
# these are dimensions of shape 1 that would be squeezed away but may
# be needed to make the arrays broadcastable (e.g., mask[:,None,None])
# Need to add n_empty_dims because (1,2) will broadcast to (3,1,2)
# and there will be no extra dims.
extra_dims = [ii
for ii,(sh1,sh2) in
enumerate(zip((0,)*n_empty_dims + mask.shape, shape))
if sh1 == 1 and sh1 != sh2]
# Add the [None,]'s and the nonexistant
n_extra_dims = n_empty_dims + len(extra_dims)
# if there are no extra dims, we're done, the original shape is fine
if n_extra_dims > 0:
strides = (0,)*n_empty_dims + mask.strides
for ed in extra_dims:
# all of the [None,] dims should have 0 stride
assert strides[ed] == 0,"Stride shape failure"
self._mask = as_strided(mask, shape=self.shape,
strides=strides)
# Make sure the mask shape matches the Mask object shape
assert self._mask.shape == self.shape,"Shape initialization failure"
def _validate_wcs(self, new_data=None, new_wcs=None, **kwargs):
"""
Check that the new WCS matches the current one
Parameters
----------
kwargs : dict
Passed to `wcs_utils.check_equality`
"""
if new_data is not None and not is_broadcastable_and_smaller(self._mask.shape,
new_data.shape):
raise ValueError("data shape cannot be broadcast to match mask shape")
if new_wcs is not None:
if new_wcs not in self._wcs_whitelist:
try:
if not wcs_utils.check_equality(new_wcs, self._wcs,
warn_missing=True,
**kwargs):
raise ValueError("WCS does not match mask WCS")
else:
self._wcs_whitelist.add(new_wcs)
except InconsistentAxisTypesError:
warnings.warn("Inconsistent axis type encountered; WCS is "
"invalid and therefore will not be checked "
"against other WCSes.",
WCSWarning
)
self._wcs_whitelist.add(new_wcs)
def _include(self, data=None, wcs=None, view=()):
result_mask = self._mask[view]
return result_mask if self._mask_type == 'include' else np.logical_not(result_mask)
def _exclude(self, data=None, wcs=None, view=()):
result_mask = self._mask[view]
return result_mask if self._mask_type == 'exclude' else np.logical_not(result_mask)
@property
def shape(self):
return self._shape
def __getitem__(self, view):
return BooleanArrayMask(self._mask[view],
wcs_utils.slice_wcs(self._wcs, view,
shape=self.shape,
drop_degenerate=True),
shape=self._mask[view].shape)
[docs]
def with_spectral_unit(self, unit, velocity_convention=None, rest_value=None):
"""
Get a BooleanArrayMask copy with a WCS in the modified unit
"""
newwcs = self._get_new_wcs(unit, velocity_convention, rest_value)
newmask = BooleanArrayMask(self._mask, newwcs,
include=self._mask_type=='include')
return newmask
with_spectral_unit.__doc__ += with_spectral_unit_docs
[docs]
class LazyMask(MaskBase):
"""
A boolean mask defined by the evaluation of a function on a fixed dataset.
This is conceptually identical to a fixed boolean mask as in
:class:`BooleanArrayMask` but defers the
evaluation of the mask until it is needed.
Parameters
----------
function : callable
The function to apply to ``data``. This method should accept
a numpy array, which will be a subset of the data array passed
to __init__. It should return a boolean array, where `True` values
indicate that which pixels are valid/unaffected by masking.
data : array-like
The array to evaluate ``function`` on. This should support Numpy-like
slicing syntax.
wcs : `~astropy.wcs.WCS`
The WCS of the input data, which is used to define the coordinates
for which the boolean mask is defined.
"""
def __init__(self, function, cube=None, data=None, wcs=None):
self._function = function
if cube is not None and (data is not None or wcs is not None):
raise ValueError("Pass only cube or (data & wcs)")
elif cube is not None:
self._data = cube._data
self._wcs = cube._wcs
elif data is not None and wcs is not None:
self._data = data
self._wcs = wcs
else:
raise ValueError("Either a cube or (data & wcs) is required.")
self._wcs_whitelist = set()
@property
def shape(self):
return self._data.shape
def _validate_wcs(self, new_data=None, new_wcs=None, **kwargs):
"""
Check that the new WCS matches the current one
Parameters
----------
kwargs : dict
Passed to `wcs_utils.check_equality`
"""
if new_data is not None:
if not is_broadcastable_and_smaller(new_data.shape, self._data.shape):
raise ValueError("data shape cannot be broadcast to match mask shape")
if new_wcs is not None:
if new_wcs not in self._wcs_whitelist:
if not wcs_utils.check_equality(new_wcs, self._wcs,
warn_missing=True, **kwargs):
raise ValueError("WCS does not match mask WCS")
else:
self._wcs_whitelist.add(new_wcs)
def _include(self, data=None, wcs=None, view=()):
self._validate_wcs(data, wcs)
return self._function(self._data[view])
def __getitem__(self, view):
return LazyMask(self._function, data=self._data[view],
wcs=wcs_utils.slice_wcs(self._wcs, view,
shape=self._data.shape,
drop_degenerate=True))
[docs]
def with_spectral_unit(self, unit, velocity_convention=None, rest_value=None):
"""
Get a LazyMask copy with a WCS in the modified unit
"""
newwcs = self._get_new_wcs(unit, velocity_convention, rest_value)
newmask = LazyMask(self._function, data=self._data, wcs=newwcs)
return newmask
with_spectral_unit.__doc__ += with_spectral_unit_docs
[docs]
class LazyComparisonMask(LazyMask):
"""
A boolean mask defined by the evaluation of a comparison function between a
fixed dataset and some other value.
This is conceptually similar to the :class:`LazyMask` but it will ensure
that the comparison value can be compared to the data
Parameters
----------
function : callable
The function to apply to ``data``. This method should accept
a numpy array, which will be the data array passed to __init__, and a
second argument also passed to __init__. It should return a boolean
array, where `True` values indicate that which pixels are
valid/unaffected by masking.
comparison_value : float or array
The comparison value for the array
data : array-like
The array to evaluate ``function`` on. This should support Numpy-like
slicing syntax.
wcs : `~astropy.wcs.WCS`
The WCS of the input data, which is used to define the coordinates
for which the boolean mask is defined.
"""
def __init__(self, function, comparison_value, cube=None, data=None,
wcs=None):
self._function = function
if cube is not None and (data is not None or wcs is not None):
raise ValueError("Pass only cube or (data & wcs)")
elif cube is not None:
self._data = cube._data
self._wcs = cube._wcs
elif data is not None and wcs is not None:
self._data = data
self._wcs = wcs
else:
raise ValueError("Either a cube or (data & wcs) is required.")
if (hasattr(comparison_value, 'shape') and not
is_broadcastable_and_smaller(self._data.shape,
comparison_value.shape)):
raise ValueError("The data and the comparison value cannot "
"be broadcast to match shape")
self._comparison_value = comparison_value
self._wcs_whitelist = set()
def _include(self, data=None, wcs=None, view=()):
self._validate_wcs(data, wcs)
if hasattr(self._comparison_value, 'shape') and self._comparison_value.shape:
cv_view = view_of_subset(self._comparison_value.shape,
self._data.shape, view)
return self._function(self._data[view],
self._comparison_value[cv_view])
else:
return self._function(self._data[view],
self._comparison_value)
def __getitem__(self, view):
if hasattr(self._comparison_value, 'shape') and self._comparison_value.shape:
cv_view = view_of_subset(self._comparison_value.shape,
self._data.shape, view)
return LazyComparisonMask(self._function, data=self._data[view],
comparison_value=self._comparison_value[cv_view],
wcs=wcs_utils.slice_wcs(self._wcs, view,
drop_degenerate=True))
else:
return LazyComparisonMask(self._function, data=self._data[view],
comparison_value=self._comparison_value,
wcs=wcs_utils.slice_wcs(self._wcs, view,
drop_degenerate=True))
[docs]
def with_spectral_unit(self, unit, velocity_convention=None, rest_value=None):
"""
Get a LazyComparisonMask copy with a WCS in the modified unit
"""
newwcs = self._get_new_wcs(unit, velocity_convention, rest_value)
newmask = LazyComparisonMask(self._function, data=self._data,
comparison_value=self._comparison_value,
wcs=newwcs)
return newmask
[docs]
class FunctionMask(MaskBase):
"""
A mask defined by a function that is evaluated at run-time using the data
passed to the mask.
This function differs from :class:`LazyMask` in the arguments which
are passed to the function. FunctionMasks receive an array,
wcs object, and view, whereas LazyMasks receive pre-sliced views
into an array specified at mask-creation time.
Parameters
----------
function : callable
The function to evaluate the mask. The call signature should be
``function(data, wcs, slice)`` where ``data`` and ``wcs`` are the
arguments that get passed to e.g. ``include``, ``exclude``,
``_filled``, and ``_flattened``. The function should return
a boolean array, where `True` values indicate that which pixels
are valid / unaffected by masking.
"""
def __init__(self, function):
self._function = function
def _validate_wcs(self, new_data=None, new_wcs=None, **kwargs):
pass
def _include(self, data=None, wcs=None, view=()):
result = self._function(data, wcs, view)
if result.shape != data[view].shape:
raise ValueError("Function did not return mask with correct shape - expected {0}, got {1}".format(data[view].shape, result.shape))
return result
def __getitem__(self, slice):
return self
[docs]
def with_spectral_unit(self, unit, velocity_convention=None, rest_value=None):
"""
Functional masks do not have WCS defined, so this simply returns a copy
of the current mask in order to be consistent with
``with_spectral_unit`` from other Masks
"""
return FunctionMask(self._function)