Skip to content

Commit

Permalink
psychometric returns nan on empty block (incomplete set of trials)
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Nov 11, 2020
1 parent 90c01cc commit 5bc7a11
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 31 deletions.
3 changes: 3 additions & 0 deletions brainbox/behavior/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
else:
block_idx = trials.probabilityLeft == block

if not np.any(block_idx):
return np.nan * np.zeros(4)

contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
rightward = trials.choice == -1
# Calculate the proportion rightward for each contrast type
Expand Down
58 changes: 27 additions & 31 deletions brainbox/tests/test_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,23 @@ def setUp(self):
self.trial_data = pickle.load(f)
np.random.seed(0)

def test_concatenate_and_computations(self):
sess_dates = ['2020-08-25', '2020-08-24', '2020-08-21']
def _get_trials(self, sess_dates):
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
_ = [trials[k].pop('task_protocol') for k in trials.keys()]
trials_total = np.sum([len(trials[k]['contrastRight']) for k in trials.keys()])
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
return trials, task_protocol

def test_psychometric_insufficient_data(self):
# the psychometric aggregate should return NaN when there is no data for a given contrast
trials, _ = self._get_trials(sess_dates=['2020-08-25', '2020-08-24', '2020-08-21'])
trials_all = train.concatenate_trials(trials)
trials_all['probability_left'] = trials_all['contrastLeft'] * 0 + 80
psych_nan = train.compute_psychometric(trials_all, block=100)
assert np.sum(np.isnan(psych_nan)) == 4

def test_concatenate_and_computations(self):
trials, _ = self._get_trials(sess_dates=['2020-08-25', '2020-08-24', '2020-08-21'])
trials_total = np.sum([len(trials[k]['contrastRight']) for k in trials.keys()])
trials_all = train.concatenate_trials(trials)
assert (len(trials_all.contrastRight) == trials_total)

Expand All @@ -148,40 +158,32 @@ def test_concatenate_and_computations(self):
assert (np.isclose(rt, 0.83655))

def test_in_training(self):
sess_dates = ['2020-08-25', '2020-08-24', '2020-08-21']
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
trials, task_protocol = self._get_trials(
sess_dates=['2020-08-25', '2020-08-24', '2020-08-21'])
assert (np.all(np.array(task_protocol) == 'training'))
status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[],
n_delay=0)
status, info = train.get_training_status(
trials, task_protocol, ephys_sess_dates=[], n_delay=0)
assert (status == 'in training')

def test_trained_1a(self):
sess_dates = ['2020-08-26', '2020-08-25', '2020-08-24']
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
trials, task_protocol = self._get_trials(
sess_dates=['2020-08-26', '2020-08-25', '2020-08-24'])
assert (np.all(np.array(task_protocol) == 'training'))
status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[],
n_delay=0)
assert (status == 'trained 1a')

def test_trained_1b(self):
sess_dates = ['2020-08-27', '2020-08-26', '2020-08-25']
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
trials, task_protocol = self._get_trials(
sess_dates=['2020-08-27', '2020-08-26', '2020-08-25'])
assert (np.all(np.array(task_protocol) == 'training'))
status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[],
n_delay=0)
assert(status == 'trained 1b')

def test_training_to_bias(self):
sess_dates = ['2020-08-31', '2020-08-28', '2020-08-27']
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
trials, task_protocol = self._get_trials(
sess_dates=['2020-08-31', '2020-08-28', '2020-08-27'])
assert (~np.all(np.array(task_protocol) == 'training') and
np.any(np.array(task_protocol) == 'training'))
status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[],
Expand All @@ -190,29 +192,23 @@ def test_training_to_bias(self):

def test_ready4ephys(self):
sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28']
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
trials, task_protocol = self._get_trials(sess_dates=sess_dates)
assert (np.all(np.array(task_protocol) == 'biased'))
status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[],
n_delay=0)
assert (status == 'ready4ephysrig')

def test_ready4delay(self):
sess_dates = ['2020-09-03', '2020-09-02', '2020-08-31']
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
trials, task_protocol = self._get_trials(sess_dates=sess_dates)
assert (np.all(np.array(task_protocol) == 'biased'))
status, info = train.get_training_status(trials, task_protocol,
ephys_sess_dates=['2020-09-03'], n_delay=0)
assert (status == 'ready4delay')

def test_ready4recording(self):
sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28']
trials_copy = copy.deepcopy(self.trial_data)
trials = Bunch(zip(sess_dates, [trials_copy[k] for k in sess_dates]))
task_protocol = [trials[k].pop('task_protocol') for k in trials.keys()]
trials, task_protocol = self._get_trials(sess_dates=sess_dates)
assert (np.all(np.array(task_protocol) == 'biased'))
status, info = train.get_training_status(trials, task_protocol,
ephys_sess_dates=sess_dates, n_delay=1)
Expand Down

0 comments on commit 5bc7a11

Please sign in to comment.