Source code for glue.viewers.image.viewer

import os

from astropy.wcs import WCS

from glue.core.subset import roi_to_subset_state
from glue.core.coordinates import Coordinates, LegacyCoordinates
from glue.core.coordinate_helpers import dependent_axes
from glue.core.data_region import RegionData

from glue.viewers.matplotlib.viewer import SimpleMatplotlibViewer
from glue.viewers.scatter.layer_artist import ScatterLayerArtist, ScatterRegionLayerArtist
from glue.viewers.image.layer_artist import ImageLayerArtist, ImageSubsetLayerArtist
from glue.viewers.image.compat import update_image_viewer_state
from glue.viewers.image.state import ImageViewerState

from glue.viewers.image.frb_artist import imshow
from glue.viewers.image.composite_array import CompositeArray

__all__ = ['MatplotlibImageMixin', 'SimpleImageViewer']


def get_identity_wcs(naxis):

    wcs = WCS(naxis=naxis)
    wcs.wcs.ctype = ['X'] * naxis
    wcs.wcs.crval = [0.] * naxis
    wcs.wcs.crpix = [1.] * naxis
    wcs.wcs.cdelt = [1.] * naxis

    return wcs


EXTRA_FOOTER = """
# Set tick label size - for now tick_params (called lower down) doesn't work
# properly, but these lines won't be needed in future.
ax.coords[{x_att_axis}].set_ticklabel(size={x_ticklabel_size})
ax.coords[{y_att_axis}].set_ticklabel(size={y_ticklabel_size})
""".strip()


[docs]class MatplotlibImageMixin(object):
[docs] def setup_callbacks(self): self._wcs_set = False self._changing_slice_requires_wcs_update = None self.axes.set_adjustable('datalim') self.state.add_callback('x_att', self._set_wcs) self.state.add_callback('y_att', self._set_wcs) self.state.add_callback('slices', self._on_slice_change) self.state.add_callback('reference_data', self._set_wcs, echo_old=True) self.axes._composite = CompositeArray() self.axes._composite_image = imshow(self.axes, self.axes._composite, aspect='auto', origin='lower', interpolation='nearest') self._set_wcs()
[docs] def update_x_ticklabel(self, *event): # We need to overload this here for WCSAxes if hasattr(self, '_wcs_set') and self._wcs_set and self.state.x_att is not None: axis = self.state.reference_data.ndim - self.state.x_att.axis - 1 else: axis = 0 self.axes.coords[axis].set_ticklabel(size=self.state.x_ticklabel_size) self.redraw()
[docs] def update_y_ticklabel(self, *event): # We need to overload this here for WCSAxes if hasattr(self, '_wcs_set') and self._wcs_set and self.state.y_att is not None: axis = self.state.reference_data.ndim - self.state.y_att.axis - 1 else: axis = 1 self.axes.coords[axis].set_ticklabel(size=self.state.y_ticklabel_size) self.redraw()
def _update_axes(self, *args): if self.state.x_att_world is not None: self.state.x_axislabel = self.state.x_att_world.label if self.state.y_att_world is not None: self.state.y_axislabel = self.state.y_att_world.label self.axes.figure.canvas.draw_idle()
[docs] def add_data(self, data): result = super(MatplotlibImageMixin, self).add_data(data) # If this is the first layer (or the first after all layers were) # removed, set the WCS for the axes. if len(self.layers) == 1: self._set_wcs() return result
def _update_data_numerical(self, *args, **kwargs): super()._update_data_numerical(*args, **kwargs) self.state._reference_data_changed(force=True) def _on_slice_change(self, event=None): if self._changing_slice_requires_wcs_update: self._set_wcs(relim=False) def _set_wcs(self, before=None, after=None, relim=True): if self.state.x_att is None or self.state.y_att is None or self.state.reference_data is None: return # A callback event for reference_data is triggered if the choices change # but the actual selection doesn't - so we avoid resetting the WCS in # this case. if after is not None and before is after: return ref_coords = getattr(self.state.reference_data, 'coords', None) if ref_coords is None or isinstance(ref_coords, LegacyCoordinates): self.axes.reset_wcs(slices=self.state.wcsaxes_slice, wcs=get_identity_wcs(self.state.reference_data.ndim)) else: self.axes.reset_wcs(slices=self.state.wcsaxes_slice, wcs=ref_coords) # Reset the axis labels to match the fact that the new axes have no labels self.state.x_axislabel = '' self.state.y_axislabel = '' self._update_appearance_from_settings() self._update_axes() self.update_x_ticklabel() self.update_y_ticklabel() if relim: self.state.reset_limits() # Determine whether changing slices requires changing the WCS if ref_coords is None or type(ref_coords) is Coordinates: self._changing_slice_requires_wcs_update = False else: ix = self.state.x_att.axis iy = self.state.y_att.axis x_dep = list(dependent_axes(ref_coords, ix)) y_dep = list(dependent_axes(ref_coords, iy)) if ix in x_dep: x_dep.remove(ix) if iy in x_dep: x_dep.remove(iy) if ix in y_dep: y_dep.remove(ix) if iy in y_dep: y_dep.remove(iy) self._changing_slice_requires_wcs_update = bool(x_dep or y_dep) self._wcs_set = True
[docs] def apply_roi(self, roi, override_mode=None): # Force redraw to get rid of ROI. We do this because applying the # subset state below might end up not having an effect on the viewer, # for example there may not be any layers, or the active subset may not # be one of the layers. So we just explicitly redraw here to make sure # a redraw will happen after this method is called. self.redraw() if len(self.layers) == 0: return if self.state.x_att is None or self.state.y_att is None or self.state.reference_data is None: return subset_state = roi_to_subset_state(roi, x_att=self.state.x_att, y_att=self.state.y_att) self.apply_subset_state(subset_state, override_mode=override_mode)
def _scatter_artist(self, axes, state, layer=None, layer_state=None): if len(self._layer_artist_container) == 0: raise Exception("Can only add a scatter plot overlay once an image is present") return ScatterLayerArtist(axes, state, layer=layer, layer_state=None) def _region_artist(self, axes, state, layer=None, layer_state=None): if len(self._layer_artist_container) == 0: raise Exception("Can only add a region plot overlay once an image is present") return ScatterRegionLayerArtist(axes, state, layer=layer, layer_state=None)
[docs] def get_data_layer_artist(self, layer=None, layer_state=None): if isinstance(layer, RegionData): cls = self._region_artist elif layer.ndim == 1: cls = self._scatter_artist else: cls = ImageLayerArtist return self.get_layer_artist(cls, layer=layer, layer_state=layer_state)
[docs] def get_subset_layer_artist(self, layer=None, layer_state=None): if isinstance(layer.data, RegionData): cls = self._region_artist elif layer.ndim == 1: cls = self._scatter_artist else: cls = ImageSubsetLayerArtist return self.get_layer_artist(cls, layer=layer, layer_state=layer_state)
[docs] @staticmethod def update_viewer_state(rec, context): return update_image_viewer_state(rec, context)
[docs] def show_crosshairs(self, x, y): if getattr(self, '_crosshairs', None) is not None: self._crosshairs.remove() self._crosshairs, = self.axes.plot([x], [y], '+', ms=12, mfc='none', mec='#d32d26', mew=1, zorder=100) self.axes.figure.canvas.draw_idle()
[docs] def hide_crosshairs(self): if getattr(self, '_crosshairs', None) is not None: self._crosshairs.remove() self._crosshairs = None self.axes.figure.canvas.draw_idle()
def _script_header(self): imports = [] imports.append('import matplotlib.pyplot as plt') imports.append('from glue.viewers.matplotlib.mpl_axes import init_mpl') imports.append('from glue.viewers.image.composite_array import CompositeArray') imports.append('from glue.viewers.image.frb_artist import imshow') imports.append('from glue.viewers.matplotlib.mpl_axes import set_figure_colors') script = "" script += "fig, ax = init_mpl(wcs=True)\n" script += f"ax.set_aspect('{self.state.aspect}')\n" script += '\ncomposite = CompositeArray()\n' script += f"image = imshow(ax, composite, origin='lower', interpolation='nearest', aspect='{self.state.aspect}')\n\n" dindex = self.session.data_collection.index(self.state.reference_data) script += f"ref_data = data_collection[{dindex}]\n" if isinstance(self.state.reference_data.coords, (LegacyCoordinates, type(None))): imports.append('from glue.viewers.image.viewer import get_identity_wcs') ref_wcs = "get_identity_wcs(ref_data.ndim)" else: ref_wcs = "ref_data.coords" script += f"ax.reset_wcs(slices={self.state.wcsaxes_slice}, wcs={ref_wcs})\n" script += "# for the legend\n" script += "legend_handles = []\n" script += "legend_labels = []\n" script += "legend_handler_dict = dict()\n\n" return imports, script def _script_footer(self): imports, script = super(MatplotlibImageMixin, self)._script_footer() options = dict(x_att_axis=0 if self.state.x_att is None else self.state.reference_data.ndim - self.state.x_att.axis - 1, y_att_axis=1 if self.state.y_att is None else self.state.reference_data.ndim - self.state.y_att.axis - 1, x_ticklabel_size=self.state.x_ticklabel_size, y_ticklabel_size=self.state.y_ticklabel_size) return [], EXTRA_FOOTER.format(**options) + os.linesep * 2 + script
[docs]class SimpleImageViewer(MatplotlibImageMixin, SimpleMatplotlibViewer): _state_cls = ImageViewerState def __init__(self, *args, **kwargs): kwargs['wcs'] = True super().__init__(*args, **kwargs) MatplotlibImageMixin.setup_callbacks(self)