Skip to content

Commit

Permalink
feat: add CSVSummary callback (#173)
Browse files Browse the repository at this point in the history
Co-authored-by: Remco de Boer <[email protected]>
  • Loading branch information
sebastianJaeger and redeboer authored Nov 20, 2020
1 parent f8c4bd6 commit fe61e94
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 1 deletion.
1 change: 1 addition & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
"tensorwaves",
"toctree",
"topness",
"traceback",
"unbinned",
"venv",
"weisskopf",
Expand Down
1 change: 1 addition & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.csv
*.doctree
*.inv
*build/
Expand Down
45 changes: 45 additions & 0 deletions docs/usage/3_perform_fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@
"source": [
"from tensorwaves.optimizer.callbacks import (\n",
" CallbackList,\n",
" CSVSummary,\n",
" ProgressBar,\n",
" TFSummary,\n",
" YAMLSummary,\n",
Expand All @@ -235,6 +236,7 @@
" ProgressBar(),\n",
" TFSummary(),\n",
" YAMLSummary(\"current_fit_result.yaml\", estimator),\n",
" CSVSummary(\"fit_traceback.csv\", estimator, step_size=2),\n",
" ]\n",
" )\n",
")\n",
Expand Down Expand Up @@ -379,6 +381,49 @@
"See more info [here](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks#tensorboard_in_notebooks)\n",
"````"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An alternative would be to use the output of the {class}`.CSVSummary` callback. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"fit_traceback = pd.read_csv(\"fit_traceback.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fit_traceback.plot(\"function_call\", \"estimator_value\");"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fit_traceback.plot(\n",
" \"function_call\",\n",
" [\n",
" \"Phase_J/psi(1S)_to_f(0)(1500)_0+gamma_1;f(0)(1500)_to_pi0_0+pi0_0;\",\n",
" \"Mass_f(0)(1710)\",\n",
" \"Width_f(0)(1710)\",\n",
" ],\n",
");"
]
}
],
"metadata": {
Expand Down
47 changes: 46 additions & 1 deletion src/tensorwaves/optimizer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import IO, Iterable, List, Optional

import pandas as pd
import tensorflow as tf
import yaml
from tqdm import tqdm
Expand Down Expand Up @@ -62,7 +63,7 @@ def __call__(self, parameters: dict, estimator_value: float) -> None:
return
output_dict = {
"Time": datetime.now(),
"Iteration": self.__function_call,
"FunctionCalls": self.__function_call,
"Estimator": {
"Type": self.__estimator_type,
"Value": float(estimator_value),
Expand All @@ -84,6 +85,50 @@ def finalize(self) -> None:
self.__stream.close()


class CSVSummary(Callback):
def __init__(
self,
filename: str,
estimator: Estimator,
step_size: int = 10,
) -> None:
"""Log fit parameters and the estimator value to a CSV file."""
self.__function_call = -1
self.__step_size = step_size
self.__first_call = True
self.__stream = open(filename, "w")
_empty_file(self.__stream)
if not isinstance(estimator, Estimator):
raise TypeError(f"Requires an in {Estimator.__name__} instance")
self.__estimator_type: str = estimator.__class__.__name__

def __call__(self, parameters: dict, estimator_value: float) -> None:
self.__function_call += 1
if self.__function_call % self.__step_size != 0:
return
output_dict = {
"time": datetime.now(),
"function_call": self.__function_call,
"estimator_type": self.__estimator_type,
"estimator_value": float(estimator_value),
}
output_dict.update(
{name: float(value) for name, value in parameters.items()}
)

data_frame = pd.DataFrame(output_dict, index=[self.__function_call])
data_frame.to_csv(
self.__stream,
mode="a",
header=self.__first_call,
index=False,
)
self.__first_call = False

def finalize(self) -> None:
self.__stream.close()


class TFSummary(Callback):
def __init__(
self,
Expand Down

0 comments on commit fe61e94

Please sign in to comment.