Skip to content

Commit

Permalink
fix file path when render model (pyro-ppl#1857)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Sep 16, 2024
1 parent f5aca91 commit a7a2f31
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
9 changes: 7 additions & 2 deletions numpyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,14 @@ def render_model(

if filename is not None:
filename = Path(filename)
# remove leading period from suffix
filename_without_suffix = filename.with_suffix("")
graph.render(
filename.stem, view=False, cleanup=True, format=filename.suffix[1:]
) # remove leading period from suffix
filename_without_suffix,
view=False,
cleanup=True,
format=filename.suffix[1:],
)

return graph

Expand Down
17 changes: 16 additions & 1 deletion test/test_model_rendering.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import os

import numpy as np
import pytest

import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer.inspect import generate_graph_specification, get_model_relations
from numpyro.infer.inspect import (
generate_graph_specification,
get_model_relations,
render_model,
)


def simple(data):
Expand Down Expand Up @@ -129,3 +135,12 @@ def test_model_transformation(test_model, model_kwargs, expected_graph_spec):
graph_spec = generate_graph_specification(relations)

assert graph_spec == expected_graph_spec


def test_render_model_filename():
def model():
numpyro.sample("x", dist.Normal(0, 1))

render_model(model, filename="graph.png")
assert os.path.exists("graph.png")
os.remove("graph.png")

0 comments on commit a7a2f31

Please sign in to comment.