Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Backports #7890

Merged
merged 2 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mne/io/ctf/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _conv_comp(comp, first, last, chs):
col_names = comp[first]['sensors'][:n_col]
row_names = [comp[p]['sensor_name'] for p in range(first, last + 1)]
mask = np.in1d(col_names, ch_names) # missing channels excluded
col_names = np.array(col_names)[mask]
col_names = np.array(col_names)[mask].tolist()
n_col = len(col_names)
n_row = len(row_names)
ccomp = dict(ctfkind=np.array([comp[first]['coeff_type']]),
Expand Down
1 change: 0 additions & 1 deletion mne/io/egi/egi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def my_fread(*x, **y):
for event in range(info['n_events']):
event_codes = ''.join(np.fromfile(fid, 'S1', 4).astype('U1'))
info['event_codes'].append(event_codes)
info['event_codes'] = np.array(info['event_codes'])
else:
raise NotImplementedError('Only continuous files are supported')
info['unsegmented'] = unsegmented
Expand Down
2 changes: 1 addition & 1 deletion mne/io/egi/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _read_events(input_fname, info):
mff_events, event_codes = _read_mff_events(input_fname, info['sfreq'],
info['n_samples'])
info['n_events'] = len(event_codes)
info['event_codes'] = np.asarray(event_codes).astype('<U4')
info['event_codes'] = event_codes
events = np.zeros([info['n_events'],
info['n_segments'] * info['n_samples']])
for n, event in enumerate(event_codes):
Expand Down
50 changes: 40 additions & 10 deletions mne/io/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from .open import fiff_open
from .tree import dir_tree_find
from .tag import read_tag, find_tag, _coord_dict
from .proj import _read_proj, _write_proj, _uniquify_projs, _normalize_proj
from .proj import (_read_proj, _write_proj, _uniquify_projs, _normalize_proj,
Projection)
from .ctf_comp import read_ctf_comp, write_ctf_comp
from .write import (start_file, end_file, start_block, end_block,
write_string, write_dig_points, write_float, write_int,
Expand All @@ -31,7 +32,7 @@
from ..transforms import invert_transform, Transform, _coord_frame_name
from ..utils import (logger, verbose, warn, object_diff, _validate_type,
_stamp_to_dt, _dt_to_stamp, _pl, _is_numeric)
from ._digitization import (_format_dig_points, _dig_kind_proper,
from ._digitization import (_format_dig_points, _dig_kind_proper, DigPoint,
_dig_kind_rev, _dig_kind_ints, _read_dig_fif)
from ._digitization import write_dig as _dig_write_dig
from .compensator import get_current_comp
Expand Down Expand Up @@ -190,6 +191,16 @@ def set_montage(self, montage, raise_if_subset=DEPRECATED_PARAM,
return self


def _format_trans(obj, key):
try:
t = obj[key]
except KeyError:
pass
else:
if t is not None:
obj[key] = Transform(t['from'], t['to'], t['trans'])


# XXX Eventually this should be de-duplicated with the MNE-MATLAB stuff...
class Info(dict, MontageMixin):
"""Measurement information.
Expand Down Expand Up @@ -528,9 +539,24 @@ class Info(dict, MontageMixin):

def __init__(self, *args, **kwargs):
super(Info, self).__init__(*args, **kwargs)
t = self.get('dev_head_t', None)
if t is not None and not isinstance(t, Transform):
self['dev_head_t'] = Transform(t['from'], t['to'], t['trans'])
# Deal with h5io writing things as dict
for key in ('dev_head_t', 'ctf_head_t', 'dev_ctf_t'):
_format_trans(self, key)
for res in self.get('hpi_results', []):
_format_trans(res, 'coord_trans')
if self.get('dig', None) is not None and len(self['dig']) and \
not isinstance(self['dig'][0], DigPoint):
self['dig'] = _format_dig_points(self['dig'])
for pi, proj in enumerate(self.get('projs', [])):
if not isinstance(proj, Projection):
self['projs'][pi] = Projection(proj)
# Old files could have meas_date as tuple instead of datetime
try:
meas_date = self['meas_date']
except KeyError:
pass
else:
self['meas_date'] = _ensure_meas_date_none_or_dt(meas_date)

def copy(self):
"""Copy the instance.
Expand Down Expand Up @@ -1362,11 +1388,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
info['proj_name'] = proj_name
if meas_date is None:
meas_date = (info['meas_id']['secs'], info['meas_id']['usecs'])
if np.array_equal(meas_date, DATE_NONE):
meas_date = None
else:
meas_date = _stamp_to_dt(meas_date)
info['meas_date'] = meas_date
info['meas_date'] = _ensure_meas_date_none_or_dt(meas_date)
info['utc_offset'] = utc_offset

info['sfreq'] = sfreq
Expand Down Expand Up @@ -1408,6 +1430,14 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
return info, meas


def _ensure_meas_date_none_or_dt(meas_date):
if meas_date is None or np.array_equal(meas_date, DATE_NONE):
meas_date = None
elif not isinstance(meas_date, datetime.datetime):
meas_date = _stamp_to_dt(meas_date)
return meas_date


def _check_dates(info, prepend_error=''):
"""Check dates before writing as fif files.

Expand Down
13 changes: 11 additions & 2 deletions mne/io/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from mne import concatenate_raws, create_info, Annotations
from mne.datasets import testing
from mne.io import read_raw_fif, RawArray, BaseRaw
from mne.utils import _TempDir, catch_logging, _raw_annot, _stamp_to_dt
from mne.externals.h5io import read_hdf5, write_hdf5
from mne.io import read_raw_fif, RawArray, BaseRaw, Info
from mne.utils import (_TempDir, catch_logging, _raw_annot, _stamp_to_dt,
object_diff, check_version)
from mne.io.meas_info import _get_valid_units
from mne.io._digitization import DigPoint

Expand Down Expand Up @@ -169,6 +171,13 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True,
for ch_name, unit in raw._orig_units.items():
assert unit.lower() in valid_units_lower, ch_name

# Make sure that writing info to h5 format
# (all fields should be compatible)
if check_version('h5py'):
fname_h5 = op.join(tempdir, 'info.h5')
write_hdf5(fname_h5, raw.info)
new_info = Info(read_hdf5(fname_h5))
assert object_diff(new_info, raw.info) == ''
return raw


Expand Down
6 changes: 6 additions & 0 deletions mne/tests/test_bem.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,28 +247,34 @@ def test_fit_sphere_to_headshape():
# Top of the head (extra point)
{'coord_frame': FIFF.FIFFV_COORD_HEAD,
'kind': FIFF.FIFFV_POINT_EXTRA,
'ident': 0,
'r': np.array([0.0, 0.0, 1.0])},

# EEG points
# Fz
{'coord_frame': FIFF.FIFFV_COORD_HEAD,
'kind': FIFF.FIFFV_POINT_EEG,
'ident': 0,
'r': np.array([0, .72, .69])},
# F3
{'coord_frame': FIFF.FIFFV_COORD_HEAD,
'kind': FIFF.FIFFV_POINT_EEG,
'ident': 1,
'r': np.array([-.55, .67, .50])},
# F4
{'coord_frame': FIFF.FIFFV_COORD_HEAD,
'kind': FIFF.FIFFV_POINT_EEG,
'ident': 2,
'r': np.array([.55, .67, .50])},
# Cz
{'coord_frame': FIFF.FIFFV_COORD_HEAD,
'kind': FIFF.FIFFV_POINT_EEG,
'ident': 3,
'r': np.array([0.0, 0.0, 1.0])},
# Pz
{'coord_frame': FIFF.FIFFV_COORD_HEAD,
'kind': FIFF.FIFFV_POINT_EEG,
'ident': 4,
'r': np.array([0, -.72, .69])},
]
for d in dig:
Expand Down
6 changes: 5 additions & 1 deletion mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def test_io():

info = mne.create_info(['MEG 001', 'MEG 002', 'MEG 003'], 1000.,
['mag', 'mag', 'mag'])
info['meas_date'] = datetime.datetime(year=2020, month=2, day=5)
info['meas_date'] = datetime.datetime(year=2020, month=2, day=5,
tzinfo=datetime.timezone.utc)
info._check_consistency()
tfr = AverageTFR(info, data=data, times=times, freqs=freqs,
nave=20, comment='test', method='crazy-tfr')
tfr.save(fname)
Expand All @@ -426,6 +428,8 @@ def test_io():
pytest.raises(IOError, tfr.save, fname)

tfr.comment = None
# test old meas_date
info['meas_date'] = (1, 2)
tfr.save(fname, overwrite=True)
assert_equal(read_tfrs(fname, condition=0).comment, tfr.comment)
tfr.comment = 'test-A'
Expand Down
1 change: 1 addition & 0 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,7 @@ def read_tfrs(fname, condition=None):
tfr_data = read_hdf5(fname, title='mnepython', slash='replace')
for k, tfr in tfr_data:
tfr['info'] = Info(tfr['info'])
tfr['info']._check_consistency()
if 'metadata' in tfr:
tfr['metadata'] = _prepare_read_metadata(tfr['metadata'])
is_average = 'nave' in tfr
Expand Down