Skip to content

Commit

Permalink
Improvement: new lines in variable names (#79)
Browse files Browse the repository at this point in the history
Description and Motivation
cai_causal_graph.utils.get_variable_name_and_lag now allows new lines in the names of variables.
  • Loading branch information
maxelliott-causalens authored Jul 16, 2024
1 parent 14c3b33 commit 652a7f0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
12 changes: 11 additions & 1 deletion cai_causal_graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ def get_variable_name_and_lag(node_name: NodeLike) -> Tuple[str, int]:
if not isinstance(node_name, str):
raise TypeError(f'Expected node name to be a string, got type {type(node_name)}.')

is_match = re.match(r'^(.+?)(?: lag\(n=(\d+)\))?(?: future\(n=(\d+)\))?$', node_name)
# This matches three groups:
# (?s:(.+?\n?)) - The variable name. The main bulk of it - (?s:(.+?)) - matches any characters (including new lines)
# in a non-greedy fashion, meaning it won't capture ' lag(n=X)' or ' future(n=X)' at the end as part of the
# variable, as they will be captured by the other matching groups. The extra \n* in the group is there because
# for some reason the main bulk fails to match trailing new lines if there is no lag or future afterward. Not
# sure why, but this fixes it.
# (?: lag\(n=(\d+)\))? - Optionally the lag in the past. This matches the whole 'lag(n=X)' section, but the captured
# group is only the lag value.
# (?: future\(n=(\d+)\))? - Optionally the lag in the future. This matches the whole 'future(n=X)' section, but the
# captured group is only the future lag value.
is_match = re.match(r'^(?s:(.+?\n*))(?: lag\(n=(\d+)\))?(?: future\(n=(\d+)\))?$', node_name)

lag_matches = re.findall(r'lag\(n=(\d+)\)', node_name)
future_matches = re.findall(r'future\(n=(\d+)\)', node_name)
Expand Down
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## NEXT

- `cai_causal_graph.utils.get_variable_name_and_lag` now allows new lines in the names of variables.
- Upgraded `poetry` version from `1.8.2` to `1.8.3` in the GitHub workflows.

## 0.5.2
Expand Down
15 changes: 15 additions & 0 deletions tests/test_time_series_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ def test_get_variable_name_and_lag(self):
self.assertTupleEqual(get_variable_name_and_lag('x1 lag(n=1 lag(n=1)'), ('x1 lag(n=1', -1))
self.assertTupleEqual(get_variable_name_and_lag('x1 lag(n=1 future(n=1)'), ('x1 lag(n=1', 1))

# New lines match
self.assertTupleEqual(get_variable_name_and_lag('x\n1'), ('x\n1', 0))
self.assertTupleEqual(get_variable_name_and_lag('x\n1 lag(n=1)'), ('x\n1', -1))
self.assertTupleEqual(get_variable_name_and_lag('x\n1 future(n=1)'), ('x\n1', 1))

# New line at the end of a variable name matches (edge case regression test)
self.assertTupleEqual(get_variable_name_and_lag('x1\n'), ('x1\n', 0))
self.assertTupleEqual(get_variable_name_and_lag('x1\n lag(n=1)'), ('x1\n', -1))
self.assertTupleEqual(get_variable_name_and_lag('x1\n future(n=1)'), ('x1\n', 1))

# Multiple new line at the end of a variable name matches (edge case regression test)
self.assertTupleEqual(get_variable_name_and_lag('x1\n\n\n'), ('x1\n\n\n', 0))
self.assertTupleEqual(get_variable_name_and_lag('x1\n\n\n lag(n=1)'), ('x1\n\n\n', -1))
self.assertTupleEqual(get_variable_name_and_lag('x1\n\n\n future(n=1)'), ('x1\n\n\n', 1))

def test_get_variable_name_and_lag_raises(self):
self.assertRaises(ValueError, get_variable_name_and_lag, 'x1 lag(n=1) lag(n=2)')
self.assertRaises(ValueError, get_variable_name_and_lag, 'x1 future(n=1) future(n=2)')
Expand Down

0 comments on commit 652a7f0

Please sign in to comment.