Skip to content

Commit

Permalink
Merge pull request #5137 from larsoner/pick
Browse files Browse the repository at this point in the history
MRG: Refactor EDF channel picking
  • Loading branch information
cbrnr authored Apr 18, 2018
2 parents 3a109ba + e9abba8 commit 1dc1502
Showing 1 changed file with 66 additions and 81 deletions.
147 changes: 66 additions & 81 deletions mne/io/edf/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,15 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
# and for efficiency we want to be able to combine mult and cals
# so proj support will have to wait until this is resolved
raise NotImplementedError('mult is not supported yet')
exclude = self._raw_extras[fi]['exclude']
sel = np.arange(self.info['nchan'])[idx]
n_samps = self._raw_extras[fi]['n_samps']
buf_len = int(self._raw_extras[fi]['max_samp'])
sfreq = self.info['sfreq']
dtype = self._raw_extras[fi]['dtype_np']
dtype_byte = self._raw_extras[fi]['dtype_byte']
data_offset = self._raw_extras[fi]['data_offset']
stim_channel = self._raw_extras[fi]['stim_channel']
tal_channels = self._raw_extras[fi]['tal_channel']
tal_sel = self._raw_extras[fi]['tal_sel']
orig_sel = self._raw_extras[fi]['sel']
annot = self._raw_extras[fi]['annot']
annotmap = self._raw_extras[fi]['annotmap']
subtype = self._raw_extras[fi]['subtype']
Expand All @@ -202,6 +201,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
# gain constructor
physical_range = np.array([ch['range'] for ch in self.info['chs']])
cal = np.array([ch['cal'] for ch in self.info['chs']])
assert cal.shape == (len(self.info['chs']),)
cal = np.atleast_2d(physical_range / cal) # physical / digital
gains = np.atleast_2d(self._raw_extras[fi]['units'])

Expand All @@ -210,21 +210,8 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
digital_min = self._raw_extras[fi]['digital_min']

offsets = np.atleast_2d(physical_min - (digital_min * cal)).T
if tal_channels is not None:
for tal_channel in tal_channels:
offsets[tal_channel] = 0

# This is needed to rearrange the indices to correspond to correct
# chunks on the file if excluded channels exist:
selection = sel.copy()
idx_map = np.argsort(selection)
for ei in sorted(exclude):
for ii, si in enumerate(sorted(selection)):
if si >= ei:
selection[idx_map[ii]] += 1
if tal_channels is not None:
tal_channels = [tc + 1 if tc >= ei else tc for tc in
sorted(tal_channels)]
offsets[np.in1d(orig_sel, tal_sel)] = 0
this_sel = orig_sel[idx]

# We could read this one EDF block at a time, which would be this:
ch_offsets = np.cumsum(np.concatenate([[0], n_samps]))
Expand All @@ -245,7 +232,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
# Read and reshape to (n_chunks_read, ch0_ch1_ch2_ch3...)
many_chunk = _read_ch(fid, subtype, ch_offsets[-1] * n_read,
dtype_byte, dtype).reshape(n_read, -1)
for ii, ci in enumerate(selection):
for ii, ci in enumerate(this_sel):
# This now has size (n_chunks_read, n_samp[ci])
ch_data = many_chunk[:, ch_offsets[ci]:ch_offsets[ci + 1]]
r_sidx = r_lims[ai][0]
Expand All @@ -254,7 +241,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
d_sidx = d_lims[ai][0]
d_eidx = d_lims[ai + n_read - 1][1]
if n_samps[ci] != buf_len:
if tal_channels is not None and ci in tal_channels:
if ci in tal_sel:
# don't resample tal_channels, zero-pad instead.
if n_samps[ci] < buf_len:
z = np.zeros((len(ch_data),
Expand All @@ -264,7 +251,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
ch_data = ch_data[:, :buf_len]
elif ci == stim_channel:
if (annot and annotmap or stim_data is not None or
tal_channels is not None):
len(tal_sel) > 0):
# don't resample, it gets overwritten later
ch_data = np.zeros((len(ch_data), buf_len))
else:
Expand All @@ -285,21 +272,22 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
assert ch_data.shape == (len(ch_data), buf_len)
data[ii, d_sidx:d_eidx] = ch_data.ravel()[r_sidx:r_eidx]

data *= cal.T[sel] # scale
data += offsets[sel] # offset
data *= gains.T[sel] # apply units gain last
data *= cal.T[idx] # scale
data += offsets[idx] # offset
data *= gains.T[idx] # apply units gain last

# only try to read the stim channel if it's not None and it's
# actually one of the requested channels
idx = np.arange(self.info['nchan'])[idx] # slice -> ints
read_size = len(r_lims) * buf_len
if stim_channel is not None and (sel == stim_channel).sum() > 0:
stim_channel_idx = np.where(sel == stim_channel)[0]
stim_channel_idx = np.where(idx == stim_channel)[0]
if stim_channel is not None and len(stim_channel_idx) > 0:
if annot and annotmap:
evts = _read_annot(annot, annotmap, sfreq,
self._last_samps[fi])
data[stim_channel_idx, :] = evts[start:stop + 1]
elif tal_channels is not None:
tal_channel_idx = np.intersect1d(sel, tal_channels)
elif len(tal_sel) > 0:
tal_channel_idx = np.in1d(orig_sel[idx], tal_sel)
evts = _parse_tal_channel(np.atleast_2d(data[tal_channel_idx]))
self._raw_extras[fi]['events'] = evts

Expand Down Expand Up @@ -418,9 +406,9 @@ def _get_info(fname, stim_channel, annot, annotmap, eog, misc, exclude,
raise NotImplementedError(
'Only GDF, EDF, and BDF files are supported, got %s.' % ext)

include = edf_info['include']
sel = edf_info['sel']
ch_names = edf_info['ch_names']
n_samps = edf_info['n_samps'][include]
n_samps = edf_info['n_samps'][sel]
nchan = edf_info['nchan']
physical_ranges = edf_info['physical_max'] - edf_info['physical_min']
cals = edf_info['digital_max'] - edf_info['digital_min']
Expand All @@ -432,7 +420,7 @@ def _get_info(fname, stim_channel, annot, annotmap, eog, misc, exclude,
if 'stim_data' in edf_info and stim_channel == 'auto': # For GDF events.
cals = np.append(cals, 1)
if stim_channel is not None:
stim_channel = _check_stim_channel(stim_channel, ch_names, include)
stim_channel = _check_stim_channel(stim_channel, ch_names, sel)

# Annotations
tal_ch_name = 'EDF Annotations'
Expand All @@ -446,12 +434,10 @@ def _get_info(fname, stim_channel, annot, annotmap, eog, misc, exclude,
% tal_ch_name)
for idx, tal_ch in enumerate(tal_chs, 1):
ch_names[tal_ch] = ch_names[tal_ch] + '-%s' % idx
tal_channel = tal_chs
else:
tal_channel = None
edf_info['tal_channel'] = tal_channel
tal_sel = edf_info['sel'][tal_chs]
edf_info['tal_sel'] = tal_sel

if tal_channel is not None and stim_channel is not None and not preload:
if len(tal_sel) > 0 and stim_channel is not None and not preload:
raise RuntimeError('%s' % ('EDF+ Annotations (TAL) channel needs to be'
' parsed completely on loading.'
' You must set preload parameter to True.'))
Expand Down Expand Up @@ -497,7 +483,7 @@ def _get_info(fname, stim_channel, annot, annotmap, eog, misc, exclude,
edf_info['units'][idx] = 1
if isinstance(stim_channel, str):
stim_channel = idx
if tal_channel is not None and idx in tal_channel:
if edf_info['sel'][idx] in tal_sel:
chan_info['range'] = 1
chan_info['cal'] = 1
chan_info['coil_type'] = FIFF.FIFFV_COIL_NONE
Expand Down Expand Up @@ -569,7 +555,7 @@ def _get_info(fname, stim_channel, annot, annotmap, eog, misc, exclude,

# These are the conditions under which a stim channel will be interpolated
if stim_channel is not None and not (annot and annotmap) and \
tal_channel is None and n_samps[stim_channel] != int(max_samp):
len(tal_sel) == 0 and n_samps[stim_channel] != int(max_samp):
warn('Interpolating stim channel. Events may jitter.')
info._update_redundant()

Expand Down Expand Up @@ -626,29 +612,28 @@ def _read_edf_header(fname, annot, annotmap, exclude):
channels = list(range(nchan))
ch_names = [fid.read(16).strip().decode() for ch in channels]
exclude = _find_exclude_idx(ch_names, exclude)
sel = np.setdiff1d(np.arange(len(ch_names)), exclude)
for ch in channels:
fid.read(80) # transducer
units = [fid.read(8).strip().decode() for ch in channels]
edf_info['units'] = list()
include = list()
for i, unit in enumerate(units):
if i in exclude:
continue
if unit == 'uV':
edf_info['units'].append(1e-6)
else:
edf_info['units'].append(1)
include.append(i)
ch_names = [ch_names[idx] for idx in include]
ch_names = [ch_names[idx] for idx in sel]

physical_min = np.array([float(fid.read(8).decode())
for ch in channels])[include]
for ch in channels])[sel]
physical_max = np.array([float(fid.read(8).decode())
for ch in channels])[include]
for ch in channels])[sel]
digital_min = np.array([float(fid.read(8).decode())
for ch in channels])[include]
for ch in channels])[sel]
digital_max = np.array([float(fid.read(8).decode())
for ch in channels])[include]
for ch in channels])[sel]
prefiltering = [fid.read(80).decode().strip(' \x00')
for ch in channels][:-1]
highpass = np.ravel([re.findall(r'HP:\s+(\w+)', filt)
Expand All @@ -663,8 +648,8 @@ def _read_edf_header(fname, annot, annotmap, exclude):
# Populate edf_info
edf_info.update(
ch_names=ch_names, data_offset=header_nbytes,
digital_max=digital_max, digital_min=digital_min, exclude=exclude,
highpass=highpass, include=include, lowpass=lowpass,
digital_max=digital_max, digital_min=digital_min,
highpass=highpass, sel=sel, lowpass=lowpass,
meas_date=calendar.timegm(date.utctimetuple()),
n_records=n_records, n_samps=n_samps, nchan=nchan,
subject_info=patient, physical_max=physical_max,
Expand Down Expand Up @@ -766,15 +751,15 @@ def _read_gdf_header(fname, stim_channel, exclude):
for ch in channels]

exclude = _find_exclude_idx(ch_names, exclude)
include = list()
sel = list()
for i, unit in enumerate(units):
if unit[:2] == 'uV':
units[i] = 1e-6
else:
units[i] = 1
include.append(i)
sel.append(i)

ch_names = [ch_names[idx] for idx in include]
ch_names = [ch_names[idx] for idx in sel]
physical_min = np.fromfile(fid, np.float64, len(channels))
physical_max = np.fromfile(fid, np.float64, len(channels))
digital_min = np.fromfile(fid, np.int64, len(channels))
Expand Down Expand Up @@ -803,7 +788,7 @@ def _read_gdf_header(fname, stim_channel, exclude):
digital_max=digital_max,
dtype_byte=[gdftype_byte[t] for t in dtype],
dtype_np=[gdftype_np[t] for t in dtype], exclude=exclude,
highpass=highpass, include=include, lowpass=lowpass,
highpass=highpass, sel=sel, lowpass=lowpass,
meas_date=calendar.timegm(date.utctimetuple()),
meas_id=meas_id, n_records=n_records, n_samps=n_samps,
nchan=nchan, subject_info=patient, physical_max=physical_max,
Expand Down Expand Up @@ -958,7 +943,7 @@ def _read_gdf_header(fname, stim_channel, exclude):
""" # noqa
units = np.fromfile(fid, np.uint16, len(channels)).tolist()
unitcodes = np.array(units[:])
include = list()
sel = list()
for i, unit in enumerate(units):
if unit == 4275: # microvolts
units[i] = 1e-6
Expand All @@ -971,9 +956,9 @@ def _read_gdf_header(fname, stim_channel, exclude):
'(assuming dimensionless). Please contact the '
'MNE-Python developers for support.' % i)
units[i] = 1
include.append(i)
sel.append(i)

ch_names = [ch_names[idx] for idx in include]
ch_names = [ch_names[idx] for idx in sel]
physical_min = np.fromfile(fid, np.float64, len(channels))
physical_max = np.fromfile(fid, np.float64, len(channels))
digital_min = np.fromfile(fid, np.float64, len(channels))
Expand Down Expand Up @@ -1025,7 +1010,7 @@ def _read_gdf_header(fname, stim_channel, exclude):
dtype_byte=[gdftype_byte[t] for t in dtype],
dtype_np=[gdftype_np[t] for t in dtype],
digital_min=digital_min, digital_max=digital_max,
exclude=exclude, gnd=gnd, highpass=highpass, include=include,
exclude=exclude, gnd=gnd, highpass=highpass, sel=sel,
impedance=impedance, lowpass=lowpass,
meas_date=calendar.timegm(date.utctimetuple()),
meas_id=meas_id, n_records=n_records, n_samps=n_samps,
Expand Down Expand Up @@ -1071,29 +1056,29 @@ def _read_gdf_header(fname, stim_channel, exclude):
if stim_channel == 'auto' and edf_info['nchan'] not in exclude:
if len(events) == 0:
warn('No events found. Cannot construct a stimulus channel.')
edf_info['events'] = list()
return edf_info
edf_info['include'].append(edf_info['nchan'])
edf_info['n_samps'] = np.append(edf_info['n_samps'], 0)
edf_info['units'] = np.append(edf_info['units'], 1)
edf_info['ch_names'] += [u'STI 014']
edf_info['physical_min'] = np.append(edf_info['physical_min'], 0)
edf_info['digital_min'] = np.append(edf_info['digital_min'], 0)
vmax = np.max(events[2])
edf_info['physical_max'] = np.append(edf_info['physical_max'], vmax)
edf_info['digital_max'] = np.append(edf_info['digital_max'], vmax)

data = np.zeros(np.max(n_samps * n_records))
warn_overlap = False
for samp, id, dur in zip(events[1], events[2], events[4]):
if np.sum(data[samp:samp + dur]) > 0:
warn_overlap = True # Warn only once.
data[samp:samp + dur] += id
if warn_overlap:
warn('Overlapping events detected. Use find_edf_events for the '
'original events.')
edf_info['stim_data'] = data
edf_info['events'] = events
else:
edf_info['sel'].append(edf_info['nchan'])
edf_info['n_samps'] = np.append(edf_info['n_samps'], 0)
edf_info['units'] = np.append(edf_info['units'], 1)
edf_info['ch_names'] += [u'STI 014']
edf_info['physical_min'] = np.append(edf_info['physical_min'], 0)
edf_info['digital_min'] = np.append(edf_info['digital_min'], 0)
vmax = np.max(events[2])
edf_info['physical_max'] = np.append(edf_info['physical_max'],
vmax)
edf_info['digital_max'] = np.append(edf_info['digital_max'], vmax)

data = np.zeros(np.max(n_samps * n_records))
warn_overlap = False
for samp, id, dur in zip(events[1], events[2], events[4]):
if np.sum(data[samp:samp + dur]) > 0:
warn_overlap = True # Warn only once.
data[samp:samp + dur] += id
if warn_overlap:
warn('Overlapping events detected. Use find_edf_events for '
'the original events.')
edf_info['stim_data'] = data
edf_info.update(events=events, sel=np.arange(len(edf_info['ch_names'])))
return edf_info


Expand Down Expand Up @@ -1137,14 +1122,14 @@ def _read_annot(annot, annotmap, sfreq, data_length):
return stim_channel


def _check_stim_channel(stim_channel, ch_names, include):
def _check_stim_channel(stim_channel, ch_names, sel):
"""Check that the stimulus channel exists in the current datafile."""
if isinstance(stim_channel, str):
if stim_channel == 'auto':
if 'auto' in ch_names:
raise ValueError("'auto' exists as a channel name. Change "
"stim_channel parameter!")
stim_channel = len(include) - 1
stim_channel = len(sel) - 1
elif stim_channel not in ch_names:
err = 'Could not find a channel named "{}" in datafile.' \
.format(stim_channel)
Expand All @@ -1156,7 +1141,7 @@ def _check_stim_channel(stim_channel, ch_names, include):
raise ValueError(err)
else:
if stim_channel == -1:
stim_channel = len(include) - 1
stim_channel = len(sel) - 1
elif stim_channel > len(ch_names):
raise ValueError('Requested stim_channel index ({}) exceeds total '
'number of channels in datafile ({})'
Expand Down

0 comments on commit 1dc1502

Please sign in to comment.