"""
YASA (Yet Another Spindle Algorithm): fast and robust detection of spindles,
slow-waves, and rapid eye movements from sleep EEG recordings.
- Author: Raphael Vallat (www.raphaelvallat.com)
- GitHub: https://github.com/raphaelvallat/yasa
- License: BSD 3-Clause License
"""
import mne
import logging
import numpy as np
import pandas as pd
from scipy import signal
from mne.filter import filter_data
from collections import OrderedDict
from scipy.interpolate import interp1d
from scipy.fftpack import next_fast_len
from sklearn.ensemble import IsolationForest
from .spectral import stft_power
from .numba import _detrend, _rms
from .io import set_log_level, is_tensorpac_installed, is_pyriemann_installed
from .others import (moving_transform, trimbothstd, get_centered_indices,
sliding_window, _merge_close, _zerocrossings)
logger = logging.getLogger('yasa')
__all__ = ['art_detect', 'spindles_detect', 'SpindlesResults',
'sw_detect', 'SWResults', 'rem_detect', 'REMResults']
#############################################################################
# DATA PREPROCESSING
#############################################################################
def _check_data_hypno(data, sf=None, ch_names=None, hypno=None, include=None,
check_amp=True):
"""Helper functions for preprocessing of data and hypnogram."""
# 1) Extract data as a 2D NumPy array
if isinstance(data, mne.io.BaseRaw):
sf = data.info['sfreq'] # Extract sampling frequency
ch_names = data.ch_names # Extract channel names
data = data.get_data() * 1e6 # Convert from V to uV
else:
assert sf is not None, 'sf must be specified if not using MNE Raw.'
data = np.asarray(data, dtype=np.float64)
assert data.ndim in [1, 2], 'data must be 1D (times) or 2D (chan, times).'
if data.ndim == 1:
# Force to 2D array: (n_chan, n_samples)
data = data[None, ...]
n_chan, n_samples = data.shape
# 2) Check channel names
if ch_names is None:
ch_names = ['CHAN' + str(i).zfill(3) for i in range(n_chan)]
else:
assert len(ch_names) == n_chan
# 3) Check hypnogram
if hypno is not None:
hypno = np.asarray(hypno, dtype=int)
assert hypno.ndim == 1, 'Hypno must be one dimensional.'
assert hypno.size == n_samples, 'Hypno must have same size as data.'
unique_hypno = np.unique(hypno)
logger.info('Number of unique values in hypno = %i', unique_hypno.size)
assert include is not None, 'include cannot be None if hypno is given'
include = np.atleast_1d(np.asarray(include))
assert include.size >= 1, '`include` must have at least one element.'
assert hypno.dtype.kind == include.dtype.kind, ('hypno and include '
'must have same dtype')
assert np.in1d(hypno, include).any(), ('None of the stages specified '
'in `include` are present in '
'hypno.')
# 4) Check data amplitude
logger.info('Number of samples in data = %i', n_samples)
logger.info('Sampling frequency = %.2f Hz', sf)
logger.info('Data duration = %.2f seconds', n_samples / sf)
all_ptp = np.ptp(data, axis=-1)
all_trimstd = trimbothstd(data, cut=0.05)
bad_chan = np.zeros(n_chan, dtype=bool)
for i in range(n_chan):
logger.info('Trimmed standard deviation of %s = %.4f uV'
% (ch_names[i], all_trimstd[i]))
logger.info('Peak-to-peak amplitude of %s = %.4f uV'
% (ch_names[i], all_ptp[i]))
if check_amp and not(0.1 < all_trimstd[i] < 1e3):
logger.error('Wrong data amplitude for %s '
'(trimmed STD = %.3f). Unit of data MUST be uV! '
'Channel will be skipped.'
% (ch_names[i], all_trimstd[i]))
bad_chan[i] = True
# 5) Create sleep stage vector mask
if hypno is not None:
mask = np.in1d(hypno, include)
else:
mask = np.ones(n_samples, dtype=bool)
return (data, sf, ch_names, hypno, include, mask, n_chan, n_samples,
bad_chan)
#############################################################################
# BASE DETECTION RESULTS CLASS
#############################################################################
class _DetectionResults(object):
"""Main class for detection results."""
def __init__(self, events, data, sf, ch_names, hypno, data_filt):
self._events = events
self._data = data
self._sf = sf
self._hypno = hypno
self._ch_names = ch_names
self._data_filt = data_filt
def get_mask(self):
"""get_mask"""
from yasa.others import _index_to_events
mask = np.zeros(self._data.shape, dtype=int)
for i in self._events['IdxChannel'].unique():
ev_chan = self._events[self._events['IdxChannel'] == i]
idx_ev = _index_to_events(
ev_chan[['Start', 'End']].to_numpy() * self._sf)
mask[i, idx_ev] = 1
return np.squeeze(mask)
def summary(self, event_type, grp_chan=False, grp_stage=False,
aggfunc='mean', sort=True):
"""Summary"""
grouper = []
if grp_stage is True and 'Stage' in self._events:
grouper.append('Stage')
if grp_chan is True and 'Channel' in self._events:
grouper.append('Channel')
if not len(grouper):
return self._events.copy()
if event_type == 'spindles':
aggdict = {'Start': 'count',
'Duration': aggfunc,
'Amplitude': aggfunc,
'RMS': aggfunc,
'AbsPower': aggfunc,
'RelPower': aggfunc,
'Frequency': aggfunc,
'Oscillations': aggfunc,
'Symmetry': aggfunc}
# if 'SOPhase' in self._events:
# from scipy.stats import circmean
# aggdict['SOPhase'] = lambda x: circmean(x, low=-np.pi,
# high=np.pi)
elif event_type == 'sw':
aggdict = {'Start': 'count',
'Duration': aggfunc,
'ValNegPeak': aggfunc,
'ValPosPeak': aggfunc,
'PTP': aggfunc,
'Slope': aggfunc,
'Frequency': aggfunc}
if 'PhaseAtSigmaPeak' in self._events:
from scipy.stats import circmean
aggdict['PhaseAtSigmaPeak'] = lambda x: circmean(x, low=-np.pi,
high=np.pi)
aggdict['ndPAC'] = aggfunc
else: # REM
aggdict = {'Start': 'count',
'Duration': aggfunc,
'LOCAbsValPeak': aggfunc,
'ROCAbsValPeak': aggfunc,
'LOCAbsRiseSlope': aggfunc,
'ROCAbsRiseSlope': aggfunc,
'LOCAbsFallSlope': aggfunc,
'ROCAbsFallSlope': aggfunc}
# Apply grouping
df_grp = self._events.groupby(grouper, sort=sort,
as_index=False).agg(aggdict)
df_grp = df_grp.rename(columns={'Start': 'Count'})
# Calculate density (= number per min of each stage)
if self._hypno is not None and grp_stage is True:
stages = np.unique(self._events['Stage'])
dur = {}
for st in stages:
# Get duration in minutes of each stage present in dataframe
dur[st] = self._hypno[self._hypno == st].size / (60 * self._sf)
# Insert new density column in grouped dataframe after count
df_grp.insert(
loc=df_grp.columns.get_loc('Count') + 1, column='Density',
value=df_grp.apply(lambda rw: rw['Count'] / dur[rw['Stage']],
axis=1))
return df_grp.set_index(grouper)
def get_sync_events(self, center, time_before, time_after,
filt=(None, None)):
"""Get_sync_events
(not for REM, spindles & SW only)
"""
from yasa.others import get_centered_indices
assert time_before >= 0
assert time_after >= 0
bef = int(self._sf * time_before)
aft = int(self._sf * time_after)
# TODO: Step size is determined by sf: 0.01 sec at 100 Hz, 0.002 sec at
# 500 Hz, 0.00390625 sec at 256 Hz. Should we add a step_size=0.01
# option?
time = np.arange(-bef, aft + 1, dtype='int') / self._sf
if any(filt):
data = mne.filter.filter_data(self._data, self._sf, l_freq=filt[0],
h_freq=filt[1], method='fir',
verbose=False)
else:
data = self._data
df_sync = pd.DataFrame()
for i in self._events['IdxChannel'].unique():
ev_chan = self._events[self._events['IdxChannel'] == i].copy()
ev_chan['Event'] = np.arange(ev_chan.shape[0])
peaks = (ev_chan[center] * self._sf).astype(int).to_numpy()
# Get centered indices
idx, idx_valid = get_centered_indices(data[i, :], peaks, bef, aft)
# If no good epochs are returned raise a warning
if len(idx_valid) == 0:
logger.error(
'Time before and/or time after exceed data bounds, please '
'lower the temporal window around center. '
'Skipping channel.'
)
continue
# Get data at indices and time vector and convert to df
amps = data[i, idx]
df_chan = pd.DataFrame(amps.T)
df_chan['Time'] = time
# Convert to long-format
df_chan = df_chan.melt(id_vars='Time', var_name='Event',
value_name='Amplitude')
# Append stage
if 'Stage' in self._events:
df_chan = df_chan.merge(
ev_chan[['Event', 'Stage']].iloc[idx_valid]
)
# Append channel name
df_chan['Channel'] = ev_chan['Channel'].iloc[0]
df_chan['IdxChannel'] = i
# Append to master dataframe
df_sync = df_sync.append(df_chan, ignore_index=True)
return df_sync
def plot_average(self, event_type, center='Peak', hue='Channel',
time_before=1, time_after=1, filt=(None, None),
figsize=(6, 4.5), **kwargs):
"""plot_average
(not for REM, spindles & SW only)
"""
import seaborn as sns
import matplotlib.pyplot as plt
df_sync = self.get_sync_events(center=center, time_before=time_before,
time_after=time_after, filt=filt)
assert not df_sync.empty, "Could not calculate event-locked data."
assert hue in ['Stage', 'Channel'], "hue must be 'Channel' or 'Stage'"
assert hue in df_sync.columns, "%s is not present in data." % hue
if event_type == 'spindles':
title = "Average spindle"
else: # "sw":
title = "Average SW"
# Start figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.lineplot(data=df_sync, x='Time', y='Amplitude', hue=hue, ax=ax,
**kwargs)
# ax.legend(frameon=False, loc='lower right')
ax.set_xlim(df_sync['Time'].min(), df_sync['Time'].max())
ax.set_title(title)
ax.set_xlabel('Time (sec)')
ax.set_ylabel('Amplitude (uV)')
return ax
def plot_detection(self):
"""Plot an overlay of the detected events on the signal."""
import matplotlib.pyplot as plt
import ipywidgets as ipy
# Define mask
sf = self._sf
win_size = 10
mask = self.get_mask()
highlight = self._data * mask
highlight = np.where(highlight == 0, np.nan, highlight)
highlight_filt = self._data_filt * mask
highlight_filt = np.where(highlight_filt == 0, np.nan, highlight_filt)
n_epochs = int((self._data.shape[-1] / sf) / win_size)
times = np.arange(self._data.shape[-1]) / sf
# Define xlim and xrange
xlim = [0, win_size]
xrng = np.arange(xlim[0] * sf, (xlim[1] * sf + 1), dtype=int)
# Plot
fig, ax = plt.subplots(figsize=(12, 4))
plt.plot(times[xrng], self._data[0, xrng], 'k', lw=1)
plt.plot(times[xrng], highlight[0, xrng], 'indianred')
plt.xlabel('Time (seconds)')
plt.ylabel('Amplitude (uV)')
fig.canvas.header_visible = False
fig.tight_layout()
# WIDGETS
layout = ipy.Layout(
width="50%",
justify_content='center',
align_items='center'
)
sl_ep = ipy.IntSlider(
min=0,
max=n_epochs,
step=1,
value=0,
layout=layout,
description="Epoch:",
)
sl_amp = ipy.IntSlider(
min=25,
max=500,
step=25,
value=150,
layout=layout,
orientation='horizontal',
description="Amplitude:"
)
dd_ch = ipy.Dropdown(
options=self._ch_names, value=self._ch_names[0],
description='Channel:'
)
dd_win = ipy.Dropdown(
options=[1, 5, 10, 30, 60],
value=win_size,
description='Window size:',
)
dd_check = ipy.Checkbox(
value=False,
description='Filtered',
)
def update(epoch, amplitude, channel, win_size, filt):
"""Update plot."""
n_epochs = int((self._data.shape[-1] / sf) / win_size)
sl_ep.max = n_epochs
xlim = [epoch * win_size, (epoch + 1) * win_size]
xrng = np.arange(xlim[0] * sf, (xlim[1] * sf), dtype=int)
# Check if filtered
data = self._data if not filt else self._data_filt
overlay = highlight if not filt else highlight_filt
try:
ax.lines[0].set_data(times[xrng], data[dd_ch.index, xrng])
ax.lines[1].set_data(times[xrng], overlay[dd_ch.index, xrng])
ax.set_xlim(xlim)
except IndexError:
pass
ax.set_ylim([-amplitude, amplitude])
return ipy.interact(update, epoch=sl_ep, amplitude=sl_amp,
channel=dd_ch, win_size=dd_win, filt=dd_check)
#############################################################################
# SPINDLES DETECTION
#############################################################################
[docs]def spindles_detect(data, sf=None, ch_names=None, hypno=None,
include=(1, 2, 3), freq_sp=(12, 15), freq_broad=(1, 30),
duration=(0.5, 2), min_distance=500,
thresh={'rel_pow': 0.2, 'corr': 0.65, 'rms': 1.5},
multi_only=False, remove_outliers=False, verbose=False):
"""Spindles detection.
Parameters
----------
data : array_like
Single or multi-channel data. Unit must be uV and shape (n_samples) or
(n_chan, n_samples). Can also be a :py:class:`mne.io.BaseRaw`,
in which case ``data``, ``sf``, and ``ch_names`` will be automatically
extracted, and ``data`` will also be automatically converted from
Volts (MNE) to micro-Volts (YASA).
sf : float
Sampling frequency of the data in Hz.
Can be omitted if ``data`` is a :py:class:`mne.io.BaseRaw`.
.. tip:: If the detection is taking too long, make sure to downsample
your data to 100 Hz (or 128 Hz). For more details, please refer to
:py:func:`mne.filter.resample`.
ch_names : list of str
Channel names. Can be omitted if ``data`` is a
:py:class:`mne.io.BaseRaw`.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is loaded, the
detection will only be applied to the value defined in
``include`` (default = N1 + N2 + N3 sleep).
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Values in ``hypno`` that will be included in the mask. The default is
(1, 2, 3), meaning that the detection is applied on N1, N2 and N3
sleep. This has no effect when ``hypno`` is None.
freq_sp : tuple or list
Spindles frequency range. Default is 12 to 15 Hz. Please note that YASA
uses a FIR filter (implemented in MNE) with a 1.5Hz transition band,
which means that for `freq_sp = (12, 15 Hz)`, the -6 dB points are
located at 11.25 and 15.75 Hz.
freq_broad : tuple or list
Broad band frequency range. Default is 1 to 30 Hz.
duration : tuple or list
The minimum and maximum duration of the spindles.
Default is 0.5 to 2 seconds.
min_distance : int
If two spindles are closer than ``min_distance`` (in ms), they are
merged into a single spindles. Default is 500 ms.
thresh : dict
Detection thresholds:
* ``'rel_pow'``: Relative power (= power ratio freq_sp / freq_broad).
* ``'corr'``: Moving correlation between original signal and
sigma-filtered signal.
* ``'rms'``: Number of standard deviations above the mean of a moving
root mean square of sigma-filtered signal.
You can disable one or more threshold by putting ``None`` instead:
.. code-block:: python
thresh = {'rel_pow': None, 'corr': 0.65, 'rms': 1.5}
thresh = {'rel_pow': None, 'corr': None, 'rms': 3}
multi_only : boolean
Define the behavior of the multi-channel detection. If True, only
spindles that are present on at least two channels are kept. If False,
no selection is applied and the output is just a concatenation of the
single-channel detection dataframe. Default is False.
remove_outliers : boolean
If True, YASA will automatically detect and remove outliers spindles
using :py:class:`sklearn.ensemble.IsolationForest`.
The outliers detection is performed on all the spindles
parameters with the exception of the ``Start``, ``Peak``, ``End``,
``Stage``, and ``SOPhase`` columns.
YASA uses a random seed (42) to ensure reproducible results.
Note that this step will only be applied if there are more than 50
detected spindles in the first place. Default to False.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
sp : :py:class:`yasa.SpindlesResults`
To get the full detection dataframe, use:
>>> sp = spindles_detect(...)
>>> sp.summary()
This will give a :py:class:`pandas.DataFrame` where each row is a
detected spindle and each column is a parameter (= feature or property)
of this spindle. To get the average spindles parameters per channel and
sleep stage:
>>> sp.summary(grp_chan=True, grp_stage=True)
Notes
-----
The parameters that are calculated for each spindle are:
* ``'Start'``: Start time of the spindle, in seconds from the beginning of
data.
* ``'Peak'``: Time at the most prominent spindle peak (in seconds).
* ``'End'`` : End time (in seconds).
* ``'Duration'``: Duration (in seconds)
* ``'Amplitude'``: Peak-to-peak amplitude of the (detrended) spindle in
the raw data (in µV).
* ``'RMS'``: Root-mean-square (in µV)
* ``'AbsPower'``: Median absolute power (in log10 µV^2),
calculated from the Hilbert-transform of the ``freq_sp`` filtered signal.
* ``'RelPower'``: Median relative power of the ``freq_sp`` band in spindle
calculated from a short-term fourier transform and expressed as a
proportion of the total power in ``freq_broad``.
* ``'Frequency'``: Median instantaneous frequency of spindle (in Hz),
derived from an Hilbert transform of the ``freq_sp`` filtered signal.
* ``'Oscillations'``: Number of oscillations (= number of positive peaks
in spindle.)
* ``'Symmetry'``: Location of the most prominent peak of spindle,
normalized from 0 (start) to 1 (end). Ideally this value should be close
to 0.5, indicating that the most prominent peak is halfway through the
spindle.
* ``'Stage'`` : Sleep stage during which spindle occured, if ``hypno``
was provided.
All parameters are calculated from the broadband-filtered EEG
(frequency range defined in ``freq_broad``).
For better results, apply this detection only on artefact-free NREM sleep.
References
----------
The sleep spindles detection algorithm is based on:
* Lacourse, K., Delfrate, J., Beaudry, J., Peppard, P., & Warby, S. C.
(2018). `A sleep spindle detection algorithm that emulates human expert
spindle scoring. <https://doi.org/10.1016/j.jneumeth.2018.08.014>`_
Journal of Neuroscience Methods.
Examples
--------
For a walkthrough of the spindles detection, please refer to the following
Jupyter notebooks:
https://github.com/raphaelvallat/yasa/blob/master/notebooks/01_spindles_detection.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/02_spindles_detection_multi.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/03_spindles_detection_NREM_only.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/04_spindles_slow_fast.ipynb
"""
set_log_level(verbose)
(data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan
) = _check_data_hypno(data, sf, ch_names, hypno, include)
# If all channels are bad
if sum(bad_chan) == n_chan:
logger.warning('All channels have bad amplitude. Returning None.')
return None
# Check detection thresholds
if 'rel_pow' not in thresh.keys():
thresh['rel_pow'] = 0.20
if 'corr' not in thresh.keys():
thresh['corr'] = 0.65
if 'rms' not in thresh.keys():
thresh['rms'] = 1.5
do_rel_pow = thresh['rel_pow'] not in [None, "none", "None"]
do_corr = thresh['corr'] not in [None, "none", "None"]
do_rms = thresh['rms'] not in [None, "none", "None"]
n_thresh = sum([do_rel_pow, do_corr, do_rms])
assert n_thresh >= 1, 'At least one threshold must be defined.'
# Filtering
nfast = next_fast_len(n_samples)
# 1) Broadband bandpass filter (optional -- careful of lower freq for PAC)
data_broad = filter_data(data, sf, freq_broad[0], freq_broad[1],
method='fir', verbose=0)
# 2) Sigma bandpass filter
# The width of the transition band is set to 1.5 Hz on each side,
# meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located at
# 11.25 and 15.75 Hz.
data_sigma = filter_data(data, sf, freq_sp[0], freq_sp[1],
l_trans_bandwidth=1.5, h_trans_bandwidth=1.5,
method='fir', verbose=0)
# Hilbert power (to define the instantaneous frequency / power)
analytic = signal.hilbert(data_sigma, N=nfast)[:, :n_samples]
inst_phase = np.angle(analytic)
inst_pow = np.square(np.abs(analytic))
inst_freq = (sf / (2 * np.pi) * np.diff(inst_phase, axis=-1))
# Extract the SO signal for coupling
# if coupling:
# # We need to use the original (non-filtered data)
# data_so = filter_data(data, sf, freq_so[0], freq_so[1], method='fir',
# l_trans_bandwidth=0.1, h_trans_bandwidth=0.1,
# verbose=0)
# # Now extract the instantaneous phase using Hilbert transform
# so_phase = np.angle(signal.hilbert(data_so, N=nfast)[:, :n_samples])
# Initialize empty output dataframe
df = pd.DataFrame()
for i in range(n_chan):
# ####################################################################
# START SINGLE CHANNEL DETECTION
# ####################################################################
# First, skip channels with bad data amplitude
if bad_chan[i]:
continue
# Compute the pointwise relative power using interpolated STFT
# Here we use a step of 200 ms to speed up the computation.
# Note that even if the threshold is None we still need to calculate it
# for the individual spindles parameter (RelPow).
f, t, Sxx = stft_power(data_broad[i, :], sf, window=2, step=.2,
band=freq_broad, interp=False, norm=True)
idx_sigma = np.logical_and(f >= freq_sp[0], f <= freq_sp[1])
rel_pow = Sxx[idx_sigma].sum(0)
# Let's interpolate `rel_pow` to get one value per sample
# Note that we could also have use the `interp=True` in the
# `stft_power` function, however 2D interpolation is much slower than
# 1D interpolation.
func = interp1d(t, rel_pow, kind='cubic', bounds_error=False,
fill_value=0)
t = np.arange(n_samples) / sf
rel_pow = func(t)
if do_corr:
_, mcorr = moving_transform(x=data_sigma[i, :], y=data_broad[i, :],
sf=sf, window=.3, step=.1,
method='corr', interp=True)
if do_rms:
_, mrms = moving_transform(x=data_sigma[i, :], sf=sf,
window=.3, step=.1, method='rms',
interp=True)
# Let's define the thresholds
if hypno is None:
thresh_rms = mrms.mean() + thresh['rms'] * \
trimbothstd(mrms, cut=0.10)
else:
thresh_rms = mrms[mask].mean() + thresh['rms'] * \
trimbothstd(mrms[mask], cut=0.10)
# Avoid too high threshold caused by Artefacts / Motion during Wake
thresh_rms = min(thresh_rms, 10)
logger.info('Moving RMS threshold = %.3f', thresh_rms)
# Boolean vector of supra-threshold indices
idx_sum = np.zeros(n_samples)
if do_rel_pow:
idx_rel_pow = (rel_pow >= thresh['rel_pow']).astype(int)
idx_sum += idx_rel_pow
logger.info('N supra-theshold relative power = %i',
idx_rel_pow.sum())
if do_corr:
idx_mcorr = (mcorr >= thresh['corr']).astype(int)
idx_sum += idx_mcorr
logger.info('N supra-theshold moving corr = %i', idx_mcorr.sum())
if do_rms:
idx_mrms = (mrms >= thresh_rms).astype(int)
idx_sum += idx_mrms
logger.info('N supra-theshold moving RMS = %i', idx_mrms.sum())
# Make sure that we do not detect spindles outside mask
if hypno is not None:
idx_sum[~mask] = 0
# The detection using the three thresholds tends to underestimate the
# real duration of the spindle. To overcome this, we compute a soft
# threshold by smoothing the idx_sum vector with a 100 ms window.
w = int(0.1 * sf)
idx_sum = np.convolve(idx_sum, np.ones(w) / w, mode='same')
# And we then find indices that are strictly greater than 2, i.e. we
# find the 'true' beginning and 'true' end of the events by finding
# where at least two out of the three treshold were crossed.
where_sp = np.where(idx_sum > (n_thresh - 1))[0]
# If no events are found, skip to next channel
if not len(where_sp):
logger.warning('No spindle were found in channel %s.', ch_names[i])
continue
# Merge events that are too close
if min_distance is not None and min_distance > 0:
where_sp = _merge_close(where_sp, min_distance, sf)
# Extract start, end, and duration of each spindle
sp = np.split(where_sp, np.where(np.diff(where_sp) != 1)[0] + 1)
idx_start_end = np.array([[k[0], k[-1]] for k in sp]) / sf
sp_start, sp_end = idx_start_end.T
sp_dur = sp_end - sp_start
# Find events with bad duration
good_dur = np.logical_and(sp_dur > duration[0], sp_dur < duration[1])
# If no events of good duration are found, skip to next channel
if all(~good_dur):
logger.warning('No spindle were found in channel %s.', ch_names[i])
continue
# Initialize empty variables
sp_amp = np.zeros(len(sp))
sp_freq = np.zeros(len(sp))
sp_rms = np.zeros(len(sp))
sp_osc = np.zeros(len(sp))
sp_sym = np.zeros(len(sp))
sp_abs = np.zeros(len(sp))
sp_rel = np.zeros(len(sp))
sp_sta = np.zeros(len(sp))
sp_pro = np.zeros(len(sp))
# sp_cou = np.zeros(len(sp))
# Number of oscillations (number of peaks separated by at least 60 ms)
# --> 60 ms because 1000 ms / 16 Hz = 62.5 m, in other words, at 16 Hz,
# peaks are separated by 62.5 ms. At 11 Hz peaks are separated by 90 ms
distance = 60 * sf / 1000
for j in np.arange(len(sp))[good_dur]:
# Important: detrend the signal to avoid wrong PTP amplitude
sp_x = np.arange(data_broad[i, sp[j]].size, dtype=np.float64)
sp_det = _detrend(sp_x, data_broad[i, sp[j]])
# sp_det = signal.detrend(data_broad[i, sp[i]], type='linear')
sp_amp[j] = np.ptp(sp_det) # Peak-to-peak amplitude
sp_rms[j] = _rms(sp_det) # Root mean square
sp_rel[j] = np.median(rel_pow[sp[j]]) # Median relative power
# Hilbert-based instantaneous properties
sp_inst_freq = inst_freq[i, sp[j]]
sp_inst_pow = inst_pow[i, sp[j]]
sp_abs[j] = np.median(np.log10(sp_inst_pow[sp_inst_pow > 0]))
sp_freq[j] = np.median(sp_inst_freq[sp_inst_freq > 0])
# Number of oscillations
peaks, peaks_params = signal.find_peaks(sp_det, distance=distance,
prominence=(None, None))
sp_osc[j] = len(peaks)
# For frequency and amplitude, we can also optionally use these
# faster alternatives. If we use them, we do not need to compute
# the Hilbert transform of the filtered signal.
# sp_freq[j] = sf / np.mean(np.diff(peaks))
# sp_amp[j] = peaks_params['prominences'].max()
# Peak location & symmetry index
# pk is expressed in sample since the beginning of the spindle
pk = peaks[peaks_params['prominences'].argmax()]
sp_pro[j] = sp_start[j] + pk / sf
sp_sym[j] = pk / sp_det.size
# SO-spindles coupling
# if coupling:
# sp_cou[j] = so_phase[i, sp[j]][pk]
# Sleep stage
if hypno is not None:
sp_sta[j] = hypno[sp[j]][0]
# Create a dataframe
sp_params = {'Start': sp_start,
'Peak': sp_pro,
'End': sp_end,
'Duration': sp_dur,
'Amplitude': sp_amp,
'RMS': sp_rms,
'AbsPower': sp_abs,
'RelPower': sp_rel,
'Frequency': sp_freq,
'Oscillations': sp_osc,
'Symmetry': sp_sym,
# 'SOPhase': sp_cou,
'Stage': sp_sta}
df_chan = pd.DataFrame(sp_params)[good_dur]
# We need at least 50 detected spindles to apply the Isolation Forest.
if remove_outliers and df_chan.shape[0] >= 50:
col_keep = ['Duration', 'Amplitude', 'RMS', 'AbsPower', 'RelPower',
'Frequency', 'Oscillations', 'Symmetry']
ilf = IsolationForest(contamination='auto', max_samples='auto',
verbose=0, random_state=42)
good = ilf.fit_predict(df_chan[col_keep])
good[good == -1] = 0
logger.info('%i outliers were removed in channel %s.'
% ((good == 0).sum(), ch_names[i]))
# Remove outliers from DataFrame
df_chan = df_chan[good.astype(bool)]
logger.info('%i spindles were found in channel %s.'
% (df_chan.shape[0], ch_names[i]))
# ####################################################################
# END SINGLE CHANNEL DETECTION
# ####################################################################
df_chan['Channel'] = ch_names[i]
df_chan['IdxChannel'] = i
df = df.append(df_chan, ignore_index=True)
# If no spindles were detected, return None
if df.empty:
logger.warning('No spindles were found in data. Returning None.')
return None
# Remove useless columns
to_drop = []
if hypno is None:
to_drop.append('Stage')
else:
df['Stage'] = df['Stage'].astype(int)
# if not coupling:
# to_drop.append('SOPhase')
if len(to_drop):
df = df.drop(columns=to_drop)
# Find spindles that are present on at least two channels
if multi_only and df['Channel'].nunique() > 1:
# We round to the nearest second
idx_good = np.logical_or(df['Start'].round(0).duplicated(keep=False),
df['End'].round(0).duplicated(keep=False)
).to_list()
df = df[idx_good].reset_index(drop=True)
return SpindlesResults(events=df, data=data, sf=sf, ch_names=ch_names,
hypno=hypno, data_filt=data_sigma)
[docs]class SpindlesResults(_DetectionResults):
"""Output class for spindles detection.
Attributes
----------
_events : :py:class:`pandas.DataFrame`
Output detection dataframe
_data : array_like
Original EEG data of shape *(n_chan, n_samples)*.
_data_filt : array_like
Sigma-filtered EEG data of shape *(n_chan, n_samples)*.
_sf : float
Sampling frequency of data.
_ch_names : list
Channel names.
_hypno : array_like or None
Sleep staging vector.
"""
[docs] def __init__(self, events, data, sf, ch_names, hypno, data_filt):
super().__init__(events, data, sf, ch_names, hypno, data_filt)
[docs] def summary(self, grp_chan=False, grp_stage=False, aggfunc='mean',
sort=True):
"""Return a summary of the spindles detection, optionally grouped
across channels and/or stage.
Parameters
----------
grp_chan : bool
If True, group by channel (for multi-channels detection only).
grp_stage : bool
If True, group by sleep stage (provided that an hypnogram was
used).
aggfunc : str or function
Averaging function (e.g. ``'mean'`` or ``'median'``).
sort : bool
If True, sort group keys when grouping.
"""
return super().summary(event_type='spindles',
grp_chan=grp_chan, grp_stage=grp_stage,
aggfunc=aggfunc, sort=sort)
[docs] def get_mask(self):
"""Return a boolean array indicating for each sample in data if this
sample is part of a detected event (True) or not (False).
"""
return super().get_mask()
[docs] def get_sync_events(self, center='Peak', time_before=1, time_after=1,
filt=(None, None)):
"""
Return the raw or filtered data of each detected event after
centering to a specific timepoint.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the center peak of the spindles.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
Returns
-------
df_sync : :py:class:`pandas.DataFrame`
Long-format dataframe::
'Event' : Event number
'Time' : Timing of the events (in seconds)
'Amplitude' : Raw or filtered data for event
'Channel' : Channel
'IdxChannel' : Index of channel in data
'Stage': Sleep stage in which the events occured (if available)
"""
return super().get_sync_events(center=center, time_before=time_before,
time_after=time_after,
filt=filt)
[docs] def plot_average(self, center='Peak', hue='Channel', time_before=1,
time_after=1, filt=(None, None), figsize=(6, 4.5),
**kwargs):
"""
Plot the average spindle.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the most prominent peak of the spindle.
hue : str
Grouping variable that will produce lines with different colors.
Can be either 'Channel' or 'Stage'.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
figsize : tuple
Figure size in inches.
**kwargs : dict
Optional argument that are passed to :py:func:`seaborn.lineplot`.
"""
return super().plot_average(event_type='spindles', center=center,
hue=hue, time_before=time_before,
time_after=time_after, filt=filt,
figsize=figsize, **kwargs)
[docs] def plot_detection(self):
"""Plot an overlay of the detected spindles on the EEG signal.
This only works in Jupyter and it requires the ipywidgets
(https://ipywidgets.readthedocs.io/en/latest/) package.
To activate the interactive mode, make sure to run:
>>> %matplotlib widget
.. versionadded:: 0.4.0
"""
return super().plot_detection()
#############################################################################
# SLOW-WAVES DETECTION
#############################################################################
[docs]def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3),
freq_sw=(0.3, 1.5), dur_neg=(0.3, 1.5), dur_pos=(0.1, 1),
amp_neg=(40, 300), amp_pos=(10, 200), amp_ptp=(75, 500),
coupling=False, freq_sp=(12, 16), remove_outliers=False,
verbose=False):
"""Slow-waves detection.
Parameters
----------
data : array_like
Single or multi-channel data. Unit must be uV and shape (n_samples) or
(n_chan, n_samples). Can also be a :py:class:`mne.io.BaseRaw`,
in which case ``data``, ``sf``, and ``ch_names`` will be automatically
extracted, and ``data`` will also be automatically converted from
Volts (MNE) to micro-Volts (YASA).
sf : float
Sampling frequency of the data in Hz.
Can be omitted if ``data`` is a :py:class:`mne.io.BaseRaw`.
.. tip:: If the detection is taking too long, make sure to downsample
your data to 100 Hz (or 128 Hz). For more details, please refer to
:py:func:`mne.filter.resample`.
ch_names : list of str
Channel names. Can be omitted if ``data`` is a
:py:class:`mne.io.BaseRaw`.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is loaded, the
detection will only be applied to the value defined in
``include`` (default = N2 + N3 sleep).
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Values in ``hypno`` that will be included in the mask. The default is
(2, 3), meaning that the detection is applied on N2 and N3
sleep. This has no effect when ``hypno`` is None.
freq_sw : tuple or list
Slow wave frequency range. Default is 0.3 to 1.5 Hz. Please note that
YASA uses a FIR filter (implemented in MNE) with a 0.2 Hz transition
band, which means that the -6 dB points are located at 0.2 and 1.6 Hz.
dur_neg : tuple or list
The minimum and maximum duration of the negative deflection of the
slow wave. Default is 0.3 to 1.5 second.
dur_pos : tuple or list
The minimum and maximum duration of the positive deflection of the
slow wave. Default is 0.1 to 1 second.
amp_neg : tuple or list
Absolute minimum and maximum negative trough amplitude of the
slow-wave. Default is 40 uV to 300 uV. Can also be in unit of standard
deviations if the data has been previously z-scored. If you do not want
to specify any negative amplitude thresholds,
use ``amp_neg=(None, None)``.
amp_pos : tuple or list
Absolute minimum and maximum positive peak amplitude of the
slow-wave. Default is 10 uV to 200 uV. Can also be in unit of standard
deviations if the data has been previously z-scored.
If you do not want to specify any positive amplitude thresholds,
use ``amp_pos=(None, None)``.
amp_ptp : tuple or list
Minimum and maximum peak-to-peak amplitude of the slow-wave.
Default is 75 uV to 500 uV. Can also be in unit of standard
deviations if the data has been previously z-scored.
Use ``np.inf`` to set no upper amplitude threshold
(e.g. ``amp_ptp=(75, np.inf)``).
coupling : boolean
If True, YASA will also calculate the phase-amplitude coupling between
the slow-waves phase and the spindles-related sigma band
amplitude. Specifically, the following columns will be added to the
output dataframe:
1. ``'SigmaPeak'``: The location (in seconds) of the maximum sigma peak
amplitude within a 4-seconds epoch centered around the negative peak
(through) of the current slow-wave.
2. ``PhaseAtSigmaPeak``: the phase of the bandpas-filtered slow-wave
signal (in radians) at ``'SigmaPeak'``.
Importantly, since ``PhaseAtSigmaPeak`` is expressed in radians,
one should use circular statistics to calculate the mean direction
and vector length:
.. code-block:: python
import pingouin as pg
mean_direction = pg.circ_mean(sw['PhaseAtSigmaPeak'])
vector_length = pg.circ_r(sw['PhaseAtSigmaPeak'])
3. ``ndPAC``: the normalized Mean Vector Length
(also called the normalized direct PAC, or ndPAC) within a 4-sec
epoch centered around the negative peak of the slow-wave.
The lower and upper frequencies for the slow-waves and
spindles-related sigma signals are defined in ``freq_sw`` and
``freq_sp``, respectively.
For more details, please refer to the `Jupyter notebook
<https://github.com/raphaelvallat/yasa/blob/master/notebooks/12_spindles-SO_coupling.ipynb>`_
Note that setting ``coupling=True`` may significantly increase
computation time.
.. versionadded:: 0.2.0
freq_sp : tuple or list
Spindles-related frequency of interest. This is only relevant if
``coupling=True``. Default is 12 to 16 Hz, with a wide transition
bandwidth of 1.5 Hz.
.. versionadded:: 0.2.0
remove_outliers : boolean
If True, YASA will automatically detect and remove outliers slow-waves
using :py:class:`sklearn.ensemble.IsolationForest`.
The outliers detection is performed on the frequency, amplitude and
duration parameters of the detected slow-waves. YASA uses a random seed
(42) to ensure reproducible results. Note that this step will only be
applied if there are more than 50 detected slow-waves in the first
place. Default to False.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
sw : :py:class:`yasa.SWResults`
To get the full detection dataframe, use:
>>> sw = sw_detect(...)
>>> sw.summary()
This will give a :py:class:`pandas.DataFrame` where each row is a
detected slow-wave and each column is a parameter (= property).
To get the average SW parameters per channel and sleep stage:
>>> sw.summary(grp_chan=True, grp_stage=True)
Notes
-----
The parameters that are calculated for each slow-wave are:
* ``'Start'``: Start time of each detected slow-wave, in seconds from the
beginning of data.
* ``'NegPeak'``: Location of the negative peak (in seconds)
* ``'MidCrossing'``: Location of the negative-to-positive zero-crossing
(in seconds)
* ``'Pospeak'``: Location of the positive peak (in seconds)
* ``'End'``: End time(in seconds)
* ``'Duration'``: Duration (in seconds)
* ``'ValNegPeak'``: Amplitude of the negative peak (in uV, calculated
on the ``freq_sw`` bandpass-filtered signal)
* ``'ValPosPeak'``: Amplitude of the positive peak (in uV, calculated
on the ``freq_sw`` bandpass-filtered signal)
* ``'PTP'``: Peak-to-peak amplitude (= ``ValPosPeak`` - ``ValNegPeak``,
calculated on the ``freq_sw`` bandpass-filtered signal)
* ``'Slope'``: Slope between ``NegPeak`` and ``MidCrossing`` (in uV/sec,
calculated on the ``freq_sw`` bandpass-filtered signal)
* ``'Frequency'``: Frequency of the slow-wave (= 1 / ``Duration``)
* ``'SigmaPeak'``: Location of the sigma peak amplitude within a 4-sec
epoch centered around the negative peak of the slow-wave. This is only
calculated when ``coupling=True``.
* ``'PhaseAtSigmaPeak'``: SW phase at max sigma amplitude within a
4-sec epoch centered around the negative peak of the slow-wave. This is
only calculated when ``coupling=True``
* ``'ndPAC'``: Normalized direct PAC within a 4-sec epoch centered
the negative peak of the slow-wave. This is only calculated when
``coupling=True``
* ``'Stage'``: Sleep stage (only if hypno was provided)
.. image:: https://raw.githubusercontent.com/raphaelvallat/yasa/master/docs/pictures/slow_waves.png # noqa
:width: 500px
:align: center
:alt: slow-wave
For better results, apply this detection only on artefact-free NREM sleep.
References
----------
The slow-waves detection algorithm is based on:
* Massimini, M., Huber, R., Ferrarelli, F., Hill, S., & Tononi, G.
(2004). `The sleep slow oscillation as a traveling wave.
<https://doi.org/10.1523/JNEUROSCI.1318-04.2004>`_. The Journal of
Neuroscience, 24(31), 6862–6870.
* Carrier, J., Viens, I., Poirier, G., Robillard, R., Lafortune, M.,
Vandewalle, G., Martin, N., Barakat, M., Paquet, J., & Filipini, D.
(2011). `Sleep slow wave changes during the middle years of life.
<https://doi.org/10.1111/j.1460-9568.2010.07543.x>`_
The European Journal of Neuroscience, 33(4), 758–766.
Examples
--------
For an example of how to run the detection, please refer to the tutorial:
https://github.com/raphaelvallat/yasa/blob/master/notebooks/05_sw_detection.ipynb
"""
set_log_level(verbose)
(data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan
) = _check_data_hypno(data, sf, ch_names, hypno, include)
# If all channels are bad
if sum(bad_chan) == n_chan:
logger.warning('All channels have bad amplitude. Returning None.')
return None
# Define time vector
times = np.arange(data.size) / sf
idx_mask = np.where(mask)[0]
# Bandpass filter
nfast = next_fast_len(n_samples)
data_filt = filter_data(data, sf, freq_sw[0], freq_sw[1], method='fir',
verbose=0, l_trans_bandwidth=0.2,
h_trans_bandwidth=0.2)
# Extract the spindles-related sigma signal for coupling
if coupling:
is_tensorpac_installed()
import tensorpac.methods as tpm
# The width of the transition band is set to 1.5 Hz on each side,
# meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located
# at 11.25 and 15.75 Hz. The frequency band for the amplitude signal
# must be large enough to fit the sidebands caused by the assumed
# modulating lower frequency band (Aru et al. 2015).
# https://doi.org/10.1016/j.conb.2014.08.002
data_sp = filter_data(data, sf, freq_sp[0], freq_sp[1], method='fir',
l_trans_bandwidth=1.5, h_trans_bandwidth=1.5,
verbose=0)
# Now extract the instantaneous phase/amplitude using Hilbert transform
sw_pha = np.angle(signal.hilbert(data_filt, N=nfast)[:, :n_samples])
sp_amp = np.abs(signal.hilbert(data_sp, N=nfast)[:, :n_samples])
# Initialize empty output dataframe
df = pd.DataFrame()
for i in range(n_chan):
# ####################################################################
# START SINGLE CHANNEL DETECTION
# ####################################################################
# First, skip channels with bad data amplitude
if bad_chan[i]:
continue
# Find peaks in data
# Negative peaks with value comprised between -40 to -300 uV
idx_neg_peaks, _ = signal.find_peaks(-1 * data_filt[i, :],
height=amp_neg)
# Positive peaks with values comprised between 10 to 150 uV
idx_pos_peaks, _ = signal.find_peaks(data_filt[i, :], height=amp_pos)
# Intersect with sleep stage vector
idx_neg_peaks = np.intersect1d(idx_neg_peaks, idx_mask,
assume_unique=True)
idx_pos_peaks = np.intersect1d(idx_pos_peaks, idx_mask,
assume_unique=True)
# If no peaks are detected, return None
if len(idx_neg_peaks) == 0 or len(idx_pos_peaks) == 0:
logger.warning('No SW were found in channel %s.', ch_names[i])
continue
# Make sure that the last detected peak is a positive one
if idx_pos_peaks[-1] < idx_neg_peaks[-1]:
# If not, append a fake positive peak one sample after the last neg
idx_pos_peaks = np.append(idx_pos_peaks, idx_neg_peaks[-1] + 1)
# For each negative peak, we find the closest following positive peak
pk_sorted = np.searchsorted(idx_pos_peaks, idx_neg_peaks)
closest_pos_peaks = idx_pos_peaks[pk_sorted] - idx_neg_peaks
closest_pos_peaks = closest_pos_peaks[np.nonzero(closest_pos_peaks)]
idx_pos_peaks = idx_neg_peaks + closest_pos_peaks
# Now we compute the PTP amplitude and keep only the good peaks
sw_ptp = (np.abs(data_filt[i, idx_neg_peaks]) +
data_filt[i, idx_pos_peaks])
good_ptp = np.logical_and(sw_ptp > amp_ptp[0], sw_ptp < amp_ptp[1])
# If good_ptp is all False
if all(~good_ptp):
logger.warning('No SW were found in channel %s.', ch_names[i])
continue
sw_ptp = sw_ptp[good_ptp]
idx_neg_peaks = idx_neg_peaks[good_ptp]
idx_pos_peaks = idx_pos_peaks[good_ptp]
# Now we need to check the negative and positive phase duration
# For that we need to compute the zero crossings of the filtered signal
zero_crossings = _zerocrossings(data_filt[i, :])
# Make sure that there is a zero-crossing after the last detected peak
if zero_crossings[-1] < max(idx_pos_peaks[-1], idx_neg_peaks[-1]):
# If not, append the index of the last peak
zero_crossings = np.append(zero_crossings,
max(idx_pos_peaks[-1],
idx_neg_peaks[-1]))
# Find distance to previous and following zc
neg_sorted = np.searchsorted(zero_crossings, idx_neg_peaks)
previous_neg_zc = zero_crossings[neg_sorted - 1] - idx_neg_peaks
following_neg_zc = zero_crossings[neg_sorted] - idx_neg_peaks
neg_phase_dur = (np.abs(previous_neg_zc) + following_neg_zc) / sf
# Distance (in samples) between the positive peaks and the previous and
# following zero-crossings
pos_sorted = np.searchsorted(zero_crossings, idx_pos_peaks)
previous_pos_zc = zero_crossings[pos_sorted - 1] - idx_pos_peaks
following_pos_zc = zero_crossings[pos_sorted] - idx_pos_peaks
pos_phase_dur = (np.abs(previous_pos_zc) + following_pos_zc) / sf
# We now compute a set of metrics
sw_start = times[idx_neg_peaks + previous_neg_zc]
sw_end = times[idx_pos_peaks + following_pos_zc]
sw_dur = sw_end - sw_start # Same as pos_phase_dur + neg_phase_dur
sw_midcrossing = times[idx_neg_peaks + following_neg_zc]
sw_idx_neg = times[idx_neg_peaks] # Location of negative peak
sw_idx_pos = times[idx_pos_peaks] # Location of positive peak
# Slope between peak trough and midcrossing
sw_slope = sw_ptp / (sw_midcrossing - sw_idx_neg)
# Hypnogram
if hypno is not None:
sw_sta = hypno[idx_neg_peaks]
else:
sw_sta = np.zeros(sw_dur.shape)
# And we apply a set of thresholds to remove bad slow waves
good_sw = np.logical_and.reduce((
# Data edges
previous_neg_zc != 0,
following_neg_zc != 0,
previous_pos_zc != 0,
following_pos_zc != 0,
# Duration criteria
neg_phase_dur > dur_neg[0],
neg_phase_dur < dur_neg[1],
pos_phase_dur > dur_pos[0],
pos_phase_dur < dur_pos[1],
# Sanity checks
sw_midcrossing > sw_start,
sw_midcrossing < sw_end,
sw_slope > 0,
))
if all(~good_sw):
logger.warning('No SW were found in channel %s.', ch_names[i])
continue
# Filter good events
idx_neg_peaks = idx_neg_peaks[good_sw]
idx_pos_peaks = idx_pos_peaks[good_sw]
sw_start = sw_start[good_sw]
sw_idx_neg = sw_idx_neg[good_sw]
sw_midcrossing = sw_midcrossing[good_sw]
sw_idx_pos = sw_idx_pos[good_sw]
sw_end = sw_end[good_sw]
sw_dur = sw_dur[good_sw]
sw_ptp = sw_ptp[good_sw]
sw_slope = sw_slope[good_sw]
sw_sta = sw_sta[good_sw]
# Create a dictionnary
sw_params = OrderedDict({
'Start': sw_start,
'NegPeak': sw_idx_neg,
'MidCrossing': sw_midcrossing,
'PosPeak': sw_idx_pos,
'End': sw_end,
'Duration': sw_dur,
'ValNegPeak': data_filt[i, idx_neg_peaks],
'ValPosPeak': data_filt[i, idx_pos_peaks],
'PTP': sw_ptp,
'Slope': sw_slope,
'Frequency': 1 / sw_dur,
'Stage': sw_sta,
})
# Add phase (in radians) of slow-oscillation signal at maximum
# spindles-related sigma amplitude within a 4-seconds centered epochs.
if coupling:
# Get phase and amplitude for each centered epoch
# TODO: allow user-specified window size.
time_before = time_after = 2
bef = int(sf * time_before)
aft = int(sf * time_after)
# Center of each epoch is defined as the negative peak of the SW
n_peaks = idx_neg_peaks.shape[0]
# idx.shape = (len(idx_valid), bef + aft + 1)
idx, idx_valid = get_centered_indices(data[i, :], idx_neg_peaks,
bef, aft)
sw_pha_ev = sw_pha[i, idx]
sp_amp_ev = sp_amp[i, idx]
# 1) Find location of max sigma amplitude in epoch
idx_max_amp = sp_amp_ev.argmax(axis=1)
# Now we need to append it back to the original unmasked shape
# to avoid error when idx.shape[0] != idx_valid.shape, i.e.
# some epochs were out of data bounds.
sw_params['SigmaPeak'] = np.ones(n_peaks) * np.nan
# Timestamp at sigma peak, expressed in seconds from negative peak
# e.g. -0.39, 0.5, 1, 2 -- limits are [time_before, time_after]
time_sigpk = (idx_max_amp - bef) / sf
# convert to absolute time from beginning of the recording
# time_sigpk only includes valid epoch
time_sigpk_abs = sw_idx_neg[idx_valid] + time_sigpk
sw_params['SigmaPeak'][idx_valid] = time_sigpk_abs
# 2) PhaseAtSigmaPeak
# Find SW phase at max sigma amplitude in epoch
pha_at_max = np.squeeze(np.take_along_axis(sw_pha_ev,
idx_max_amp[..., None],
axis=1))
sw_params['PhaseAtSigmaPeak'] = np.ones(n_peaks) * np.nan
sw_params['PhaseAtSigmaPeak'][idx_valid] = pha_at_max
# 3) Normalized Direct PAC, without thresholding
ndp = np.squeeze(tpm.norm_direct_pac(sw_pha_ev[None, ...],
sp_amp_ev[None, ...], p=1))
sw_params['ndPAC'] = np.ones(n_peaks) * np.nan
sw_params['ndPAC'][idx_valid] = ndp
# Make sure that Stage is the last column of the dataframe
sw_params.move_to_end('Stage')
# Convert to dataframe, keeping only good events
df_chan = pd.DataFrame(sw_params)
# Remove all duplicates
df_chan = df_chan.drop_duplicates(subset=['Start'], keep=False)
df_chan = df_chan.drop_duplicates(subset=['End'], keep=False)
# We need at least 50 detected slow waves to apply the Isolation Forest
if remove_outliers and df_chan.shape[0] >= 50:
col_keep = ['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Slope',
'Frequency']
ilf = IsolationForest(contamination='auto', max_samples='auto',
verbose=0, random_state=42)
good = ilf.fit_predict(df_chan[col_keep])
good[good == -1] = 0
logger.info('%i outliers were removed in channel %s.'
% ((good == 0).sum(), ch_names[i]))
# Remove outliers from DataFrame
df_chan = df_chan[good.astype(bool)]
logger.info('%i slow-waves were found in channel %s.'
% (df_chan.shape[0], ch_names[i]))
# ####################################################################
# END SINGLE CHANNEL DETECTION
# ####################################################################
df_chan['Channel'] = ch_names[i]
df_chan['IdxChannel'] = i
df = df.append(df_chan, ignore_index=True)
# If no SW were detected, return None
if df.empty:
logger.warning('No SW were found in data. Returning None.')
return None
if hypno is None:
df = df.drop(columns=['Stage'])
else:
df['Stage'] = df['Stage'].astype(int)
return SWResults(events=df, data=data, sf=sf, ch_names=ch_names,
hypno=hypno, data_filt=data_filt)
[docs]class SWResults(_DetectionResults):
"""Output class for slow-waves detection.
Attributes
----------
_events : :py:class:`pandas.DataFrame`
Output detection dataframe
_data : array_like
EEG data of shape *(n_chan, n_samples)*.
_data_filt : array_like
Slow-wave filtered EEG data of shape *(n_chan, n_samples)*.
_sf : float
Sampling frequency of data.
_ch_names : list
Channel names.
_hypno : array_like or None
Sleep staging vector.
"""
[docs] def __init__(self, events, data, sf, ch_names, hypno, data_filt):
super().__init__(events, data, sf, ch_names, hypno, data_filt)
[docs] def summary(self, grp_chan=False, grp_stage=False, aggfunc='mean',
sort=True):
"""Return a summary of the SW detection, optionally grouped across
channels and/or stage.
Parameters
----------
grp_chan : bool
If True, group by channel (for multi-channels detection only).
grp_stage : bool
If True, group by sleep stage (provided that an hypnogram was
used).
aggfunc : str or function
Averaging function (e.g. ``'mean'`` or ``'median'``).
sort : bool
If True, sort group keys when grouping.
"""
return super().summary(event_type='sw',
grp_chan=grp_chan, grp_stage=grp_stage,
aggfunc=aggfunc, sort=sort)
[docs] def get_mask(self):
"""Return a boolean array indicating for each sample in data if this
sample is part of a detected event (True) or not (False).
"""
return super().get_mask()
[docs] def get_sync_events(self, center='NegPeak', time_before=0.4,
time_after=0.8, filt=(None, None)):
"""
Return the raw data of each detected event after
centering to a specific timepoint.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the negative peak of the slow-wave.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
Returns
-------
df_sync : :py:class:`pandas.DataFrame`
Ouput long-format dataframe::
'Event' : Event number
'Time' : Timing of the events (in seconds)
'Amplitude' : Raw or filtered data for event
'Channel' : Channel
'IdxChannel' : Index of channel in data
'Stage': Sleep stage in which the events occured (if available)
"""
return super().get_sync_events(center=center, time_before=time_before,
time_after=time_after,
filt=filt)
[docs] def plot_average(self, center='NegPeak', hue='Channel', time_before=0.4,
time_after=0.8, filt=(None, None), figsize=(6, 4.5),
**kwargs):
"""
Plot the average slow-wave.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the negative peak of the slow-wave.
hue : str
Grouping variable that will produce lines with different colors.
Can be either 'Channel' or 'Stage'.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
figsize : tuple
Figure size in inches.
**kwargs : dict
Optional argument that are passed to :py:func:`seaborn.lineplot`.
"""
return super().plot_average(event_type='sw', center=center,
hue=hue, time_before=time_before,
time_after=time_after, filt=filt,
figsize=figsize, **kwargs)
[docs] def plot_detection(self):
"""Plot an overlay of the detected slow-waves on the EEG signal.
This only works in Jupyter and it requires the ipywidgets
(https://ipywidgets.readthedocs.io/en/latest/) package.
To activate the interactive mode, make sure to run:
>>> %matplotlib widget
.. versionadded:: 0.4.0
"""
return super().plot_detection()
#############################################################################
# REMs DETECTION
#############################################################################
[docs]def rem_detect(loc, roc, sf, hypno=None, include=4, amplitude=(50, 325),
duration=(0.3, 1.2), freq_rem=(0.5, 5), remove_outliers=False,
verbose=False):
"""Rapid eye movements (REMs) detection.
This detection requires both the left EOG (LOC) and right EOG (LOC).
The units of the data must be uV. The algorithm is based on an amplitude
thresholding of the negative product of the LOC and ROC
filtered signal.
.. versionadded:: 0.1.5
Parameters
----------
loc, roc : array_like
Continuous EOG data (Left and Right Ocular Canthi, LOC / ROC) channels.
Unit must be uV.
.. warning::
The default unit of :py:class:`mne.io.BaseRaw` is Volts.
Therefore, if passing data from a :py:class:`mne.io.BaseRaw`,
you need to multiply the data by 1e6 to convert to micro-Volts
(1 V = 1,000,000 uV), e.g.:
>>> data = raw.get_data() * 1e6 # Make sure that data is in uV
sf : float
Sampling frequency of the data, in Hz.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is loaded, the
detection will only be applied to the value defined in
``include`` (default = REM sleep).
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Values in ``hypno`` that will be included in the mask. The default is
(4), meaning that the detection is applied on REM sleep.
This has no effect when ``hypno`` is None.
amplitude : tuple or list
Minimum and maximum amplitude of the peak of the REM.
Default is 50 uV to 325 uV.
duration : tuple or list
The minimum and maximum duration of the REMs.
Default is 0.3 to 1.2 seconds.
freq_rem : tuple or list
Frequency range of REMs. Default is 0.5 to 5 Hz.
remove_outliers : boolean
If True, YASA will automatically detect and remove outliers REMs
using :py:class:`sklearn.ensemble.IsolationForest`.
YASA uses a random seed (42) to ensure reproducible results.
Note that this step will only be applied if there are more than
50 detected REMs in the first place. Default to False.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
rem : :py:class:`yasa.REMResults`
To get the full detection dataframe, use:
>>> rem = rem_detect(...)
>>> rem.summary()
This will give a :py:class:`pandas.DataFrame` where each row is a
detected REM and each column is a parameter (= property).
To get the average parameters sleep stage:
>>> rem.summary(grp_stage=True)
Notes
-----
The parameters that are calculated for each REM are:
* ``'Start'``: Start of each detected REM, in seconds from the
beginning of data.
* ``'Peak'``: Location of the peak (in seconds of data)
* ``'End'``: End time (in seconds)
* ``'Duration'``: Duration (in seconds)
* ``'LOCAbsValPeak'``: LOC absolute amplitude at REM peak (in uV)
* ``'ROCAbsValPeak'``: ROC absolute amplitude at REM peak (in uV)
* ``'LOCAbsRiseSlope'``: LOC absolute rise slope (in uV/s)
* ``'ROCAbsRiseSlope'``: ROC absolute rise slope (in uV/s)
* ``'LOCAbsFallSlope'``: LOC absolute fall slope (in uV/s)
* ``'ROCAbsFallSlope'``: ROC absolute fall slope (in uV/s)
* ``'Stage'``: Sleep stage (only if hypno was provided)
Note that all the output parameters are computed on the filtered LOC and
ROC signals.
For better results, apply this detection only on artefact-free REM sleep.
References
----------
The rapid eye movements detection algorithm is based on:
* Agarwal, R., Takeuchi, T., Laroche, S., & Gotman, J. (2005).
`Detection of rapid-eye movements in sleep studies.
<https://doi.org/10.1109/TBME.2005.851512>`_
IEEE Transactions on Bio-Medical Engineering, 52(8), 1390–1396.
* Yetton, B. D., Niknazar, M., Duggan, K. A., McDevitt, E. A., Whitehurst,
L. N., Sattari, N., & Mednick, S. C. (2016). `Automatic detection of
rapid eye movements (REMs): A machine learning approach.
<https://doi.org/10.1016/j.jneumeth.2015.11.015>`_
Journal of Neuroscience Methods, 259, 72–82.
Examples
--------
For an example of how to run the detection, please refer to
https://github.com/raphaelvallat/yasa/blob/master/notebooks/07_REMs_detection.ipynb
"""
set_log_level(verbose)
# Safety checks
loc = np.squeeze(np.asarray(loc, dtype=np.float64))
roc = np.squeeze(np.asarray(roc, dtype=np.float64))
assert loc.ndim == 1, 'LOC must be 1D.'
assert roc.ndim == 1, 'ROC must be 1D.'
assert loc.size == roc.size, 'LOC and ROC must have the same size.'
data = np.vstack((loc, roc))
(data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan
) = _check_data_hypno(data, sf, ['LOC', 'ROC'], hypno, include)
# If all channels are bad
if any(bad_chan):
logger.warning('At least one channel has bad amplitude. '
'Returning None.')
return None
# Bandpass filter
data_filt = filter_data(data, sf, freq_rem[0], freq_rem[1], verbose=0)
# Calculate the negative product of LOC and ROC, maximal during REM.
negp = -data_filt[0, :] * data_filt[1, :]
# Find peaks in data
# - height: required height of peaks (min and max.)
# - distance: required distance in samples between neighboring peaks.
# - prominence: required prominence of peaks.
# - wlen: limit search for bases to a specific window.
hmin, hmax = amplitude[0]**2, amplitude[1]**2
pks, pks_params = signal.find_peaks(negp, height=(hmin, hmax),
distance=(duration[0] * sf),
prominence=(0.8 * hmin),
wlen=(duration[1] * sf))
# Intersect with sleep stage vector
# We do that before calculating the features in order to gain some time
idx_mask = np.where(mask)[0]
pks, idx_good, _ = np.intersect1d(pks, idx_mask, True, True)
for k in pks_params.keys():
pks_params[k] = pks_params[k][idx_good]
# If no peaks are detected, return None
if len(pks) == 0:
logger.warning('No REMs were found in data. Returning None.')
return None
# Hypnogram
if hypno is not None:
# The sleep stage at the beginning of the REM is considered.
rem_sta = hypno[pks_params['left_bases']]
else:
rem_sta = np.zeros(pks.shape)
# Calculate time features
pks_params['Start'] = pks_params['left_bases'] / sf
pks_params['Peak'] = pks / sf
pks_params['End'] = pks_params['right_bases'] / sf
pks_params['Duration'] = pks_params['End'] - pks_params['Start']
# Time points in minutes (HH:MM:SS)
# pks_params['StartMin'] = pd.to_timedelta(pks_params['Start'], unit='s').dt.round('s') # noqa
# pks_params['PeakMin'] = pd.to_timedelta(pks_params['Peak'], unit='s').dt.round('s') # noqa
# pks_params['EndMin'] = pd.to_timedelta(pks_params['End'], unit='s').dt.round('s') # noqa
# Absolute LOC / ROC value at peak (filtered)
pks_params['LOCAbsValPeak'] = abs(data_filt[0, pks])
pks_params['ROCAbsValPeak'] = abs(data_filt[1, pks])
# Absolute rising and falling slope
dist_pk_left = (pks - pks_params['left_bases']) / sf
dist_pk_right = (pks_params['right_bases'] - pks) / sf
locrs = (data_filt[0, pks] -
data_filt[0, pks_params['left_bases']]) / dist_pk_left
rocrs = (data_filt[1, pks] -
data_filt[1, pks_params['left_bases']]) / dist_pk_left
locfs = (data_filt[0, pks_params['right_bases']] -
data_filt[0, pks]) / dist_pk_right
rocfs = (data_filt[1, pks_params['right_bases']] -
data_filt[1, pks]) / dist_pk_right
pks_params['LOCAbsRiseSlope'] = abs(locrs)
pks_params['ROCAbsRiseSlope'] = abs(rocrs)
pks_params['LOCAbsFallSlope'] = abs(locfs)
pks_params['ROCAbsFallSlope'] = abs(rocfs)
pks_params['Stage'] = rem_sta # Sleep stage
# Convert to Pandas DataFrame
df = pd.DataFrame(pks_params)
# Make sure that the sign of ROC and LOC is opposite
df['IsOppositeSign'] = (np.sign(data_filt[1, pks]) !=
np.sign(data_filt[0, pks]))
df = df[np.sign(data_filt[1, pks]) != np.sign(data_filt[0, pks])]
# Remove bad duration
tmin, tmax = duration
good_dur = np.logical_and(pks_params['Duration'] >= tmin,
pks_params['Duration'] < tmax)
df = df[good_dur]
# Keep only useful channels
df = df[['Start', 'Peak', 'End', 'Duration', 'LOCAbsValPeak',
'ROCAbsValPeak', 'LOCAbsRiseSlope', 'ROCAbsRiseSlope',
'LOCAbsFallSlope', 'ROCAbsFallSlope', 'Stage']]
if hypno is None:
df = df.drop(columns=['Stage'])
else:
df['Stage'] = df['Stage'].astype(int)
# We need at least 100 detected REMs to apply the Isolation Forest.
if remove_outliers and df.shape[0] >= 50:
col_keep = ['Duration', 'LOCAbsValPeak', 'ROCAbsValPeak',
'LOCAbsRiseSlope', 'ROCAbsRiseSlope', 'LOCAbsFallSlope',
'ROCAbsFallSlope']
ilf = IsolationForest(contamination='auto', max_samples='auto',
verbose=0, random_state=42)
good = ilf.fit_predict(df[col_keep])
good[good == -1] = 0
logger.info('%i outliers were removed.', (good == 0).sum())
# Remove outliers from DataFrame
df = df[good.astype(bool)]
logger.info('%i REMs were found in data.', df.shape[0])
df = df.reset_index(drop=True)
return REMResults(events=df, data=data, sf=sf, ch_names=ch_names,
hypno=hypno, data_filt=data_filt)
[docs]class REMResults(_DetectionResults):
"""Output class for REMs detection.
Attributes
----------
_events : :py:class:`pandas.DataFrame`
Output detection dataframe
_data : array_like
EOG data of shape *(n_chan, n_samples)*, where the two channels are
LOC and ROC.
_data_filt : array_like
Filtered EOG data of shape *(n_chan, n_samples)*, where the two
channels are LOC and ROC.
_sf : float
Sampling frequency of data.
_ch_names : list
Channel names (= ``['LOC', 'ROC']``)
_hypno : array_like or None
Sleep staging vector.
"""
[docs] def __init__(self, events, data, sf, ch_names, hypno, data_filt):
super().__init__(events, data, sf, ch_names, hypno, data_filt)
[docs] def summary(self, grp_stage=False, aggfunc='mean', sort=True):
"""Return a summary of the REM detection, optionally grouped across
stage.
Parameters
----------
grp_stage : bool
If True, group by sleep stage (provided that an hypnogram was
used).
aggfunc : str or function
Averaging function (e.g. ``'mean'`` or ``'median'``).
sort : bool
If True, sort group keys when grouping.
"""
# ``grp_chan`` is always False for REM detection because the
# REMs are always detected on a combination of LOC and ROC.
return super().summary(event_type='rem', grp_chan=False,
grp_stage=grp_stage, aggfunc=aggfunc, sort=sort)
[docs] def get_mask(self):
"""Return a boolean array indicating for each sample in data if this
sample is part of a detected event (True) or not (False).
"""
# We cannot use super() because "Channel" is not present in _events.
from yasa.others import _index_to_events
mask = np.zeros(self._data.shape, dtype=int)
idx_ev = _index_to_events(
self._events[['Start', 'End']].to_numpy() * self._sf)
mask[:, idx_ev] = 1
return mask
[docs] def get_sync_events(self, center='Peak', time_before=0.4, time_after=0.4,
filt=(None, None)):
"""
Return the raw or filtered data of each detected event after
centering to a specific timepoint.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the peak of the REM.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
Returns
-------
df_sync : :py:class:`pandas.DataFrame`
Ouput long-format dataframe::
'Event' : Event number
'Time' : Timing of the events (in seconds)
'Amplitude' : Raw or filtered data for event
'Channel' : Channel
'IdxChannel' : Index of channel in data
"""
from yasa.others import get_centered_indices
assert time_before >= 0
assert time_after >= 0
bef = int(self._sf * time_before)
aft = int(self._sf * time_after)
if any(filt):
data = mne.filter.filter_data(self._data, self._sf, l_freq=filt[0],
h_freq=filt[1], method='fir',
verbose=False)
else:
data = self._data
time = np.arange(-bef, aft + 1, dtype='int') / self._sf
# Get location of peaks in data
peaks = (self._events[center] * self._sf).astype(int).to_numpy()
# Get centered indices (here we could use second channel as well).
idx, idx_valid = get_centered_indices(data[0, :], peaks, bef, aft)
# If no good epochs are returned raise a warning
assert len(idx_valid), (
'Time before and/or time after exceed data bounds, please '
'lower the temporal window around center.')
# Initialize empty dataframe
df_sync = pd.DataFrame()
# Loop across both EOGs (LOC and ROC)
for i, ch in enumerate(self._ch_names):
amps = data[i, idx]
df_chan = pd.DataFrame(amps.T)
df_chan['Time'] = time
df_chan = df_chan.melt(id_vars='Time', var_name='Event',
value_name='Amplitude')
df_chan['Channel'] = ch
df_chan['IdxChannel'] = i
df_sync = df_sync.append(df_chan, ignore_index=True)
return df_sync
[docs] def plot_average(self, center='Peak', time_before=0.4,
time_after=0.4, filt=(None, None), figsize=(6, 4.5),
**kwargs):
"""
Plot the average REM.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the peak of the REM.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
figsize : tuple
Figure size in inches.
**kwargs : dict
Optional argument that are passed to :py:func:`seaborn.lineplot`.
"""
import seaborn as sns
import matplotlib.pyplot as plt
df_sync = self.get_sync_events(center=center, time_before=time_before,
time_after=time_after, filt=filt)
# Start figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.lineplot(data=df_sync, x='Time', y='Amplitude', hue='Channel',
ax=ax, **kwargs)
# ax.legend(frameon=False, loc='lower right')
ax.set_xlim(df_sync['Time'].min(), df_sync['Time'].max())
ax.set_title("Average REM")
ax.set_xlabel('Time (sec)')
ax.set_ylabel('Amplitude (uV)')
return ax
#############################################################################
# ARTEFACT DETECTION
#############################################################################
[docs]def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4),
method='covar', threshold=3, n_chan_reject=1, verbose=False):
r"""
Automatic artifact rejection.
.. versionadded:: 0.2.0
Parameters
----------
data : array_like
Single or multi-channel EEG data.
Unit must be uV and shape *(n_chan, n_samples)*.
Can also be a :py:class:`mne.io.BaseRaw`, in which case ``data``
and ``sf`` will be automatically extracted,
and ``data`` will also be automatically converted from Volts (MNE)
to micro-Volts (YASA).
.. warning::
``data`` must only contains EEG channels. Please make sure to
exclude any EOG, EKG or EMG channels.
sf : float
Sampling frequency of the data in Hz.
Can be omitted if ``data`` is a :py:class:`mne.io.BaseRaw` object.
window : float
The window length (= resolution) for artifact rejection, in seconds.
Default to 5 seconds. Shorter windows (e.g. 1 or 2-seconds) will
drastically increase computation time when ``method='covar'``.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is passed, the
detection will be applied separately for each of the stages defined in
``include``.
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Sleep stages in ``hypno`` on which to perform the artifact rejection.
The default is ``hypno=(1, 2, 3, 4)``, meaning that the artifact
rejection is applied separately for all sleep stages, excluding wake.
This parameter has no effect when ``hypno`` is None.
method : str
Artifact detection method (see Notes):
* ``'covar'`` : Covariance-based, default for 4+ channels data
* ``'std'`` : Standard-deviation-based, default for single-channel data
threshold : float
The number of standard deviations above or below which an
epoch is considered an artifact. Higher values will result in a more
conservative detection, i.e. less rejected epochs.
n_chan_reject : int
The number of channels that must be below or above ``threshold`` on any
given epochs to consider this epoch as an artefact when
``method='std'``. The default is 1, which means that the epoch will
be marked as artifact as soon as one channel is above or below the
threshold. This may be too conservative when working with a large
number of channels (e.g.hdEEG) in which case users can increase
``n_chan_reject``. Note that this parameter only has an effect
when ``method='std'``.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
art_epochs : array_like
1-D array of shape *(n_epochs)* where 1 = Artefact and 0 = Good.
zscores : array_like
Array of z-scores, shape is *(n_epochs)* if ``method='covar'`` and
*(n_epochs, n_chan)* if ``method='std'``.
Notes
-----
.. caution::
This function will only detect major body artefacts present on the EEG
channel. It will not detect EKG contamination or eye blinks. For more
artifact rejection tools, please refer to the `MNE Python package
<https://mne.tools/stable/auto_tutorials/preprocessing/plot_10_preprocessing_overview.html>`_.
.. tip::
For best performance, apply this function on pre-staged data and make
sure to pass the hypnogram.
Sleep stages have very different EEG signatures
and the artifect rejection will be much more accurate when applied
separately on each sleep stage.
We provide below a short description of the different methods. For
multi-channel data, and if computation time is not an issue, we recommend
using ``method='covar'`` which uses a clustering approach on
variance-covariance matrices, and therefore takes into account
not only the variance in each channel and each epoch, but also the
inter-relationship (covariance) between channel.
``method='covar'`` is however not supported for single-channel EEG or when
less than 4 channels are present in ``data``. In these cases, one can
use the much faster ``method='std'`` which is simply based on a z-scoring
of the log-transformed standard deviation of each channel and each epoch.
**1/ Covariance-based multi-channel artefact rejection**
``method='covar'`` is essentially a wrapper around the
:py:class:`pyriemann.clustering.Potato` class implemented in the
`pyRiemann package
<https://pyriemann.readthedocs.io/en/latest/index.html>`_.
The main idea of this approach is to estimate a reference covariance
matrix :math:`\bar{C}` (for each sleep stage separately if ``hypno`` is
present) and reject every epoch which is too far from this reference
matrix.
The distance of the covariance matrix of the current epoch :math:`C`
from the reference matrix is calculated using Riemannian
geometry, which is more adapted than Euclidean geometry for
symmetric positive definite covariance matrices:
.. math:: d = {\left( \sum_i \log(\lambda_i)^2 \right)}^{-1/2}
where :math:`\lambda_i` are the joint eigenvalues of :math:`C` and
:math:`\bar{C}`. The epoch with covariance matric :math:`C`
will be marked as an artifact if the distance :math:`d`
is greater than a threshold :math:`T`
(typically 2 or 3 standard deviations).
:math:`\bar{C}` is iteratively estimated using a clustering approach.
**2/ Standard-deviation-based single and multi-channel artefact rejection**
``method='std'`` is a much faster and straightforward approach which
is simply based on the distribution of the standard deviations of each
epoch. Specifically, one first calculate the standard
deviations of each epoch and each channel. Then, the resulting array of
standard deviations is log-transformed and z-scored (for each sleep
stage separately if ``hypno`` is present). Any epoch with one or more
channel exceeding the threshold will be marked as artifact.
Note that this approach is more sensitive to noise and/or the influence of
one bad channel (e.g. electrode fell off at some point during the night).
We therefore recommend that you visually inspect and remove any bad
channels prior to using this function.
References
----------
* Barachant, A., Andreev, A., & Congedo, M. (2013). `The Riemannian
Potato: an automatic and adaptive artifact detection method for online
experiments using Riemannian geometry.
<https://hal.archives-ouvertes.fr/hal-00781701/>`_ TOBI
Workshop lV, 19–20.
* Barthélemy, Q., Mayaud, L., Ojeda, D., & Congedo, M. (2019).
`The Riemannian Potato Field: A Tool for Online Signal Quality Index of
EEG. <https://doi.org/10.1109/TNSRE.2019.2893113>`_
IEEE Transactions on Neural Systems and Rehabilitation Engineering:
A Publication of the IEEE Engineering in Medicine and Biology Society,
27(2), 244–255.
* https://pyriemann.readthedocs.io/en/latest/index.html
Examples
--------
For an example of how to run the detection, please refer to
https://github.com/raphaelvallat/yasa/blob/master/notebooks/13_artifact_rejection.ipynb
"""
###########################################################################
# PREPROCESSING
###########################################################################
set_log_level(verbose)
(data, sf, _, hypno, include, _, n_chan, n_samples, _
) = _check_data_hypno(data, sf, ch_names=None, hypno=hypno,
include=include, check_amp=False)
assert isinstance(n_chan_reject, int), 'n_chan_reject must be int.'
assert n_chan_reject >= 1, 'n_chan_reject must be >= 1.'
assert n_chan_reject <= n_chan, 'n_chan_reject must be <= n_chan.'
# Safety check: sampling frequency and window
assert isinstance(sf, (int, float)), 'sf must be int or float'
assert isinstance(window, (int, float)), 'window must be int or float'
if isinstance(sf, float):
assert sf.is_integer(), 'sf must be a whole number.'
sf = int(sf)
win_sec = window
window = win_sec * sf # Convert window to samples
if isinstance(window, float):
assert window.is_integer(), 'window * sf must be a whole number.'
window = int(window)
# Safety check: hypnogram
if hypno is not None:
# Extract hypnogram with only complete epochs
idx_max_full_epoch = int(np.floor(n_samples / window))
hypno_win = hypno[::window][:idx_max_full_epoch]
# Safety checks: methods
assert isinstance(method, str), "method must be a string."
method = method.lower()
if method in ['cov', 'covar', 'covariance', 'riemann', 'potato']:
method = 'covar'
is_pyriemann_installed()
from pyriemann.estimation import Covariances, Shrinkage
from pyriemann.clustering import Potato
# Must have at least 4 channels to use method='covar'
if n_chan <= 4:
logger.warning("Must have at least 4 channels for method='covar'. "
"Automatically switching to method='std'.")
method = 'std'
###########################################################################
# START THE REJECTION
###########################################################################
# Remove flat channels
isflat = (np.nanstd(data, axis=-1) == 0)
if isflat.any():
logger.warning('Flat channel(s) were found and removed in data.')
data = data[~isflat]
n_chan = data.shape[0]
# Epoch the data (n_epochs, n_chan, n_samples)
_, epochs = sliding_window(data, sf, window=win_sec)
n_epochs = epochs.shape[0]
# We first need to identify epochs with flat data (n_epochs, n_chan)
isflat = (epochs == epochs[:, :, 1][..., None]).all(axis=-1)
# 1 when all channels are flat, 0 when none ar flat (n_epochs)
prop_chan_flat = isflat.sum(axis=-1) / n_chan
# If >= 50% of channels are flat, automatically mark as artefact
epoch_is_flat = prop_chan_flat >= 0.5
where_flat_epochs = np.nonzero(epoch_is_flat)[0]
n_flat_epochs = where_flat_epochs.size
# Now let's make sure that we have an hypnogram and an include variable
if 'hypno_win' not in locals():
# [-2, -2, -2, -2, ...], where -2 stands for unscored
hypno_win = -2 * np.ones(n_epochs, dtype='float')
include = np.array([-2], dtype='float')
# We want to make sure that hypno-win and n_epochs have EXACTLY same shape
assert n_epochs == hypno_win.shape[-1], 'Hypno and epochs do not match.'
# Finally, we make sure not to include any flat epochs in calculation
# just using a random number that is unlikely to be picked by users
if n_flat_epochs > 0:
hypno_win[where_flat_epochs] = -111991
# Add logger info
logger.info('Number of channels in data = %i', n_chan)
logger.info('Number of samples in data = %i', n_samples)
logger.info('Sampling frequency = %.2f Hz', sf)
logger.info('Data duration = %.2f seconds', n_samples / sf)
logger.info('Number of epochs = %i' % n_epochs)
logger.info('Artifact window = %.2f seconds' % win_sec)
logger.info('Method = %s' % method)
logger.info('Threshold = %.2f standard deviations' % threshold)
# Create empty `hypno_art` vector (1 sample = 1 epoch)
epoch_is_art = np.zeros(n_epochs, dtype='int')
if method == 'covar':
# Calculate the covariance matrices,
# shape (n_epochs, n_chan, n_chan)
covmats = Covariances().fit_transform(epochs)
# Shrink the covariance matrix (ensure positive semi-definite)
covmats = Shrinkage().fit_transform(covmats)
# Define Potato instance: 0 = clean, 1 = art
# To increase speed we set the max number of iterations from 10 to 100
potato = Potato(metric='riemann', threshold=threshold, pos_label=0,
neg_label=1, n_iter_max=10)
# Create empty z-scores output (n_epochs)
zscores = np.zeros(n_epochs, dtype='float') * np.nan
for stage in include:
where_stage = np.where(hypno_win == stage)[0]
# At least 30 epochs are required to calculate z-scores
# which amounts to 2.5 minutes when using 5-seconds window
if where_stage.size < 30:
if hypno is not None:
# Only show warnig if user actually pass an hypnogram
logger.warning(f"At least 30 epochs are required to "
f"calculate z-score. Skipping "
f"stage {stage}")
continue
# Apply Potato algorithm, extract z-scores and labels
zs = potato.fit_transform(covmats[where_stage])
art = potato.predict(covmats[where_stage]).astype(int)
if hypno is not None:
# Only shows if user actually pass an hypnogram
perc_reject = 100 * (art.sum() / art.size)
text = (f"Stage {stage}: {art.sum()} / {art.size} "
f"epochs rejected ({perc_reject:.2f}%)")
logger.info(text)
# Append to global vector
epoch_is_art[where_stage] = art
zscores[where_stage] = zs
elif method in ['std', 'sd']:
# Calculate log-transformed standard dev in each epoch
# We add 1 to avoid log warning id std is zero (e.g. flat line)
# (n_epochs, n_chan)
std_epochs = np.log(np.nanstd(epochs, axis=-1) + 1)
# Create empty zscores output (n_epochs, n_chan)
zscores = np.zeros((n_epochs, n_chan), dtype='float') * np.nan
for stage in include:
where_stage = np.where(hypno_win == stage)[0]
# At least 30 epochs are required to calculate z-scores
# which amounts to 2.5 minutes when using 5-seconds window
if where_stage.size < 30:
if hypno is not None:
# Only show warnig if user actually pass an hypnogram
logger.warning(f"At least 30 epochs are required to "
f"calculate z-score. Skipping "
f"stage {stage}")
continue
# Calculate z-scores of STD for each channel x stage
c_mean = np.nanmean(std_epochs[where_stage], axis=0,
keepdims=True)
c_std = np.nanstd(std_epochs[where_stage], axis=0,
keepdims=True)
zs = (std_epochs[where_stage] - c_mean) / c_std
# Any epoch with at least X channel above or below threshold
n_chan_supra = (np.abs(zs) > threshold).sum(axis=1) # >
art = (n_chan_supra >= n_chan_reject).astype(int) # >= !
if hypno is not None:
# Only shows if user actually pass an hypnogram
perc_reject = 100 * (art.sum() / art.size)
text = (f"Stage {stage}: {art.sum()} / {art.size} "
f"epochs rejected ({perc_reject:.2f}%)")
logger.info(text)
# Append to global vector
epoch_is_art[where_stage] = art
zscores[where_stage, :] = zs
# Mark flat epochs as artefacts
if n_flat_epochs > 0:
logger.info(f"Rejecting {n_flat_epochs} epochs with >=50% of channels "
f"that are flat. Z-scores set to np.nan for these epochs.")
epoch_is_art[where_flat_epochs] = 1
# Log total percentage of epochs rejected
perc_reject = 100 * (epoch_is_art.sum() / n_epochs)
text = (f"TOTAL: {epoch_is_art.sum()} / {n_epochs} "
f"epochs rejected ({perc_reject:.2f}%)")
logger.info(text)
# Convert epoch_is_art to boolean [0, 0, 1] -- > [False, False, True]
epoch_is_art = epoch_is_art.astype(bool)
return epoch_is_art, zscores