Skip to content

Commit

Permalink
feat: split_gaze_data into trial
Browse files Browse the repository at this point in the history
  • Loading branch information
SiQube committed Oct 23, 2024
1 parent a8caf95 commit 16b7ae7
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
34 changes: 34 additions & 0 deletions src/pymovements/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,40 @@ def load_precomputed_reading_measures(self) -> None:
self.paths,
)

def _split_gaze_data(
self,
by: list[str] | str,
) -> None:
"""Split gaze data into seperated GazeDataFrame's.
Parameters
----------
by: list[str] | str
Column's to split dataframe by.
"""
if isinstance(by, str):
by = [by]
new_data = [
(
GazeDataFrame(
new_frame,
experiment=_frame.experiment,
trial_columns=self.definition.trial_columns,
time_column=self.definition.time_column,
time_unit=self.definition.time_unit,
position_columns=self.definition.position_columns,
velocity_columns=self.definition.velocity_columns,
acceleration_columns=self.definition.acceleration_columns,
distance_column=self.definition.distance_column,
),
fileinfo_row,
)
for (_frame, fileinfo_row) in zip(self.gaze, self.fileinfo['gaze'].to_dicts())
for new_frame in _frame.frame.partition_by(by=by)
]
self.gaze = [data[0] for data in new_data]
self.fileinfo['gaze'] = pl.concat([pl.from_dict(data[1]) for data in new_data])

def split_precomputed_events(
self,
by: list[str] | str,
Expand Down
53 changes: 51 additions & 2 deletions tests/unit/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def mock_toy(
'y_left_pix': np.zeros(1000),
'x_right_pix': np.zeros(1000),
'y_right_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
Expand All @@ -154,6 +156,8 @@ def mock_toy(
'y_left_pix': pl.Float64,
'x_right_pix': pl.Float64,
'y_right_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_left_pix', 'y_left_pix', 'x_right_pix', 'y_right_pix']
Expand All @@ -169,6 +173,8 @@ def mock_toy(
'y_right_pix': np.zeros(1000),
'x_avg_pix': np.zeros(1000),
'y_avg_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
Expand All @@ -179,6 +185,8 @@ def mock_toy(
'y_right_pix': pl.Float64,
'x_avg_pix': pl.Float64,
'y_avg_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = [
Expand All @@ -192,12 +200,16 @@ def mock_toy(
'time': np.arange(1000),
'x_left_pix': np.zeros(1000),
'y_left_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_left_pix': pl.Float64,
'y_left_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_left_pix', 'y_left_pix']
Expand All @@ -208,12 +220,16 @@ def mock_toy(
'time': np.arange(1000),
'x_right_pix': np.zeros(1000),
'y_right_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_right_pix': pl.Float64,
'y_right_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_right_pix', 'y_right_pix']
Expand All @@ -224,12 +240,16 @@ def mock_toy(
'time': np.arange(1000),
'x_pix': np.zeros(1000),
'y_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_pix': pl.Float64,
'y_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_pix', 'y_pix']
Expand Down Expand Up @@ -1000,7 +1020,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration):
},
(
"Column 'position' not found. Available columns are: "
"['time', 'subject_id', 'pixel', 'custom_position', 'velocity']"
"['time', 'trial_id_1', 'trial_id_2', 'subject_id', "
"'pixel', 'custom_position', 'velocity']"
),
id='no_position',
),
Expand All @@ -1012,7 +1033,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration):
},
(
"Column 'velocity' not found. Available columns are: "
"['time', 'subject_id', 'pixel', 'position', 'custom_velocity']"
"['time', 'trial_id_1', 'trial_id_2', 'subject_id', "
"'pixel', 'position', 'custom_velocity']"
),
id='no_velocity',
),
Expand Down Expand Up @@ -1930,3 +1952,30 @@ def test_load_split_precomputed_events(precomputed_dataset_configuration, by, ex
dataset.load()
dataset.split_precomputed_events(by)
assert len(dataset.precomputed_events) == expected_len


@pytest.mark.parametrize(
('by', 'expected_len'),
[
pytest.param(
'trial_id_1',
40,
id='subset_int',
),
pytest.param(
'trial_id_2',
60,
id='subset_int',
),
pytest.param(
['trial_id_1', 'trial_id_2'],
80,
id='subset_int',
),
],
)
def test_load_split_gaze(gaze_dataset_configuration, by, expected_len):
dataset = pm.Dataset(**gaze_dataset_configuration['init_kwargs'])
dataset.load()
dataset._split_gaze_data(by)
assert len(dataset.gaze) == expected_len

0 comments on commit 16b7ae7

Please sign in to comment.