Source code for parse_pd

# -*- coding: utf-8 -*-
"""Find photodiode events.

Take a potentially corrupted photodiode channel and find
the event time samples at which it turned on.
"""
# Authors: Alex Rockhill <aprockhill@mailbox.org>
#
# License: BSD (3-clause)

import os
import os.path as op
import numpy as np
from tqdm import tqdm

import mne


def _read_tsv(fname):
    """Read tab-separated value file data."""
    if op.splitext(fname)[-1] != '.tsv':
        raise ValueError(f'Unable to read {fname}, tab-separated-value '
                         '(tsv) is required.')
    if op.getsize(fname) == 0:
        raise ValueError(f'Error in reading tsv, file {fname} empty')
    df = dict()
    with open(fname, 'r') as fid:
        headers = fid.readline().rstrip().split('\t')
        for header in headers:
            df[header] = list()
        for line in fid:
            line_data = line.rstrip().split('\t')
            if len(line_data) != len(headers):
                raise ValueError(f'Error with file {fname}, the columns are '
                                 'different lengths')
            for i, data in enumerate(line_data):
                numeric = all([c.isdigit() or c in ('.', '-')
                               for c in data])
                if numeric:
                    if data.isdigit():
                        df[headers[i]].append(int(data))
                    else:
                        df[headers[i]].append(float(data))
                else:
                    df[headers[i]].append(data)
    if any([not val for val in df.values()]):  # no empty lists
        raise ValueError(f'Error in reading tsv, file {fname} '
                         'contains no data')
    return df


def _to_tsv(fname, df):
    """Write tab-separated value file data."""
    if op.splitext(fname)[-1] != '.tsv':
        raise ValueError(f'Unable to write to {fname}, tab-separated-value '
                         '(tsv) is required.')
    if len(df.keys()) == 0:
        raise ValueError('Empty data file, no keys')
    first_column = list(df.keys())[0]
    with open(fname, 'w') as fid:
        fid.write('\t'.join([str(k) for k in df.keys()]) + '\n')
        for i in range(len(df[first_column])):
            fid.write('\t'.join([str(val[i]) for val in df.values()]) + '\n')


def _read_raw(raw, preload=None, verbose=True):
    """Read raw object from file if it's not already loaded."""
    if isinstance(raw, mne.io.BaseRaw):
        if preload:
            raw.load_data()
        elif preload is not None:
            if raw.preload:
                raise ValueError('`raw` object cannot be preloaded')
        if raw.filenames[0] is None:
            raise ValueError('`raw` object must be loaded from disk')
    else:
        _, ext = op.splitext(raw)
        """Read raw data into an mne.io.Raw object."""
        if verbose:
            print('Reading in {}'.format(raw))
        if ext == '.fif':
            raw = mne.io.read_raw_fif(raw, preload=preload)
        elif ext == '.edf':
            raw = mne.io.read_raw_edf(raw, preload=preload)
        elif ext == '.bdf':
            raw = mne.io.read_raw_bdf(raw, preload=preload)
        elif ext == '.vhdr':
            raw = mne.io.read_raw_brainvision(raw, preload=preload)
        elif ext == '.set':
            raw = mne.io.read_raw_eeglab(raw, preload=preload)
        else:
            raise ValueError('Extension {} not recognized, options are'
                             'fif, edf, bdf, vhdr (brainvision), set '
                             '(eeglab)'.format(ext))
    return raw


def _load_beh(beh, beh_key):
    """Load the behavioral data frame and check columns."""
    if not isinstance(beh, dict):
        beh = _read_tsv(beh)
    if beh_key not in beh:
        raise ValueError(f'`beh_key` {beh_key} not in the columns of '
                         f'the `beh` behavior dictionary. Please check '
                         'that the correct column is provided')
    for beh_e in beh[beh_key]:
        if beh_e != 'n/a' and not isinstance(beh_e, (int, float)):
            raise ValueError('Expected numeric value or \'n/a\' for '
                             f'behavior values, got {beh_e}')
    return np.array([np.nan if beh_e == 'n/a' else beh_e
                     for beh_e in beh[beh_key]]), beh


def _get_channel_data(raw, ch_names):
    """Get the time-series data from the channel names."""
    if any([ch not in raw.ch_names for ch in ch_names]):
        raise ValueError(f'Not all pd_ch_names, {ch_names}, '
                         'in raw channel names')
    ch_raw = raw.copy().pick_channels(ch_names).load_data()
    ch_data = ch_raw._data[0] - ch_raw._data[1] if len(ch_names) == 2 \
        else ch_raw._data[0]
    ch_data -= np.median(ch_data)
    return ch_data


def _get_data(raw, ch_names):
    """Get the names of the photodiode channels from the user."""
    # if pd_ch_names provided
    if ch_names is not None:
        if any([ch not in raw.ch_names for ch in ch_names]):
            raise ValueError(f'Not all channel names, {ch_names}, '
                             'in raw channel names')
    else:  # if no pd_ch_names provided
        ch_names = input('Enter channel names separated by a '
                         'comma or type "plot" to plot the data first:\t')
        if ch_names.lower() == 'plot':
            raw.plot()
        n_chs = 0 if ch_names == 'plot' else len(ch_names.split(','))
        while n_chs not in (1, 2) or not all([ch.strip() in raw.ch_names for
                                              ch in ch_names.split(',')]):
            ch_names = input('Enter channel names separated by a comma:\t')
            for ch in ch_names.split(','):
                if not ch.strip() in raw.ch_names:
                    print(f'{ch.strip()} not in raw channel names')
            n_chs = len(ch_names.split(','))
            if n_chs > 2:
                print(f'{n_chs} is too many names, enter 1 name '
                      'for common referenced photodiode data or '
                      '2 names for bipolar reference')
        ch_names = [ch.strip() for ch in ch_names.split(',')]
    # get pd data using channel names
    ch_data = _get_channel_data(raw, ch_names)
    return ch_data, ch_names


def _check_if_pd_event(pd_diff, i, max_len_i, zscore, max_flip_i,
                       baseline_std):
    """Take one stretch of data and determine if there is an event there.

    Use almost all events for on/off due to noise in photodiode causing
    the last event to hop under the threshold and back.
    """
    s = pd_diff[i:i + max_len_i].copy()
    s -= np.median(s)
    s /= baseline_std
    binned_s = np.digitize(s, [-np.inf, -zscore, zscore, np.inf]) - 2
    for direction, binary_s in {'up': binned_s, 'down': -binned_s}.items():
        onset = np.where(binary_s == 1)[0]
        # must be flip on but can't flip back and forth
        if onset.size > 0 and onset.size < max_flip_i:
            e = onset[0]
            almost_all_on = sum(binary_s[onset[0]:onset[-1]]) >= onset.size - 2
            if all(binary_s[:e] == 0) and almost_all_on:  # must start off
                # must have an offset and no more events
                offset = np.where(binary_s[e:] == -1)[0]
                if offset.size > 0 and offset.size < max_flip_i:
                    o = offset[0]
                    almost_all_off = -sum(binary_s[e + o:e + offset[-1]]) \
                        >= offset.size - 2
                    almost_all_zero = \
                        sum(abs(binary_s[e + o + max_flip_i:])) <= 1
                    if almost_all_zero and almost_all_off:
                        return direction, i + e + 1, i + e + o + 2
    return None, None, None


def _find_pd_candidates(pd, max_len, baseline, zscore,
                        max_flip_i, sfreq, verbose=True):
    """Find all points in the signal that look like a square wave."""
    if verbose:
        print('Finding photodiode events')
    max_len_i = np.round(sfreq * max_len).astype(int)
    baseline_i = np.round(max_len_i * baseline).astype(int)
    # zscore photodiode based on baseline values
    pd_diff = np.diff(pd)
    pd_diff -= np.median(pd_diff)
    median_std = np.median([np.std(pd_diff[i - baseline_i:i]) for i in
                            range(baseline_i, len(pd_diff) - baseline_i,
                                  baseline_i)])
    # find indices to check based on being the first above zscore
    check_i = set(np.where(abs(pd_diff) / median_std > zscore)[0])
    check_remove = set()
    for i in check_i:
        if i + 1 in check_i:
            check_remove.add(i + 1)
    check_i = check_i.difference(check_remove)
    # check for clean onset and offset
    pd_candidates = dict(up=list(), down=list(),
                         up_off=list(), down_off=list())
    for i in tqdm(sorted(list(check_i))):
        direction, onset, offset = _check_if_pd_event(
            pd_diff, i, max_len_i, zscore, max_flip_i, median_std)
        # no events immediately following (caused by noise)
        if onset is not None:
            in_flip = (onset - pd_candidates[direction][-1]) < max_flip_i if \
                pd_candidates[direction] else False
            if not in_flip:
                pd_candidates[direction].append(onset)
                pd_candidates[direction + '_off'].append(offset)
    this_dir = 'down' if \
        len(pd_candidates['down']) > len(pd_candidates['up']) else 'up'
    pd_candidates = pd_candidates[this_dir], pd_candidates[this_dir + '_off']
    if len(pd_candidates[0]) == 0:
        raise ValueError('No photodiode candidates found, please raise an '
                         'issue with code to reproduce the error on GitHub')
    if verbose:
        print(f'{len(pd_candidates[0])} {this_dir}-deflection photodiode '
              'candidate events found')
    return (np.array(sorted(pd_candidates[0])),
            np.array(sorted(pd_candidates[1])))


def _get_audio_zscore(audio, fs):
    import matplotlib as mpl
    mpl.rcParams['toolbar'] = 'None'
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(6, 6))
    fig.subplots_adjust(top=0.75, left=0.15)

    def scale(event):
        amount = 0.95 if event.key in ('left', 'down') else 1.25
        if event.key in ('up', 'down'):
            ymin, ymax = ax.get_ylim()
            # ymin < 0 and ymax > 0 because median subtracted
            ymin *= amount
            ymax *= amount
            ax.set_ylim([ymin, ymax])
        elif event.key in ('left', 'right'):
            xmin, xmax = ax.get_xlim()
            # ymin < 0 and ymax > 0 because median subtracted
            xmin /= amount
            xmax *= amount
            if xmin < xmax:
                ax.set_xlim([xmin, xmax])
        fig.canvas.draw()

    ax.set_title(
        'Use the left/right keys to scale time on the x axis '
        '\nand use the up/down keys to zoom the yaxis in and out'
        '\nfind a y value that includes all the events'
        '\nit is recommended to choose the lowest value that is above baseline'
        '\nclose the window when finished')
    ax.set_xlabel('time (s)')
    ax.set_ylabel('zscore')
    ax.plot(audio, color='b')
    xmin = max([0, audio.size // 2 - 10 * fs])
    xmax = min([audio.size, audio.size // 2 + 10 * fs])
    ax.set_xlim(xmin, xmax)
    ax.set_xticks(np.linspace(0, audio.size, 5))
    ax.set_xticklabels(np.round(np.linspace(0, audio.size / fs, 5), 2))
    ax.set_ylim(audio.min() * 0.9, audio.max() * 1.25)
    fig.canvas.mpl_connect('key_press_event', scale)
    fig.show()

    # get user input of zscore
    zscore = None
    while zscore is None:
        zscore = input('What zscore should be used? ')
        if not zscore or any([not d.isdigit() and d != '.' for d in zscore]):
            print('A positive number input is required for zscore')
            zscore = None
        else:
            zscore = float(zscore)
    return zscore


def _find_audio_candidates(audio, max_len, zscore, sfreq, verbose=True):
    if verbose:
        print('Finding points where the audio is above `zscore` threshold...')
    max_len_i = np.round(max_len * sfreq).astype(int)
    audio = abs((audio - np.median(audio)) / audio.std())
    if zscore is None:
        zscore = _get_audio_zscore(audio, sfreq)
    candidates = np.where(audio > zscore)[0]
    delete_indices = list()
    for i, candidate in enumerate(candidates):
        if any(audio[candidate - max_len_i: candidate] > zscore):
            delete_indices.append(i)
    candidates = np.delete(candidates, delete_indices)
    if verbose:
        print(f'{len(candidates)} audio candidate events found')
    return candidates


def _event_dist(beh_e, candidates_set, max_samp, resync_i):
    """Find the shortest distance from the behavioral event to a pd event."""
    j = 0
    if np.isnan(beh_e):
        return np.nan, np.nan
    beh_e = np.round(beh_e).astype(int)
    while beh_e + j < max_samp + resync_i and beh_e - j > 0 and j < resync_i:
        if beh_e - j in candidates_set:
            return j, beh_e - j
        if beh_e + j in candidates_set:
            return -j, beh_e + j
        j += 1
    return np.nan, np.nan


def _check_alignment(beh_events, alignment, candidates, candidates_set,
                     resync_i, check_i=None):
    """Check the alignment, account for misalignment accumulation."""
    check_i = resync_i if check_i is None else check_i
    beh_events = beh_events.copy()  # don't modify original
    events = np.zeros((beh_events.size))
    start = np.argmin([abs(beh_e - candidates).min()
                       for beh_e in beh_events + alignment])
    for i, beh_e in enumerate(beh_events[start:]):
        error, events[start + i] = \
            _event_dist(beh_e + alignment, candidates_set, candidates[-1],
                        check_i)
        if abs(error) <= resync_i and start + i + 1 < beh_events.size:
            beh_events[start + i + 1:] -= error
    for i, beh_e in enumerate(beh_events[:start][::-1]):
        error, events[start - i - 1] = \
            _event_dist(beh_e + alignment, candidates_set, candidates[-1],
                        check_i)
        if abs(error) <= resync_i and start - i - 2 > 0:
            beh_events[:start - i - 2] -= error
    return beh_events, events


def _plot_trial_errors(beh_events, alignment, events,
                       errors, exclude_shift, sfreq):
    """Plot the synchronization error on every trial."""
    import matplotlib.pyplot as plt
    fig, (ax, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    # plot scrollable dot pattern to first pass check alignment
    beh_events_s = (beh_events + alignment) / sfreq
    beh_d = np.diff(beh_events_s).mean()
    ax.scatter(beh_events_s, np.repeat(-1, beh_events.size))
    ax.scatter(events / sfreq, np.repeat(1, events.size))
    ax.set_xlim([beh_events_s.min() - beh_d, beh_events_s[:10].max() + beh_d])
    ax.set_ylim([-5, 5])
    ax.set_xlabel('Time (s)')
    ax.set_yticks([-1, 1])
    ax.set_yticklabels(['Beh', 'Sync'])
    ax.set_title('Alignment (First 10)')
    # plot the difference between expected and adjusted behavior
    errors = errors.copy()  # don't modify the original
    # don't show huge errors
    errors[abs(errors) / sfreq > 2 * exclude_shift] = np.nan
    ax2.plot(errors / sfreq * 1000)
    exclude_shift_a = np.array([exclude_shift, exclude_shift]) * 1000
    ax2.plot([0, errors.size], exclude_shift_a, color='r')
    ax2.plot([0, errors.size], -exclude_shift_a, color='r')
    ax2.set_ylabel('Difference (ms)')
    ax2.set_xlabel('Trial')
    ax2.set_title('Event Differences')
    fig.tight_layout()
    fig.show()


def _find_best_alignment(beh_events, candidates, exclude_shift, resync,
                         sfreq, verbose=True):
    """Find the beh event that causes the best alignment when used to start."""
    beh_adjusted = np.zeros((beh_events.size))
    events = np.zeros((beh_events.size))
    beh_idx = np.where(~np.isnan(beh_events))[0]
    missing_idx = np.where(np.isnan(beh_events))[0]
    beh_events = beh_events[~np.isnan(beh_events)]  # can't use missing
    resync_i = np.round(sfreq * resync).astype(int)
    min_error = best_alignment = None
    bin_size = np.diff(beh_events).min() / 2
    candidates_set = set(candidates)
    if verbose:
        print('Checking best alignments')
    for beh_e in tqdm(beh_events):
        this_min_error = alignment = None
        for sync_e in candidates:
            bins = np.zeros((2 * beh_events.size))
            bins[::2] = beh_events - beh_e - bin_size / 2
            bins[1::2] = beh_events - beh_e + bin_size / 2
            indices = np.digitize(candidates - sync_e, bins=bins)
            matched_b = \
                beh_events[(indices[indices % 2 == 1] - 1) // 2] - beh_e
            matched_c = candidates[indices % 2 == 1] - sync_e
            unmatched_b = beh_events.size - \
                np.unique(indices[indices % 2 == 1]).size
            errors = abs(matched_b - matched_c)
            error = np.median(errors) + bin_size * unmatched_b
            if this_min_error is None or this_min_error > error:
                alignment = sync_e - beh_e
                this_min_error = error
        beh_events_adjusted, these_events = _check_alignment(
            beh_events, alignment, candidates,
            candidates_set, resync_i)
        errors = beh_events_adjusted - these_events + alignment
        error = np.nansum(abs(errors)) + \
            resync_i * errors[np.isnan(errors)].size
        if min_error is None or error < min_error:
            min_error = error
            best_alignment = alignment
    best_beh_events_adjusted, best_events = _check_alignment(
        beh_events, best_alignment, candidates, candidates_set, resync_i,
        check_i=3 * resync_i)  # get all errors even if more than resync away
    if verbose:
        best_errors = best_beh_events_adjusted - best_events + best_alignment
        errors = best_errors[~np.isnan(best_errors)] / sfreq * 1000
        errors = errors[abs(errors) < resync * 1000]
        n_missed_events = beh_events.size - errors.size
        beh0 = beh_events[~np.isnan(beh_events)][0]
        shift = (beh0 + best_alignment - candidates[0]) / sfreq
        print('Best alignment is with the first behavioral event shifted '
              '{:.2f} s relative to the first synchronization event and '
              'has errors: min {:.2f} ms, q1 {:.2f} ms, med {:.2f} ms, '
              'q3 {:.2f} ms, max {:.2f} ms, {:d} missed events'.format(
                  shift, min(errors), np.quantile(errors, 0.25),
                  np.median(errors), np.quantile(errors, 0.75),
                  max(errors), n_missed_events))
        _plot_trial_errors(beh_events, best_alignment, best_events,
                           best_errors, exclude_shift, sfreq)
    beh_adjusted[beh_idx] = best_beh_events_adjusted
    beh_adjusted[missing_idx] = np.nan
    events[beh_idx] = best_events
    events[missing_idx] = np.nan
    return beh_adjusted, best_alignment, events


def _recover_event(idx, ch_data, beh_e, exclude_shift,
                   zscore, max_len, sfreq):
    """Recover with a corrupted baseline or plateau but not on/offset."""
    import matplotlib.pyplot as plt
    beh_e_i = np.round(beh_e).astype(int)
    max_len_i = np.round(max_len * sfreq).astype(int)
    exclude_shift_i = np.round(exclude_shift * sfreq).astype(int)
    section = np.diff(ch_data[beh_e_i - exclude_shift_i:
                              beh_e_i + exclude_shift_i])
    baseline = np.diff(ch_data[beh_e_i - 2 * exclude_shift_i:
                               beh_e_i - exclude_shift_i])
    section = (section - np.median(baseline)) / baseline.std()
    check_i = set(np.where(abs(section) > zscore)[0])
    check_remove = set()
    for i in check_i:
        if i + 1 in check_i:
            check_remove.add(i + 1)
    check_i = check_i.difference(check_remove)
    if len(check_i) == 0:
        return np.nan, f'{idx}\nnone found to recover'
    elif len(check_i) > 3:  # only can recover 3, don't overwhelm user
        return np.nan, f'{idx}\ntoo many ({len(check_i)}) to recover'
    event, text = np.nan, f'{idx}\nrecovered but discarded'
    for i in check_i:
        sync_e = i + beh_e_i - exclude_shift_i
        fig, ax = plt.subplots()
        section_size = np.round(1.5 * max_len_i).astype(int)
        section = ch_data[sync_e - section_size: sync_e + section_size]
        ax.plot(np.linspace(-1.5 * max_len, 1.5 * max_len, section.size),
                section)
        ax.plot([0, 0], [section.min(), section.max()])
        ax.set_title(f'Corrupted Event {idx}')
        ax.set_xlabel('time (s)')
        ax.set_ylabel('voltage')
        fig.show()
        if input('Recover event? (y/N) ').lower() == 'y':
            return sync_e, f'{idx}\nrecovered (not excluded)'
    return event, text


def _plot_excluded_events(section_data, max_len):
    """Plot events that were more than `exclude_shift` away."""
    import matplotlib.pyplot as plt
    n_events_ex = len(section_data)
    if not n_events_ex:
        return
    nrows = int(n_events_ex**0.5)
    ncols = int(np.ceil(n_events_ex / nrows))
    fig, axes = plt.subplots(nrows, ncols, figsize=(nrows * 10,
                                                    ncols * 5))
    fig.suptitle('Excluded Events')
    fig.subplots_adjust(hspace=0.75, wspace=0.5)
    if nrows == 1 and ncols == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    for ax in axes[n_events_ex:]:
        ax.axis('off')  # turn off all unused axes
    ymax = np.quantile([abs(sect[2]).max() for sect in section_data
                        if sect[2].size > 0], 0.25) * 1.1
    for i, (event, title, section) in enumerate(section_data):
        axes[i].plot(np.linspace(-1, 1, section.size), section)
        axes[i].plot([0, 0], [-ymax, ymax], color='r')
        axes[i].set_ylim([-ymax, ymax])
        axes[i].set_title(title, fontsize=12)
        if i % ncols == 0:
            axes[i].set_ylabel('voltage')
        axes[i].set_yticks([])
        if i // ncols == nrows - 1:
            axes[i].set_xticks(np.linspace(-1, 1, 3))
            axes[i].set_xticklabels(
                np.round(np.linspace(-2 * max_len, 2 * max_len, 3), 2))
            axes[i].set_xlabel('time (s)')
        else:
            axes[i].set_xticks([])
    fig.show()


def _exclude_ambiguous_events(beh_events, alignment, events, ch_data,
                              candidates, exclude_shift, max_len, sfreq,
                              recover, zscore, verbose=True):
    """Exclude all events that are outside the given shift compared to beh."""
    if verbose:
        section_data = list()
        print('Excluding events that have zero close synchronization events '
              'or more than one synchronization event within `max_len` time')
    max_len_i = np.round(sfreq * max_len).astype(int)
    exclude_shift_i = np.round(sfreq * exclude_shift).astype(int)
    for i, (beh_e, sync_e) in enumerate(zip(beh_events, events)):
        error = beh_e - sync_e + alignment
        if np.abs(error) < exclude_shift_i:
            n_events = np.logical_and(candidates > (sync_e - max_len_i),
                                      candidates < (sync_e + max_len_i)).sum()
            if n_events > 1:
                events[i] = np.nan
                text = (f'{i}\n{n_events} sync events found')
                if recover:
                    events[i], text = _recover_event(
                        i, ch_data, beh_e + alignment, exclude_shift, zscore,
                        max_len, sfreq)
                if verbose:
                    print(text.replace('\n', ' '))
                    event = np.round(beh_e + alignment).astype(int)
                    section_data.append(
                        (beh_e, text, ch_data[event - 2 * max_len_i:
                                              event + 2 * max_len_i]))
        elif not np.isnan(beh_e):
            if recover:
                events[i], text = _recover_event(
                    i, ch_data, beh_e + alignment, exclude_shift, zscore,
                    max_len, sfreq)
            else:
                events[i] = np.nan
                # if off by a less than max_len, report samples
                text = f'{i}\noff by {int(error / sfreq * 1000)} ms' \
                    if abs(error) < max_len_i else f'{i}\nnone found'
            if verbose:
                print(text.replace('\n', ' '))
                event = np.round(beh_e + alignment).astype(int)
                section_data.append(
                    (beh_e, text, ch_data[event - 2 * max_len_i:
                                          event + 2 * max_len_i]))
    if verbose:
        _plot_excluded_events(section_data, max_len)
    return events


def _save_data(raw, events, event_id, ch_names, beh=None,
               add_events=False, overwrite=False):
    """Save the events determined from the photodiode."""
    fname = raw.filenames[0]
    basename = op.splitext(op.basename(fname))[0]
    out_dir = op.join(op.dirname(fname), basename + '_pd_parser_data')
    if not op.isdir(out_dir):
        os.makedirs(out_dir)
    behf = op.join(out_dir, basename + '_beh_df.tsv')
    if beh is None:
        if op.isfile(behf) and overwrite:
            os.remove(behf)
    else:
        if 'pd_parser_sample' in beh and not add_events and not overwrite:
            raise ValueError(
                'The key (column name) `pd_parser_sample` is not allowed '
                'in the behavior tsv file (it\'s reserved for internal use. '
                'Please rename that key (column) to continue.')
        if not add_events:
            beh['pd_parser_sample'] = ['n/a' if np.isnan(e) else int(e) for
                                       e in events]
            _to_tsv(behf, beh)
    onsets = events[~np.isnan(events)].astype(int)
    annot = mne.Annotations(onset=raw.times[onsets],
                            duration=np.repeat(0.1, len(onsets)),
                            description=np.repeat(event_id,
                                                  len(onsets)))
    if add_events:
        annot_orig, ch_names_orig, _ = _load_data(raw)
        annot += annot_orig
        ch_names += [ch for ch in ch_names_orig if ch not in ch_names]
        overwrite = True
    annot.save(op.join(out_dir, basename + '_annot.fif'), overwrite=overwrite)
    with open(op.join(out_dir, basename + '_ch_names.tsv'), 'w') as fid:
        fid.write('\t'.join(ch_names))
    return annot, None if beh is None else beh['pd_parser_sample']


def _load_data(raw):
    """Load previously saved photodiode data--annot and pd channel names."""
    raw = _read_raw(raw, preload=None)
    fname = raw.filenames[0]
    basename = op.splitext(op.basename(fname))[0]
    out_dir = op.join(op.dirname(fname), basename + '_pd_parser_data')
    annot_fname = op.join(out_dir, basename + '_annot.fif')
    channels_fname = op.join(out_dir, basename + '_ch_names.tsv')
    behf = op.join(out_dir, basename + '_beh_df.tsv')
    if not op.isfile(annot_fname) or not op.isfile(channels_fname):
        raise ValueError(f'pd-parser data not found in {out_dir}, '
                         f'specifically, {annot_fname} and '
                         f'{channels_fname}. Either `parse_pd` was '
                         f'not run, or it failed or {out_dir} '
                         'may have been moved or deleted. Rerun '
                         '`parse_pd` and optionally `add_relative_events` '
                         'to fix this')
    with open(channels_fname, 'r') as fid:
        ch_names = fid.readline().rstrip().split('\t')
    beh_df = _read_tsv(behf) if op.isfile(behf) else None
    return mne.read_annotations(annot_fname), ch_names, beh_df


def _check_overwrite(raw, add_events, overwrite):
    """Check if the ``pd-parser`` data directory already exists."""
    basename = op.splitext(op.basename(raw.filenames[0]))[0]
    if op.isdir(op.join(op.dirname(
        raw.filenames[0]), basename + '_pd_parser_data')) and \
            not overwrite and not add_events:
        raise ValueError('Photodiode data directory already exists and '
                         'overwrite=False, set overwrite=True to overwrite')


[docs]def find_pd_params(raw, pd_ch_names=None, verbose=True): """Plot the data so the user can determine the right parameters. The user can adjust window size to determine max_len and horizontal line height to determine zscore. Parameters ---------- raw: str | mne Raw object The object or filepath of the time-series data file (e.g. meg, eeg, ieeg). pd_ch_names : list Names of the channel(s) containing the photodiode data. One channel is to be given for a common reference and two for a bipolar reference. If no channels are provided, the data will be plotted and the user will provide them. verbose : bool Whether to display or supress text output on the progress of the function. """ # load raw data file with the photodiode data import matplotlib as mpl mpl.rcParams['toolbar'] = 'None' import matplotlib.pyplot as plt raw = _read_raw(raw, verbose=verbose) pd, _ = _get_data(raw, pd_ch_names) fig, ax = plt.subplots(figsize=(6, 6)) fig.subplots_adjust(top=0.75, left=0.15) plot_data = dict() recs = dict() def zoom(amount): ymin, ymax = ax.get_ylim() # ymin < 0 and ymax > 0 because median subtracted ymin *= amount ymax *= amount ax.set_ylim([ymin, ymax]) fig.canvas.draw() def scale(amount): xmin, xmax = ax.get_xlim() # ymin < 0 and ymax > 0 because median subtracted xmin -= amount xmax += amount if xmin < xmax: ax.set_xlim([xmin, xmax]) fig.canvas.draw() def set_zscore(event): if event.key == 'enter': ymin, ymax = ax.get_ylim() xmin, xmax = plot_data['xlims'] pd_diff = np.diff(pd) baseline_i = np.round(0.25 * raw.info['sfreq']).astype(int) median_std = np.median( [np.std(pd_diff[i - baseline_i:i]) for i in range(baseline_i, len(pd_diff) - baseline_i, baseline_i)]) zy = plot_data['zscore'].get_ydata()[0] recs['zscore'] = zy / median_std recommendations = ( 'Recommendations\nmax_len: {:.2f}, zscore: {:.2f}\n' 'Try using these parameters for `parse_pd` and\n' 'please report to the developers if there are issues\n' ''.format(recs['max_len'], recs['zscore'])) ax.set_title(recommendations + 'You may now close the window') print(recommendations) fig.canvas.draw() elif event.key in ('up', 'down'): ymin, ymax = ax.get_ylim() delta = (ymax - ymin) / 100 zy = plot_data['zscore'].get_ydata()[0] zy_ref = plot_data['zscore_reflection'].get_ydata()[0] zy += delta if event.key == 'up' else -delta zy_ref -= delta if event.key == 'up' else -delta plot_data['zscore'].set_ydata(np.ones((pd.size)) * zy) plot_data['zscore_reflection'].set_ydata( np.ones((pd.size)) * zy_ref) fig.canvas.draw() elif event.key in ('-', '+', '='): scale(1 if event.key == '-' else -1) def set_max_len(event): if event.key == 'enter': xmin, xmax = ax.get_xlim() plot_data['xlims'] = (xmin, xmax) recs['max_len'] = (xmax - xmin) / 2 * 1.1 eid = fig.canvas.mpl_connect('key_press_event', set_zscore) fig.canvas.mpl_disconnect(eid - 1) # disconnect previous plot_data['zscore'] = ax.plot( raw.times, np.ones((pd.size)) * np.quantile(pd, 0.25), color='g')[0] plot_data['zscore_reflection'] = ax.plot( raw.times, -np.ones((pd.size)) * np.quantile(pd, 0.25), color='r')[0] ax.set_title( 'Scale\nUse the up/down arrows to set the horizontal line \n' 'half way up the photodiode onset event with the baseline \n' 'in the middle of the y-axis\n' 'Use +/- to scale the time axis to see more events\n' 'press enter when finished') fig.canvas.draw() elif event.key in ('up', 'down'): xmin, xmax = ax.get_xlim() # ymin < 0 and ymax > 0 because median subtracted xmin += 0.1 if event.key == 'up' else -0.1 xmax -= 0.1 if event.key == 'up' else -0.1 ax.set_xlim([xmin, xmax]) fig.canvas.draw() def align_keypress(event): if event.key == 'enter': eid = fig.canvas.mpl_connect('key_press_event', set_max_len) fig.canvas.mpl_disconnect(eid - 1) # disconnect previous ax.set_title( 'Window\nUse the up/down arrows to increase/decrease the\n' 'size of the window so that only one pd event is in the\n' 'window (leave room for the longest event if this isn\'t it)\n' 'press enter when finished') fig.canvas.draw() elif event.key in ('-', '+', '='): zoom(1.1 if event.key == '-' else 0.9) elif event.key in ('left', 'right'): xmin, xmax = ax.get_xlim() xmin += 0.1 if event.key == 'right' else -0.1 xmax += 0.1 if event.key == 'right' else -0.1 ax.set_xlim([xmin, xmax]) zerox = plot_data['zero'].get_xdata()[0] zerox += 0.1 if event.key == 'right' else -0.1 plot_data['zero'].set_xdata([zerox, zerox]) fig.canvas.draw() elif event.key in ('up', 'down'): ymin, ymax = ax.get_ylim() delta = (ymax - ymin) / 100 ymin += delta if event.key == 'up' else -delta ymax += delta if event.key == 'up' else -delta ax.set_ylim([ymin, ymax]) fig.canvas.draw() ax.set_title( 'Align\nUse the left/right keys to find an uncorrupted photodiode ' 'event\nand align the onset to the center of the window\n' 'use +/- to zoom the yaxis in and out (up/down to translate)\n' 'press enter when finished') ax.set_xlabel('time (s)') ax.set_ylabel('voltage') ax.plot(raw.times, pd, color='b') midpoint = raw.times[pd.size // 2] plot_data['zero'] = ax.plot( [midpoint, midpoint], [pd.min() * 10, pd.max() * 10], color='k')[0] ax.set_xlim(midpoint - 2.5, midpoint + 2.5) ax.set_ylim(pd.min() * 1.25, pd.max() * 1.25) fig.canvas.mpl_connect('key_press_event', align_keypress) fig.show()
[docs]def parse_pd(raw, pd_event_name='Fixation', beh=None, beh_key='fix_onset_time', pd_ch_names=None, exclude_shift=0.03, resync=0.075, max_len=1., zscore=10, max_flip_i=40, baseline=0.25, add_events=False, recover=False, overwrite=False, verbose=True): """Parse photodiode events. Parses photodiode events from a likely very corrupted channel using behavioral data to sync events to determine which behavioral events don't have a match and are thus corrupted and should be excluded (while ignoring events that look like photodiode events but don't match behavior) Parameters ---------- raw: str | mne Raw object The object or filepath of the time-series data file (e.g. meg, eeg, ieeg). pd_event_name: str The name of the event corresponding to the photodiode. beh: str | dict The dictionary or filepath to a tsv file with the behavioral timing. beh_key: str The key (column name) of the beh dictionary that corresponds to the events. pd_ch_names : list Names of the channel(s) containing the photodiode data. One channel is to be given for a common reference and two for a bipolar reference. If no channels are provided, the data will be plotted and the user will provide them. exclude_shift: float How many seconds different than expected from the behavior events to exclude that event. Use `find_pd_params` to determine if unsure. resync: float The number of seconds to difference allowed to still use a photodiode event to resynchronize with time-stamped events. Events with differences between `resync` and `exclude_shift` will still be used for alignment but will be excluded from the events. When `exclude_shift` is smaller than `resync`, this parameter allows event differences less than `exclude_shift` to be removed without losing an alignment which depends on resynchronizing to these events between `exclude_shift` and `resync`. This is most likely to happen when the drift between behavior events and the photodiode is large, so many events are to be excluded for being off by a small amount but still correctly correspond to a behavior event. max_len: float The longest photodiode event can be. zscore: float How large of a z-score difference to use to threshold photodiode events. Note, the must be large enough that any overshoot when returning to threshold is less than zscore compared to baseline. max_flip_i: int The maximum number of samples the photodiode event can take to transition. This shouldn't usually need to be changed unless the transition takes longer. baseline: float How much relative to the max_len to use to idenify the time before the photodiode event. This should not be changed most likely unless there is a specific reason/issue. add_events: bool Whether to add the events found from the current call of `parse_pd` to a events found previously (e.g. first parse with `pd_event_name='Fixation'` and then parse with `pd_event_name='Response'`. Note: `pd_parser.add_relative_events` will be relative to the first event added. recover: bool Whether to recover corrupted events manually. verbose: bool Whether to display or supress text output on the progress of the function. overwrite: bool Whether to overwrite existing data if it exists. Returns ------- annot: mne.Annotations The annotations with the added events. samples: list The samples corresponding to the events, with 'n/a' if no event is found. """ if baseline <= 0 or baseline > 1: raise ValueError(f'baseline must be between 0 and 1, got {baseline}') # load raw data file with the photodiode data raw = _read_raw(raw, verbose=verbose) # check if already parsed _check_overwrite(raw, add_events, overwrite) # use keyword argument if given, otherwise get the user # to enter pd names and get data pd, pd_ch_names = _get_data(raw, pd_ch_names) candidates = _find_pd_candidates( pd=pd, max_len=max_len, baseline=baseline, zscore=zscore, max_flip_i=max_flip_i, sfreq=raw.info['sfreq'], verbose=verbose)[0] # load behavioral data with which to validate event timing if beh is None: if verbose: print('No behavioral tsv file was provided so the photodiode ' 'events will be returned without validation by task ' 'timing') _save_data(raw=raw, events=candidates, event_id=pd_event_name, ch_names=pd_ch_names, overwrite=overwrite) return # if behavior is given use it to synchronize and exclude events beh_events, beh = _load_beh(beh=beh, beh_key=beh_key) beh_events *= raw.info['sfreq'] # convert to samples beh_events_adjusted, alignment, events = _find_best_alignment( beh_events=beh_events, candidates=candidates, exclude_shift=exclude_shift, resync=resync, sfreq=raw.info['sfreq'], verbose=verbose) events = _exclude_ambiguous_events( beh_events=beh_events_adjusted, alignment=alignment, events=events, ch_data=pd, candidates=candidates, exclude_shift=exclude_shift, max_len=max_len, sfreq=raw.info['sfreq'], recover=recover, zscore=zscore, verbose=verbose) return _save_data(raw=raw, events=events, event_id=pd_event_name, ch_names=pd_ch_names, beh=beh, add_events=add_events, overwrite=overwrite)
[docs]def parse_audio(raw, audio_event_name='Tone', beh=None, beh_key='tone_onset_time', audio_ch_names=None, exclude_shift=0.03, resync=0.075, max_len=0.25, zscore=None, add_events=False, recover=False, overwrite=False, verbose=True): """Parse audio events. Parses photodiode events from a likely very corrupted channel using behavioral data to sync events to determine which behavioral events don't have a match and are thus corrupted and should be excluded (while ignoring events that look like photodiode events but don't match behavior) Parameters ---------- raw: str | mne Raw object The object or filepath of the time-series data file (e.g. meg, eeg, ieeg). audio_event_name: str The name of the event corresponding to the audio. beh: str | dict The dictionary or filepath to a tsv file with the behavioral timing. beh_key: str The key (column name) of the beh dictionary that corresponds to the events. audio_ch_names: list Names of the channel(s) containing the audio data. One channel is to be given for a common reference and two for a bipolar reference. If no channels are provided, the data will be plotted and the user will provide them. exclude_shift: float How many seconds different than expected from the behavior events to exclude that event. resync: float The number of seconds to difference allowed to still use an audio event event to resynchronize with time-stamped events. See :func:`pd_parser.parse_pd` for more information. max_len: float The longest audio event can be. zscore: float How large of a z-score difference to use to threshold the correlation of the audio with the sound. If None is passed a plot will be shown to pick a reasonable zscore. 25 is a typical value that works. add_events: bool Whether to add the events found from the current call of `parse_pd` to a events found previously (e.g. first parse with `pd_event_name='Fixation'` and then parse with `pd_event_name='Response'`. Note: `pd_parser.add_relative_events` will be relative to the first event added. recover: bool Whether to recover corrupted events manually. verbose: bool Whether to display or supress text output on the progress of the function. overwrite: bool Whether to overwrite existing data if it exists. Returns ------- annot: mne.Annotations The annotations with the added events. samples: list The samples corresponding to the events, with 'n/a' if no event is found. """ if resync < exclude_shift: raise ValueError(f'`exclude_shift` ({exclude_shift}) cannot be longer ' f'than `resync` ({resync})') # load raw data file with the photodiode data raw = _read_raw(raw, verbose=verbose) # check if already parsed _check_overwrite(raw, add_events, overwrite) # use keyword argument if given, otherwise get the user # to enter pd names and get data audio, audio_ch_names = _get_data(raw, audio_ch_names) candidates = _find_audio_candidates( audio=audio, max_len=max_len, zscore=zscore, sfreq=raw.info['sfreq'], verbose=verbose) # load behavioral data with which to validate event timing if beh is None: if verbose: print('No behavioral tsv file was provided so the photodiode ' 'events will be returned without validation by task ' 'timing') _save_data(raw=raw, events=candidates, event_id=audio_event_name, ch_names=audio_ch_names, overwrite=overwrite) return # if behavior is given use it to synchronize and exclude events beh_events, beh = _load_beh(beh=beh, beh_key=beh_key) beh_events *= raw.info['sfreq'] # convert to samples beh_events_adjusted, alignment, events = _find_best_alignment( beh_events=beh_events, candidates=candidates, exclude_shift=exclude_shift, resync=resync, sfreq=raw.info['sfreq'], verbose=verbose) events = _exclude_ambiguous_events( beh_events=beh_events_adjusted, alignment=alignment, events=events, ch_data=audio, candidates=candidates, exclude_shift=exclude_shift, max_len=max_len, sfreq=raw.info['sfreq'], recover=recover, zscore=zscore, verbose=verbose) return _save_data(raw=raw, events=events, event_id=audio_event_name, ch_names=audio_ch_names, beh=beh, add_events=add_events, overwrite=overwrite)
[docs]def add_pd_off_events(raw, off_event_name='Stim Off', max_len=1., zscore=10, max_flip_i=40, baseline=0.25, verbose=True, overwrite=False): """Add events for when the photodiode deflection returns to baseline. Parameters ---------- raw: str | mne Raw object The object or filepath of the time-series data file (e.g. meg, eeg, ieeg). off_event : str If None, no event will be assigned to cessation of the photodiode deflection. If a string is provided, an event of that name will be assigned to the cessation of the deflection. max_len: float The maximum length of the photodiode events. zscore: float How large of a z-score difference to use to threshold photodiode events. max_flip_i: int The maximum number of samples the photodiode event may take to transition. baseline: float How much relative to max_len to use to idenify the time before the photodiode event. verbose : bool Whether to display or supress text output on the progress of the function. overwrite : bool Whether to overwrite existing data if it exists. Returns ------- annot: mne.Annotations The annotations with the added events. .. note:: The same parameters must be used for :func:`pd_parser.parse_pd`. """ raw = _read_raw(raw, verbose=verbose) annot, pd_ch_names, beh = _load_data(raw) max_len_i = np.round(raw.info['sfreq'] * max_len).astype(int) pd = _get_channel_data(raw, pd_ch_names) events = {samp: i for i, samp in enumerate(beh['pd_parser_sample']) if samp != 'n/a'} on_candidates, off_candidates = _find_pd_candidates( pd=pd, max_len=max_len, baseline=baseline, zscore=zscore, max_flip_i=max_flip_i, sfreq=raw.info['sfreq'], verbose=verbose) off_events = {events[onset]: offset for onset, offset in zip(on_candidates, off_candidates) if onset in events} recovered = [event_idx for event_idx in events.values() if event_idx not in off_events.keys()] if recovered: # some events found manually, recover for idx in recovered: # from half of max length, look backward and forward half beh_e = \ beh['pd_parser_sample'][idx] + max_flip_i + max_len_i // 2 event, text = _recover_event(idx, pd, beh_e, max_len / 2, zscore, max_len, raw.info['sfreq']) if not np.isnan(event): off_events[idx] = event if verbose: print(text.replace('\n', ' ')) onsets = np.array(list(off_events.values())) annot += mne.Annotations( onset=raw.times[onsets], duration=np.repeat(0.1, onsets.size), description=np.repeat(off_event_name, onsets.size)) # save modified data basename = op.splitext(op.basename(raw.filenames[0]))[0] out_dir = op.join(op.dirname(raw.filenames[0]), basename + '_pd_parser_data') annot.save(op.join(out_dir, basename + '_annot.fif'), overwrite=True) return annot
[docs]def add_relative_events(raw, beh, relative_event_keys, relative_event_names=None, overwrite=False, verbose=True): """Add events relative to those determined from the photodiode. Parameters ---------- raw: str | mne Raw object The object or filepath of the time-series data file (e.g. meg, eeg, ieeg). beh: str | dict The dictionary or filepath to a tsv file with the behavioral timing. relative_event_keys : list The names of the keys where time data is stored relative to the photodiode event relative_event_names : list The names of the events in `relative_event_keys`. verbose: bool Whether to display or supress text output on the progress of the function. overwrite: bool Whether to overwrite existing data if it exists. Returns ------- annot: mne.Annotations The annotations with the added events. """ if relative_event_names is None: if verbose: print('Using relative event keys {} as relative event ' 'names'.format(', '.join(relative_event_keys))) relative_event_names = relative_event_keys if len(relative_event_keys) != len(relative_event_names): raise ValueError( 'Mismatched length of relative event behavior ' f'file keys (column names), {len(relative_event_keys)} and ' f'names of the events {len(relative_event_names)}') raw = _read_raw(raw, verbose=verbose) relative_events = \ {name: _load_beh(beh, rel_event)[0] for name, rel_event in zip(relative_event_names, relative_event_keys)} annot, _, beh = _load_data(raw) for event_name in relative_event_names: if event_name in annot.description: if overwrite: annot.delete([i for i, desc in enumerate(annot.description) if desc == event_name]) else: raise ValueError(f'Event name {event_name} already exists in ' 'saved events and `overwrite=False`, use ' '`overwrite=True` to overwrite') events = {i: samp for i, samp in enumerate(beh['pd_parser_sample']) if samp != 'n/a'} for name, beh_events in relative_events.items(): onsets = np.array([events[i] + (beh_events[i] * raw.info['sfreq']) for i in sorted(events.keys()) if not np.isnan(beh_events[i])]).round().astype(int) annot += mne.Annotations(onset=raw.times[onsets], duration=np.repeat(0.1, onsets.size), description=np.repeat(name, onsets.size)) # save modified data basename = op.splitext(op.basename(raw.filenames[0]))[0] out_dir = op.join(op.dirname(raw.filenames[0]), basename + '_pd_parser_data') annot.save(op.join(out_dir, basename + '_annot.fif'), overwrite=True) return annot
[docs]def add_events_to_raw(raw, keep_pd_channels=False, verbose=True): """Save out a new raw file with photodiode events. Note: this function is not recommended, rather just skip it and use `save_to_bids` which doesn't modify the underlying raw data especially converting it to fif if it isn't fif already. In `save_to_bids` the raw file itself doens't contain the event information, it's only stored in the sidecar. Parameters ---------- raw: str | mne Raw object The object or filepath of the time-series data file (e.g. meg, eeg, ieeg). keep_pd_channels : bool Whether to keep the channel(s) the photodiode data was on. verbose: bool Whether to display or supress text output on the progress of the function. Returns ------- raw : mne.io.Raw The modified raw object with events. """ raw = _read_raw(raw, verbose=verbose) annot, pd_ch_names, _ = _load_data(raw) raw.set_annotations(annot) chs = [ch for ch in pd_ch_names if ch in raw.ch_names] if not keep_pd_channels and chs and not chs == raw.ch_names: raw.drop_channels(chs) return raw
[docs]def save_to_bids(bids_dir, raw, sub, task, ses=None, run=None, data_type=None, eogs=None, ecgs=None, emgs=None, verbose=True, overwrite=False): """Convert data to BIDS format with events found from the photodiode. Parameters ---------- bids_dir: str The subject directory in the bids directory where the data should be saved. raw: str | mne Raw object The object or filepath of the time-series data file (e.g. meg, eeg, ieeg). sub: str The name of the subject. task: str The name of the task. ses: str The name of the session (optional). run: str The name of the run (optional). data_type: str The type of the channels containing data, i.e. 'eeg' or 'seeg'. eogs: list | None The channels recording eye electrophysiology. ecgs: list | None The channels recording heart electrophysiology. emgs: list | None The channels recording muscle electrophysiology. beh: None | str | dict The dictionary or filepath to a tsv file with the behavioral timing. If None, the stored data is used. verbose: bool Whether to display or supress text output on the progress of the function. overwrite: bool Whether to overwrite existing data if it exists. """ import mne_bids if not op.isdir(bids_dir): os.makedirs(bids_dir) raw = _read_raw(raw, preload=False, verbose=verbose) aux_chs = list() for name, ch_list in zip(['eog', 'ecg', 'emg'], [eogs, ecgs, emgs]): if ch_list is not None: aux_chs += ch_list raw.set_channel_types({ch: name for ch in ch_list}) if data_type is not None: raw.set_channel_types({ch: data_type for ch in raw.ch_names if ch not in aux_chs}) annot, pd_channels, beh = _load_data(raw) raw.set_annotations(annot) events, event_id = mne.events_from_annotations(raw) raw.info['bads'] += [ch for ch in pd_channels if ch not in raw.info['bads']] # raw.set_channel_types({ch: 'stim' for ch in pd_channels # if ch in raw.ch_names}) bids_path = mne_bids.BIDSPath(subject=sub, session=ses, task=task, run=run, root=bids_dir) mne_bids.write_raw_bids(raw, bids_path, verbose=verbose, overwrite=overwrite) beh_path = bids_path.copy().update(datatype='beh') if not op.isdir(op.dirname(beh_path.fpath)): os.makedirs(op.dirname(beh_path.fpath)) if beh is not None: _to_tsv(str(beh_path.fpath) + '_beh.tsv', beh)
[docs]def simulate_pd_data(n_events=10, n_secs_on=1.0, amp=300., iti=6., iti_jitter=1.5, rc_decay=0.0001, prop_corrupted=0.1, sfreq=1000., seed=11, show=False): """Simulate photodiode data. Simulate data that is a square wave with a linear change in deflection `drift` amount towards zero that then over shoots and drifts back as photodiodes tend to do. Some events are also corrupted. Parameters ---------- n_events: float The number of events to simulate. n_secs_on: float | np.array The number of seconds each event is on. If a float is provided, the time is the same for each event. If an array is provided, it must be the length of the number of events, and it determines the length of each event respectively. amp: float The amplitude of the photodiode in standard deviations above baseline. iti: float The interval in between events. iti_jitter: float The jitter displacing the events from exactly `iti` distance away from each other. rc_decay: float The factor controlling how much the photodiode decays back to baseline over time with no external simulus (0. == perfect square wave). sfreq: float The sampling frequency of the data. show: bool Whether to plot the data. Returns ------- raw: mne.io.Raw The raw object containing the photodiode data beh: dict A dictionary with keys (columns names): `trial` : int The index of the event. `time` : float The time that both the corrupted and uncorrupted events occurred in seconds. events: np.array The uncorrupted events where the first column is the time stamp, the second column is unused (zero) and the third column is the event identifier. corrupted_indices: np.array The indices of the events which were corrupted in the simulation. """ if isinstance(n_secs_on, list) and len(n_secs_on) != n_events: raise ValueError('If a list of `n_secs_on` is provided, it must ' f'match the number of events, {n_events}, got ' f'{len(n_secs_on)}') assert rc_decay >= 0 and iti > 0 and iti_jitter > 0 # n_secs on as list is okay, just make an array if isinstance(n_secs_on, list): n_secs_on = np.array(n_secs_on) # convert events to samples if isinstance(n_secs_on, np.ndarray): n_samp_on = np.round(n_secs_on * sfreq).astype(int) else: n_samp_on = np.repeat(np.round(n_secs_on * sfreq).astype(int), n_events) iti_samp = np.round(iti * sfreq).astype(int) iti_jitter_samp = np.round(iti_jitter * sfreq).astype(int) if iti_samp - iti_jitter_samp <= n_samp_on.min(): raise ValueError( f'Events will run into each other because `iti` ({iti})' f' - `iti_jitter` ({iti_jitter}) is less the than minimum' f' `n_secs_on` ({n_samp_on.min() / sfreq})') # seed random number generator np.random.seed(seed) # make events events = np.zeros((n_events, 3), dtype=int) events[:, 0] = iti_samp + np.cumsum(np.round( (np.random.random(n_events) * iti_jitter + iti) * sfreq)).astype(int) events[:, 2] = 1 # make pink noise n_secs_on_mean = n_secs_on if isinstance(n_secs_on, (float, int)) else \ np.array(n_secs_on).max() n_points = events[:, 0].max() + int(10 * n_secs_on_mean * sfreq) n_points += n_points % 2 # must be even x = np.random.randn(n_points // 2) + np.random.randn(n_points // 2) * 1j x /= np.sqrt(np.arange(1, x.size + 1)) pd_data = np.fft.irfft(x).real pd_data /= pd_data.std() # add photodiode square waves to pink noise flip_data = np.zeros((pd_data.shape)) for i in range(n_events): event = events[i, 0] n_on = n_samp_on[i] next_event = events[i + 1, 0] - n_on if i < n_events - 1 else \ pd_data.size - n_on flip_data[event] += amp for i in range(1, n_on): # on decay flip_data[event + i] += flip_data[event + i - 1] * (1 - rc_decay) flip_data[event + n_on - 1] -= amp * ((1 - rc_decay)**(n_on / 2)) max_i = min([next_event - event, int(i / rc_decay), int(10 * sfreq)]) for i in range(max_i): flip_data[event + n_on + i] += \ flip_data[event + n_on + i - 1] * (1 - rc_decay) pd_data += flip_data # corrupt some events n_events_corrupted = np.round(n_events * prop_corrupted).astype(int) corrupted_indices = np.random.choice(range(n_events), n_events_corrupted, replace=False) for i in corrupted_indices: n_on = n_samp_on[i] samp_range = range(events[i, 0] - iti_jitter_samp, events[i, 0] + n_on + iti_jitter_samp) # about 2% of times corrupted ts_cor = int(len(samp_range) * np.random.random() * 0.02 + 0.005) for ts in np.random.choice(samp_range, ts_cor, replace=False): # disrupt 1 / 5 of on time, 5 times amplitude pd_data[ts - n_on // 10: ts + n_on // 10] += \ (np.random.random() - 0.5) * 5 * amp - amp beh = dict(trial=np.arange(n_events), time=events[:, 0].astype(float) / sfreq) events = np.delete(events, corrupted_indices, axis=0) # plot if show if show: import matplotlib.pyplot as plt fig, ax = plt.subplots() ax.plot(np.linspace(0, sfreq * n_points, pd_data.size), pd_data) ax.set_xlabel('time (s)') ax.set_ylabel('amp') ax.set_title('Photodiode Data') fig.show() # create mne.io.Raw object info = mne.create_info(['pd'], sfreq, ['stim']) raw = mne.io.RawArray(pd_data[np.newaxis], info) return raw, beh, events, corrupted_indices