Source code for spike2py.plot

from typing import Literal, Tuple

import numpy as np
import matplotlib
import matplotlib.pyplot as plt


from spike2py import channels, trial
from spike2py.types import all_channels, ticksline_channels

LINE_WIDTH = 2
FIG_SIZE = (12, 4)
WAVEFORM_FIG_SIZE = (12, 8)
MAX_TRIAL_FIG_HEIGHT = 30
LEGEND_LOC = "upper right"
COLORS = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:red",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]
DEFAULT_COLOR = COLORS[0]

matplotlib.rcParams.update({"font.size": 14})


[docs]def plot_channel(spike2py_channel: all_channels, save: Literal[True, False]) -> None: """Plot individual channels. Parameters ---------- spike2py_channel: Instance of spike2py.channels.<ch> where possible ch are Event, Keyboard, Wavemark and Waveform save: Whether or not to save the generated figure. Returns ------- None """ if len(spike2py_channel.times) == 0: print("{spike2py_channel.info.name} channel has no data to plot.") return channel_type = repr(spike2py_channel).split()[0] if channel_type == "Waveform": # Split axis generation and plotting to allow reuse of plotting with trial plotting fig, ax = plt.subplots(figsize=WAVEFORM_FIG_SIZE) _plot_waveform(spike2py_channel, ax) else: ticks_line = _TicksLine(spike2py_channel) if ticks_line.ch_type == "Wavemark": fig, ax = plt.subplots( 1, 2, figsize=FIG_SIZE, gridspec_kw={"width_ratios": [3, 1]} ) ticks_line.plot(ax) else: fig, ax = plt.subplots(figsize=FIG_SIZE) ticks_line.plot(ax) title = ( f"{spike2py_channel.info.subject_id}_" f"{spike2py_channel.info.name}_" f"{spike2py_channel.info.file.name}" ) ax.text(0, 0, title, horizontalalignment='left', verticalalignment='top', transform=ax.transAxes, fontsize=10) if save: _save_plot(spike2py_channel.info) else: plt.show()
def _get_color(index: int) -> str: return COLORS[index % len(COLORS)] def _plot_waveform( waveform: "channels.Waveform", ax: plt.Axes, color: str = DEFAULT_COLOR, ) -> None: ax.plot(waveform.times, waveform.values, label=waveform.info.name, color=color) ax.set_xlim(waveform.times[0], waveform.times[-1]) units = waveform.info.units if waveform.info.units is True else "a.u." ax.set_ylabel(f"amplitude ({units})") ax.legend(loc=LEGEND_LOC) ax.set_xlabel("time (s)") ax.grid() def _save_plot(channel_info: "channels.ChannelInfo") -> None: try: fig_name = ( f"{channel_info.subject_id}_" f"{channel_info.trial_name}_" f"{channel_info.name}.pdf" ) except AttributeError: fig_name = f"{channel_info.subject_id}_" f"{channel_info.name}.pdf" fig_path = channel_info.path_save_figures / fig_name plt.savefig(fig_path, dpi='figure') plt.close() class _TicksLine: """Class that manages plotting of Event, Keyboard and Wavemark channels""" def __init__( self, ticks_line_channel: ticksline_channels, color: str = COLORS[0], y_offset: int = 0, ): """Initialise TicksLine Class Parameters ---------- ticks_line_channel: Instance of spike2py.channels.< > where possible channel types are Event, Keyboard, and Wavemark color: Set color of ticks and line y_offset: Offset of ticks and line in the y-direction """ self.ch = ticks_line_channel self.color = color self.offset = y_offset self.ch_type = repr(ticks_line_channel).split()[0] self.line_start_end = (self.ch.times[0], self.ch.times[-1]) self.line_y_vals = (0.5 + y_offset, 0.5 + y_offset) self.tick_y_vals = (0.2 + y_offset, 0.8 + y_offset) def plot(self, ax: plt.Axes): if isinstance(ax, np.ndarray): ax1 = ax[0] ax2 = ax[1] else: ax1 = ax ax2 = None self._plot_ticks_line(ax1) self._finalise_plot(ax1) if self.ch_type == "Keyboard": self._plot_codes(ax1) if self.ch_type == "Textmark": self._plot_codes(ax1) if (self.ch_type == "Wavemark") and ax2: self._plot_action_potentials(ax2) plt.tight_layout() def _plot_ticks_line(self, ax1: plt.Axes): for time in self.ch.times: ax1.plot( (time, time), self.tick_y_vals, linewidth=LINE_WIDTH, color=self.color ) ax1.plot( self.line_start_end, self.line_y_vals, linewidth=LINE_WIDTH, label=self.ch.info.name, color=self.color, ) def _plot_codes(self, ax1: plt.Axes): for time, code in zip(self.ch.times, self.ch.codes): ax1.text( time, self.tick_y_vals[1] + 0.2, code, color=self.color, fontsize=10 ) def _plot_action_potentials(self, ax2: plt.Axes): for action_potential in self.ch.action_potentials: ax2.plot(action_potential, color=self.color, alpha=0.5) ax2.get_yaxis().set_visible(False) ax2.get_xaxis().set_visible(False) def _finalise_plot(self, ax1: plt.Axes): ax1.legend(loc=LEGEND_LOC) ax1.set_xlabel("time (s)") ax1.get_yaxis().set_visible(False) ax1.grid()
[docs]def plot_trial(spike2py_trial: "trial.Trial", save: Literal[True, False]) -> None: fig_height, n_subplots = _fig_height_n_subplots(spike2py_trial) if n_subplots == 1: print( f"The trial `{spike2py_trial.info.name}` has only one plottable channel." "\nPlease use `trial_name.ch_name.plot()` instead." ) fig, ax = plt.subplots( sharex=True, nrows=n_subplots, figsize=(12, fig_height), gridspec_kw={"hspace": 0}, ) _plot_trial(spike2py_trial, ax) title = ( f"{spike2py_trial.info.subject_id}_" f"{spike2py_trial.info.name}_" f"{spike2py_trial.info.file.name}" ) ax[-1].text(0, 0, title, horizontalalignment='left', verticalalignment='top', transform=ax[-1].transAxes, fontsize=10) if save: _save_plot(spike2py_trial.info) else: plt.show()
def _fig_height_n_subplots(spike2py_trial: "trial.Trial") -> Tuple[int, int]: """Determine height and number of subplots to plot trial. Event, Keyboard and Wavemark channels are all plotted on same subplot at the top of the figure. Need to make sure these channels have data.""" fig_height = 4 n_subplots = 0 plottable_ticks_line = False for channel, channel_type in spike2py_trial.channels: current_channel = spike2py_trial.__getattribute__(channel) if ( (channel_type in ["event", "keyboard", "wavemark"]) and (not plottable_ticks_line) and (len(current_channel.times) != 0) ): fig_height += 2 plottable_ticks_line = True n_subplots += 1 elif channel_type == "waveform": fig_height += 2 n_subplots += 1 fig_height = min(fig_height, MAX_TRIAL_FIG_HEIGHT) return fig_height, n_subplots def _plot_trial(spike2py_trial: "trial.Trial", ax: plt.Axes): waveform_counter = 1 other_ch_counter = 0 n_subplots = len(ax) for channel, channel_type in spike2py_trial.channels: current_channel = spike2py_trial.__getattribute__(channel) if channel_type == "waveform": _plot_waveform( waveform=current_channel, ax=ax[n_subplots - waveform_counter], color=_get_color(waveform_counter - 1), ) waveform_counter += 1 elif len(current_channel.times) != 0: ticks_line = _TicksLine( ticks_line_channel=current_channel, color=_get_color(other_ch_counter), y_offset=other_ch_counter, ) ticks_line.plot(ax[0]) other_ch_counter += 1