Source code for glue.viewers.matplotlib.viewer

from textwrap import indent

import numpy as np

from matplotlib.patches import Rectangle
from matplotlib.artist import setp as msetp

from glue.config import settings
from glue.viewers.common.viewer import Viewer
from glue.viewers.matplotlib.mpl_axes import update_appearance_from_settings, init_mpl
from echo import delay_callback
from glue.utils import mpl_to_datetime64

__all__ = ['MatplotlibViewerMixin', 'SimpleMatplotlibViewer']

SCRIPT_HEADER = """
# Initialize figure
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, aspect='{aspect}', projection={projection})

# for the legend
legend_handles = []
legend_labels = []
legend_handler_dict = dict()
""".strip()

LIMIT_SCRIPT = """
# Set limits
ax.set_xlim(left={x_min}, right={x_max})
ax.set_ylim(bottom={y_min}, top={y_max})
"""

SCRIPT_FOOTER = """
{limit_script}

{scale_script}

# Set axis label properties
ax.set_xlabel('{x_axislabel}', weight='{x_axislabel_weight}', size={x_axislabel_size})
ax.set_ylabel('{y_axislabel}', weight='{y_axislabel_weight}', size={y_axislabel_size})

# Set tick label properties
ax.tick_params('x', labelsize={x_ticklabel_size})
ax.tick_params('y', labelsize={y_ticklabel_size})

# Set the figure colors
# To create the figure with the default settings, comment out the next line
# Alternatively, to create the figure with colors from the current glue environment:
#  - from glue.config import settings
#  - set_figure_colors(ax, background=settings.BACKGROUND_COLOR, foreground=settings.FOREGROUND_COLOR
set_figure_colors(ax, background={background_color}, foreground={foreground_color})

# For manual edition of the plot
#  - Uncomment the next code line (plt.show)
#  - Also change the matplotlib backend to qt5Agg
#  - And comment the "plt.close" line
# plt.show()

# Save figure
fig.savefig('glue_plot.png')
plt.close(fig)
""".strip()

SCRIPT_LEGEND = """
ax.legend(legend_handles, legend_labels,
    handler_map=legend_handler_dict,
    loc='{location}',            # location
    framealpha={alpha:.2f},      # opacity of the frame
    title='{title}',             # title of the legend
    title_fontsize={fontsize},   # fontsize of the title
    fontsize={fontsize},          # fontsize of the labels
    facecolor='{frame_color}',
    labelcolor='{label_color}',
    edgecolor={edge_color}
)
""".strip()

ZORDER_MAX = 100000


[docs]class MatplotlibViewerMixin(object):
[docs] def setup_callbacks(self): for spine in self.axes.spines.values(): spine.set_zorder(ZORDER_MAX) self.loading_rectangle = Rectangle((0, 0), 1, 1, color='0.9', alpha=0.9, zorder=ZORDER_MAX - 1, transform=self.axes.transAxes) self.loading_rectangle.set_visible(False) self.axes.add_patch(self.loading_rectangle) self.loading_text = self.axes.text(0.4, 0.5, 'Computing', color='k', zorder=self.loading_rectangle.get_zorder() + 1, ha='left', va='center', transform=self.axes.transAxes) self.loading_text.set_visible(False) self.state.add_callback('x_min', self.limits_to_mpl) self.state.add_callback('x_max', self.limits_to_mpl) self.state.add_callback('y_min', self.limits_to_mpl) self.state.add_callback('y_max', self.limits_to_mpl) if (self.state.x_min or self.state.x_max or self.state.y_min or self.state.y_max) is None: self.limits_from_mpl() else: self.limits_to_mpl() self.state.add_callback('x_log', self.update_x_log, priority=1000) self.state.add_callback('y_log', self.update_y_log, priority=1000) self.update_x_log() self.update_y_log() self.axes.callbacks.connect('xlim_changed', self.limits_from_mpl) self.axes.callbacks.connect('ylim_changed', self.limits_from_mpl) self.figure.canvas.mpl_connect('resize_event', self._on_resize) self._on_resize() self.axes.set_autoscale_on(False) self.state.add_callback('x_axislabel', self.update_x_axislabel) self.state.add_callback('x_axislabel_weight', self.update_x_axislabel) self.state.add_callback('x_axislabel_size', self.update_x_axislabel) self.state.add_callback('y_axislabel', self.update_y_axislabel) self.state.add_callback('y_axislabel_weight', self.update_y_axislabel) self.state.add_callback('y_axislabel_size', self.update_y_axislabel) self.state.add_callback('x_ticklabel_size', self.update_x_ticklabel) self.state.add_callback('y_ticklabel_size', self.update_y_ticklabel) self.state.legend.add_callback('visible', self.draw_legend) self.state.legend.add_callback('location', self.draw_legend) self.state.legend.add_callback('alpha', self.update_legend) self.state.legend.add_callback('title', self.draw_legend) self.state.legend.add_callback('fontsize', self.draw_legend) self.state.legend.add_callback('frame_color', self.update_legend) self.state.legend.add_callback('show_edge', self.update_legend) self.state.legend.add_callback('text_color', self.update_legend) self.update_x_axislabel() self.update_y_axislabel() self.update_x_ticklabel() self.update_y_ticklabel() self.draw_legend()
def _update_computation(self, message=None): pass
[docs] def update_x_axislabel(self, *event): self.axes.set_xlabel(self.state.x_axislabel, weight=self.state.x_axislabel_weight, size=self.state.x_axislabel_size) self.redraw()
[docs] def update_y_axislabel(self, *event): self.axes.set_ylabel(self.state.y_axislabel, weight=self.state.y_axislabel_weight, size=self.state.y_axislabel_size) self.redraw()
[docs] def update_x_ticklabel(self, *event): self.axes.tick_params(axis='x', labelsize=self.state.x_ticklabel_size) self.axes.xaxis.get_offset_text().set_fontsize(self.state.x_ticklabel_size) self.redraw()
[docs] def update_y_ticklabel(self, *event): self.axes.tick_params(axis='y', labelsize=self.state.y_ticklabel_size) self.axes.yaxis.get_offset_text().set_fontsize(self.state.y_ticklabel_size) self.redraw()
[docs] def get_handles_legend(self): """Collect the handles and labels from each layer artist.""" handles = [] labels = [] handler_dict = {} for layer_artist in self._layer_artist_container: handle, label, handler = layer_artist.get_handle_legend() if handle is not None: handles.append(handle) labels.append(label) if handler is not None: handler_dict[handle] = handler return handles, labels, handler_dict
def _update_legend_visual(self, legend): """Update the legend colors and opacity. No redraw.""" msetp(legend.get_title(), color=self.state.legend.text_color) msetp(legend.get_texts(), color=self.state.legend.text_color) msetp(legend.get_frame(), alpha=self.state.legend.alpha, facecolor=self.state.legend.frame_color, edgecolor=self.state.legend.edge_color )
[docs] def update_legend(self, *args): """Update the legend colors and opacity.""" legend = self.axes.get_legend() if legend is not None: self._update_legend_visual(legend) self.redraw()
[docs] def draw_legend(self, *args): if self.state.legend.visible: handles, labels, handler_map = self.get_handles_legend() kwargs = dict(loc=self.state.legend.mpl_location, title=self.state.legend.title, title_fontsize=self.state.legend.fontsize, fontsize=self.state.legend.fontsize) if handler_map is not None: kwargs["handler_map"] = handler_map legend = self.axes.legend(handles, labels, **kwargs) self._update_legend_visual(legend) legend.set_draggable(self.state.legend.draggable) else: legend = self.axes.get_legend() if legend is not None: legend.remove() self.redraw()
[docs] def redraw(self): self.figure.canvas.draw_idle()
[docs] def update_x_log(self, *args): self.axes.set_xscale('log' if self.state.x_log else 'linear') self.redraw()
[docs] def update_y_log(self, *args): self.axes.set_yscale('log' if self.state.y_log else 'linear') self.redraw()
[docs] def limits_from_mpl(self, *args, **kwargs): if getattr(self, '_skip_limits_from_mpl', False): return if isinstance(self.state.x_min, np.datetime64): x_min, x_max = [mpl_to_datetime64(x) for x in self.axes.get_xlim()] else: x_min, x_max = self.axes.get_xlim() if isinstance(self.state.y_min, np.datetime64): y_min, y_max = [mpl_to_datetime64(y) for y in self.axes.get_ylim()] else: y_min, y_max = self.axes.get_ylim() with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'): self.state.x_min = x_min self.state.x_max = x_max self.state.y_min = y_min self.state.y_max = y_max
[docs] def limits_to_mpl(self, *args): if (self.state.x_min is None or self.state.x_max is None or self.state.y_min is None or self.state.y_max is None): return x_min, x_max = self.state.x_min, self.state.x_max if self.state.x_log: if self.state.x_max <= 0: x_min, x_max = 0.1, 1 elif self.state.x_min <= 0: x_min = x_max / 10 y_min, y_max = self.state.y_min, self.state.y_max if self.state.y_log: if self.state.y_max <= 0: y_min, y_max = 0.1, 1 elif self.state.y_min <= 0: y_min = y_max / 10 # Since we deal with aspect ratio internally, there are no conditions # under which we would want to immediately change the state once the # limits are set in Matplotlib, so we avoid this by setting the # _skip_limits_from_mpl attribute which is then recognized inside # limits_from_mpl self._skip_limits_from_mpl = True try: self.axes.set_xlim(x_min, x_max) self.axes.set_ylim(y_min, y_max) self.axes.figure.canvas.draw_idle() finally: self._skip_limits_from_mpl = False
[docs] @property def axes_ratio(self): # Get figure aspect ratio width, height = self.figure.get_size_inches() fig_aspect = height / width # Get axes aspect ratio l, b, w, h = self.axes.get_position().bounds axes_ratio = fig_aspect * (h / w) return axes_ratio
def _on_resize(self, *args): self.state._set_axes_aspect_ratio(self.axes_ratio) def _update_appearance_from_settings(self, message=None): update_appearance_from_settings(self.axes) self.redraw()
[docs] def get_layer_artist(self, cls, layer=None, layer_state=None): return cls(self.axes, self.state, layer=layer, layer_state=layer_state)
[docs] def apply_roi(self, roi, override_mode=None): """ This method must be implemented by subclasses """ raise NotImplementedError
def _script_header(self): state_dict = self.state.as_dict() proj = "'" + self.state.plot_mode + "'" if hasattr(self.state, 'plot_mode') else "None" state_dict['projection'] = proj return ['import matplotlib', "matplotlib.use('Agg')", "# matplotlib.use('qt5Agg')", 'import matplotlib.pyplot as plt', 'from glue.viewers.matplotlib.mpl_axes import set_figure_colors'], SCRIPT_HEADER.format(**state_dict) def _script_footer(self): mode = 'rectilinear' if not hasattr(self.state, 'plot_mode') else self.state.plot_mode state_dict = self.state.as_dict() temp_str = '' if mode not in ['aitoff', 'hammer', 'lambert', 'mollweide', 'polar']: x_log_str = 'log' if self.state.x_log else 'linear' temp_str += "ax.set_xscale('" + x_log_str + "')\n" if mode not in ['aitoff', 'hammer', 'lambert', 'mollweide', 'polar']: y_log_str = 'log' if self.state.y_log else 'linear' temp_str += "ax.set_yscale('" + y_log_str + "')\n" if mode == 'polar': state_dict['x_axislabel'] = '' state_dict['y_axislabel'] = '' state_dict['background_color'] = "\"{color}\"".format(color=settings.BACKGROUND_COLOR) state_dict['foreground_color'] = "\"{color}\"".format(color=settings.FOREGROUND_COLOR) state_dict['scale_script'] = "# Set scale (log or linear)\n" + temp_str if temp_str else '' full_sphere = ['aitoff', 'hammer', 'lambert', 'mollweide'] state_dict['limit_script'] = '' if mode in full_sphere else LIMIT_SCRIPT.format(**state_dict) return [], SCRIPT_FOOTER.format(**state_dict) def _script_legend(self): state_dict = self.state.legend.as_dict() state_dict['location'] = self.state.legend.mpl_location state_dict['edge_color'] = self.state.legend.edge_color state_dict['label_color'] = settings.FOREGROUND_COLOR legend_str = SCRIPT_LEGEND.format(**state_dict) if not self.state.legend.visible: legend_str = indent(legend_str, "# ") return [], legend_str
[docs]class SimpleMatplotlibViewer(MatplotlibViewerMixin, Viewer): def __init__(self, session, parent=None, wcs=None, state=None, projection=None): super().__init__(session, state=state) self.figure, self.axes = init_mpl(wcs=wcs, projection=projection) MatplotlibViewerMixin.setup_callbacks(self)