Source code for spike2py.read

import sys
from pathlib import Path
import textwrap
from typing import List, Final

import scipy.io as sio
import numpy as np

from spike2py.types import (
    mat_data,
    parsed_wavemark,
    parsed_waveform,
    parsed_event,
    parsed_textmark,
    parsed_keyboard,
    parsed_spike2py_data,
    parsed_mat_data,
)

CHANNEL_DATA_LENGTH: Final = {
    "event": 5,
    "keyboard": 6,
    "textmark": 8,
    "waveform": 10,
    "wavemark": 14,
}


class WrongFileType(Exception):
    """Custom exception to use when `.mat` file not provided"""

    pass


[docs]def read(file: Path, channels: List[str] = None) -> parsed_spike2py_data: """Interface to read data files Parameters ---------- file Absolute path to data file. Only .mat files are currently supported. channels List of channel names, as they appeared in the original .smr file. Example: ['biceps', 'triceps', 'torque'] If not included, all channels will be processed. Raises ------ WrongFileType `file` parameter is not a `.mat` file Returns ------- dict Data from a trial where `keys` are channel names and `values` are deeply nested numpy.ndarray """ file_extension = Path(file).suffix if file_extension != ".mat": raise WrongFileType( f"Processing {file_extension} files is not supported." "\nIn Spike2 export the data to .mat and start over." ) return _parse_mat_data(_read_mat(file, channels))
def _read_mat(mat_file: Path, channels: List[str]) -> mat_data: """Read Spike2 data exported to a Matlab .mat file Parameters ---------- mat_file Absolute path to .mat file channels List of channel names, as they appeared in the original .smr file, or None, in which case all channels are processed. Returns ------- dict Requested channels with channel names as `keys` and deeply nested arrays containing channel data as `values`. """ try: data: dict = sio.loadmat(mat_file) except OSError: print( f"File {mat_file.name} not found. Please verify path and file name and try again." ) sys.exit(1) all_channels = [ data_key for data_key in data.keys() if not data_key.startswith("__") ] if channels is None: channels = all_channels else: _verify_channels_exists(channels, all_channels, mat_file) return {key: value for (key, value) in data.items() if key in channels} def _verify_channels_exists(channels, all_channels, mat_file): for channel in channels: if channel not in all_channels: print( f"Channel {channel} does not exist in {mat_file.name}. \n" f"Available channels include:\n" ) for ch in all_channels: print(ch) sys.exit(1) def _parse_mat_data(mat_data: mat_data) -> parsed_mat_data: """Parse deeply nested array that contain channel data Parameters ---------- mat_data Deeply nested array containing channel data and metadata. Returns ------- dict Channel data and metadata. The `keys` and `values` will differ for the different channel types. See the `_parse_mat_<channel type>` helper functions for details. """ parser_lookup = { CHANNEL_DATA_LENGTH["event"]: _parse_mat_events, CHANNEL_DATA_LENGTH["textmark"]: _parse_mat_textmark, CHANNEL_DATA_LENGTH["keyboard"]: _parse_mat_keyboard, CHANNEL_DATA_LENGTH["waveform"]: _parse_mat_waveform, CHANNEL_DATA_LENGTH["wavemark"]: _parse_mat_wavemark, } parsed_data = dict() for key, value in mat_data.items(): parsed_data[key] = parser_lookup[len(value.dtype)](value) return parsed_data def _parse_mat_events(mat_events: np.ndarray) -> parsed_event: """Parse event channel data as exported by Spike2 to .mat Parameters ---------- mat_events Deeply nested array containing waveform channel data and metadata. Returns ------- dict Data from event channel. """ return { "times": _flatten_array(mat_events["times"]), "ch_type": "event", } def _flatten_array(array): return array[0][0].flatten() def _parse_mat_keyboard(mat_keyboard: np.ndarray) -> parsed_keyboard: """Parse keyboard channel data as exported by Spike2 to .mat Parameters ---------- mat_keyboard Deeply nested array containing keyboard channel data and metadata. Returns ------- dict Data from keyboard channel. """ keyboard_codes = _flatten_array(mat_keyboard["codes"]) characters = None if len(keyboard_codes) != 0: characters = _keyboard_codes_to_characters(keyboard_codes) return { "codes": characters, "times": _flatten_array(mat_keyboard["times"]), "ch_type": "keyboard", } def _parse_mat_textmark(mat_textmark: np.ndarray) -> parsed_textmark: """Parse textmark ('Memory') channel data as exported by Spike2 to .mat Parameters ---------- mat_textmark Deeply nested array containing keyboard channel data and metadata. Returns ------- dict Data from textmark channel. """ codes = list(mat_textmark["text"][0][0]) return { "codes": codes, "times": _flatten_array(mat_textmark["times"]), "ch_type": "textmark", } def _keyboard_codes_to_characters(keyboard_codes: List[int]) -> List[str]: """Helper function that converts encoded character(s) into list of str Parameters ---------- keyboard_codes List of int values, where each keyboard entry is encoded by four int values. e.g. single keyboard entry: [42, 0, 0, 0] e.g. multi keyboard entries: [42, 0, 0, 0, 57, 0, 0, 0, 73, 0, 0, 0] Returns ------- list List of str values, corresponding to each of the keyboard entries. """ hex_keyboard_codes = textwrap.fill(keyboard_codes.tobytes().hex(), 8).split("\n") return [ bytearray.fromhex(hex_code[0:8][:2]).decode() for hex_code in hex_keyboard_codes ] def _parse_mat_waveform(mat_waveform: np.ndarray) -> parsed_waveform: """Parse waveform channel data as exported by Spike2 to .mat Parameters ---------- mat_waveform Deeply nested array containing waveform channel data and metadata. Returns ------- dict Data from waveform channel. """ units_flattened = _flatten_array(mat_waveform["units"]) units = None if units_flattened.size > 0: units = units_flattened[0] times = _flatten_array(mat_waveform["times"]) values = _flatten_array(mat_waveform["values"]) shortest_array = min(len(times), len(values)) return { "times": times[:shortest_array], "units": units, "values": values[:shortest_array], "sampling_frequency": int(1 / _flatten_array(mat_waveform["interval"])[0]), "ch_type": "waveform", } def _parse_mat_wavemark(mat_wavemark: np.ndarray) -> parsed_wavemark: """Parse wavemark channel data as exported by Spike2 to .mat Parameters ---------- mat_wavemark Deeply nested array containing wavemark channel data and metadata. Returns ------- dict Data from wavemark channel. """ units = None times = None sampling_frequency = None action_potentials = None units_flattened = _flatten_array(mat_wavemark["units"]) if units_flattened.size > 0: units = units_flattened[0] times = mat_wavemark["times"][0][0].flatten() sampling_frequency = int(1 / mat_wavemark["interval"][0][0].flatten()[0]) action_potentials = _extract_wavemarks(mat_wavemark) return { "units": units, "times": times, "sampling_frequency": sampling_frequency, "action_potentials": action_potentials, "ch_type": "wavemark", } def _extract_wavemarks(mat_wavemark: np.ndarray) -> List[List[int]]: """Helper function to flatten, extract and group wavemark values""" template_length = int(_flatten_array(mat_wavemark["length"])[0]) concatenated_wavemarks = _flatten_array(mat_wavemark["values"]) number_of_wavemarks = int(len(concatenated_wavemarks) / template_length) return concatenated_wavemarks.reshape(template_length, number_of_wavemarks)