Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: #448 print flush #527

Merged
merged 9 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def equations(self, precision=3):
precision=precision,
)

def print(self, lhs=None, precision=3):
def print(self, lhs=None, precision=3, **kwargs):
"""Print the SINDy model equations.

Parameters
Expand All @@ -362,6 +362,8 @@ def print(self, lhs=None, precision=3):

precision: int, optional (default 3)
Precision to be used when printing out model coefficients.

**kwargs: Additional keyword arguments passed to the builtin print function
"""
eqns = self.equations(precision)
if sindy_pi_flag and isinstance(self.optimizer, SINDyPI):
Expand All @@ -370,17 +372,15 @@ def print(self, lhs=None, precision=3):
feature_names = self.feature_names
for i, eqn in enumerate(eqns):
if self.discrete_time:
names = "(" + feature_names[i] + ")"
print(names + "[k+1] = " + eqn)
names = f"({feature_names[i]})[k+1]"
elif lhs is None:
if not sindy_pi_flag or not isinstance(self.optimizer, SINDyPI):
names = "(" + feature_names[i] + ")"
print(names + "' = " + eqn)
names = f"({feature_names[i]})'"
else:
names = feature_names[i]
print(names + " = " + eqn)
names = f"({feature_names[i]})"
else:
print(lhs[i] + " = " + eqn)
names = f"{lhs[i]}"
print(f"{names} = {eqn}", **kwargs)

def score(self, x, t=None, x_dot=None, u=None, metric=r2_score, **metric_kws):
"""
Expand Down
17 changes: 4 additions & 13 deletions test/test_pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def test_equations(data, capsys):
model.print(precision=2)

out, _ = capsys.readouterr()

assert len(out) > 0
assert "(x0)' = " in out


def test_print_discrete_time(data_discrete_time, capsys):
Expand All @@ -500,20 +502,9 @@ def test_print_discrete_time(data_discrete_time, capsys):
model.print()

out, _ = capsys.readouterr()
assert len(out) > 0


def test_print_discrete_time_multiple_trajectories(
data_discrete_time_multiple_trajectories, capsys
):
x = data_discrete_time_multiple_trajectories
model = SINDy(discrete_time=True)
model.fit(x)

model.print()

out, _ = capsys.readouterr()
assert len(out) > 1
assert len(out) > 0
assert "(x0)[k+1] = " in out


def test_differentiate(data_lorenz, data_multiple_trajectories):
Expand Down
Loading