Source code for spike2py.sig_proc

from typing import List, Literal

import numpy as np
from scipy.signal import butter, filtfilt, detrend

from spike2py.types import filt_cutoff_single, filt_cutoff_pair, filt_cutoff


[docs]class SignalProcessing: """Mixin class that adds signal processing methods""" def _setattr(self, name: str): setattr(self, name, self.values)
[docs] def remove_mean(self, first_n_samples: int = None): """Subtract mean of first n samples (default is all samples)""" values_slice = slice(0, -1) if first_n_samples is not None: if (first_n_samples < 1) or (first_n_samples > len(self.values)): raise ValueError( "first_n_samples must be between 1 and " f"the length of the signal (i.e. {len(self.values)})." ) if not isinstance(first_n_samples, int): raise TypeError("first_n_samples must be a whole number, an integer") values_slice = slice(0, first_n_samples) self.values -= np.mean(self.values[values_slice]) self._setattr("proc_remove_mean") return self
[docs] def remove_value(self, value: float): """Subtracts value from `values`""" try: self.values -= value str_value = self._float_to_string_with_underscore(value) self._setattr(f"proc_remove_value_{str_value}") return self except np.core._exceptions.UFuncTypeError: raise TypeError("`value` must be a whole number or a decimal number.")
def _float_to_string_with_underscore(self, float_value: float): return str(abs(float_value)).replace(".", "_")
[docs] def lowpass(self, cutoff: filt_cutoff_single, order: int = 4): """Apply dual-pass Butterworth lowpass filter to `values`""" self._filt(cutoff, order, "lowpass") return self
[docs] def highpass(self, cutoff: filt_cutoff_single, order: int = 4): """Apply dual-pass Butterworth highpass filter to `values`""" self._filt(cutoff, order, "highpass") return self
[docs] def bandpass(self, cutoff: filt_cutoff_pair, order: int = 4): """Apply dual-pass Butterworth bandpass filter to `values`""" self._filt(cutoff, order, "bandpass") return self
[docs] def bandstop(self, cutoff: filt_cutoff_pair, order: int = 4): """Apply dual-pass Butterworth bandstop filter to `values`""" self._filt(cutoff, order, "bandstop") return self
def _filt( self, cutoff: filt_cutoff, order: int, filt_type: Literal["lowpass", "highpass", "bandstop", "bandpass"], ): cutoff_1d_array = self._convert_cutoff_to_1d_array(cutoff) self._check_valid_cutoff(cutoff_1d_array) self._check_valid_filter_order(order) critical_fq = cutoff_1d_array / (self.info.sampling_frequency / 2) filt_coef_b, filt_coef_a = butter(order, critical_fq, filt_type) self.values = filtfilt(filt_coef_b, filt_coef_a, self.values) self._setattr( f"proc_filt_{self._cutoff_to_string(cutoff_1d_array)}_{filt_type}" ) def _convert_cutoff_to_1d_array(self, cutoff: filt_cutoff) -> np.ndarray: if isinstance(cutoff, list): return np.array(cutoff) else: return np.array([cutoff]) def _check_valid_cutoff(self, cutoff: np.ndarray): nyquist_fq = self.info.sampling_frequency / 2 for value in cutoff: if (value <= 0) or (value > nyquist_fq): raise ValueError( f"Filter cutoff frequency must be between 0 and " f"{int(self.info.sampling_frequency/2)}" ) def _check_valid_filter_order(self, order: int): if order not in range(1, 17): raise ValueError("Filter order must be a whole number between 1 and 16") def _cutoff_to_string(self, cutoff: np.ndarray) -> str: if len(cutoff) == 2: low = self._float_to_string_with_underscore(cutoff[0]) high = self._float_to_string_with_underscore(cutoff[1]) return f"{low}_{high}" else: return self._float_to_string_with_underscore(cutoff[0])
[docs] def calibrate(self, slope: float, offset: float = None): """Calibrate `values` using linear formula y=slope*x+offset""" if not offset: self.values = self.values * slope if slope and offset: self.values = (self.values * slope) - offset self._setattr("proc_calib") return self
[docs] def norm_percentage(self): """Normalise `values` to be between 0-100%""" self.values = (self.values / np.max(self.values)) * 100 self._setattr("proc_norm_percentage") return self
[docs] def norm_proportion(self): """Normalise `values` to be between 0-1""" self.values = self.values / np.max(self.values) self._setattr("proc_norm_proportion") return self
[docs] def norm_percent_value(self, value: float): """Normalise `values` to a percentage of `value`""" self.values = (self.values / value) * 100 self._setattr("proc_norm_value") return self
[docs] def rect(self): """Rectify values""" self.values = abs(self.values) self._setattr("proc_rect") return self
[docs] def interp_new_times(self, new_times: List[float]): """Interpolate `values` to a new time axis Parameters ---------- new_times New time axis for interpolated data. Cannot be longer in duration than current time axis, `times`. If includes only a portion of the current time axis, only values associated with that portion of the time axis will be interpolated. """ self._check_new_times(new_times) self._interp(new_times) self._setattr("proc_interp_new_times") return self
def _check_new_times(self, new_times: List[float]): if new_times[-1] > self.times[-1]: raise ValueError( "New time axis for interpolation cannot be longer" "in duration than current time axis." )
[docs] def interp_new_fs(self, new_sampling_frequency: int): """Interpolate `values` to a new sampling frequency""" new_times = np.arange( start=self.times[0], stop=self.times[-1], step=1 / new_sampling_frequency ) self._interp(new_times) self._setattr("proc_interp_new_fs") return self
def _interp(self, new_times: List[float]): self.values = np.interp(x=new_times, xp=self.times, fp=self.values) self.times_pre_interp = self.times self.times = new_times
[docs] def linear_detrend(self): """Remove linear trend from `values`""" self.values = detrend(self.values, type="linear") self._setattr("proc_linear_detrend") return self