Skip to content

Commit

Permalink
Merge pull request #236 from hanhou/master
Browse files Browse the repository at this point in the history
Foraging Oct21 before sfn
  • Loading branch information
hanhou authored Oct 28, 2021
2 parents a8780a6 + 16ac28c commit 7b0e075
Show file tree
Hide file tree
Showing 16 changed files with 2,051 additions and 495 deletions.
30 changes: 30 additions & 0 deletions pipeline/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,36 @@ class Period(dj.Lookup):
('delay', 'delay', 0, 'go', 0),
('response', 'go', 0, 'go', 1.2)]


@schema
class PeriodForaging(dj.Lookup):
# Totally different from delay-response task, so create a new table
definition = """ # time period between any two TrialEvent (eg the delay period is between delay and go)
period: varchar(30)
---
-> TrialEventType.proj(start_event_type='trial_event_type')
start_trial_shift=0: smallint # see psth_foraging.AlignType.trial_offset
start_time_shift: float # (s) any time-shift amount with respect to the start_event_type
-> TrialEventType.proj(end_event_type='trial_event_type')
end_trial_shift=0: smallint # see psth_foraging.AlignType.trial_offset
end_time_shift: float # (s) any time-shift amount with respect to the end_event_type
"""

contents = [
('before_2', 'bitcodestart', 0, -2, 'bitcodestart', 0, 0), # = iti_last_2 of the *last* trial
('delay', 'zaberready', 0, 0, 'go', 0, 0),
('go_to_end', 'go', 0, 0, 'trialend', 0, 0),
('go_1.2', 'go', 0, 0, 'go', 0, 1.2), # Is ths reasonable?

('iti_all', 'trialend', 0, 0, 'bitcodestart', 1, 0), # To be precise, should use 'zaberstart'??
('iti_first_2', 'trialend', 0, 0, 'trialend', 0, 2),
('iti_last_2', 'bitcodestart', 1, -2, 'bitcodestart', 1, 0),

('delay_bitcode', 'bitcodestart', 0, 0.146, 'go', 0, 0),
# TODO ('delay_effective') # If early lick, from the last lick before go cue to go cue
]


# ============================= PROJECTS ==================================================


Expand Down
245 changes: 214 additions & 31 deletions pipeline/foraging_model.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pipeline/ingest/behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,9 @@ def key_source(self):

def populate(self, *args, **kwargs):
# Load project info (just once)
log.info('------ Loading pybpod project -------')
self.projects = self.get_bpod_projects()
log.info('------------ Done! ----------------')

# 'populate' which won't require upstream tables
# 'reserve_jobs' not parallel, overloaded to mean "don't exit on error"
Expand Down
43 changes: 29 additions & 14 deletions pipeline/ingest/ephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,28 +147,41 @@ def _load(self, data, probe, npx_meta, rigpath, probe_insertion_exists=False, in
# when needed for different-length recordings
# - otherwise, use trial number correction array (bf['trialNum'])

sync_behav_start = np.where(sync_behav == sync_ephys[0])[0][0]
sync_behav_range = sync_behav[sync_behav_start:][:len(sync_ephys)]

if not np.all(np.equal(sync_ephys, sync_behav_range)):
# First, find the first Ephys trial in behavior data (sometimes ephys is not started in time)
sync_behav_start = np.where(sync_behav == sync_ephys[0])[0][0]

# Ephys trial will never start BEFORE behavioral trial, so the above line is always correct.
# But due to pybpod bug, sometimes the last behavioral trial is invalid, making Ephys even LONGER than behavior.
# Therefore, we must find out both sync_behav_range AND sync_ephys_range (otherwise the next `if not np.all` could fail)
sync_behav_range = sync_behav[sync_behav_start:][:len(sync_ephys)] # Note that this will not generate error even if len(sync_ephys) > len(behavior)
shared_trial_num = len(sync_behav_range)
sync_ephys_range = sync_ephys[:shared_trial_num] # Now they must have the same length

if not np.all(np.equal(sync_ephys_range, sync_behav_range)):
if trial_fix is not None:
log.info('ephys/bitcode trial mismatch - fix using "trialNum"')
trials = trial_fix
else:
raise Exception('Bitcode Mismatch - Fix with "trialNum" not available')
else:
# TODO: recheck the logic here!
if len(sync_behav) < len(sync_ephys):
start_behav = np.where(sync_behav[0] == sync_ephys)[0][0]
start_behav = np.where(sync_behav[0] == sync_ephys)[0][0] # TODO: This is problematic because ephys never leads behavior, otherwise the logic above is wrong
elif len(sync_behav) > len(sync_ephys):
start_behav = - np.where(sync_ephys[0] == sync_behav)[0][0]
else:
start_behav = 0
trial_indices = np.arange(len(sync_behav_range)) - start_behav
trial_indices = np.arange(shared_trial_num) - start_behav

# mapping to the behav-trial numbering
# "trials" here is just the 0-based indices of the behavioral trials
behav_trials = (experiment.SessionTrial & skey).fetch('trial', order_by='trial')
trials = behav_trials[trial_indices]

# TODO: this is a workaround to deal with the case where ephys stops later than behavior
# but with the assumption that ephys will NEVER start earlier than behavior
trial_start = trial_start[:shared_trial_num] # Truncate ephys 'trial_start' at the tail
# And also truncate the ingestion of digital markers (see immediate below)

assert len(trial_start) == len(trials), 'Unequal number of bitcode "trial_start" ({}) and ingested behavior trials ({})'.format(len(trial_start), len(trials))

Expand All @@ -177,7 +190,7 @@ def _load(self, data, probe, npx_meta, rigpath, probe_insertion_exists=False, in
# But this is critical for the foraging task, because we need global session-wise times to plot flexibly-aligned PSTHs (in particular, spikes during ITI).
# However, we CANNOT get this from behavior pybpod .csv files (PC-TIME is inaccurate, whereas BPOD-TIME is trial-based)
if probe == 1 and 'digMarkerPerTrial' in bitcode_raw: # Only import once for one session
insert_ephys_events(skey, bitcode_raw)
insert_ephys_events(skey, bitcode_raw, shared_trial_num)

# trialize the spikes & subtract go cue
t, trial_spikes, trial_units = 0, [], []
Expand All @@ -189,7 +202,7 @@ def _load(self, data, probe, npx_meta, rigpath, probe_insertion_exists=False, in
s0, s1 = trial_start[t], trial_start[t+1]

trial_idx = np.where((spikes > s0) & (spikes < s1))
spike_trial_num[trial_idx] = trials[t]
spike_trial_num[trial_idx] = trials[t] # Assign (behavioral) trial number to each spike

trial_spikes.append(spikes[trial_idx] - trial_go[t])
trial_units.append(units[trial_idx])
Expand Down Expand Up @@ -1105,7 +1118,7 @@ def read_bitcode(bitcode_dir, h2o, skey):
return behavior_bitcodes, ephys_bitcodes, trial_numbers, ephys_trial_ref_times, ephys_trial_start_times, bf


def insert_ephys_events(skey, bf):
def insert_ephys_events(skey, bf, trial_trunc=None):
'''
all times are session-based
'''
Expand All @@ -1119,8 +1132,10 @@ def insert_ephys_events(skey, bf):
headings = bf['headings'][0]
digMarkerPerTrial = bf['digMarkerPerTrial']

if trial_trunc is None: trial_trunc = digMarkerPerTrial.shape[0]

for col, event_type in enumerate(headings):
times = digMarkerPerTrial[:, col]
times = digMarkerPerTrial[:trial_trunc, col]
not_nan = np.where(~np.isnan(times))[0]
trials = not_nan + 1 # Trial all starts from 1
df = df.append(pd.DataFrame({**skey,
Expand All @@ -1133,7 +1148,7 @@ def insert_ephys_events(skey, bf):

# --- Zaber pulses (only available from ephys NIDQ) ---
if 'zaberPerTrial' in bf:
for trial, pulses in enumerate(bf['zaberPerTrial'][0]):
for trial, pulses in enumerate(bf['zaberPerTrial'][0][:trial_trunc]):
df = df.append(pd.DataFrame({**skey,
'trial': trial + 1, # Trial all starts from 1
'trial_event_id': np.arange(len(pulses)) + len(headings),
Expand All @@ -1152,8 +1167,8 @@ def insert_ephys_events(skey, bf):

if len(exist_lick):
log.info(f' loading licks from NIDQ ...')

for trial, *licks in enumerate(zip(*(bf[lick_wrapper[ltype]][0] for ltype in exist_lick))):
for trial, *licks in enumerate(zip(*(bf[lick_wrapper[ltype]][0][:trial_trunc] for ltype in exist_lick))):
lick_times = {ltype: ltime for ltype, ltime in zip(exist_lick, *licks)}
all_lick_types = np.concatenate(
[[ltype] * len(ltimes) for ltype, ltimes in lick_times.items()])
Expand All @@ -1179,7 +1194,7 @@ def insert_ephys_events(skey, bf):
_idx = [_idx for _idx, field in enumerate(bf['chan'].dtype.descr) if 'cameraNameInDJ' in field][0]
cameras = bf['chan'][0,0][_idx][0,:]
for camera, all_frames in zip(cameras, bf['cameraPerTrial'][0]):
for trial, frames in enumerate(all_frames[0]):
for trial, frames in enumerate(all_frames[0][:trial_trunc]):
key = {**skey,
'trial': trial + 1, # Trial all starts from 1
'tracking_device': camera[0]}
Expand Down
18 changes: 14 additions & 4 deletions pipeline/ingest/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def compare_pc_and_bpod_times(q_sess=dj.AndList(['water_restriction_number = "HH
ax[0].legend()


def compare_ni_and_bpod_times(q_sess=dj.AndList(['subject_id = "473361"', 'session >= 57']), event_to_align='bitcodestart'):
def compare_ni_and_bpod_times(q_sess=dj.AndList(['subject_id = "473361"', 'session >= 57']), event_to_align='bitcodestart', legend=False):
'''
Compare NI-TIME and BPOD-TIME
This is a critical validation for ephys timing alignment
Expand All @@ -325,23 +325,30 @@ def compare_ni_and_bpod_times(q_sess=dj.AndList(['subject_id = "473361"', 'sessi
'''

event_all = ['bitcodestart', 'go', 'choice', 'trialend']
event_all = ['bitcodestart', 'go', 'choice', 'trialend'] if not legend else ['go', 'choice', 'trialend']
event_to_compare = [e for e in event_all if e != event_to_align]

# 1. -- all events, Bpod vs NIXD --
# Ephys time
ephys_times = (ephys.TrialEvent & q_sess & f'trial_event_type IN {tuple(event_all)}').fetch(format='frame')
ephys_times = ephys_times.reset_index().pivot(index = ['subject_id', 'session', 'trial'], columns='trial_event_type').trial_event_time.astype(float)
session_times = ephys_times.bitcodestart
session_times = ephys_times.bitcodestart if not legend else ephys_times.go
ephys_times = ephys_times.sub(ephys_times[event_to_align], axis=0).drop(columns=event_to_align) # To each bitcode start (bpodstart is sometimes problematic if a new bpod session is started)

# Bpod time
bpod_times = (experiment.TrialEvent & q_sess & f'trial_event_type IN {tuple(event_all)}' ).fetch(format='frame')
bpod_times = bpod_times.reset_index().pivot(index = ['subject_id', 'session', 'trial'], columns='trial_event_type').trial_event_time.astype(float) # Already related to bpod start
bpod_times = bpod_times.sub(bpod_times[event_to_align], axis=0).drop(columns=event_to_align) # To each trial's bpod start

# Truncate bpod time, if len(bpod) > len(ephys), but assuming their first trial has been correctly aligned in ingest.ephys
# (this is actually a sanity check)
if len(bpod_times) > len(ephys_times):
print('Bpod length > ephys length!! Bpod truncated...')
bpod_times = bpod_times[:len(ephys_times)]


# Plot: Bpod vs NIXD, distribution of differences
fig = plt.figure(figsize=(8,13))
fig = plt.figure(figsize=(8,len(event_all)*3))
ax = fig.subplots(len(event_to_compare), 2)
max_error = 0
for n, event in enumerate(event_to_compare):
Expand All @@ -368,6 +375,9 @@ def compare_ni_and_bpod_times(q_sess=dj.AndList(['subject_id = "473361"', 'sessi
(ephys.ActionEvent * ephys_go_cue) & 'action_event_time >= ephys_go',
nixa='min(action_event_time)') # Session-time of first lick of each trial

if not len(nixa_first_lick):
return

q_all = (nixa_first_lick.proj(..., tmp='trial_event_id') # NIXA
* (ephys.TrialEvent & 'trial_event_type = "choice"').proj(nixd='trial_event_time', tmp1='trial_event_id') # NIXD
* (ephys.TrialEvent & f'trial_event_type = "{event_to_align}"').proj(ni_align='trial_event_time')
Expand Down
Loading

0 comments on commit 7b0e075

Please sign in to comment.