Skip to content

Commit

Permalink
Merge pull request #38 from chrishavlin/phaseplot_fixes
Browse files Browse the repository at this point in the history
Phaseplot fixes
  • Loading branch information
samwalkow authored Sep 9, 2022
2 parents 69cfbc1 + b0d295b commit 43266bb
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
python -m pip install -e .
- name: Test with pytest
run: |
pytest
pytest -v
24 changes: 22 additions & 2 deletions analysis_schema/_model_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ def process_pydantic(self, pydantic_instance, ds=None):
return yt_func(*the_args, **kwarg_dict)


class PhasePlot(YTGeneric):
def process_pydantic(self, pydantic_instance: data_classes.PhasePlot, ds=None):
if ds is None:
raise RuntimeError("ds required for a PhasePlot")

# before running the usual process_pydantic, need to set data_source
# if it does not exist, using values from the instantiated dataset
if pydantic_instance.data_source is None:
# this is equivalent to ds.all_data()
reg = data_classes.Region(
center=ds.domain_center.d.tolist(),
left_edge=ds.domain_left_edge.d.tolist(),
right_edge=ds.domain_right_edge.d.tolist(),
)
pydantic_instance.data_source = reg

return super().process_pydantic(pydantic_instance, ds=ds)


class Visualizations(YTRunner):
def _sanitize_viz(self, viz_model, yt_viz):
if viz_model.output_type == "file":
Expand All @@ -156,12 +175,12 @@ def _sanitize_viz(self, viz_model, yt_viz):
return yt_viz._repr_html_()

def process_pydantic(self, pydantic_instance: data_classes.Visualizations, ds=None):
generic_runner = YTGeneric()
viz_results = {}
for attr in pydantic_instance.__fields__.keys():
viz_model = getattr(pydantic_instance, attr) # SlicePlot, etc.
viz_runner = yt_registry.get(viz_model)
if viz_model is not None:
result = generic_runner.run(viz_model, ds=ds)
result = viz_runner.run(viz_model, ds=ds)
nme = f"{ds.basename}_{attr}"
viz_results[nme] = self._sanitize_viz(viz_model, result)
return viz_results
Expand Down Expand Up @@ -192,3 +211,4 @@ def _is_yt_schema_instance(obj):
yt_registry.register(data_classes.Visualizations, Visualizations())
yt_registry.register(data_classes.Dataset, Dataset())
yt_registry.register(data_classes.DataSource3D, DataSource3D())
yt_registry.register(data_classes.PhasePlot, PhasePlot())
4 changes: 2 additions & 2 deletions analysis_schema/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def force_in_mem_dstore(wkflow, field_list: list = None, units_list: list = None
# replace the data store datasets with in-memory datasets
new_store = DataStore()
if field_list is None:
field_list = [("gas", "density"), ("gas", "temperature")]
units_list = ["g/cm**3", "K"]
field_list = [("gas", "density"), ("gas", "temperature"), ("gas", "mass")]
units_list = ["g/cm**3", "K", "kg"]
for dsname, dscon in wkflow.data_store.available_datasets.items():
ds_ = fake_amr_ds(fields=field_list, units=units_list)
new_store.store(dscon.filename, dataset_name=dsname, in_memory_ds=ds_)
Expand Down
5 changes: 3 additions & 2 deletions analysis_schema/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ def axis(self):
class PhasePlot(ytVisualization):
"""A yt phase plot"""

data_source: Optional[Dataset] = Field(alias="Dataset")
ds: Optional[List[Dataset]] = Field(alias="Dataset")
data_source: Optional[DataSource3D] = Field(alias="DataSource")
x_field: ytField = Field(alias="xField")
y_field: ytField = Field(alias="yField")
z_fields: Union[ytField, List[ytField]] = Field(alias="zField(s)")
z_fields: Union[ytField, List[ytField]] = Field(alias="zFields")
weight_field: Optional[ytField] = Field(alias="WegihtFieldName")
x_bins: Optional[int] = Field(alias="xBins")
y_bins: Optional[int] = Field(alias="yBins")
Expand Down
15 changes: 11 additions & 4 deletions analysis_schema/yt_analysis_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,23 @@
"type": "string"
},
"Dataset": {
"$ref": "#/definitions/Dataset"
"title": "Dataset",
"type": "array",
"items": {
"$ref": "#/definitions/Dataset"
}
},
"DataSource": {
"$ref": "#/definitions/DataSource3D"
},
"xField": {
"$ref": "#/definitions/ytField"
},
"yField": {
"$ref": "#/definitions/ytField"
},
"zField(s)": {
"title": "Zfield(S)",
"zFields": {
"title": "Zfields",
"anyOf": [
{
"$ref": "#/definitions/ytField"
Expand Down Expand Up @@ -460,7 +467,7 @@
"output_type",
"xField",
"yField",
"zField(s)"
"zFields"
]
},
"Visualizations": {
Expand Down
23 changes: 23 additions & 0 deletions tests/viz_phaseplot_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"$schema": "../analysis_schema/yt_analysis_schema.json",
"Plot": [
{
"PhasePlot": {
"Dataset": [
{
"FileName": "IsolatedGalaxy/galaxy0030/galaxy0030",
"DatasetName": "IG"
}
],
"xField": {
"field": "density", "field_type": "gas"
},
"yField": {
"field_type": "gas", "field": "temperature"
},
"zFields": {"field": "mass", "field_type": "gas"},
"output_type": "file"
}
}
]
}
29 changes: 29 additions & 0 deletions tests/viz_phaseplot_2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"$schema": "../analysis_schema/yt_analysis_schema.json",
"Plot": [
{
"PhasePlot": {
"Dataset": [
{
"FileName": "IsolatedGalaxy/galaxy0030/galaxy0030",
"DatasetName": "IG"
}
],
"DataSource": {
"sphere": {
"Radius": 0.25,
"Center": [0.25, 0.25, 0.25]
}
},
"xField": {
"field": "density", "field_type": "gas"
},
"yField": {
"field_type": "gas", "field": "temperature"
},
"zFields": {"field": "mass", "field_type": "gas"},
"output_type": "file"
}
}
]
}

0 comments on commit 43266bb

Please sign in to comment.