diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a94a800 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,27 @@ +name: CI +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/dev.txt + pip install . + - name: unit testing + run: | + cd tests + pytest . + - name: Code formatting + run: | + black --check libpyvinyl/ diff --git a/.gitignore b/.gitignore index 7c5019f..fd798ed 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,16 @@ *.code-workspace .ropeproject tags -pyvinyl.egg-info/ doc/build +libpyvinyl.egg-info +.#*.* +.readthedocs.yaml +**/notebooks/*.json +**/notebooks/*.h5 +**/notebooks/tmp* +**/notebooks/.ipynb_checkpoints/ +tests/*.json +tests/*.h5 +tests/tmp* +dist/ +build/ diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index b7a0d4f..0000000 --- a/.travis.yml +++ /dev/null @@ -1,31 +0,0 @@ -language: python -python: - - "3.5" - - "3.6" - - "3.7" - - "3.8" - - "nightly" - -sudo: required -dist: xenial - -matrix: - allow_failures: - - python: "3.5" - - python: "3.6" - - python: "nightly" - - -cache: - apt: false - directories: - - $HOME/.cache/pip - - $HOME/lib - -install: - - cd $TRAVIS_BUILD_DIR - - pip install -r requirements.txt - - pip install . - -script: - - python tests/Test.py diff --git a/DEVEL.md b/DEVEL.md new file mode 100644 index 0000000..f72c24d --- /dev/null +++ b/DEVEL.md @@ -0,0 +1,22 @@ +How to test +------------------------------ + +Minimally needed: +``` +pip install -e ./ +cd tests/unit +python Test.py +``` + +Recommended: +``` +pip install --user pytest +pip install -e ./ +cd tests +# Test all +pytest ./ +# Unit test only +pytest ./unit +# Integration test only +pytest ./integration +``` diff --git a/README.md b/README.md index 48bc2af..9d01fb4 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,110 @@ # libpyvinyl - The python APIs for Virtual Neutron and x-raY Laboratory -[![Build Status](https://travis-ci.com/PaNOSC-ViNYL/libpyvinyl.svg?branch=master)](https://travis-ci.com/PaNOSC-ViNYL/libpyvinyl) +[![CI](https://github.com/PaNOSC-ViNYL/libpyvinyl/actions/workflows/ci.yml/badge.svg)](https://github.com/PaNOSC-ViNYL/libpyvinyl/actions/workflows/ci.yml) [![Documentation Status](https://readthedocs.org/projects/libpyvinyl/badge/?version=latest)](https://libpyvinyl.readthedocs.io/en/latest/?badge=latest) - + ## Summary + The python package `libpyvinyl` exposes the high level API for simulation codes under -the umbrella of the Virtual Neutron and x-raY Laboratory (ViNYL). +the umbrella of the Virtual Neutron and x-raY Laboratory (ViNYL). The fundamental class is the `BaseCalculator` and its sister class `Parameters`. While `Parameters` is a pure state engine, i.e. it's sole purpose is to encapsulate the physical, numerical, and computational parameters of a simulation, the `BaseCalculator` -exposes the interface to +exposes the interface to -- Configure a simulation (through the corresponding `Parameters` instance) -- Launch the simulation run -- Collect the simulation output data and make it queriable as a class attribute -- Snapshoot a simulation by dumping the object to disk (using the `dill` library). +- Configure a simulation. +- Launch the simulation run. +- Collect the simulation output data. +- Construct a `Data` instance that represents the simulation output data. +- Snapshoot a simulation by dumping the object to disk. - Reload a simulation run from disk and continue the run with optionally modified parameters. -The `BaseCalculaton` is an abstract base class, it shall not be instantiated as such. +The `BaseCalculator` is an abstract base class, it shall not be instantiated as such. The anticipated use is to inherit specialised `Calculators` from `BaseCalculator` and to implement the core functionality in the derived class. In particular, this is required -for the methods responsible to launch a simulation (`run()`) . +for the methods responsible to launch a simulation through the `backengine()` method. As an example, we demonstrate in an [accompanying notebook](https://github.com/PaNOSC-ViNYL/libpyvinyl/blob/master/doc/source/include/notebooks/example-01.ipynb) how to declare a derived `Calculator` and implement a `backengine` method. The example then -shows how to run the simulation, store the results in a `hdf5` file, snapshot the simulation +shows how to run the simulation, store the results in a `hdf5` file, snapshot the simulation and reload the simulation into memory. -## Acknowledgement -This project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No. 823852. +## Installation + +We recommend installation in a virtual environment, either `conda` or `pyenv`. + +### Create a `conda` environment + +``` +$> conda create -n libpyvinyl +``` + +### Common users + +``` +$> pip install libpyvinyl +``` + +### Developers + +We provide a requirements file for developers in _requirements/dev.txt_. + +``` +$> cd requirements +$> pip install -r dev.txt +``` + +`conda install` is currently not supported. + +Then, install `libpyvinyl` into the same environment. The `-e` flag links the installed library to +the source code in the repository, such that changes in the latter are immediately effective in the installed version. +``` +$> cd .. +$> pip install -e . +``` +## Testing + +We recommend to run the unittests and integration tests. + +``` +$> pytest tests +``` + +You should see a test report similar to this: + +``` +=============================================================== test session starts ================================================================ +platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1 +rootdir: /home/juncheng/Projects/libpyvinyl +collected 100 items + +integration/plusminus/tests/test_ArrayCalculators.py . [ 1%] +integration/plusminus/tests/test_Instrument.py . [ 2%] +integration/plusminus/tests/test_NumberCalculators.py ... [ 5%] +integration/plusminus/tests/test_NumberData.py ........... [ 16%] +unit/test_BaseCalculator.py .......... [ 26%] +unit/test_BaseData.py ........................... [ 53%] +unit/test_Instrument.py ....... [ 60%] +unit/test_Parameters.py ........................................ [100%] + +=============================================================== 100 passed in 0.56s ================================================================ +``` + +You can also run unittests only: + +``` +pytest tests/unit +``` + +Or to run integration tests only: + +``` +pytest tests/integration +``` + +## Acknowledgement + +This project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No. 823852. diff --git a/doc/source/_templates/footer.html b/doc/source/_templates/footer.html index ccb00fa..af67528 100644 --- a/doc/source/_templates/footer.html +++ b/doc/source/_templates/footer.html @@ -10,7 +10,7 @@

The software - libpyvinyl is licensed under the LGPL version 3 or later. + libpyvinyl is licensed under the LGPL version 3 or later.

This project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No. 823852 diff --git a/doc/source/conf.py b/doc/source/conf.py index d7951ea..41ccf63 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -15,22 +15,23 @@ import os import sys -sys.path.insert(0,'../libpyvinyl') +sys.path.insert(0, "../libpyvinyl") import libpyvinyl import sphinx_rtd_theme # -- Project information ----------------------------------------------------- -project = 'libpyvinyl' -copyright = '2020, Carsten Fortmann-Grote, Mads Bertelsen, Juncheng E' -author = 'Carsten Fortmann-Grote, Mads Bertelsen, Juncheng E' +project = "libpyvinyl" +copyright = ( + "2020-2021, Carsten Fortmann-Grote, Mads Bertelsen, Juncheng E, Shervin Nourbakhsh" +) +author = "Carsten Fortmann-Grote, Mads Bertelsen, Juncheng E, Shervin Nourbakhsh" # The short X.Y version -version = '0.0.2' +version = libpyvinyl.__version__ # The full version, including alpha/beta/rc tags -release = '0.0.2-alpha1' - +release = version # -- General configuration --------------------------------------------------- @@ -42,29 +43,30 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx_rtd_theme', - 'nbsphinx', - 'recommonmark', + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx_rtd_theme", + "nbsphinx", + "recommonmark", + "sphinx_autodoc_typehints", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -87,7 +89,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -99,7 +101,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -115,7 +117,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'libpyvinyldoc' +htmlhelp_basename = project + "doc" # -- Options for LaTeX output ------------------------------------------------ @@ -124,15 +126,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -142,8 +141,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'libpyvinyl.tex', 'libpyvinyl Documentation', - 'Carsten Fortmann-Grote, Mads Bertelsen, Juncheng E', 'manual'), + ( + master_doc, + project + ".tex", + project + " Documentation", + author, + "manual", + ), ] @@ -151,10 +155,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'libpyvinyl', 'libpyvinyl Documentation', - [author], 1) -] +man_pages = [(master_doc, project, project + " Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -163,9 +164,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'libpyvinyl', 'libpyvinyl Documentation', - author, 'libpyvinyl', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + project, + project + " Documentation", + author, + project, + "One line description of project.", + "Miscellaneous", + ), ] @@ -184,7 +191,7 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- @@ -194,7 +201,7 @@ # -- Options for intersphinx extension --------------------------------------- # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} # -- Options for todo extension ---------------------------------------------- diff --git a/doc/source/include/notebooks/example-01.ipynb b/doc/source/include/notebooks/example-01.ipynb index 8709a7e..6a95225 100644 --- a/doc/source/include/notebooks/example-01.ipynb +++ b/doc/source/include/notebooks/example-01.ipynb @@ -9,107 +9,93 @@ }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## Imports " - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, + "metadata": {}, "outputs": [], "source": [ - "from libpyvinyl.BaseCalculator import BaseCalculator, Parameters\n", - "import os\n", - "import h5py\n", - "import numpy" - ], + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "from libpyvinyl import BaseCalculator, CalculatorParameters, Parameter\n", + "import os\n", + "import h5py\n", + "import numpy" + ] }, { "cell_type": "markdown", - "source": [ - "## Implement a Calculator that derives from `BaseCalculator`." - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } - }, - { - "cell_type": "code", - "execution_count": 27, - "outputs": [], + }, "source": [ - "from libpyvinyl.BaseCalculator import BaseCalculator, Parameters\n", - "import numpy\n", - "import h5py\n", - "import sys" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + "## Implement a Calculator that derives from `BaseCalculator`." + ] }, { "cell_type": "code", - "execution_count": 30, - "outputs": [], - "source": [ - "sys.path.insert(0,'../../../../tests')\n", - "\n", - "from RandomImageCalculator import RandomImageCalculator" - ], + "execution_count": 3, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": null, + }, "outputs": [], "source": [ - "class RandomImageCalculatorNB(BaseCalculator):\n", + "class RandomImageCalculator(BaseCalculator):\n", " \"\"\" class: Implements simulation of a rondom image for demonstration purposes. \"\"\"\n", - " def __init__(self, parameters=None, dumpfile=None, input_path=None, output_path=None):\n", + " def __init__(self, \n", + " *args,\n", + " **kwargs\n", + " ):\n", " \"\"\" Constructor of the RandomImageCalculator class. \"\"\"\n", - " super().__init__(parameters=parameters, dumpfile=dumpfile, output_path=output_path)\n", + " super().__init__(*args, **kwargs)\n", + " \n", + " self.__data = None\n", + " \n", "\n", " def backengine(self):\n", " \"\"\" Method to do the actual calculation.\"\"\"\n", - " tmpdata = numpy.random.random((self.parameters.grid_size_x, self.parameters.grid_size_y))\n", - "\n", - " self._set_data(tmpdata)\n", - " return 0\n", + " self.__data = [numpy.random.random((self.parameters.grid_size_x, self.parameters.grid_size_y))]\n", + " \n", + " self.saveH5()\n", "\n", " def saveH5(self, openpmd=False):\n", " \"\"\" Save image to hdf5 file 'output_path'. \"\"\"\n", - " with h5py.File(self.output_path, \"w\") as h5:\n", - " ds = h5.create_dataset(\"/data\", data=self.data)\n", - "\n", - " h5.close()\n", - "\n" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + " for i,fname in enumerate(self.output_filenames):\n", + " with h5py.File(fname, \"w\") as h5:\n", + " ds = h5.create_dataset(\"/data\", data=self.__data[i])\n", + " \n", + " def init_parameters():\n", + " pass" + ] }, { "cell_type": "markdown", @@ -120,52 +106,101 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "parameters = Parameters(photon_energy=6e3, pulse_energy=1.0e-6, grid_size_x=128, grid_size_y=128)" + "photon_energy = Parameter(name='photon_energy', unit='keV', comment=\"The photon energy in units of kilo electronvolt (keV)\")\n", + "pulse_energy = Parameter(name='pulse_energy', unit='J')\n", + "grid_size_x = Parameter(name='grid_size_x', unit=\"\")\n", + "grid_size_y = Parameter(name='grid_size_y', unit=\"\")\n" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], "source": [ - "### Setup the calculator" - ], - "metadata": { - "collapsed": false - } + "photon_energy.value = 6.0\n", + "pulse_energy.value = 5.0e-6\n", + "grid_size_x.value = 128\n", + "grid_size_y.value = 256\n" + ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 6, + "metadata": {}, "outputs": [], "source": [ - "calculator = RandomImageCalculator(parameters, output_path=\"out.h5\")" - ], + "parameters = CalculatorParameters([photon_energy, pulse_energy, grid_size_x, grid_size_y])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup the calculator" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [ + { + "ename": "ValueError", + "evalue": "len(output_keys) = 1 is not equal to len(output_data_types) = 1", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_3697194/3811637111.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcalculator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mRandomImageCalculator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'RandomImageCalculator'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_keys\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'random_image.h5'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_data_types\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/tmp/ipykernel_3697194/1285912060.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 6\u001b[0m ):\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\"\"\" Constructor of the RandomImageCalculator class. \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Repositories/libpyvinyl/libpyvinyl/BaseCalculator.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, input, output_keys, output_data_types, output_filenames, instrument_base_dir, calculator_base_dir, parameters)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__check_consistency\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;31m# Create output data objects according to the output_data_classes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Repositories/libpyvinyl/libpyvinyl/BaseCalculator.py\u001b[0m in \u001b[0;36m__check_consistency\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;34m\"\"\"Check the consistency of the input parameters\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_keys\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_data_types\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 136\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 137\u001b[0m \u001b[0;34mf\"len(output_keys) = {len(self.output_keys)} is not equal to len(output_data_types) = {len(self.output_keys)}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m )\n", + "\u001b[0;31mValueError\u001b[0m: len(output_keys) = 1 is not equal to len(output_data_types) = 1" + ] + } + ], + "source": [ + "calculator = RandomImageCalculator(parameters=parameters, name='RandomImageCalculator', output_keys='random_image.h5', output_data_types=[], input=[] )" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Run the backengine" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": 33, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { - "text/plain": "0" + "text/plain": [ + "0" + ] }, "execution_count": 33, "metadata": {}, @@ -174,30 +209,45 @@ ], "source": [ "calculator.backengine()" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Look at the data and store as hdf5" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": 34, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { - "text/plain": "array([[0.07471297, 0.00703423, 0.92835525, ..., 0.83050226, 0.93011749,\n 0.10575778],\n [0.30465388, 0.36400513, 0.48903381, ..., 0.0438568 , 0.39087367,\n 0.97940832],\n [0.61994805, 0.84566645, 0.42535347, ..., 0.94919735, 0.17939005,\n 0.74872113],\n ...,\n [0.22103305, 0.07844426, 0.8127275 , ..., 0.4273249 , 0.78210725,\n 0.59653636],\n [0.00889755, 0.40566176, 0.33960702, ..., 0.2634355 , 0.34068678,\n 0.99275201],\n [0.99495603, 0.18621833, 0.25057866, ..., 0.33598942, 0.10660242,\n 0.20565293]])" + "text/plain": [ + "array([[0.07471297, 0.00703423, 0.92835525, ..., 0.83050226, 0.93011749,\n", + " 0.10575778],\n", + " [0.30465388, 0.36400513, 0.48903381, ..., 0.0438568 , 0.39087367,\n", + " 0.97940832],\n", + " [0.61994805, 0.84566645, 0.42535347, ..., 0.94919735, 0.17939005,\n", + " 0.74872113],\n", + " ...,\n", + " [0.22103305, 0.07844426, 0.8127275 , ..., 0.4273249 , 0.78210725,\n", + " 0.59653636],\n", + " [0.00889755, 0.40566176, 0.33960702, ..., 0.2634355 , 0.34068678,\n", + " 0.99275201],\n", + " [0.99495603, 0.18621833, 0.25057866, ..., 0.33598942, 0.10660242,\n", + " 0.20565293]])" + ] }, "execution_count": 34, "metadata": {}, @@ -206,90 +256,98 @@ ], "source": [ "calculator.data" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "code", "execution_count": 35, - "outputs": [], - "source": [ - "calculator.saveH5(calculator.output_path)" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "calculator.saveH5(calculator.output_path)" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Save the parameters to a human readable json file." - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": 36, - "outputs": [], - "source": [ - "parameters.to_json(\"my_parameters.json\")" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "parameters.to_json(\"my_parameters.json\")" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Save calculator to binary dump." - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": 37, - "outputs": [], - "source": [ - "dumpfile = calculator.dump()" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "dumpfile = calculator.dump()" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "### Load back parameters" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": 38, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { - "text/plain": "6000.0" + "text/plain": [ + "6000.0" + ] }, "execution_count": 38, "metadata": {}, @@ -298,111 +356,121 @@ ], "source": [ "new_parameters = Parameters.from_json(\"my_parameters.json\")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", - "source": [ - "### Look ath the photon energy of the restored parameters." - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "### Look ath the photon energy of the restored parameters." + ] }, { "cell_type": "code", "execution_count": null, - "outputs": [], - "source": [ - "new_parameters.photon_energy" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "new_parameters.photon_energy" + ] }, { "cell_type": "markdown", - "source": [ - "### Reconstruct the dumped calculator." - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "### Reconstruct the dumped calculator." + ] }, { "cell_type": "code", "execution_count": null, - "outputs": [], - "source": [ - "reloaded_calculator = RandomImageCalculator(dumpfile=dumpfile)" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "reloaded_calculator = RandomImageCalculator(dumpfile=dumpfile)" + ] }, { "cell_type": "markdown", - "source": [ - "### Query the data from the reconstructed calculator." - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "### Query the data from the reconstructed calculator." + ] }, { "cell_type": "code", "execution_count": null, - "outputs": [], - "source": [ - "reloaded_calculator.data" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "reloaded_calculator.data" + ] }, { "cell_type": "markdown", - "source": [ - "### Look at the photon energy of the reconstructed calculator." - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "### Look at the photon energy of the reconstructed calculator." + ] }, { "cell_type": "code", "execution_count": 39, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { - "text/plain": "6000.0" + "text/plain": [ + "6000.0" + ] }, "execution_count": 39, "metadata": {}, @@ -411,32 +479,29 @@ ], "source": [ "reloaded_calculator.parameters.photon_energy\n" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "code", "execution_count": null, - "outputs": [], - "source": [], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "name": "pycharm-44f7cfec", + "display_name": "pyvinyl", "language": "python", - "display_name": "PyCharm (libpyvinyl)" + "name": "pyvinyl" }, "language_info": { "codemirror_mode": { @@ -448,9 +513,16 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.7" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } } }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/doc/source/include/refman.rst b/doc/source/include/refman.rst index 38cf3be..42c40a2 100644 --- a/doc/source/include/refman.rst +++ b/doc/source/include/refman.rst @@ -1,12 +1,22 @@ API Reference Manual ==================== -.. autoclass:: libpyvinyl.BaseCalculator +.. automodule:: libpyvinyl.BaseCalculator :members: :undoc-members: :show-inheritance: -.. autoclass:: libpyvinyl.Parameters +.. automodule:: libpyvinyl.Parameters.Collections :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + +.. automodule:: libpyvinyl.Parameters.Parameter + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: libpyvinyl.Instrument + :members: + :undoc-members: + :show-inheritance: diff --git a/libpyvinyl/AbstractBaseClass.py b/libpyvinyl/AbstractBaseClass.py index 475ff98..c5bf935 100644 --- a/libpyvinyl/AbstractBaseClass.py +++ b/libpyvinyl/AbstractBaseClass.py @@ -25,6 +25,7 @@ #################################################################################### from abc import ABCMeta, abstractmethod + class AbstractBaseClass(object, metaclass=ABCMeta): """ :class AbstractBaseClass: Base class of libpyvinyl @@ -33,4 +34,3 @@ class AbstractBaseClass(object, metaclass=ABCMeta): @abstractmethod def __init__(self): pass - diff --git a/libpyvinyl/BaseCalculator.py b/libpyvinyl/BaseCalculator.py index 12d5716..c4685e0 100644 --- a/libpyvinyl/BaseCalculator.py +++ b/libpyvinyl/BaseCalculator.py @@ -1,14 +1,14 @@ """ -:module BaseCalculator: Module hosting the BaseCalculator and Parameters classes. +:module BaseCalculator: Module hosting the BaseCalculator class. """ #################################################################################### # # -# This file is part of libpyvinyl - The APIs for Virtual Neutron and x-raY # +# This file is part of libpyvinyl - The APIs for Virtual Neutron and x-raY # # Laboratory. # # # -# Copyright (C) 2020 Carsten Fortmann-Grote # +# Copyright (C) 2021 Carsten Fortmann-Grote, Juncheng E # # # # This program is free software: you can redistribute it and/or modify it under # # the terms of the GNU Lesser General Public License as published by the Free # @@ -25,19 +25,22 @@ #################################################################################### from abc import abstractmethod -from libpyvinyl.AbstractBaseClass import AbstractBaseClass -from libpyvinyl.Parameters import CalculatorParameters +from typing import Union, Optional from tempfile import mkstemp import copy import dill -import h5py -import sys +from pathlib import Path import logging -import numpy import os -logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', - level=logging.WARNING) +from libpyvinyl.AbstractBaseClass import AbstractBaseClass +from libpyvinyl.BaseData import BaseData, DataCollection +from libpyvinyl.Parameters import CalculatorParameters + + +logging.basicConfig( + format="%(asctime)s %(levelname)s:%(message)s", level=logging.WARNING +) class BaseCalculator(AbstractBaseClass): @@ -50,77 +53,305 @@ class BaseCalculator(AbstractBaseClass): This class is to be used as a base class for all calculators that implement a special simulation module, such as a photon diffraction calculator. Such a - specialized Calculator than has the same interface to the simulation - backengine as all other ViNyL Calculators. + specialized Calculator has the same interface to the simulation + backengine as all other ViNYL Calculators. + + A Complete example including a instrument and calculators can be found in + `test/integration/plusminus` """ - @abstractmethod - def __init__(self, name: str, parameters=None, dumpfile=None, **kwargs): + + def __init__( + self, + name: str, + input: Union[DataCollection, list, BaseData], + output_keys: Union[list, str], + output_data_types: list, + output_filenames: Union[list, str, None] = None, + instrument_base_dir: str = "./", + calculator_base_dir: str = "BaseCalculator", + parameters: CalculatorParameters = None, + ): """ - :param name: The name for this calculator. - :type name: str + :param name: The name of this calculator. + :type name: str + + :param name: The input of this calculator. It can be a `DataCollection`, + a list of `DataCollection`s or a single Data Object. + :type name: DataCollection, list or BaseData + + :param output_keys: The key(s) of this calculator's output data. It's a list of `str`s or + a single str. + :type output_keys: list or str + + :param output_data_types: The data type(s), i.e., classes, of each output. It's a list of the + data classes or a single data class. The available data classes are based on `BaseData`. + :type output_data_types: list or DataClass + + :param output_filenames: The name(s) of the output file(s). It can be a str of a filename or + a list of filenames. If the mapping is dict mapping, the name is `None`. Defaults to None. + :type output_filenames: list or str + + :param instrument_base_dir: The base directory for the instrument to which this calculator + belongs. Defaults to "./". The final exact output file path depends on `instrument_base_dir` + and `calculator_base_dir`: `instrument_base_dir`/`calculator_base_dir`/filename + :type instrument_base_dir: str + + :param calculator_base_dir: The base directory for this calculator. Defaults to "./". The final + exact output file path depends on `instrument_base_dir` and + `calculator_base_dir`: `instrument_base_dir`/`calculator_base_dir`/filename + :type instrument_base_dir: str :param parameters: The parameters for this calculator. :type parameters: Parameters - :param dumpfile: If given, load a previously dumped (aka pickled) calculator. + """ + # Initialize the variables + self.__name = None + self.__instrument_base_dir = None + self.__calculator_base_dir = None + self.__input = None + self.__output_keys = None + self.__output_data_types = None + self.__output_filenames = None + self.__parameters = None + + self.name = name + self.input = input + self.output_keys = output_keys + self.output_data_types = output_data_types + self.output_filenames = output_filenames + self.instrument_base_dir = instrument_base_dir + self.calculator_base_dir = calculator_base_dir + self.parameters = parameters + + self.__check_consistency() + # Create output data objects according to the output_data_classes + self.__init_output() + + def __check_consistency(self): + """Check the consistency of the input parameters""" + if len(self.output_keys) != len(self.output_data_types): + raise ValueError( + f"len(output_keys) = {len(self.output_keys)} is not equal to len(output_data_types) = {len(self.output_data_types)}" + ) + + def __check_output_filenames(self): + """Since output_filenames can be None for output in dict mapping, only check output_files when necessary""" + if len(self.output_data_types) != len(self.output_filenames): + raise ValueError( + f"len(output_filenames) = {len(self.output_filenames)} is not equal to len(output_data_types) = {len(self.output_data_types)}" + ) + + @property + def name(self) -> str: + return self.__name + + @name.setter + def name(self, value): + if isinstance(value, str): + self.__name = value + else: + raise TypeError( + f"Calculator: `name` is expected to be a str, not {type(value)}" + ) + + @property + def parameters(self) -> CalculatorParameters: + return self.__parameters + + @parameters.setter + def parameters(self, value: CalculatorParameters): + self.reset_parameters(value) + + def reset_parameters(self, value: CalculatorParameters): + """Resets the calculator parameters""" + if isinstance(value, CalculatorParameters): + self.__parameters = value + elif value is None: + self.init_parameters() + else: + raise TypeError( + f"Calculator: `parameters` is expected to be CalculatorParameters, not {type(value)}" + ) + + def set_parameters(self, args_as_dict=None, **kwargs): + """ + Sets parameters contained in this calculator using dict or kwargs + """ + if args_as_dict is not None: + parameter_dict = args_as_dict + else: + parameter_dict = kwargs - :param kwargs: (key, value) pairs of further arguments to the calculator, e.g input_path, output_path. + for key, parameter_value in parameter_dict.items(): + self.parameters[key].value = parameter_value - If both 'parameters' and 'dumpfile' are given, the dumpfile is loaded - first. Passing a parameters object may be used to update some - parameters. + @property + def instrument_base_dir(self) -> str: + return self.__instrument_base_dir - Example: - ``` - # Define a specialized calculator. - class MyCalculator(BaseCalculator): + @instrument_base_dir.setter + def instrument_base_dir(self, value): + self.set_instrument_base_dir(value) - def __init__(self, parameters=None, dumpfile=None, **kwargs): - super()__init__(parameters, dumpfile, **kwargs) + def set_instrument_base_dir(self, value: str): + """Set the instrument base directory""" + if isinstance(value, str): + self.__instrument_base_dir = value + else: + raise TypeError( + f"Calculator: `instrument_base_dir` is expected to be a str, not {type(value)}" + ) - def backengine(self): - os.system("my_simulation_backengine_call") + @property + def calculator_base_dir(self) -> str: + return self.__calculator_base_dir - def saveH5(self): - # Format output into openpmd hdf5 format. + @calculator_base_dir.setter + def calculator_base_dir(self, value): + self.set_calculator_base_dir(value) + + def set_calculator_base_dir(self, value: str): + """Set the calculator base directory""" + if isinstance(value, str): + self.__calculator_base_dir = value + else: + raise TypeError( + f"Calculator: `calculator_base_dir` is expected to be a str, not {type(value)}" + ) - class MyParameters(Parameters): - pass + @property + def input(self) -> DataCollection: + return self.__input + + @input.setter + def input(self, value): + self.set_input(value) + + def set_input(self, value: Union[DataCollection, list, BaseData, None]): + """Set the calculator input data. It can be a DataCollection, list or BaseData object.""" + if isinstance(value, (DataCollection, type(None))): + self.__input = value + elif isinstance(value, list): + self.__input = DataCollection(*value) + elif isinstance(value, BaseData): + self.__input = DataCollection(value) + else: + raise TypeError( + f"Calculator: `input` can be a DataCollection, list or BaseData object, and will be treated as a DataCollection. Your input type: {type(value)} is not accepted." + ) - my_calculator = MyCalculator(my_parameters) + @property + def output_keys(self) -> list: + return self.__output_keys - my_calculator.backengine() + @output_keys.setter + def output_keys(self, value): + self.set_output_keys(value) - my_calculator.saveH5("my_sim_output.h5") - my_calculater.dump("my_calculator.dill") - ``` + @property + def base_dir(self): + """The base path for the output files of this calculator in consideration of instrument_base_dir and calculator_base_dir""" + base_dir = Path(self.instrument_base_dir) / self.calculator_base_dir + return str(base_dir) - """ + @property + def output_file_paths(self): + """The final output file paths considering base_dir""" + self.__check_output_filenames() + paths = [] + + for filename in self.output_filenames: + path = Path(self.base_dir) / filename + # Make sure the file directory exists + path.parent.mkdir(parents=True, exist_ok=True) + paths.append(str(path)) + return paths + + def set_output_keys(self, value: Union[list, str]): + """Set the calculator output keys. It can be a list of str or a single str.""" + if isinstance(value, list): + for item in value: + assert type(item) is str + self.__output_keys = value + elif isinstance(value, str): + self.__output_keys = [value] + else: + raise TypeError( + f"Calculator: `output_keys` can be a list or str, and will be treated as a list. Your input type: {type(value)} is not accepted." + ) - if isinstance(name, str): - self.name = name + @property + def output_data_types(self) -> list: + return self.__output_data_types + + @output_data_types.setter + def output_data_types(self, value): + self.set_output_data_types(value) + + def set_output_data_types(self, value: Union[list, BaseData]): + """Set the calculator output data type. It can be a list of DataClass or a single DataClass.""" + if isinstance(value, list): + for item in value: + assert issubclass(item, BaseData) + self.__output_data_types = value + elif issubclass(value, BaseData): + self.__output_data_types = [value] else: - raise TypeError("name should be in str type.") - # Set data - self.__data = None + raise TypeError( + f"Calculator: `output_data_types` can be a list or a subclass of BaseData, and will be treated as a list. Your input type: {type(value)} is not accepted." + ) - if isinstance(parameters, (type(None), CalculatorParameters)): - self.parameters = parameters + @property + def output_filenames(self) -> list: + return self.__output_filenames + + @output_filenames.setter + def output_filenames(self, value): + self.set_output_filenames(value) + + def set_output_filenames(self, value: Union[list, str, None]): + """Set the calculator output filenames. It can be a list of filenames or just a single str.""" + if isinstance(value, list): + for item in value: + assert type(item) is str or type(None) + self.__output_filenames = value + elif isinstance(value, (str, type(None))): + self.__output_filenames = [value] else: raise TypeError( - "parameters should be in CalculatorParameters type.") + f"Calculator: `output_filenames` can be a list or just a str or None, and will be treated as a list. Your input type: {type(value)} is not accepted." + ) + + @property + def output(self): + """The output of this calculator""" + return self.__output - # Must load after setting paramters to avoid being overrode by empty parameters. - if dumpfile is not None: - self.__load_from_dump(dumpfile) + @property + def data(self): + """The alias of output. It's not recommended to use this variable name due to it's ambiguity.""" + return self.__output + + @abstractmethod + def init_parameters(self): + """Virtual method to initialize all parameters. Must be implemented on the + specialized class.""" - if "output_path" in kwargs: - self.output_path = kwargs["output_path"] + raise NotImplementedError + + def __init_output(self): + """Create output data objects according to the output_data_types""" + output = DataCollection() + for i, key in enumerate(self.output_keys): + output_data = self.output_data_types[i](key) + output.add_data(output_data) + self.__output = output def __call__(self, parameters=None, **kwargs): - """ The copy constructor + """The copy constructor :param parameters: The parameters for the new calculator. :type parameters: CalculatorParameters @@ -135,51 +366,42 @@ def __call__(self, parameters=None, **kwargs): new.__dict__.update(kwargs) - if parameters is not None: + if parameters is None: + new.parameters = copy.deepcopy(new.parameters) + else: new.parameters = parameters - return new - def __load_from_dump(self, dumpfile): - """ """ - """ - Load a dill dump and initialize self's internals. + @classmethod + def from_dump(cls, dumpfile: str): + """Load a dill dump from a dumpfile. + :param dumpfile: The file name of the dumpfile. + :type dumpfile: str + :return: The calculator object restored from the dumpfile. + :rtype: CalcualtorClass """ - with open(dumpfile, 'rb') as fhandle: + with open(dumpfile, "rb") as fhandle: try: tmp = dill.load(fhandle) except: - raise IOError( - "Cannot load calculator from {}.".format(dumpfile)) - - self.__dict__ = copy.deepcopy(tmp.__dict__) - - del tmp + raise IOError("Cannot load calculator from {}.".format(dumpfile)) - @property - def parameters(self): - """ The parameters of this calculator. """ - - return self.__parameters - - @parameters.setter - def parameters(self, val): - - if not isinstance(val, (type(None), CalculatorParameters)): - raise TypeError( - """Passed argument 'parameters' has wrong type. Expected CalculatorParameters, found {}.""" - .format(type(val))) + if not isinstance(tmp, cls): + raise TypeError(f"The object in the file {dumpfile} is not a {cls}") - self.__parameters = val + return tmp - def dump(self, fname=None): + def dump(self, fname: Optional[str] = None) -> str: """ Dump class instance to file. :param fname: Filename (path) of the file to write. + :type fname: str + :return: The filename of the dumpfile + :rtype: str """ if fname is None: @@ -188,109 +410,14 @@ def dump(self, fname=None): prefix=self.__class__.__name__[-1], dir=os.getcwd(), ) - try: - with open(fname, "wb") as file_handle: - dill.dump(self, file_handle) - except: - raise + with open(fname, "wb") as file_handle: + dill.dump(self, file_handle) return fname @abstractmethod - def saveH5(self, fname: str, openpmd: bool = True): - """ Save the simulation data to hdf5 file. - - :param fname: The filename (path) of the file to write the data to. - :type fname: str - - :param openpmd: Flag that controls whether the data is to be written in according to the openpmd metadata standard. Default is True. - - """ - - @property - def data(self): - return self.__data - - @data.setter - def data(self, val): - raise AttributeError("Attribute 'data' is read-only.") - - @abstractmethod - def backengine(self): - pass - - @classmethod - def run_from_cli(cls): - """ - Method to start calculator computations from command line. - - :return: exit with status code - - """ - if len(sys.argv) == 2: - fname = sys.argv[1] - calculator = cls(fname) - status = calculator._run() - sys.exit(status) - - def _run(self): - """ - Method to do computations. By default starts backengine. - - :return: status code. - - """ - result = self.backengine() - - if result is None: - result = 0 - - return result - - def _set_data(self, data): - """ """ - """ Private method to store the data on the object. - - :param data: The data to store. - - """ - - self.__data = data - - -# Mocks for testing. Have to be here to work around bug in dill that does not -# like classes to be defined outside of __main__. -class SpecializedCalculator(BaseCalculator): - def __init__(self, name, parameters=None, dumpfile=None, **kwargs): - - super().__init__(name, parameters, dumpfile, **kwargs) - - def setParams(self, photon_energy: float = 10, pulse_energy: float = 1e-3): - if not isinstance(self.parameters, CalculatorParameters): - self.parameters = CalculatorParameters() - self.parameters.new_parameter("photon_energy", - unit="eV", - comment="Photon energy") - self.parameters['photon_energy'].set_value(photon_energy) - - self.parameters.new_parameter("pulse_energy", - unit="joule", - comment="Pulse energy") - self.parameters['pulse_energy'].set_value(pulse_energy) - def backengine(self): - self._BaseCalculator__data = numpy.random.normal( - loc=self.parameters['photon_energy'].value, - scale=0.001 * self.parameters['photon_energy'].value, - size=(100, )) - - return 0 - - def saveH5(self, openpmd=False): - with h5py.File(self.output_path, "w") as h5: - ds = h5.create_dataset("/data", data=self.data) - - h5.close() + raise NotImplementedError # This project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No. 823852. diff --git a/libpyvinyl/BaseData.py b/libpyvinyl/BaseData.py new file mode 100644 index 0000000..cb06948 --- /dev/null +++ b/libpyvinyl/BaseData.py @@ -0,0 +1,489 @@ +""" :module BaseData: Module hosts the BaseData class.""" +from typing import Union, Optional +from abc import abstractmethod, ABCMeta +from libpyvinyl.AbstractBaseClass import AbstractBaseClass + + +class BaseData(AbstractBaseClass): + """The abstract data class. Inheriting classes represent simulation input and/or output + data and provide a harmonized user interface to simulation data of various kinds rather than a data format. + Their purpose is to provide a harmonized user interface to common data operations such as reading/writing from/to disk. + + :param key: The key to identify the Data Object. + :type key: str + :param expected_data: A placeholder dict for expected data. The keys of this dict are expected to be found during the execution of `get_data()`. + The value for each key can be `None`. + :type expected_data: dict + :param data_dict: The dict to map by this DataClass. It has to be `None` if a file mapping was already set, defaults to None. + :type data_dict: dict, optional + :param filename: The filename of the file to map by this DataClass. It has to be `None` if a dict mapping was already set, defaults to None. + :type filename: str, optional + :param file_format_class: The FormatClass to map the file by this DataClass, It has to be `None` if a dict mapping was already set, defaults to None + :type file_format_class: class, optional + :param file_format_kwargs: The kwargs needed to map the file, defaults to None. + :type file_format_kwargs: dict, optional + """ + + def __init__( + self, + key: str, + expected_data: dict, + data_dict: Optional[dict] = None, + filename: Optional[str] = None, + file_format_class=None, + file_format_kwargs: Optional[dict] = None, + ): + self.__key = None + self.__expected_data = None + self.__data_dict = None + self.__filename = None + self.__file_format_class = None + self.__file_format_kwargs = None + + self.key = key + # Expected_data is checked when `self.get_data()` + self.expected_data = expected_data + # This will be always be None if the data class is mapped to a file + self.data_dict = data_dict + # These will be always be None if the data class is mapped to a python data dict object + self.filename = filename + self.file_format_class = file_format_class + self.file_format_kwargs = file_format_kwargs + + self.__check_consistency() + + @property + def key(self) -> str: + """The key of the class instance for calculator usage""" + return self.__key + + @key.setter + def key(self, value: str): + if isinstance(value, str): + self.__key = value + else: + raise TypeError(f"Data Class: key should be a str, not {type(value)}") + + @property + def expected_data(self): + """The expected_data of the class instance for calculator usage""" + return self.__expected_data + + @expected_data.setter + def expected_data(self, value): + if isinstance(value, dict): + self.__expected_data = value + else: + raise TypeError( + f"Data Class: expected_data should be a dict, not {type(value)}" + ) + + @property + def data_dict(self): + """The data_dict of the class instance for calculator usage""" + return self.__data_dict + + @data_dict.setter + def data_dict(self, value): + if isinstance(value, dict): + self.__data_dict = value + elif value is None: + self.__data_dict = None + else: + raise TypeError( + f"Data Class: data_dict should be None or a dict, not {type(value)}" + ) + self.__check_consistency() + + def set_dict(self, data_dict: dict): + """Set a mapping dict for this DataClass. + + :param data_dict: The data dict to map + :type data_dict: dict + """ + self.data_dict = data_dict + + def set_file(self, filename: str, format_class, **kwargs): + """Set a mapping file for this DataClass. + + :param filename: The filename of the file to map. + :type filename: str + :param format_class: The FormatClass to map the file + :type format_class: class + """ + self.filename = filename + self.file_format_class = format_class + self.file_format_kwargs = kwargs + self.__check_consistency() + + @property + def filename(self): + """The filename of the file to map by this DataClass.""" + return self.__filename + + @filename.setter + def filename(self, value): + if isinstance(value, str): + self.__filename = value + elif value is None: + self.__filename = None + else: + raise TypeError( + f"Data Class: filename should be None or a str, not {type(value)}" + ) + + @property + def file_format_class(self): + """The FormatClass to map the file by this DataClass""" + return self.__file_format_class + + @file_format_class.setter + def file_format_class(self, value): + if isinstance(value, ABCMeta): + self.__file_format_class = value + elif value is None: + self.__file_format_class = None + else: + raise TypeError( + f"Data Class: format_class should be None or a format class, not {type(value)}" + ) + + @property + def file_format_kwargs(self): + """The kwargs needed to map the file""" + return self.__file_format_kwargs + + @file_format_kwargs.setter + def file_format_kwargs(self, value): + if isinstance(value, dict): + self.__file_format_kwargs = value + elif value is None: + self.__file_format_kwargs = None + else: + raise TypeError( + f"Data Class: file_format_kwargs should be None or a dict, not {type(value)}" + ) + + @property + def mapping_type(self): + """If this data class is a file mapping or python dict mapping.""" + return self.__check_mapping_type() + + def __check_mapping_type(self): + """Check the mapping_type of this class.""" + if self.data_dict is not None: + return dict + elif self.filename is not None: + return self.file_format_class + else: + raise TypeError("Neither self.__data_dict or self.__filename was found.") + + @property + def mapping_content(self): + """Returns an overview of the keys of the mapped dict or the filename of the mapped file""" + if self.mapping_type == dict: + return self.data_dict.keys() + else: + return self.filename + + @staticmethod + def _add_ioformat(format_dict, format_class): + """Register an ioformat to a `format_dict` listing the formats supported by this DataClass. + + :param format_dict: The dict listing the supported formats. + :type format_dict: dict + :param format_class: The FormatClass to add. + :type format_class: class + """ + register = format_class.format_register() + for key, val in register.items(): + if key == "key": + this_format = val + format_dict[val] = {} + else: + format_dict[this_format][key] = val + + @classmethod + @abstractmethod + def supported_formats(self): + format_dict = {} + # Add the supported format classes when creating a concrete class. + # See the example at `tests/BaseDataTest.py` + self._add_ioformat(format_dict, FormatClass) + return format_dict + + @classmethod + def list_formats(self): + """Print supported formats""" + out_string = "" + supported_formats = self.supported_formats() + for key in supported_formats: + dicts = supported_formats[key] + format_class = dicts["format_class"] + if format_class: + out_string += "Format class: {}\n".format(format_class) + out_string += "Key: {}\n".format(key) + out_string += "Description: {}\n".format(dicts["description"]) + ext = dicts["ext"] + if ext != "": + out_string += "File extension: {}\n".format(ext) + kwargs = dicts["read_kwargs"] + if kwargs != [""]: + out_string += "Extra reading keywords: {}\n".format(kwargs) + kwargs = dicts["write_kwargs"] + if kwargs != [""]: + out_string += "Extra writing keywords: {}\n".format(kwargs) + out_string += "\n" + print(out_string) + + def __check_consistency(self): + # If all of the file-related parameters are set: + if all([self.filename, self.file_format_class]): + # If the data_dict is also set: + if self.data_dict is not None: + raise RuntimeError( + "self.data_dict and self.filename can not be set for one data class at the same time." + ) + else: + pass + # If any one of the file-related parameters is None: + elif ( + self.filename is None + and self.file_format_class is None + and self.file_format_kwargs is None + ): + pass + # If some of the file-related parameters is None and some is not None: + else: + raise RuntimeError( + "self.filename, self.file_format_class, self.file_format_kwargs are not consistent." + ) + + @classmethod + def from_file(cls, filename: str, format_class, key: dict, **kwargs): + """Create a Data Object mapping a file. + + :param filename: The filename of the file to map by this DataClass. It has to be `None` if a dict mapping was already set, defaults to None. + :type filename: str, optional + :param file_format_class: The FormatClass to map the file by this DataClass, It has to be `None` if a dict mapping was already set, defaults to None + :type file_format_class: class, optional + :param file_format_kwargs: The kwargs needed to map the file, defaults to None. + :type file_format_kwargs: dict, optional + :param key: The key to identify the Data Object. + :type key: str + + :return: A Data Object + :rtype: BaseData + """ + return cls( + key, + filename=filename, + file_format_class=format_class, + file_format_kwargs=kwargs, + ) + + @classmethod + def from_dict(cls, data_dict: dict, key: str): + """Create a Data Object mapping a data dict. + + :param data_dict: The dict to map by this DataClass. It has to be `None` if a file mapping was already set, defaults to None. + :type data_dict: dict + :param key: The key to identify the Data Object. + :type key: str + :return: A Data Object + :rtype: BaseData + """ + return cls(key, data_dict=data_dict) + + def write(self, filename: str, format_class, key: str = None, **kwargs): + """Write the data mapped by the Data Object into a file and return a Data Object + mapping the file. It converts either a file or a python object to a file + The behavior related to a file will always be handled by the format class. + If it's a python dictionary mapping, write with the specified format_class + directly. + + :param filename: The filename of the file to be written. + :type filename: str + :param file_format_class: The FormatClass to write the file. + :type file_format_class: class + :param key: The identification key of the new Data Object. When it's `None`, a new key will + be generated with a suffix added to the previous identification key by the FormatClass. Defaults to None. + :type key: str, optional + :return: A Data Object + :rtype: BaseData + """ + if self.mapping_type == dict: + return format_class.write(self, filename, key, **kwargs) + elif format_class in self.file_format_class.direct_convert_formats(): + return self.file_format_class.convert( + self, filename, format_class, key, **kwargs + ) + # If it's a file mapping and would like to write in the same file format of the + # mapping, it will let the user know that a file containing the data in the same format already existed. + elif format_class == self.file_format_class: + print( + f"Hint: This data already existed in the file {self.__filename} in format {self.__file_format_class}. `cp {self.__filename} {filename}` could be faster." + ) + print( + f"Will still write the data into the file {filename} in format {format_class}" + ) + return format_class.write(self, filename, key, **kwargs) + else: + return format_class.write(self, filename, key, **kwargs) + + def __check_for_expected_data(self, data_to_read): + """Check if the `data_to_read` contains the data we have""" + for key in self.expected_data.keys(): + try: + data_to_read[key] + except KeyError: + raise KeyError( + f"Expected data dict key '{key}' is not found." + ) from None + + def __get_dict_data(self): + """Get the data dict from a dict mapping""" + if self.__data_dict is not None: + # It will automatically check the data needed to be extracted. + self.__check_for_expected_data(self.__data_dict) + return self.data_dict + else: + raise RuntimeError( + "__get_dict_data() should not be called when self.__data_dict is None" + ) + + def __get_file_data(self): + """Get the data dict from a file mapping""" + if self.__filename is not None: + data_to_read = self.__file_format_class.read( + self.__filename, **self.__file_format_kwargs + ) + # It will automatically check the data needed to be extracted. + self.__check_for_expected_data(data_to_read) + data_to_return = {} + for key in data_to_read.keys(): + data_to_return[key] = data_to_read[key] + return data_to_return + else: + raise RuntimeError( + "__get_file_data() should not be called when self.__filename is None" + ) + + def get_data(self): + """Return the data in a dictionary""" + # From either a file or a python object to a python object + if self.__data_dict is not None: + return self.__get_dict_data() + elif self.__filename is not None: + return self.__get_file_data() + + def __str__(self): + """Returns strings of Data objects info""" + string = f"key = {self.key}\n" + string += f"mapping = {self.mapping_type}: {self.mapping_content}" + return string + + +# DataCollection class +class DataCollection: + """A collection of Data Objects""" + + def __init__(self, *args): + self.data_object_dict = {} + self.add_data(*args) + + def __len__(self): + return len(self.data_object_dict) + + def __setitem__(self, key, value): + if key != value.key: + print( + f"Warning: the key '{key}' of this DataCollection will be replaced by the key '{value.key}' set in the input data." + ) + del self.data_object_dict[key] + self.add_data(value) + + def __getitem__(self, keys): + if isinstance(keys, str): + return self.get_data_object(keys) + elif isinstance(keys, list): + subset = [] + for key in keys: + subset.append(self.get_data_object(key)) + return DataCollection(*subset) + + def add_data(self, *args): + """Add data objects to the data colletion""" + for data in args: + assert isinstance(data, BaseData) + self.data_object_dict[data.key] = data + + def get_data(self): + """Get the data of the data object(s). + When there is only one item in the DataCollection, it returns the data dict, + When there are more then one items, it returns a dictionary of the data dicts""" + if len(self.data_object_dict) == 1: + return next(iter(self.data_object_dict.values())).get_data() + else: + data_dicts = {} + for key, obj in self.data_object_dict.items(): + data_dicts[key] = obj.get_data() + return data_dicts + + def write( + self, + filename: Union[str, dict], + format_class, + key: Union[str, dict] = None, + **kwargs, + ): + """Write the data object(s) to the file(s). + When there is only one item in the DataCollection, it returns the data object mapping the file which was wirttern, + When there are more then one items, it returns a dictionary of the data objects. + + :param filename: The name(s) of the file(s) to write. When there are multiple items, they are expected in + a dict where the keys corresponding to the data in this collection. + :type filename: str or dict + :param format_class: The format class of the file(s). When there are multiple items, they are expected in + a dict where the keys corresponding to the data in this collection. + :type format_class: class or dict + :param key: The key(s) of the data object(s) mapping the written file(s), defaults to None. + :type key: str or dict, optional + :return: A data object or a dict of data objects. + :rtype: DataClass or dict + """ + + if len(self.data_object_dict) == 1: + obj = next(iter(self.data_object_dict.values())) + return obj.write(filename, format_class, key, **kwargs) + else: + assert isinstance(key, dict) + data_dicts = {} + for col_key, obj in self.data_object_dict.items(): + written_data = obj.write( + filename[col_key], format_class[col_key], key[col_key], **kwargs + ) + data_dicts[written_data.key] = written_data + return data_dicts + + def get_data_object(self, key: str): + """Get one data object by its key + + :param key: The key of the data object to get. + :type key: str + :return: A data object + :rtype: DataClass + """ + return self.data_object_dict[key] + + def to_list(self): + """Return a list of the data objects in the data collection""" + return [value for value in self.data_object_dict.values()] + + def __str__(self): + """Returns strings of the data object info""" + string = "Data collection:\n" + string += "key - mapping\n\n" + for data_object in self.data_object_dict.values(): + string += f"{data_object.key} - {data_object.mapping_type}: {data_object.mapping_content}\n" + return string diff --git a/libpyvinyl/BaseFormat.py b/libpyvinyl/BaseFormat.py new file mode 100644 index 0000000..abdaac3 --- /dev/null +++ b/libpyvinyl/BaseFormat.py @@ -0,0 +1,109 @@ +from abc import abstractmethod +from libpyvinyl.AbstractBaseClass import AbstractBaseClass +from libpyvinyl.BaseData import BaseData + + +class BaseFormat(AbstractBaseClass): + """The abstract format class. It's the interface of a certain data format.""" + + def __init__(self): + # Nothing needs to be done here. + pass + + @classmethod + @abstractmethod + def format_register(self): + # Override this `format_register` method in a concrete format class. + key = "Base" + desciption = "Base data format" + file_extension = "base" + read_kwargs = [""] + write_kwargs = [""] + return self._create_format_register( + key, desciption, file_extension, read_kwargs, write_kwargs + ) + + @classmethod + def _create_format_register( + cls, + key: str, + desciption: str, + file_extension: str, + read_kwargs=[""], + write_kwargs=[""], + ): + format_register = { + "key": key, # FORMAT KEY + "description": desciption, # FORMAT DESCRIPTION + "ext": file_extension, # FORMAT EXTENSION + "format_class": cls, # CLASS NAME OF THE FORMAT + "read_kwargs": read_kwargs, # KEYWORDS LIST NEEDED TO READ + "write_kwargs": write_kwargs, # KEYWORDS LIST NEEDED TO WRITE + } + return format_register + + @classmethod + @abstractmethod + def read(self, filename: str, **kwargs) -> dict: + """Read the data from the file with the `filename` to a dictionary. The dictionary will + be used by its corresponding data class.""" + # Example codes. Override this function in a concrete class. + data_dict = {} + with h5py.File(filename, "r") as h5: + for key, val in h5.items(): + data_dict[key] = val[()] + return data_dict + + @classmethod + @abstractmethod + def write(cls, object: BaseData, filename: str, key: str, **kwargs): + """Save the data with the `filename`.""" + # Example codes. Override this function in a concrete class. + data_dict = object.get_data() + arr = np.array([data_dict["number"]]) + np.savetxt(filename, arr, fmt="%.3f") + if key is None: + original_key = object.key + key = original_key + "_to_TXTFormat" + return object.from_file(filename, cls, key) + else: + return object.from_file(filename, cls, key) + + @staticmethod + @abstractmethod + def direct_convert_formats(): + # Assume the format can be converted directly to the formats supported by these classes: + # AFormat, BFormat + # Override this `direct_convert_formats` in a concrete format class + return [Aformat, BFormat] + + @classmethod + @abstractmethod + def convert( + cls, obj: BaseData, output: str, output_format_class: str, key, **kwargs + ): + """Direct convert method, if the default converting would be too slow or not suitable for the output_format""" + # If there is no direct converting supported: + raise NotImplementedError + if output_format_class is AFormat: + return cls.convert_to_AFormat(obj.filename, output) + else: + raise TypeError( + "Direct converting to format {} is not supported".format( + output_format_class + ) + ) + # Set the key of the returned object + if key is None: + original_key = obj.key + key = original_key + "_from_BaseFormat" + return obj.from_file(output, output_format_class, key) + + # Example convert_to_AFormat() + # @classmethod + # def convert_to_AFormat(cls, input: str, output: str): + # """The engine of convert method.""" + # print("Directly converting BaseFormat to AFormat") + # number = float(np.loadtxt(input)) + # with h5py.File(output, "w") as h5: + # h5["number"] = number diff --git a/libpyvinyl/BeamlinePropagator.py b/libpyvinyl/BeamlinePropagator.py deleted file mode 100644 index b5f3193..0000000 --- a/libpyvinyl/BeamlinePropagator.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -:module BeamlinePropagator: Module hosting the BeamlinePropagator and BeamlinePropagatorParameters -abstract classes. -""" - - -#################################################################################### -# # -# This file is part of libpyvinyl - The APIs for Virtual Neutron and x-raY # -# Laboratory. # -# # -# Copyright (C) 2020 Carsten Fortmann-Grote # -# # -# This program is free software: you can redistribute it and/or modify it under # -# the terms of the GNU Lesser General Public License as published by the Free # -# Software Foundation, either version 3 of the License, or (at your option) any # -# later version. # -# # -# This program is distributed in the hope that it will be useful, but WITHOUT ANY # -# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A # -# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # -# # -# You should have received a copy of the GNU Lesser General Public License along # -# with this program. If not, see None: + """ + Add a new parameter with the given name as master parameter. + The goal is to link parameters in multiple calculators that represent the same quantity and that should be all changed at the same time when any of them is changed. This is obtained creating the link and by changing the value of the newly created master parameter. + + :param name: name of the master parameter + :param links: dictionary with the names of the calculators and calculator parameters that represent the same quantity and hence can be changed all at once modifying the master parameter" + """ self.parameters.add_master_parameter(name, links, **kwargs) @property - def name(self): + def name(self) -> str: """The name of this instrument.""" return self.__name @name.setter - def name(self, value: str): - self.__name = value + def name(self, value: str) -> None: + if isinstance(value, str): + self.__name = value + else: + raise TypeError( + f"Instrument: name is expecting a str rather than {type(value)}" + ) @property - def calculators(self): + def calculators(self) -> Dict[str, BaseCalculator]: """The list of calculators. It's modified either when constructing the class instance - or using the `add_calculator` function. + or using the :meth:`~libpyvinyl.Instrument.add_calculator` function. """ return self.__calculators @property - def parameters(self): - """The parameter collection of each calculator in the instrument. These parameters are links to the - exact parameters of each calculator.""" + def parameters(self) -> InstrumentParameters: + """ + The parameter collection of each calculator in the instrument. + These parameters are links to the + exact parameters of each calculator. + """ return self.__parameters @property - def master(self): + def master(self) -> MasterParameters: + """Return the master parameters""" return self.parameters.master - def set_base_path(self, base: str): - """Set each calculator's output_path as 'base_path/calculator.name'. + def set_instrument_base_dir(self, base: str) -> None: + """Set each calculator's `instrument_base_dir` to '`base`. Each calculator's data file ouput directory + will be "`instrument_base_dir`/`calculator_base_dir`". - :param base: The base path to be set. + :param base: The base directory to be set. :type base: str """ - self.base_path = base - basePath = Path(base) - for key in self.calculators: - outputPath = basePath / self.calculators[key].name - calculator = self.calculators[key] - calculator.output_path = str(outputPath) - - def list_calculators(self): + if isinstance(base, str): + self.__instrument_base_dir = base + for calculator in self.calculators.values(): + calculator.instrument_base_dir = self.__instrument_base_dir + else: + raise TypeError( + f"Instrument: instrument_base_dir is expecting a str rather than {type(base)}" + ) + + def list_calculators(self) -> None: + """ + Print the list of all defined calculators for this instrument + """ string = f"- Instrument: {self.name} -\n" string += "Calculators:\n" for key in self.calculators: string += f"{key}\n" print(string) - def list_parameters(self): + def list_parameters(self) -> None: + """ + Print the list of all calculator parameters + """ print(self.parameters) - def add_calculator(self, calculator): + def add_calculator(self, calculator: BaseCalculator) -> None: + """ + Append one calculator to the list of calculators. + + N.B. calculators are executed in the same order as they are provided + + :param calculator: calculator + """ self.__calculators[calculator.name] = calculator self.__parameters.add(calculator.name, calculator.parameters) - def remove_calculator(self, calculator_name): + def remove_calculator(self, calculator_name: str) -> None: + """ + Remove the calculator with the given name from the list of calculators + + :param calculator_name: name of one calculator already added to the list + """ + del self.__calculators[calculator_name] del self.__parameters[calculator_name] + + def run(self) -> None: + """ + Run the entire simulation, + i.e. all the calculators in the order they have been provided + """ + for calculator in self.calculators.values(): + calculator.backengine() diff --git a/libpyvinyl/Parameters/Collections.py b/libpyvinyl/Parameters/Collections.py index d943557..e915753 100644 --- a/libpyvinyl/Parameters/Collections.py +++ b/libpyvinyl/Parameters/Collections.py @@ -1,8 +1,49 @@ # Created by Mads Bertelsen and modified by Juncheng E -import json +import json_tricks as json +from collections import OrderedDict +import copy + from libpyvinyl.AbstractBaseClass import AbstractBaseClass from .Parameter import Parameter +from pint.quantity import Quantity +from pint.unit import Unit, UnitsContainer + +from typing import Union + + +def quantity_encode( + obj: Union[Quantity, Unit, UnitsContainer, any], primitives: bool = False +): + """ + Function to encode pint.Quantity and pint.Unit objects in json + + It returns obj if the encoding was not possible. + """ + if isinstance(obj, Quantity): + return {"__quantity__": str(obj)} + elif isinstance(obj, Unit): + return str(obj) + elif isinstance(obj, UnitsContainer): + return "" + else: + return obj + + +def quantity_decode(dct): + """ + Function to decode pint.Quantity object from json + """ + if "__quantity__" in dct: + a = dct["__quantity__"] + if "inf" in a: + return Quantity("inf", a.strip("inf")) + else: + return Quantity(dct["__quantity__"]) + elif "__units__" in dct: + return dct["__units__"] + else: + return dct class CalculatorParameters(AbstractBaseClass): @@ -11,11 +52,12 @@ class CalculatorParameters(AbstractBaseClass): Parameters are stored in a dict using their name as key """ + def __init__(self, parameters=None): """ Creates a Parameters object, optionally with list of parameter objects """ - self.parameters = {} + self.parameters = OrderedDict() if parameters is not None: self.add(parameters) @@ -25,7 +67,8 @@ def check_type(self, parameter): """ if not isinstance(parameter, Parameter): raise RuntimeError( - "A non-Parameter object was given to Parameters class.") + "Object of type Parameter expected, received {}".format(type(parameter)) + ) def check_list_type(self, parameter_list): """ @@ -46,8 +89,7 @@ def add(self, parameter): self.check_list_type(parameter) for par in parameter: if par.name in self.parameters: - raise RuntimeError( - "Duplicate parameter name in parameters!") + raise RuntimeError("Duplicate parameter name in parameters!") self.parameters[par.name] = par return @@ -74,13 +116,13 @@ def __getitem__(self, key): try: return self.parameters[key] except KeyError: - raise KeyError("Call parameters by parameters[key], it doesn't support list function.") + raise KeyError(f"{key} is not a valid parameter name.") def __setitem__(self, key, value): """ Sets value of parameter with given key to given value """ - self.parameters[key].set_value(value) + self.parameters[key].value = value def __delitem__(self, key): """ @@ -88,6 +130,24 @@ def __delitem__(self, key): """ del self.parameters[key] + def __iter__(self): + """ + Facilitates looping through the contained parameters + + Uses the built in iterator in the return of dict.values() so one can + iterate through the parameters with a for loop. + """ + return self.parameters.values().__iter__() + + def __next__(self): + """ + Facilitates looping through the contained parameters + + Uses the built in next method in the return of dict.values() so one can + iterate through the parameters with a for loop. + """ + return self.parameters.values().__next__() + def print_indented(self, indents): """ returns string describing this object, can optionally be indented @@ -110,8 +170,10 @@ def from_json(cls, fname: str): :type fname: str """ - with open(fname, 'r') as fp: - instance = cls.from_dict(json.load(fp)) + with open(fname, "r") as fp: + instance = cls.from_dict( + json.load(fp, extra_obj_pairs_hooks=[quantity_decode]), + ) return instance @@ -133,7 +195,12 @@ def from_dict(cls, params_dict: dict): def to_dict(self): params = {} for key in self.parameters: - params[key] = self.parameters[key].__dict__ + # Deepcopy to not modify the original parameters + params[key] = copy.deepcopy(self.parameters[key].__dict__) + a = params[key] + if "_Parameter__value_type" in a: + del a["_Parameter__value_type"] + return params def to_json(self, fname: str): @@ -144,8 +211,14 @@ def to_json(self, fname: str): :type fname: str """ - with open(fname, 'w') as fp: - json.dump(self.to_dict(), fp, indent=4) + with open(fname, "w") as fp: + json.dump( + self.to_dict(), + fp, + indent=4, + allow_nan=True, + extra_obj_encoders=[quantity_encode], + ) class MasterParameter(Parameter): @@ -155,6 +228,7 @@ class MasterParameter(Parameter): system is added that contains information on which Parameters objects this master parameter should control parameters from. """ + def __init__(self, *args, **kwargs): """ Create MasterParameter with uninitialized links @@ -175,6 +249,7 @@ class MasterParameters(CalculatorParameters): additional ability to set values for the other parameters this master should control. """ + def __init__(self, parameters_dict, *args, **kwargs): """ Create MasterParameters object with given parameters dict @@ -198,7 +273,7 @@ def __setitem__(self, key, value): calculator_par_name = master_parameter.links[calculator] self.parameters_dict[calculator][calculator_par_name] = value - self.parameters[key].set_value(value) + self.parameters[key].value = value class InstrumentParameters(AbstractBaseClass): @@ -209,6 +284,7 @@ class InstrumentParameters(AbstractBaseClass): have master parameters which control parameters for a number of calculators at once. """ + def __init__(self): """ Create an empty ParametersCollection instance @@ -225,8 +301,10 @@ def from_json(cls, fname: str): :type fname: str """ - with open(fname, 'r') as fp: - instance = cls.from_dict(json.load(fp)) + with open(fname, "r") as fp: + instance = cls.from_dict( + json.load(fp, extra_obj_pairs_hooks=[quantity_decode]) + ) return instance @@ -242,15 +320,19 @@ def from_dict(cls, instrument_dict: dict): parameters = cls() for key in instrument_dict: if key != "Master": - parameters.add(key, CalculatorParameters.from_dict(instrument_dict[key])) + parameters.add( + key, CalculatorParameters.from_dict(instrument_dict[key]) + ) if "Master" in instrument_dict.keys(): - parameters.master = CalculatorParameters.from_dict(instrument_dict["Master"]) + parameters.master = CalculatorParameters.from_dict( + instrument_dict["Master"] + ) return parameters def to_dict(self): params_collect = {} - params_collect['Master'] = self.master.to_dict() + params_collect["Master"] = self.master.to_dict() for key in self.parameters_dict: params_collect[key] = self.parameters_dict[key].to_dict() return params_collect @@ -263,8 +345,14 @@ def to_json(self, fname: str): :type fname: str """ - with open(fname, 'w') as fp: - json.dump(self.to_dict(), fp, indent=4) + with open(fname, "w") as fp: + json.dump( + self.to_dict(), + fp, + indent=4, + allow_nan=True, + extra_obj_encoders=[quantity_encode], + ) def add(self, key, parameters): """ @@ -273,7 +361,8 @@ def add(self, key, parameters): if not isinstance(parameters, CalculatorParameters): raise RuntimeError( "ParametersCollection holds objects of type Parameters," - + " was provided with something else.") + + " was provided with something else." + ) self.parameters_dict[key] = parameters @@ -288,8 +377,7 @@ def add_master_parameter(self, name, links, **kwargs): for link_key in links: if link_key not in self.parameters_dict: - raise RuntimeError( - "A link had a key which was not recognized.") + raise RuntimeError("A link had a key which was not recognized.") master_parameter.add_links(links) self.master.add(master_parameter) diff --git a/libpyvinyl/Parameters/Parameter.py b/libpyvinyl/Parameters/Parameter.py index 966109f..322beba 100644 --- a/libpyvinyl/Parameters/Parameter.py +++ b/libpyvinyl/Parameters/Parameter.py @@ -1,173 +1,470 @@ # Created by Mads Bertelsen and modified by Juncheng E +# Further modified by Shervin Nourbakhsh import math +import numpy from libpyvinyl.AbstractBaseClass import AbstractBaseClass +# importing units using the pint package from the __init__.py of this module +# from . import ureg, Q_ + +from pint.unit import Unit +from pint.quantity import Quantity +import pint.errors + +# typing +from typing import Union, Any, Tuple, List, Dict, Optional + +# ValueTypes: TypeAlias = [str, bool, int, float, object, pint.Quantity] +ValueTypes = Union[str, bool, int, float, pint.Quantity] + class Parameter(AbstractBaseClass): """ - Description of a single parameter + Description of a single parameter. + + The parameter is defined by: + - name: when added to a parameter collection, it can be accessed by this name + - value: can be a boolean, a string, a pint.Quantity, an int or float (the latter internally converted to pint.Quantity) + - unit: a string that is internally converted into a pint.Unit + - comment: a string with a brief description of the parameter and additional informations """ - def __init__(self, name, unit=None, comment=None): + + def __init__( + self, + name: str, + unit: str = "", + comment: Union[str, None] = None, + ): """ Creates parameter with given name, optionally unit and comment + + :param name: name of the parameter + :param unit: physical units returning the parameter value + :param comment: brief description of the parameter + """ - self.name = name - self.unit = unit - self.comment = comment - self.value = None - self.legal_intervals = [] - self.illegal_intervals = [] - self.options = [] + self.name: str = name + self.__unit: Union[str, Unit] = Unit(unit) if unit != None else "" + self.comment: Union[str, None] = comment + self.__value: Union[ValueTypes, None] = None + self.__intervals: List[Tuple[Quantity, Quantity]] = [] + self.__intervals_are_legal: Union[bool, None] = None + self.__options: List = [] + self.__options_are_legal: Union[bool, None] = None + self.__value_type: Union[ValueTypes, None] = None @classmethod - def from_dict(cls, param_dict): - param = cls(param_dict['name'], param_dict['unit'], - param_dict['comment']) + def from_dict(cls, param_dict: Dict): + """ + Helper class method creating a new object from a dictionary providing + - name: str MANDATORY + - unit: str + - comment: str + - ... + + This class method is mainly used to allow dumping and loading the class from json + """ + if "name" not in param_dict: + raise KeyError( + "name is a mandatory element of the dictionary, but has not been found" + ) + param = cls( + param_dict["name"], param_dict["_Parameter__unit"], param_dict["comment"] + ) for key in param_dict: param.__dict__[key] = param_dict[key] + + # set the value type, making the necessary promotions + param.__set_value_type(param.value) + for interval in param.__intervals: + param.__set_value_type(interval[0]) + param.__set_value_type(interval[1]) + for option in param.__options: + param.__set_value_type(option) return param - def add_legal_interval(self, min_value, max_value): + @property + def unit(self) -> str: + """Returning the units as a string""" + return str(self.__unit) + + @unit.setter + def unit(self, uni: str) -> None: """ - Sets a legal interval for this parameter, None for infinite + Assignment of the units + + :param uni: unit + + A pint.Unit is used if the string is recognized as a valid unit in the registry. + It is stored as a string otherwise. """ - if min_value is None: - min_value = -math.inf - if max_value is None: - max_value = math.inf + try: + self.__unit = Unit(uni) + except pint.errors.UndefinedUnitError: + self.__unit = uni + + @property + def value_no_conversion(self) -> ValueTypes: + """ + Returning the object stored in value with no conversions + """ + return self.__value + + @property + def pint_value(self) -> Quantity: + """Returning the value as a pint object if available, an error otherwise""" + if not isinstance(self.__value, Quantity): + raise TypeError("The parameter value is not of pint.Quantity type") + return self.__value + + @property + def value(self) -> ValueTypes: + """ + Returns the magnitude of a Quantity or the stored value otherwise + """ + if isinstance(self.__value, Quantity): + return self.__value.m_as(self.__unit) + else: + return self.__value + + @staticmethod + def __is_type_compatible(t1: type, t2: Union[None, type]) -> bool: + """ + Check type compatibility + + :param t1: first type + :type t1: type + + :param t2: second type + :type t2: type + + :return: bool + + True if t1 and t2 are of the same type or if one is int and the other float + False otherwise + """ + if t1 == type(None) or t2 == type(None): + return True + if t1 == None or t2 == None: + return True + + # promote any int or float to pint.Quantity + if t1 == float or t1 == int or t1 == numpy.float64: + t1 = Quantity + if t2 == float or t2 == int or t2 == numpy.float64: + t2 = Quantity + + if "quantity" in str(t1): + t1 = Quantity + if "quantity" in str(t2): + t2 = Quantity + + if t1 == t2: + return True + + return False + + def __to_quantity(self, value: Any) -> Union[Quantity, Any]: + """ + Converts value into a pint.Quantity if this Parameter is defined to be a Quantity. + It returns value unaltered otherwise. + """ + + if self.__value_type == Quantity and not isinstance(value, Quantity): + return Quantity(value, self.__unit) + + return value - self.legal_intervals.append([min_value, max_value]) + def __set_value_type(self, value: Any) -> None: + """ + Sets the type for the parameter. + It should always be preceded by a __check_compatibility to avoid chaning the type for the Parameter + + :param value: a value that might be assigned as Parameter value or in an interval or option + :type value: any type + + It will raise an exception if the type is not coherent to what previously is declared. + """ + if ( + hasattr(value, "__iter__") + and not isinstance(value, str) + and not isinstance(value, Quantity) + ): + value = value[0] + + # if an integer has units, then it is a quantity -> promotion + if isinstance(value, int) and self.__unit != "": + self.__value_type = Quantity + # if value is a float, than can be used as a quantity -> promotion + elif isinstance(value, float): + self.__value_type = Quantity + else: # cannot be treated as a quantity + self.__value_type = type(value) + + def __check_compatibility(self, value: Any) -> None: + """ + Raises an error if this parameter and the given value are not of the same type or compatible + :param value: a value that might be assigned as Parameter value or in an interval or option + :type value: any type + + It will raise an exception if the type is not coherent to what previously is declared. + """ + + vtype = type(value) + assert vtype != None + v = value + # First case: value is a list, it might be good to double check + # that all the members are of the same type + if isinstance(value, list): + vtype = type(value[0]) + for v in value: + if not self.__is_type_compatible(vtype, type(v)): + raise TypeError( + "Iterable object passed as value for the parameter, but it is made of inhomogeneous types: ", + vtype, + type(v), + ) + elif isinstance(value, dict): + raise NotImplementedError("Dictionaries are not accepted") + + # check that the value is compatible with what previously defined + if not self.__is_type_compatible(vtype, self.__value_type): + raise TypeError( + "New value of type {} is different from {} previously defined".format( + type(value), self.__value_type + ) + ) + + @value.setter + def value(self, value: ValueTypes) -> None: + """ + Sets value of this parameter if value is legal, + an exception is raised otherwise. - def add_illegal_interval(self, min_value, max_value): + :param value: value + :type value: str | boolean | int | float | object | pint.Quantity + If value is a float, it is internally converted to a pint.Quantity + """ + self.__check_compatibility(value) + self.__set_value_type(value) + value = self.__to_quantity(value) + + if self.is_legal(value): + self.__value = value + else: + raise ValueError("Value of parameter '" + self.name + "' illegal.") + + def add_interval( + self, + min_value: Union[ValueTypes, None], + max_value: Union[ValueTypes, None], + intervals_are_legal: bool, + ) -> None: """ - Sets an illegal interval for this parameter, None for infinite + Sets an interval for this parameter: [min_value, max_value] + The interval is closed on both sides: min_value and and max_value are included. + + + :param min_value: minimum value of the interval, None for infinity + :param max_value: maximum value of the interval, None for infinity + + :param intervals_are_legal: if not done previously, it defines if all the intervals of this parameter should be considered as allowed or forbidden intervals. + """ + if min_value is None: min_value = -math.inf if max_value is None: max_value = math.inf - self.illegal_intervals.append([min_value, max_value]) + self.__check_compatibility(min_value) + self.__check_compatibility(max_value) - def add_option(self, option): + self.__set_value_type(min_value) # it could have been max_value + + if self.__intervals_are_legal is None: + self.__intervals_are_legal = intervals_are_legal + else: + if self.__intervals_are_legal != intervals_are_legal: + print("WARNING: All intervals should be either legal or illegal.") + print( + " Interval: [" + + str(min_value) + + ":" + + str(max_value) + + "] is declared differently w.r.t. to the previous intervals" + ) + # should it throw an expection? + raise ValueError("Parameter", "interval", "multiple validities") + + self.__intervals.append( + (self.__to_quantity(min_value), self.__to_quantity(max_value)) + ) + + # if the interval has been added after assignement of the value of the parameter, + # the latter should be checked + if not self.value_no_conversion is None: + if self.is_legal(self.value) is False: + raise ValueError( + "Value " + + str(self.value) + + " is now illegal based on the newly added interval" + ) + + def add_option(self, option: Any, options_are_legal: bool) -> None: """ Sets allowed values for this parameter + + :param option: a discrete allowed or forbidden value + :param options_are_legal: defines if the given option is for a legal or illegal discrete value """ - if isinstance(option, list): - self.options += option + + if self.__options_are_legal is None: + self.__options_are_legal = options_are_legal else: - self.options.append(option) + if self.__options_are_legal != options_are_legal: + print("ERROR: All options should be either legal or illegal.") + print( + " This option is declared differently w.r.t. to the previous ones" + ) + # should it throw an expection? + raise ValueError("Parameter", "options", "multiple validities") - def set_value(self, value): - """ - Sets value of this parameter if value is legal, otherwise warning is shown + self.__check_compatibility(option) + self.__set_value_type(option) # it could have been max_value - This could be expanded to raise an exception, or such could be in is_legal - """ - if self.is_legal(value): - self.value = value + if isinstance(option, list): + for op in option: + self.__options.append(self.__to_quantity(op)) else: - print("WARNING: Value of parameter '" + self.name - + "' illegal, ignored.") + self.__options.append(self.__to_quantity(option)) + + # if the option has been added after assignement of the value of the parameter, + # the latter should be checked + if not self.value_no_conversion is None: + if self.is_legal(self.value) is False: + raise ValueError( + "Value " + + str(self.value) + + " is now illegal based on the newly added option" + ) + + def get_options(self): + return self.__options + + def get_options_are_legal(self): + return self.__options_are_legal - def is_legal(self, value=None): + def get_intervals(self): + return self.__intervals + + def get_intervals_are_legal(self): + return self.__intervals_are_legal + + def is_legal(self, values: Union[ValueTypes, None] = None) -> bool: """ Checks whether or not given or contained value is legal given constraints. - Illegal intervals have the highest priority to be checked. Then it will check the - legal intervals and options. The overlaps among the constrains will be overridden by - the constrain of higher priority. """ - if value is None: - value = self.value - # Check illegal intervals - for illegal_interval in self.illegal_intervals: - if illegal_interval[0] < value < illegal_interval[1]: - return False + if values is None: + values = self.__value - # Check legal intervals - is_inside_a_legal_interval = False - for legal_interval in self.legal_intervals: - if legal_interval[0] < value < legal_interval[1]: - is_inside_a_legal_interval = True + if ( + not hasattr(values, "__iter__") + or isinstance(values, str) + or isinstance(values, Quantity) + ): + # print(str(hasattr(values, "__iter__")) + str(values)) + # first if types are compatible - if not is_inside_a_legal_interval and len(self.legal_intervals) > 0: - return False + if self.__is_type_compatible(type(values), self.__value_type) is False: + return False - # checked intervals, can return if options not used (frequent case) - if len(self.options) == 0: - return True + value = self.__to_quantity(values) - for option in self.options: - if option == value: - # If the value matches any option, it is legal + # obvious, if no conditions are defined, the value is always legal + if len(self.__options) == 0 and len(self.__intervals) == 0: return True - # Since no options matched the parameter, it is illegal - return False + # first check if the value is in any defined discrete value + for option in self.__options: + if option == value: + return self.__options_are_legal - def print_paramter_constraints(self): - """ - Print the legal and illegal intervals of this parameter. - """ - print(self.name) - print('illegal intervals:', self.illegal_intervals) - print('legal intervals:', self.legal_intervals) - print('options', self.options) + # secondly check if it is in any defined interval + for interval in self.__intervals: + if interval[0] <= value <= interval[1]: + return self.__intervals_are_legal + + # at this point the value has not been found in any interval + # if intervals where defined and were forbidden intervals, the value should be accepted + if len(self.__intervals) > 0: + return not self.__intervals_are_legal + + # if there where no intervals defined, then it depends if the discrete values were forbidden or allowed + return not self.__options_are_legal + + # else + # all values have to be True - def clear_legal_intervals(self): + for value in values: + if not self.is_legal(value): + return False + + return True + + def print_parameter_constraints(self) -> None: """ - Clear the legal intervals of this parameter. + Print the legal and illegal intervals of this parameter. FIXME """ - self.legal_intervals = [] + print(self.name) + print("intervals:", self.__intervals) + print("intervals are legal:", self.__intervals_are_legal) + print("options", self.__options) + print("options are legal:", self.__options_are_legal) - def clear_illegal_intervals(self): + def clear_intervals(self) -> None: """ - Clear the illegal intervals of this parameter. + Clear the intervals of this parameter. """ - self.illegal_intervals = [] + self.__intervals = [] - def clear_options(self): + def clear_options(self) -> None: """ Clear the option values of this parameter. """ - self.options = [] + self.__options = [] - def print_line(self): + def print_line(self) -> str: """ returns string with one line description of parameter """ - if self.unit is None: + if self.__unit is None or self.__unit == Unit(""): unit_string = "" else: - unit_string = "[" + self.unit + "]" + unit_string = "[" + str(self.__unit) + "] " - if self.value is None: - string = self.name.ljust(20) + if self.value_no_conversion is None: + string = self.name.ljust(40) + " " else: - string = self.name.ljust(15) - string += str(self.value).ljust(5) + string = self.name.ljust(35) + " " + string += str(self.value).ljust(10) + " " - string += unit_string.ljust(10) + string += unit_string.ljust(20) + " " if self.comment is not None: string += self.comment string += 3 * " " - for legal_interval in self.legal_intervals: - interval = "L[" + str(legal_interval[0]) + ", " + str( - legal_interval[1]) + "]" - string += interval.ljust(10) - - for illegal_interval in self.illegal_intervals: - interval = "I[" + str(illegal_interval[0]) + ", " + str( - illegal_interval[1]) + "]" - string += interval.ljust(10) + for interval in self.__intervals: + legal = "L" if self.__intervals_are_legal else "I" + intervalstr = legal + "[" + str(interval[0]) + ", " + str(interval[1]) + "]" + string += intervalstr.ljust(10) - if len(self.options) > 0: + if len(self.__options) > 0: values = "(" - for option in self.options: + for option in self.__options: values += str(option) + ", " values = values.strip(", ") values += ")" @@ -175,37 +472,34 @@ def print_line(self): return string - def __repr__(self): + def __repr__(self) -> str: """ Returns string with thorough description of parameter """ string = "Parameter named: '" + self.name + "'" - if self.value is None: + if self.value_no_conversion is None: string += " without set value.\n" else: string += " with value: " + str(self.value) + "\n" - if self.unit is not None: - string += " [" + self.unit + "]\n" + if self.__unit is not None: + string += " [" + str(self.__unit) + "]\n" if self.comment is not None: string += " " + self.comment + "\n" - if len(self.legal_intervals) > 0: - string += " Legal intervals:\n" - for legal_interval in self.legal_intervals: - string += " [" + str(legal_interval[0]) + "," + str( - legal_interval[1]) + "]\n" - - if len(self.illegal_intervals) > 0: - string += " Illegal intervals:\n" - for illegal_interval in self.illegal_intervals: - string += " [" + str(illegal_interval[0]) + ", " + str( - illegal_interval[1]) + "]\n" - - if len(self.options) > 0: - string += " Allowed values:\n" - for option in self.options: + if len(self.__intervals) > 0: + string += ( + " Legal intervals:\n" + if self.__intervals_are_legal + else " Illegal intervals:\n" + ) + for interval in self.__intervals: + string += " [" + str(interval[0]) + "," + str(interval[1]) + "]\n" + + if len(self.__options) > 0: + string += " Allowed values:\n" # FIXME + for option in self.__options: string += " " + str(option) + "\n" return string diff --git a/libpyvinyl/Parameters/__init__.py b/libpyvinyl/Parameters/__init__.py index b6f1c9a..3f66957 100644 --- a/libpyvinyl/Parameters/__init__.py +++ b/libpyvinyl/Parameters/__init__.py @@ -1,2 +1,2 @@ from .Parameter import Parameter -from .Collections import InstrumentParameters, CalculatorParameters, MasterParameter \ No newline at end of file +from .Collections import InstrumentParameters, CalculatorParameters, MasterParameter diff --git a/libpyvinyl/SignalGenerator.py b/libpyvinyl/SignalGenerator.py deleted file mode 100644 index 2a5304b..0000000 --- a/libpyvinyl/SignalGenerator.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -:module SignalGenerator: Module hosting the SignalGenerator and SignalGeneratorParameters -abstract classes. -""" - - -#################################################################################### -# # -# This file is part of libpyvinyl - The APIs for Virtual Neutron and x-raY # -# Laboratory. # -# # -# Copyright (C) 2020 Carsten Fortmann-Grote # -# # -# This program is free software: you can redistribute it and/or modify it under # -# the terms of the GNU Lesser General Public License as published by the Free # -# Software Foundation, either version 3 of the License, or (at your option) any # -# later version. # -# # -# This program is distributed in the hope that it will be useful, but WITHOUT ANY # -# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A # -# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # -# # -# You should have received a copy of the GNU Lesser General Public License along # -# with this program. If not, see OK <---') - sys.exit(0) - - sys.exit(1) diff --git a/tests/integration/plusminus/.github/ISSUE_TEMPLATE.md b/tests/integration/plusminus/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..92fcec9 --- /dev/null +++ b/tests/integration/plusminus/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,15 @@ +* PlusMinus version: +* Python version: +* Operating System: + +### Description + +Describe what you were trying to get done. +Tell us what happened, what went wrong, and what you expected to happen. + +### What I Did + +``` +Paste the command(s) you ran and the output. +If there was a crash, please include the traceback here. +``` diff --git a/tests/integration/plusminus/.gitignore b/tests/integration/plusminus/.gitignore new file mode 100644 index 0000000..43091aa --- /dev/null +++ b/tests/integration/plusminus/.gitignore @@ -0,0 +1,105 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# IDE settings +.vscode/ \ No newline at end of file diff --git a/tests/integration/plusminus/README.rst b/tests/integration/plusminus/README.rst new file mode 100644 index 0000000..f362d3b --- /dev/null +++ b/tests/integration/plusminus/README.rst @@ -0,0 +1,18 @@ +========= +PlusMinus +========= + +An example of a small platform implementing libpyvinyl. + +Data structure +############## +.. image:: ./docs/01-data_structure.png + +Instrument example +################## +.. image:: ./docs/02-instrument_example.png + +This package was created with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template. + +.. _Cookiecutter: https://github.com/audreyr/cookiecutter +.. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage diff --git a/tests/integration/plusminus/docs/01-data_structure.png b/tests/integration/plusminus/docs/01-data_structure.png new file mode 100644 index 0000000..3977831 Binary files /dev/null and b/tests/integration/plusminus/docs/01-data_structure.png differ diff --git a/tests/integration/plusminus/docs/02-instrument_example.png b/tests/integration/plusminus/docs/02-instrument_example.png new file mode 100644 index 0000000..391311c Binary files /dev/null and b/tests/integration/plusminus/docs/02-instrument_example.png differ diff --git a/tests/integration/plusminus/plusminus/ArrayCalculators/ArrayCalculator.py b/tests/integration/plusminus/plusminus/ArrayCalculators/ArrayCalculator.py new file mode 100644 index 0000000..c2f66de --- /dev/null +++ b/tests/integration/plusminus/plusminus/ArrayCalculators/ArrayCalculator.py @@ -0,0 +1,59 @@ +from typing import Union +from pathlib import Path +import numpy as np +from libpyvinyl.BaseData import DataCollection +from libpyvinyl.BaseCalculator import BaseCalculator, CalculatorParameters +from plusminus.NumberData import NumberData +from plusminus.ArrayData import ArrayData + + +class ArrayCalculator(BaseCalculator): + def __init__( + self, + name: str, + input: Union[DataCollection, list, NumberData], + output_keys: Union[list, str] = ["array_result"], + output_data_types=[ArrayData], + output_filenames=[], + instrument_base_dir="./", + calculator_base_dir="ArrayCalculator", + parameters: CalculatorParameters = None, + ): + """A python dict calculator to create an array from two inputs.""" + super().__init__( + name, + input, + output_keys, + output_data_types=output_data_types, + output_filenames=output_filenames, + instrument_base_dir=instrument_base_dir, + calculator_base_dir=calculator_base_dir, + parameters=parameters, + ) + + def init_parameters(self): + parameters = CalculatorParameters() + # Calculator developer edit + multiply = parameters.new_parameter( + "multiply", comment="Multiply the array by a value" + ) + multiply.value = 1 + # Calculator developer end + self.parameters = parameters + + def backengine(self): + Path(self.base_dir).mkdir(parents=True, exist_ok=True) + input_data0 = self.input.to_list()[0] + assert type(input_data0) is NumberData + input_num0 = input_data0.get_data()["number"] + input_data1 = self.input.to_list()[1] + assert type(input_data1) is NumberData + input_num1 = input_data1.get_data()["number"] + output_arr = ( + np.array([input_num0, input_num1]) * self.parameters["multiply"].value + ) + data_dict = {"array": output_arr} + key = self.output_keys[0] + output_data = self.output[key] + output_data.set_dict(data_dict) + return self.output diff --git a/tests/integration/plusminus/plusminus/ArrayCalculators/__init__.py b/tests/integration/plusminus/plusminus/ArrayCalculators/__init__.py new file mode 100644 index 0000000..6f13f6f --- /dev/null +++ b/tests/integration/plusminus/plusminus/ArrayCalculators/__init__.py @@ -0,0 +1 @@ +from .ArrayCalculator import ArrayCalculator diff --git a/tests/integration/plusminus/plusminus/ArrayData/ArrayData.py b/tests/integration/plusminus/plusminus/ArrayData/ArrayData.py new file mode 100644 index 0000000..1e196af --- /dev/null +++ b/tests/integration/plusminus/plusminus/ArrayData/ArrayData.py @@ -0,0 +1,36 @@ +from libpyvinyl.BaseData import BaseData +from plusminus.ArrayData import TXTFormat, H5Format + + +class ArrayData(BaseData): + def __init__( + self, + key, + data_dict=None, + filename=None, + file_format_class=None, + file_format_kwargs=None, + ): + + ### DataClass developer's job start + expected_data = {} + expected_data["array"] = None + ### DataClass developer's job end + + super().__init__( + key, + expected_data, + data_dict, + filename, + file_format_class, + file_format_kwargs, + ) + + @classmethod + def supported_formats(self): + format_dict = {} + ### DataClass developer's job start + self._add_ioformat(format_dict, TXTFormat) + self._add_ioformat(format_dict, H5Format) + ### DataClass developer's job end + return format_dict diff --git a/tests/integration/plusminus/plusminus/ArrayData/H5Format.py b/tests/integration/plusminus/plusminus/ArrayData/H5Format.py new file mode 100644 index 0000000..4155f19 --- /dev/null +++ b/tests/integration/plusminus/plusminus/ArrayData/H5Format.py @@ -0,0 +1,47 @@ +import h5py +from libpyvinyl.BaseFormat import BaseFormat +from plusminus.ArrayData import ArrayData + + +class H5Format(BaseFormat): + def __init__(self) -> None: + super().__init__() + + @classmethod + def format_register(self): + key = "H5" + desciption = "H5 format for ArrayData" + file_extension = ".h5" + read_kwargs = [""] + write_kwargs = [""] + return self._create_format_register( + key, desciption, file_extension, read_kwargs, write_kwargs + ) + + @staticmethod + def direct_convert_formats(): + # Assume the format can be converted directly to the formats supported by these classes: + # AFormat, BFormat + # Redefine this `direct_convert_formats` for a concrete format class + return [] + + @classmethod + def read(cls, filename: str) -> dict: + """Read the data from the file with the `filename` to a dictionary. The dictionary will + be used by its corresponding data class.""" + with h5py.File(filename, "r") as h5: + array = h5["array"][()] + data_dict = {"array": array} + return data_dict + + @classmethod + def write(cls, object: ArrayData, filename: str, key: str = None): + """Save the data with the `filename`.""" + data_dict = object.get_data() + array = data_dict["array"] + with h5py.File(filename, "w") as h5: + h5["array"] = array + if key is None: + original_key = object.key + key = original_key + "_to_H5Format" + return object.from_file(filename, cls, key) diff --git a/tests/integration/plusminus/plusminus/ArrayData/TXTFormat.py b/tests/integration/plusminus/plusminus/ArrayData/TXTFormat.py new file mode 100644 index 0000000..b0454f1 --- /dev/null +++ b/tests/integration/plusminus/plusminus/ArrayData/TXTFormat.py @@ -0,0 +1,45 @@ +import numpy as np +from libpyvinyl.BaseFormat import BaseFormat +from plusminus.ArrayData import ArrayData + + +class TXTFormat(BaseFormat): + def __init__(self) -> None: + super().__init__() + + @classmethod + def format_register(self): + key = "TXT" + desciption = "TXT format for ArrayData" + file_extension = ".txt" + read_kwargs = [""] + write_kwargs = [""] + return self._create_format_register( + key, desciption, file_extension, read_kwargs, write_kwargs + ) + + @staticmethod + def direct_convert_formats(): + # Assume the format can be converted directly to the formats supported by these classes: + # AFormat, BFormat + # Redefine this `direct_convert_formats` for a concrete format class + return [] + + @classmethod + def read(cls, filename: str) -> dict: + """Read the data from the file with the `filename` to a dictionary. The dictionary will + be used by its corresponding data class.""" + array = np.loadtxt(filename) + data_dict = {"array": array} + return data_dict + + @classmethod + def write(cls, object: ArrayData, filename: str, key: str = None): + """Save the data with the `filename`.""" + data_dict = object.get_data() + arr = data_dict["array"] + np.savetxt(filename, arr, fmt="%.3f") + if key is None: + original_key = object.key + key = original_key + "_to_TXTFormat" + return object.from_file(filename, cls, key) diff --git a/tests/integration/plusminus/plusminus/ArrayData/__init__.py b/tests/integration/plusminus/plusminus/ArrayData/__init__.py new file mode 100644 index 0000000..75b656a --- /dev/null +++ b/tests/integration/plusminus/plusminus/ArrayData/__init__.py @@ -0,0 +1,3 @@ +from .ArrayData import ArrayData +from .H5Format import H5Format +from .TXTFormat import TXTFormat diff --git a/tests/integration/plusminus/plusminus/BaseCalculator.py b/tests/integration/plusminus/plusminus/BaseCalculator.py new file mode 100644 index 0000000..27e4c3c --- /dev/null +++ b/tests/integration/plusminus/plusminus/BaseCalculator.py @@ -0,0 +1,243 @@ +""" :module BaseCalculator: Module hosts the BaseData class.""" +from abc import abstractmethod, ABCMeta +from typing import Union +from pathlib import Path +from libpyvinyl.AbstractBaseClass import AbstractBaseClass +from libpyvinyl.BaseData import BaseData, DataCollection +from libpyvinyl.Parameters import CalculatorParameters + + +class BaseCalculator(AbstractBaseClass): + def __init__( + self, + name: str, + input: Union[DataCollection, list, BaseData], + output_keys: Union[list, str], + output_data_types: list, + output_filenames: Union[list, str], + instrument_base_dir="./", + calculator_base_dir="BaseCalculator", + parameters: CalculatorParameters = None, + ): + """A python object calculator example""" + # Initialize properties + self.__name = None + self.__instrument_base_dir = None + self.__calculator_base_dir = None + self.__input = None + self.__input_keys = None + self.__output_keys = None + self.__output_data_types = None + self.__output_filenames = None + self.__parameters = None + + self.name = name + self.input = input + self.output_keys = output_keys + self.output_data_types = output_data_types + self.output_filenames = output_filenames + self.instrument_base_dir = instrument_base_dir + self.calculator_base_dir = calculator_base_dir + self.parameters = parameters + + self.__init_output() + + @abstractmethod + def init_parameters(self): + raise NotImplementedError + + def __init_output(self): + output = DataCollection() + for i, key in enumerate(self.output_keys): + output_data = self.output_data_types[i](key) + output.add_data(output_data) + self.output = output + + @property + def name(self): + return self.__name + + @name.setter + def name(self, value): + if isinstance(value, str): + self.__name = value + else: + raise TypeError( + f"Calculator: `name` is expected to be a str, not {type(value)}" + ) + + @property + def parameters(self): + return self.__parameters + + @parameters.setter + def parameters(self, value): + if isinstance(value, CalculatorParameters): + self.__parameters = value + elif value is None: + self.init_parameters() + else: + raise TypeError( + f"Calculator: `parameters` is expected to be CalculatorParameters, not {type(value)}" + ) + + @property + def input(self): + return self.__input + + @input.setter + def input(self, value): + self.set_input(value) + + def set_input(self, value: Union[DataCollection, list, BaseData]): + if isinstance(value, DataCollection): + self.__input = value + elif isinstance(value, list): + self.__input = DataCollection(*value) + elif isinstance(value, BaseData): + self.__input = DataCollection(value) + else: + raise TypeError( + f"Calculator: `input` can be a DataCollection, list or BaseData object, and will be treated as a DataCollection, but not {type(value)}" + ) + + @property + def input_keys(self): + return self.__input_keys + + @input_keys.setter + def input_keys(self, value): + self.set_input_keys(value) + + def set_input_keys(self, value: Union[list, str]): + if isinstance(value, list): + for item in value: + assert type(item) is str + self.__input_keys = value + elif isinstance(value, str): + self.__input_keys = [value] + else: + raise TypeError( + f"Calculator: `input_keys` can be a list or a string, and will be treated as a list, but not {type(value)}" + ) + + @property + def output_keys(self): + return self.__output_keys + + @output_keys.setter + def output_keys(self, value): + self.set_output_keys(value) + + def set_output_keys(self, value: Union[list, str]): + if isinstance(value, list): + for item in value: + assert type(item) is str + self.__output_keys = value + elif isinstance(value, str): + self.__output_keys = [value] + else: + raise TypeError( + f"Calculator: `output_keys` can be a list or a string, and will be treated as a list, but not {type(value)}" + ) + + @property + def output_data_types(self): + return self.__output_data_types + + @output_data_types.setter + def output_data_types(self, value): + self.set_output_data_types(value) + + def set_output_data_types(self, value): + if isinstance(value, list): + for item in value: + assert type(item) is ABCMeta + self.__output_data_types = value + elif isinstance(value, ABCMeta): + self.__output_data_types = [value] + else: + raise TypeError( + f"Calculator: `output_data_types` can be a list or a DataClass, and will be treated as a list, but not {type(value)}" + ) + + @property + def output_filenames(self): + """Native calculator file names""" + return self.__output_filenames + + @output_filenames.setter + def output_filenames(self, value): + self.set_output_filenames(value) + + def set_output_filenames(self, value: Union[list, str]): + if isinstance(value, str): + self.__output_filenames = [value] + elif isinstance(value, list): + self.__output_filenames = value + else: + raise TypeError( + f"Calculator: `output_filenames` can to be a str or a list, and will be treated as a list, but not {type(value)}" + ) + + @property + def instrument_base_dir(self): + return self.__instrument_base_dir + + @instrument_base_dir.setter + def instrument_base_dir(self, value): + self.set_instrument_base_dir(value) + + def set_instrument_base_dir(self, value: str): + if isinstance(value, str): + self.__instrument_base_dir = value + else: + raise TypeError( + f"Calculator: `instrument_base_dir` is expected to be a str, not {type(value)}" + ) + + @property + def calculator_base_dir(self): + return self.__calculator_base_dir + + @calculator_base_dir.setter + def calculator_base_dir(self, value): + self.set_calculator_base_dir(value) + + def set_calculator_base_dir(self, value: str): + if isinstance(value, str): + self.__calculator_base_dir = value + else: + raise TypeError( + f"Calculator: `calculator_base_dir` is expected to be a str, not {type(value)}" + ) + + @property + def base_dir(self): + base_dir = Path(self.instrument_base_dir) / self.calculator_base_dir + return str(base_dir) + + @property + def output_file_paths(self): + paths = [] + for filename in self.output_filenames: + path = Path(self.base_dir) / filename + # Make sure the file directory exists + path.parent.mkdir(parents=True, exist_ok=True) + paths.append(str(path)) + return paths + + @abstractmethod + def backengine(self): + Path(self.base_dir).mkdir(parents=True, exist_ok=True) + input_num0 = self.input[self.input_keys[0]].get_data()["number"] + input_num1 = self.input[self.input_keys[1]].get_data()["number"] + output_num = float(input_num0) + float(input_num1) + if self.parameters["plus_times"].value > 1: + for i in range(self.parameters["plus_times"].value - 1): + output_num += input_num1 + data_dict = {"number": output_num} + key = self.output_keys[0] + output_data = NumberData.from_dict(data_dict, key) + self.output = DataCollection(output_data) + return self.output diff --git a/tests/integration/plusminus/plusminus/NumberCalculators/MinusCalculator.py b/tests/integration/plusminus/plusminus/NumberCalculators/MinusCalculator.py new file mode 100644 index 0000000..f4b0e41 --- /dev/null +++ b/tests/integration/plusminus/plusminus/NumberCalculators/MinusCalculator.py @@ -0,0 +1,55 @@ +from typing import Union +from pathlib import Path +import numpy as np +from libpyvinyl.BaseData import DataCollection +from plusminus.NumberData import NumberData, TXTFormat +from libpyvinyl.BaseCalculator import BaseCalculator, CalculatorParameters + + +class MinusCalculator(BaseCalculator): + def __init__( + self, + name: str, + input: Union[DataCollection, list, NumberData], + output_keys: Union[list, str] = ["minus_result"], + output_data_types=[NumberData], + output_filenames: Union[list, str] = ["minus_result.txt"], + instrument_base_dir="./", + calculator_base_dir="MinusCalculator", + parameters=None, + ): + """A python object calculator example""" + super().__init__( + name, + input, + output_keys, + output_data_types=output_data_types, + output_filenames=output_filenames, + instrument_base_dir=instrument_base_dir, + calculator_base_dir=calculator_base_dir, + parameters=parameters, + ) + + def init_parameters(self): + parameters = CalculatorParameters() + times = parameters.new_parameter( + "minus_times", comment="How many times to do the minus" + ) + times.value = 1 + self.parameters = parameters + + def backengine(self): + Path(self.base_dir).mkdir(parents=True, exist_ok=True) + input_num0 = self.input.to_list()[0].get_data()["number"] + input_num1 = self.input.to_list()[1].get_data()["number"] + output_num = float(input_num0) - float(input_num1) + if self.parameters["minus_times"].value > 1: + for i in range(self.parameters["minus_times"].value - 1): + output_num -= input_num1 + arr = np.array([output_num]) + file_path = self.output_file_paths[0] + np.savetxt(file_path, arr, fmt="%.3f") + key = self.output_keys[0] + output_data = self.output[key] + output_data.set_file(file_path, TXTFormat) + return self.output diff --git a/tests/integration/plusminus/plusminus/NumberCalculators/PlusCalculator.py b/tests/integration/plusminus/plusminus/NumberCalculators/PlusCalculator.py new file mode 100644 index 0000000..22f5c7f --- /dev/null +++ b/tests/integration/plusminus/plusminus/NumberCalculators/PlusCalculator.py @@ -0,0 +1,52 @@ +from typing import Union +from pathlib import Path +from libpyvinyl.BaseData import DataCollection +from plusminus.NumberData import NumberData +from libpyvinyl.BaseCalculator import BaseCalculator, CalculatorParameters + + +class PlusCalculator(BaseCalculator): + def __init__( + self, + name: str, + input: Union[DataCollection, list, NumberData], + output_keys: Union[list, str] = ["plus_result"], + output_data_types=[NumberData], + output_filenames: Union[list, str] = [], + instrument_base_dir="./", + calculator_base_dir="PlusCalculator", + parameters=None, + ): + """A python object calculator example""" + super().__init__( + name, + input, + output_keys, + output_data_types=output_data_types, + output_filenames=output_filenames, + instrument_base_dir=instrument_base_dir, + calculator_base_dir=calculator_base_dir, + parameters=parameters, + ) + + def init_parameters(self): + parameters = CalculatorParameters() + times = parameters.new_parameter( + "plus_times", comment="How many times to do the plus" + ) + times.value = 1 + self.parameters = parameters + + def backengine(self): + Path(self.base_dir).mkdir(parents=True, exist_ok=True) + input_num0 = self.input.to_list()[0].get_data()["number"] + input_num1 = self.input.to_list()[1].get_data()["number"] + output_num = float(input_num0) + float(input_num1) + if self.parameters["plus_times"].value > 1: + for i in range(self.parameters["plus_times"].value - 1): + output_num += input_num1 + data_dict = {"number": output_num} + key = self.output_keys[0] + output_data = self.output[key] + output_data.set_dict(data_dict) + return self.output diff --git a/tests/integration/plusminus/plusminus/NumberCalculators/__init__.py b/tests/integration/plusminus/plusminus/NumberCalculators/__init__.py new file mode 100644 index 0000000..68ed672 --- /dev/null +++ b/tests/integration/plusminus/plusminus/NumberCalculators/__init__.py @@ -0,0 +1,2 @@ +from .MinusCalculator import MinusCalculator +from .PlusCalculator import PlusCalculator diff --git a/tests/integration/plusminus/plusminus/NumberData/H5Format.py b/tests/integration/plusminus/plusminus/NumberData/H5Format.py new file mode 100644 index 0000000..836d1e5 --- /dev/null +++ b/tests/integration/plusminus/plusminus/NumberData/H5Format.py @@ -0,0 +1,47 @@ +import h5py +from libpyvinyl.BaseFormat import BaseFormat +from plusminus.NumberData import NumberData + + +class H5Format(BaseFormat): + def __init__(self) -> None: + super().__init__() + + @classmethod + def format_register(self): + key = "H5" + desciption = "H5 format for NumberData" + file_extension = ".h5" + read_kwargs = [""] + write_kwargs = [""] + return self._create_format_register( + key, desciption, file_extension, read_kwargs, write_kwargs + ) + + @staticmethod + def direct_convert_formats(): + # Assume the format can be converted directly to the formats supported by these classes: + # AFormat, BFormat + # Redefine this `direct_convert_formats` for a concrete format class + return [] + + @classmethod + def read(cls, filename: str) -> dict: + """Read the data from the file with the `filename` to a dictionary. The dictionary will + be used by its corresponding data class.""" + with h5py.File(filename, "r") as h5: + number = h5["number"][()] + data_dict = {"number": number} + return data_dict + + @classmethod + def write(cls, object: NumberData, filename: str, key: str = None): + """Save the data with the `filename`.""" + data_dict = object.get_data() + number = data_dict["number"] + with h5py.File(filename, "w") as h5: + h5["number"] = number + if key is None: + original_key = object.key + key = original_key + "_to_H5Format" + return object.from_file(filename, cls, key) diff --git a/tests/integration/plusminus/plusminus/NumberData/NumberData.py b/tests/integration/plusminus/plusminus/NumberData/NumberData.py new file mode 100644 index 0000000..cb19668 --- /dev/null +++ b/tests/integration/plusminus/plusminus/NumberData/NumberData.py @@ -0,0 +1,52 @@ +from libpyvinyl.BaseData import BaseData +from plusminus.NumberData import TXTFormat, H5Format + + +class NumberData(BaseData): + def __init__( + self, + key, + data_dict=None, + filename=None, + file_format_class=None, + file_format_kwargs=None, + ): + + expected_data = {} + + ### DataClass developer's job start + expected_data["number"] = None + ### DataClass developer's job end + + super().__init__( + key, + expected_data, + data_dict, + filename, + file_format_class, + file_format_kwargs, + ) + + @classmethod + def supported_formats(self): + format_dict = {} + ### DataClass developer's job start + self._add_ioformat(format_dict, TXTFormat.TXTFormat) + self._add_ioformat(format_dict, H5Format.H5Format) + ### DataClass developer's job end + return format_dict + + @classmethod + def from_file(cls, filename: str, format_class, key, **kwargs): + """Create the data class by the file in the `format`.""" + return cls( + key, + filename=filename, + file_format_class=format_class, + file_format_kwargs=kwargs, + ) + + @classmethod + def from_dict(cls, data_dict, key): + """Create the data class by a python dictionary.""" + return cls(key, data_dict=data_dict) diff --git a/tests/integration/plusminus/plusminus/NumberData/TXTFormat.py b/tests/integration/plusminus/plusminus/NumberData/TXTFormat.py new file mode 100644 index 0000000..c614aa6 --- /dev/null +++ b/tests/integration/plusminus/plusminus/NumberData/TXTFormat.py @@ -0,0 +1,45 @@ +import numpy as np +from libpyvinyl.BaseFormat import BaseFormat +from plusminus.NumberData import NumberData + + +class TXTFormat(BaseFormat): + def __init__(self) -> None: + super().__init__() + + @classmethod + def format_register(self): + key = "TXT" + desciption = "TXT format for NumberData" + file_extension = ".txt" + read_kwargs = [""] + write_kwargs = [""] + return self._create_format_register( + key, desciption, file_extension, read_kwargs, write_kwargs + ) + + @staticmethod + def direct_convert_formats(): + # Assume the format can be converted directly to the formats supported by these classes: + # AFormat, BFormat + # Redefine this `direct_convert_formats` for a concrete format class + return [] + + @classmethod + def read(cls, filename: str) -> dict: + """Read the data from the file with the `filename` to a dictionary. The dictionary will + be used by its corresponding data class.""" + number = float(np.loadtxt(filename)) + data_dict = {"number": number} + return data_dict + + @classmethod + def write(cls, object: NumberData, filename: str, key: str = None): + """Save the data with the `filename`.""" + data_dict = object.get_data() + arr = np.array([data_dict["number"]]) + np.savetxt(filename, arr, fmt="%.3f") + if key is None: + original_key = object.key + key = original_key + "_to_TXTFormat" + return object.from_file(filename, cls, key) diff --git a/tests/integration/plusminus/plusminus/NumberData/__init__.py b/tests/integration/plusminus/plusminus/NumberData/__init__.py new file mode 100644 index 0000000..f43e39f --- /dev/null +++ b/tests/integration/plusminus/plusminus/NumberData/__init__.py @@ -0,0 +1,3 @@ +from .H5Format import H5Format +from .NumberData import NumberData +from .TXTFormat import TXTFormat diff --git a/tests/integration/plusminus/plusminus/__init__.py b/tests/integration/plusminus/plusminus/__init__.py new file mode 100644 index 0000000..09c834a --- /dev/null +++ b/tests/integration/plusminus/plusminus/__init__.py @@ -0,0 +1,8 @@ +"""Top-level package for PlusMinus.""" + +__author__ = """Juncheng E""" +__email__ = "juncheng.e@xfel.eu" +__version__ = "0.1.0" + + +from libpyvinyl.BaseData import DataCollection diff --git a/tests/integration/plusminus/plusminus/plusminus.py b/tests/integration/plusminus/plusminus/plusminus.py new file mode 100644 index 0000000..dd0b80e --- /dev/null +++ b/tests/integration/plusminus/plusminus/plusminus.py @@ -0,0 +1 @@ +"""Main module.""" diff --git a/tests/__init__Test.py b/tests/integration/plusminus/requirements.txt similarity index 100% rename from tests/__init__Test.py rename to tests/integration/plusminus/requirements.txt diff --git a/tests/integration/plusminus/requirements_dev.txt b/tests/integration/plusminus/requirements_dev.txt new file mode 100644 index 0000000..8a62c17 --- /dev/null +++ b/tests/integration/plusminus/requirements_dev.txt @@ -0,0 +1,13 @@ +pip==19.2.3 +bump2version==0.5.11 +wheel==0.33.6 +watchdog==0.9.0 +flake8==3.7.8 +tox==3.14.0 +coverage==4.5.4 +Sphinx==3.5.2 +twine==1.14.0 +sphinx_rtd_theme==0.5.1 + +pytest==4.6.5 +pytest-runner==5.1 \ No newline at end of file diff --git a/tests/integration/plusminus/setup.cfg b/tests/integration/plusminus/setup.cfg new file mode 100644 index 0000000..476a919 --- /dev/null +++ b/tests/integration/plusminus/setup.cfg @@ -0,0 +1,23 @@ +[bumpversion] +current_version = 0.1.0 +commit = True +tag = True + +[bumpversion:file:setup.py] +search = version='{current_version}' +replace = version='{new_version}' + +[bumpversion:file:plusminus/__init__.py] +search = __version__ = '{current_version}' +replace = __version__ = '{new_version}' + +[bdist_wheel] +universal = 1 + +[flake8] +exclude = docs + +[aliases] +# Define setup.py command aliases here +test = pytest + diff --git a/tests/integration/plusminus/setup.py b/tests/integration/plusminus/setup.py new file mode 100644 index 0000000..4047cd4 --- /dev/null +++ b/tests/integration/plusminus/setup.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +"""The setup script.""" + +from setuptools import setup, find_packages + +with open("README.rst") as readme_file: + readme = readme_file.read() + +with open("HISTORY.rst") as history_file: + history = history_file.read() + +with open("requirements.txt") as requirements_file: + require = requirements_file.read() + requirements = require.split() + +setup_requirements = [ + "pytest-runner", +] + +test_requirements = [ + "pytest>=3", +] + +setup( + author="Juncheng E", + author_email="juncheng.e@xfel.eu", + python_requires=">=3.6", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + ], + description="An example of a small platform implementing libpynyl", + install_requires=requirements, + license="MIT license", + long_description=readme + "\n\n" + history, + include_package_data=True, + keywords="PlusMinus", + name="PlusMinus", + packages=find_packages(include=["plusminus", "plusminus.*"]), + setup_requires=setup_requirements, + test_suite="tests", + tests_require=test_requirements, + url="https://github.com/JunCEEE/PlusMinus", + version="0.1.0", + zip_safe=False, +) diff --git a/tests/integration/plusminus/tests/__init__.py b/tests/integration/plusminus/tests/__init__.py new file mode 100644 index 0000000..c1a5b10 --- /dev/null +++ b/tests/integration/plusminus/tests/__init__.py @@ -0,0 +1 @@ +"""Unit test package for plusminus.""" diff --git a/tests/integration/plusminus/tests/test_ArrayCalculators.py b/tests/integration/plusminus/tests/test_ArrayCalculators.py new file mode 100644 index 0000000..8ffa5dc --- /dev/null +++ b/tests/integration/plusminus/tests/test_ArrayCalculators.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +"""Tests for `plusminus.NumberCalculators` package.""" + +import pytest +from plusminus.ArrayCalculators import ArrayCalculator +from plusminus.NumberData import NumberData +from plusminus.ArrayData import TXTFormat +from plusminus import DataCollection + + +def test_ArrayCalculator(tmpdir): + """PlusCalculator test function, the native output of ArrayCalculator is a python dictionary""" + + input1 = NumberData.from_dict({"number": 1}, "input1") + input2 = NumberData.from_dict({"number": 2}, "input2") + input_data = [input1, input2] # This could also be allowed. + input_data = DataCollection(input1, input2) + calculator = ArrayCalculator("plus", input_data) + calculator.set_instrument_base_dir(str(tmpdir)) + output = calculator.backengine() + assert output.get_data()["array"][0] == 1 + assert output.get_data()["array"][1] == 2 + calculator.parameters["multiply"] = 5 + output = calculator.backengine() + file_output = output.write( + calculator.base_dir + "/array_5.txt", TXTFormat, key="file_output" + ) + assert file_output.get_data()["array"][0] == 5 + assert file_output.get_data()["array"][1] == 10 diff --git a/tests/integration/plusminus/tests/test_Instrument.py b/tests/integration/plusminus/tests/test_Instrument.py new file mode 100644 index 0000000..dc11a09 --- /dev/null +++ b/tests/integration/plusminus/tests/test_Instrument.py @@ -0,0 +1,38 @@ +import pytest +from libpyvinyl.Instrument import Instrument +from plusminus.ArrayCalculators import ArrayCalculator +from plusminus.NumberCalculators import PlusCalculator, MinusCalculator +from plusminus.NumberData import NumberData +import plusminus.ArrayData as AD +from plusminus import DataCollection + + +def test_CalculationInstrument(tmpdir): + """PlusCalculator test function, the native output of MinusCalculator is a python dictionary""" + + input1 = NumberData.from_dict({"number": 1}, "input1") + input2 = NumberData.from_dict({"number": 2}, "input2") + input_collection = [input1, input2] # This could also be allowed. + input_collection = DataCollection(input1, input2) + calculator1 = PlusCalculator("plus", input_collection, output_keys=["plus_result"]) + calculator2 = MinusCalculator( + "minus", input_collection, output_keys=["minus_result"] + ) + + input_collection = DataCollection( + calculator1.output["plus_result"], calculator2.output["minus_result"] + ) + calculator3 = ArrayCalculator( + "array", input_collection, output_keys=["array_result"] + ) + + calculation_instrument = Instrument("calculation_instrument") + instrument_path = tmpdir / "calculation_instrument" + calculation_instrument.add_calculator(calculator1) + calculation_instrument.add_calculator(calculator2) + calculation_instrument.add_calculator(calculator3) + calculation_instrument.set_instrument_base_dir(str(instrument_path)) + calculation_instrument.run() + print(calculator3.output.get_data()) + calculator3.output.write(str(tmpdir / "final_result.txt"), AD.TXTFormat) + calculator3.output.write(str(tmpdir / "final_result.h5"), AD.H5Format) diff --git a/tests/integration/plusminus/tests/test_NumberCalculators.py b/tests/integration/plusminus/tests/test_NumberCalculators.py new file mode 100644 index 0000000..198e8a2 --- /dev/null +++ b/tests/integration/plusminus/tests/test_NumberCalculators.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +"""Tests for `plusminus.NumberCalculators` package.""" + +import pytest +from plusminus.NumberCalculators import PlusCalculator, MinusCalculator +from plusminus.NumberData import NumberData, TXTFormat +from plusminus import DataCollection + + +def test_PlusCalculator(tmpdir): + """PlusCalculator test function, the native output of PlusCalculator is a python dictionary""" + + input1 = NumberData.from_dict({"number": 1}, "input1") + input2 = NumberData.from_dict({"number": 1}, "input2") + input_data = [input1, input2] # This could also be allowed. + input_data = DataCollection(input1, input2) + plus = PlusCalculator("plus", input_data) + plus.set_instrument_base_dir(str(tmpdir)) + plus_output = plus.backengine() + assert plus_output.get_data()["number"] == 2 + plus_output.write(plus.base_dir + "/1_time.txt", TXTFormat) + plus.parameters["plus_times"] = 5 + plus_output = plus.backengine() + file_output = plus_output.write( + plus.base_dir + "/5_time.txt", TXTFormat, key="file_output" + ) + assert file_output.get_data()["number"] == 6 + + +def test_MinusCalculator(tmpdir): + """MinusCalculator test function. The native output of MinusCalculator is a txt file""" + + input1 = NumberData.from_dict({"number": 1}, "input1") + input2 = NumberData.from_dict({"number": 1}, "input2") + input_data = DataCollection(input1, input2) + calculator = MinusCalculator("minus", input_data) + calculator.set_instrument_base_dir(str(tmpdir)) + assert "MinusCalculator" in calculator.base_dir + calculator.set_output_filenames("minus_res.txt") + output = calculator.backengine() + assert output.get_data()["number"] == 0 + calculator.parameters["minus_times"] = 5 + plus_output = calculator.backengine() + assert plus_output.get_data()["number"] == -4 + + +def test_DataCollection_multiple(): + """PlusCalculator test function""" + + input1 = NumberData.from_dict({"number": 1}, "input1") + input2 = NumberData.from_dict({"number": 1}, "input2") + input_data = DataCollection(input1, input2) + data = input_data.get_data() + assert data["input1"]["number"] == 1 + assert data["input2"]["number"] == 1 diff --git a/tests/integration/plusminus/tests/test_NumberData.py b/tests/integration/plusminus/tests/test_NumberData.py new file mode 100644 index 0000000..fea1774 --- /dev/null +++ b/tests/integration/plusminus/tests/test_NumberData.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +"""Tests for `plusminus.NumberCalculators` package.""" + +import pytest +import h5py +from plusminus.NumberData import NumberData, TXTFormat, H5Format + + +def test_construct_NumberData(): + """Test the construction of NumberData""" + my_data = NumberData.from_dict({"number": 1}, "input1") + + +def test_list_formats(): + """Test the construction of NumberData""" + my_data = NumberData.from_dict({"number": 1}, "input1") + my_data.list_formats() + + +def test_write_read_txt(tmpdir): + """Test writing to a txt file""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.txt") + my_data.write(file_name, TXTFormat) + with open(file_name, "r") as f: + assert float(f.read()) == 1 + read_data = NumberData.from_file(file_name, TXTFormat, "read_data") + assert read_data.get_data()["number"] == 1 + + +def test_write_read_h5(tmpdir): + """Test writing to a h5 file""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.h5") + my_data.write(file_name, H5Format) + with h5py.File(file_name, "r") as h5: + assert h5["number"][()] == 1 + read_data = NumberData.from_file(file_name, H5Format, "read_data") + assert read_data.get_data()["number"] == 1 + + +def test_read_txt_write_h5(tmpdir): + """Test read a txt file and write to a h5 file""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.txt") + my_data.write(file_name, TXTFormat) + read_data = NumberData.from_file(file_name, TXTFormat, "read_data") + file_name = str(tmpdir / "test.h5") + read_data.write(file_name, H5Format) + with h5py.File(file_name, "r") as h5: + assert h5["number"][()] == 1 + + +def test_txt_file_write_h5(tmpdir): + """Test write a txt file and write to a h5 file""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.txt") + read_data = my_data.write(file_name, TXTFormat, "read_data") + file_name = str(tmpdir / "test.h5") + read_data.write(file_name, H5Format) + with h5py.File(file_name, "r") as h5: + assert h5["number"][()] == 1 + + +def test_set_dict(tmpdir): + """Test setting a dict mapping""" + my_data = NumberData("input1") + my_data.set_dict({"number": 1}) + file_name = str(tmpdir / "test.txt") + my_data.write(file_name, TXTFormat) + + +def test_set_file_(tmpdir): + """Test setting a file mapping""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.txt") + my_data.write(file_name, TXTFormat) + new_data = NumberData("new_data") + new_data.set_file(file_name, TXTFormat) + assert new_data.get_data()["number"] == 1 + + +def test_set_file_report_double_setting(tmpdir): + """Test write a txt file and write to a h5 file""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.txt") + my_data.write(file_name, TXTFormat) + with pytest.raises(RuntimeError): + my_data.set_file(file_name, TXTFormat) + + +def test_return_object_without_key(tmpdir): + """Test write a txt file and write to a h5 file""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.txt") + new_data = my_data.write(file_name, TXTFormat) + assert new_data.key == "input1_to_TXTFormat" + + +def test_return_object_with_key(tmpdir): + """Test write a txt file and write to a h5 file""" + my_data = NumberData.from_dict({"number": 1}, "input1") + file_name = str(tmpdir / "test.txt") + key = "test" + new_data = my_data.write(file_name, TXTFormat, key) + assert new_data.key == key diff --git a/tests/integration/plusminus/tox.ini b/tests/integration/plusminus/tox.ini new file mode 100644 index 0000000..6a29de0 --- /dev/null +++ b/tests/integration/plusminus/tox.ini @@ -0,0 +1,26 @@ +[tox] +envlist = py36, py37, py38, flake8 + +[travis] +python = + 3.8: py38 + 3.7: py37 + 3.6: py36 + +[testenv:flake8] +basepython = python +deps = flake8 +commands = flake8 plusminus tests + +[testenv] +setenv = + PYTHONPATH = {toxinidir} +deps = + -r{toxinidir}/requirements_dev.txt +; If you want to make tox run the tests with the same versions, create a +; requirements.txt with the pinned versions and uncomment the following line: +; -r{toxinidir}/requirements.txt +commands = + pip install -U pip + pytest --basetemp={envtmpdir} + diff --git a/libpyvinyl/RadiationSampleInteractor.py b/tests/unit/Test.py similarity index 63% rename from libpyvinyl/RadiationSampleInteractor.py rename to tests/unit/Test.py index 9411bab..6cd35b2 100644 --- a/libpyvinyl/RadiationSampleInteractor.py +++ b/tests/unit/Test.py @@ -1,17 +1,16 @@ +#! /usr/bin/env python3 """ -:module RadiationSampleInteractor: Module hosting the RadiationSampleInteractor and RadiationSampleInteractorParameters -abstract classes. +:module Test: Top level test module hosting all unittest suites. """ - #################################################################################### # # -# This file is part of libpyvinyl - The APIs for Virtual Neutron and x-raY # +# This file is part of libpyvinyl - The APIs for Virtual Neutron and x-raY # # Laboratory. # # # # Copyright (C) 2020 Carsten Fortmann-Grote # # # -# This program is free software: you can redistribute it and/or modify it under # +# This program is free software: you can redistribute it and/or modify it under # # the terms of the GNU Lesser General Public License as published by the Free # # Software Foundation, either version 3 of the License, or (at your option) any # # later version. # @@ -25,24 +24,34 @@ # # #################################################################################### -from libpyvinyl.BaseCalculator import BaseCalculator, CalculatorParameters +import unittest +import sys + +from test_BaseCalculator import BaseCalculatorTest +from test_Parameters import Test_Parameter, Test_Parameters, Test_Instruments +from test_Instrument import InstrumentTest + + +def suite(): + suites = [ + unittest.makeSuite(BaseCalculatorTest, "test"), + unittest.makeSuite(Test_Parameter, "test"), + unittest.makeSuite(Test_Parameters, "test"), + unittest.makeSuite(Test_Instruments, "test"), + unittest.makeSuite(InstrumentTest, "test"), + ] -class RadiationSampleInteractorParameters(CalculatorParameters): - def __init__(self, **kwargs): - - super().__init__(**kwargs) + return unittest.TestSuite(suites) +# Run the suite and return a success status code. This enables running an automated git-bisect. +if __name__ == "__main__": -class RadiationSampleInteractor(BaseCalculator): - def __init__(self,name, parameters=None, dumpfile=None, **kwargs): - - super().__init__(name, parameters, dumpfile, **kwargs) + result = unittest.TextTestRunner(verbosity=2).run(suite()) - def backengine(self): - pass + if result.wasSuccessful(): + print("---> OK <---") + sys.exit(0) - def saveH5(self, fname, openpmd=True): - pass + sys.exit(1) -# This project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No. 823852. diff --git a/tests/unit/test_BaseCalculator.py b/tests/unit/test_BaseCalculator.py new file mode 100644 index 0000000..bcb5dd5 --- /dev/null +++ b/tests/unit/test_BaseCalculator.py @@ -0,0 +1,316 @@ +import unittest +import pytest +import os +import shutil +from typing import Union +from pathlib import Path + +from libpyvinyl.BaseCalculator import BaseCalculator +from libpyvinyl.BaseData import BaseData, DataCollection +from libpyvinyl.Parameters import CalculatorParameters +from libpyvinyl.AbstractBaseClass import AbstractBaseClass + + +class NumberData(BaseData): + """Example dict mapping data""" + + def __init__( + self, + key, + data_dict=None, + filename=None, + file_format_class=None, + file_format_kwargs=None, + ): + + expected_data = {} + + # DataClass developer's job start + expected_data["number"] = None + # DataClass developer's job end + + super().__init__( + key, + expected_data, + data_dict, + filename, + file_format_class, + file_format_kwargs, + ) + + @classmethod + def supported_formats(self): + return {} + + @classmethod + def from_file(cls, filename: str, format_class, key, **kwargs): + raise NotImplementedError() + + @classmethod + def from_dict(cls, data_dict, key): + """Create the data class by a python dictionary.""" + return cls(key, data_dict=data_dict) + + +class PlusCalculator(BaseCalculator): + """:class: Specialized calculator, calculates the sum of two datasets.""" + + def __init__( + self, + name: str, + input: Union[DataCollection, list, NumberData], + output_keys: Union[list, str] = ["plus_result"], + output_data_types=[NumberData], + output_filenames: Union[list, str] = [], + instrument_base_dir="./", + calculator_base_dir="PlusCalculator", + parameters=None, + ): + """A python object calculator example""" + super().__init__( + name, + input, + output_keys, + output_data_types=output_data_types, + output_filenames=output_filenames, + instrument_base_dir=instrument_base_dir, + calculator_base_dir=calculator_base_dir, + parameters=parameters, + ) + + def init_parameters(self): + parameters = CalculatorParameters() + times = parameters.new_parameter( + "plus_times", comment="How many times to do the plus" + ) + # Set defaults + times.value = 1 + + self.parameters = parameters + + def backengine(self): + Path(self.base_dir).mkdir(parents=True, exist_ok=True) + input_num0 = self.input.to_list()[0].get_data()["number"] + input_num1 = self.input.to_list()[1].get_data()["number"] + output_num = float(input_num0) + float(input_num1) + if self.parameters["plus_times"].value > 1: + for i in range(self.parameters["plus_times"].value - 1): + output_num += input_num1 + data_dict = {"number": output_num} + key = self.output_keys[0] + output_data = self.output[key] + output_data.set_dict(data_dict) + return self.output + + +class BaseCalculatorTest(unittest.TestCase): + """ + Test class for the BaseCalculator class. + """ + + @classmethod + def setUpClass(cls): + """Setting up the test class.""" + + input1 = NumberData.from_dict({"number": 1}, "input1") + input2 = NumberData.from_dict({"number": 1}, "input2") + input_data = [input1, input2] + plus = PlusCalculator("plus", input_data) + cls.__default_calculator = plus + cls.__default_input = input_data + + @classmethod + def tearDownClass(cls): + """Tearing down the test class.""" + del cls.__default_calculator + del cls.__default_input + + def setUp(self): + """Setting up a test.""" + self.__files_to_remove = [] + self.__dirs_to_remove = [] + + def tearDown(self): + """Tearing down a test.""" + + for f in self.__files_to_remove: + if os.path.isfile(f): + os.remove(f) + for d in self.__dirs_to_remove: + if os.path.isdir(d): + shutil.rmtree(d) + + def test_base_class_constructor_raises(self): + """Test that we cannot construct instances of the base class.""" + + self.assertRaises(TypeError, BaseCalculator, "name") + + def test_default_construction(self): + """Testing the default construction of the class.""" + + # Test positional arguments + calculator = PlusCalculator("test", self.__default_input) + + self.assertIsInstance(calculator, PlusCalculator) + self.assertIsInstance(calculator, BaseCalculator) + self.assertIsInstance(calculator, AbstractBaseClass) + + def test_deep_copy(self): + """Test the copy constructor behaves as expected.""" + # Parameters are not deepcopied by itself + calculator_copy = self.__default_calculator() + self.assertEqual(calculator_copy.parameters["plus_times"].value, 1) + new_parameters = calculator_copy.parameters + new_parameters["plus_times"] = 5 + self.assertEqual(new_parameters["plus_times"].value, 5) + self.assertEqual(calculator_copy.parameters["plus_times"].value, 5) + + # Parameters are deepcopied when copy the calculator + calculator_copy = self.__default_calculator() + self.assertEqual(calculator_copy.parameters["plus_times"].value, 1) + calculator_copy.parameters["plus_times"] = 10 + self.assertEqual(calculator_copy.parameters["plus_times"].value, 10) + self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 1) + calculator_copy.input["input1"] = NumberData.from_dict({"number": 5}, "input1") + self.assertEqual(calculator_copy.input["input1"].get_data()["number"], 5) + self.assertEqual( + self.__default_calculator.input["input1"].get_data()["number"], 1 + ) + + # Calculator reference + self.assertEqual(calculator_copy.parameters["plus_times"].value, 10) + calculator_reference = calculator_copy + self.assertEqual(calculator_reference.parameters["plus_times"].value, 10) + calculator_reference.parameters["plus_times"] = 3 + self.assertEqual(calculator_reference.parameters["plus_times"].value, 3) + self.assertEqual(calculator_copy.parameters["plus_times"].value, 3) + + # New parameters can be set while caculator deepcopy + new_parameters = CalculatorParameters() + times = new_parameters.new_parameter( + "plus_times", comment="How many times to do the plus" + ) + times.value = 1 + new_parameters["plus_times"].value = 5 + new_calculator = self.__default_calculator(parameters=new_parameters) + self.assertIsInstance(new_calculator, PlusCalculator) + self.assertIsInstance(new_calculator, BaseCalculator) + self.assertIsInstance(new_calculator, AbstractBaseClass) + self.assertEqual(new_calculator.parameters["plus_times"].value, 5) + self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 1) + + def test_dump(self): + """Test dumping to file.""" + calculator = self.__default_calculator + + self.__files_to_remove.append(calculator.dump()) + self.__files_to_remove.append(calculator.dump("dump.dill")) + + def test_parameters_in_copied_calculator(self): + """Test parameters in a copied calculator""" + + calculator = self.__default_calculator + self.assertEqual(calculator.parameters["plus_times"].value, 1) + calculator.parameters["plus_times"] = 5 + self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 5) + calculator.parameters["plus_times"] = 1 + self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 1) + + def test_resurrect_from_dump(self): + """Test loading from dumpfile.""" + + calculator = self.__default_calculator() + + self.assertEqual(calculator.parameters["plus_times"].value, 1) + output = calculator.backengine() + self.assertEqual(output.get_data()["number"], 2) + self.__dirs_to_remove.append("PlusCalculator") + + # dump + dump = calculator.dump() + self.__files_to_remove.append(dump) + + del calculator + + calculator = PlusCalculator.from_dump(dump) + + self.assertEqual( + calculator.input.get_data(), + self.__default_calculator.input.get_data(), + ) + + calculator.parameters.to_dict() + self.assertEqual( + calculator.parameters.to_dict(), + self.__default_calculator.parameters.to_dict(), + ) + + calculator.parameters["plus_times"] = 5 + self.assertNotEqual( + calculator.parameters.to_dict(), + self.__default_calculator.parameters.to_dict(), + ) + + self.assertIsNotNone(calculator.data) + + def test_attributes(self): + """Test that all required attributes are present.""" + + calculator = self.__default_calculator + + self.assertTrue(hasattr(calculator, "name")) + self.assertTrue(hasattr(calculator, "input")) + self.assertTrue(hasattr(calculator, "output")) + self.assertTrue(hasattr(calculator, "parameters")) + self.assertTrue(hasattr(calculator, "instrument_base_dir")) + self.assertTrue(hasattr(calculator, "calculator_base_dir")) + self.assertTrue(hasattr(calculator, "base_dir")) + self.assertTrue(hasattr(calculator, "backengine")) + self.assertTrue(hasattr(calculator, "data")) + self.assertTrue(hasattr(calculator, "dump")) + self.assertTrue(hasattr(calculator, "from_dump")) + + def test_set_param_values(self): + calculator = self.__default_calculator + + calculator.parameters["plus_times"] = 5 + self.assertEqual(calculator.parameters["plus_times"].value, 5) + + def test_set_param_values_with_set_parameters(self): + calculator = self.__default_calculator + + calculator.set_parameters(plus_times=7) + self.assertEqual(calculator.parameters["plus_times"].value, 7) + + def test_set_param_values_with_set_parameters_with_dict(self): + calculator = self.__default_calculator + + calculator.set_parameters({"plus_times": 9}) + self.assertEqual(calculator.parameters["plus_times"].value, 9) + + def test_collection_get_data(self): + calculator = self.__default_calculator + print(calculator.input) + input_dict = calculator.input.get_data() + self.assertEqual(input_dict["input1"]["number"], 1) + self.assertEqual(input_dict["input2"]["number"], 1) + + def test_output_file_paths(self): + calculator = self.__default_calculator + with self.assertRaises(ValueError) as exception: + calculator.output_file_paths + + calculator.output_filenames = "bingo.txt" + self.assertEqual(calculator.output_file_paths[0], "PlusCalculator/bingo.txt") + self.__dirs_to_remove.append("PlusCalculator") + + def test_calculator_output_set_inconsistent(self): + input1 = NumberData.from_dict({"number": 1}, "input1") + with self.assertRaises(ValueError) as exception: + calculator = PlusCalculator( + "test", input1, output_keys=["result"], output_data_types=[] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_BaseData.py b/tests/unit/test_BaseData.py new file mode 100644 index 0000000..d70493e --- /dev/null +++ b/tests/unit/test_BaseData.py @@ -0,0 +1,422 @@ +import pytest +import numpy as np +import h5py +from libpyvinyl.BaseData import BaseData, DataCollection +from libpyvinyl.BaseFormat import BaseFormat + + +class NumberData(BaseData): + def __init__( + self, + key, + data_dict=None, + filename=None, + file_format_class=None, + file_format_kwargs=None, + ): + + ### DataClass developer's job start + expected_data = {} + expected_data["number"] = None + ### DataClass developer's job end + + super().__init__( + key, + expected_data, + data_dict, + filename, + file_format_class, + file_format_kwargs, + ) + + @classmethod + def supported_formats(self): + format_dict = {} + ### DataClass developer's job start + self._add_ioformat(format_dict, TXTFormat) + self._add_ioformat(format_dict, H5Format) + ### DataClass developer's job end + return format_dict + + +class TXTFormat(BaseFormat): + def __init__(self) -> None: + super().__init__() + + @classmethod + def format_register(self): + key = "TXT" + desciption = "TXT format for NumberData" + file_extension = ".txt" + read_kwargs = [""] + write_kwargs = [""] + return self._create_format_register( + key, desciption, file_extension, read_kwargs, write_kwargs + ) + + @classmethod + def read(cls, filename: str) -> dict: + """Read the data from the file with the `filename` to a dictionary. The dictionary will + be used by its corresponding data class.""" + number = float(np.loadtxt(filename)) + data_dict = {"number": number} + return data_dict + + @classmethod + def write(cls, object: NumberData, filename: str, key: str = None): + """Save the data with the `filename`.""" + data_dict = object.get_data() + arr = np.array([data_dict["number"]]) + np.savetxt(filename, arr, fmt="%.3f") + if key is None: + original_key = object.key + key = original_key + "_to_TXTFormat" + return object.from_file(filename, cls, key) + else: + return object.from_file(filename, cls, key) + + @staticmethod + def direct_convert_formats(): + # Assume the format can be converted directly to the formats supported by these classes: + # AFormat, BFormat + # Redefine this `direct_convert_formats` for a concrete format class + return [H5Format] + + @classmethod + def convert( + cls, obj: NumberData, output: str, output_format_class: str, key=None, **kwargs + ): + """Direct convert method, if the default converting would be too slow or not suitable for the output_format""" + if output_format_class is H5Format: + cls.convert_to_H5Format(obj.filename, output) + else: + raise TypeError( + "Direct converting to format {} is not supported".format( + output_format_class + ) + ) + # Set the key of the returned object + if key is None: + original_key = obj.key + key = original_key + "_from_TXTFormat" + return obj.from_file(output, output_format_class, key) + else: + return obj.from_file(output, output_format_class, key) + + @classmethod + def convert_to_H5Format(cls, input: str, output: str): + """The engine of convert method.""" + print("Directly converting TXTFormat to H5Format") + number = float(np.loadtxt(input)) + with h5py.File(output, "w") as h5: + h5["number"] = number + + +class H5Format(BaseFormat): + def __init__(self) -> None: + super().__init__() + + @classmethod + def format_register(self): + key = "H5" + desciption = "H5 format for NumberData" + file_extension = ".h5" + read_kwargs = [""] + write_kwargs = [""] + return self._create_format_register( + key, desciption, file_extension, read_kwargs, write_kwargs + ) + + @classmethod + def read(cls, filename: str) -> dict: + """Read the data from the file with the `filename` to a dictionary. The dictionary will + be used by its corresponding data class.""" + with h5py.File(filename, "r") as h5: + number = h5["number"][()] + data_dict = {"number": number} + return data_dict + + @classmethod + def write(cls, object: NumberData, filename: str, key: str = None): + """Save the data with the `filename`.""" + data_dict = object.get_data() + number = data_dict["number"] + with h5py.File(filename, "w") as h5: + h5["number"] = number + if key is None: + original_key = object.key + key = original_key + "_to_H5Format" + return object.from_file(filename, cls, key) + else: + return object.from_file(filename, cls, key) + + @staticmethod + def direct_convert_formats(): + # Assume the format can be converted directly to the formats supported by these classes: + # AFormat, BFormat + # Redefine this `direct_convert_formats` for a concrete format class + return [] + + +@pytest.fixture() +def txt_file(tmp_path_factory): + fn_path = tmp_path_factory.mktemp("test_data") / "test.txt" + txt_file = str(fn_path) + with open(txt_file, "w") as f: + f.write("4") + return txt_file + + +# Data class section +def test_list_formats(capsys): + """Test listing registered format classes""" + NumberData.list_formats() + captured = capsys.readouterr() + assert "Key: TXT" in captured.out + assert "Key: H5" in captured.out + + +def test_create_empty_data_instance(): + """Test creating an empty data instance""" + with pytest.raises(TypeError): + number_data = NumberData() + test_data = NumberData(key="test_data") + assert isinstance(test_data, NumberData) + + +def test_create_data_with_set_dict(): + """Test set dict after in an empty data instance""" + test_data = NumberData(key="test_data") + my_dict = {"number": 4} + test_data.set_dict(my_dict) + assert test_data.get_data()["number"] == 4 + + +def test_create_data_with_set_file(txt_file): + """Test set file after in an empty data instance""" + test_data = NumberData(key="test_data") + test_data.set_file(txt_file, TXTFormat) + assert test_data.get_data()["number"] == 4 + + +def test_create_data_with_set_file_inconsistensy(txt_file): + """Test set dict and file for one data object: expecting an error""" + test_data = NumberData(key="test_data") + my_dict = {"number": 4} + test_data.set_dict(my_dict) + with pytest.raises(RuntimeError): + test_data.set_file(txt_file, TXTFormat) + + +def test_create_data_with_set_file_wrong_param(txt_file): + """Test set file after in an empty data instance with wrong `format_class` param""" + test_data = NumberData(key="test_data") + with pytest.raises(TypeError): + test_data.set_file(txt_file, "txt") + + +def test_create_data_with_set_file_wrong_format(txt_file): + """Test set file after in an empty data instance with wrong `format_class`""" + test_data = NumberData(key="test_data") + test_data.set_file(txt_file, H5Format) + with pytest.raises(OSError): + test_data.get_data() + + +def test_create_data_with_file(): + """Test set dict after in an empty data instance""" + test_data = NumberData(key="test_data") + assert isinstance(test_data, NumberData) + my_dict = {"number": 4} + test_data.set_dict(my_dict) + assert test_data.get_data()["number"] == 4 + + +def test_create_data_from_dict(): + """Test creating a data instance from a dict""" + my_dict = {"number": 4} + test_data = NumberData.from_dict(my_dict, "test_data") + + +def test_check_key_from_dict(): + """Test checking expected data key from dict""" + my_dict = {"number": 4} + test_data = NumberData.from_dict(my_dict, "test_data") + test_data.get_data() + my_dict = {"numberr": 4} + test_data = NumberData.from_dict(my_dict, "test_data") + with pytest.raises(KeyError): + test_data.get_data() + + +def test_create_data_from_file_wrong_param(txt_file): + """Test creating a data instance from a file in a wrong file format type""" + with pytest.raises(TypeError): + test_data = NumberData.from_file(txt_file, "txt", "test_data") + + +def test_create_data_from_TXTFormat(txt_file): + """Test creating a data instance from a file in TXTFormat""" + test_data = NumberData.from_file(txt_file, TXTFormat, "test_data") + assert test_data.get_data()["number"] == 4 + + +def test_create_data_from_wrong_format(txt_file): + """Test creating a data instance from a file in TXTFormat""" + test_data = NumberData.from_file(txt_file, H5Format, "test_data") + with pytest.raises(OSError): + test_data.get_data() + +def test_duplicate_data_TXTFormat(txt_file, tmpdir, capsys): + """Test creating a data instance from a file in TXTFormat""" + test_data = NumberData.from_file(txt_file, TXTFormat, "test_data") + test_data.write(str(tmpdir/"new_data.txt"),TXTFormat) + captured = capsys.readouterr() + assert "data already existed" in captured.out + +def test_save_dict_data_in_TXTFormat(tmpdir): + """Test saving a dict data in TXTFormat""" + my_dict = {"number": 4} + test_data = NumberData.from_dict(my_dict, "test_data") + fn = str(tmpdir / "test.txt") + test_data.write(fn, TXTFormat) + read_data = NumberData.from_file(fn, TXTFormat, "read_data") + assert read_data.get_data()["number"] == 4 + + +def test_save_dict_data_in_TXTFormat_return_data_object(tmpdir): + """Test saving a dict data in TXTFormat returning data object with default key""" + my_dict = {"number": 4} + test_data = NumberData.from_dict(my_dict, "test_data") + fn = str(tmpdir / "test.txt") + return_data = test_data.write(fn, TXTFormat) + assert return_data.get_data()["number"] == 4 + assert return_data.key == "test_data_to_TXTFormat" + + +def test_save_dict_data_in_TXTFormat_return_data_object_key(tmpdir): + """Test saving a dict data in TXTFormat returning data object with custom key""" + my_dict = {"number": 4} + test_data = NumberData.from_dict(my_dict, "test_data") + print(test_data) + # assert False + fn = str(tmpdir / "test.txt") + return_data = test_data.write(fn, TXTFormat, "custom") + assert return_data.get_data()["number"] == 4 + assert return_data.key == "custom" + + +def test_save_file_data_in_another_format_direct(txt_file, tmpdir, capsys): + """Test directly converting a TXTFormat data to H5Format""" + test_data = NumberData.from_file(txt_file, TXTFormat, "test_data") + # print(test_data) + fn = str(tmpdir / "test.h5") + return_data = test_data.write(fn, H5Format) + captured = capsys.readouterr() + assert "Directly converting TXTFormat to H5Format" in captured.out + assert return_data.get_data()["number"] == 4 + assert return_data.key == "test_data_from_TXTFormat" + return_data = test_data.write(fn, H5Format, "txt2h5") + assert return_data.key == "txt2h5" + # print(return_data) + # assert False + + +def test_save_file_data_in_another_format_indirect(tmpdir): + """Test directly converting a TXTFormat data to H5Format""" + my_dict = {"number": 4} + test_data = NumberData.from_dict(my_dict, "test_data") + fn = str(tmpdir / "test.h5") + h5_data = test_data.write(fn, H5Format, "test_data") + fn = str(tmpdir / "test.txt") + return_data = h5_data.write(fn, TXTFormat) + print(return_data) + assert return_data.get_data()["number"] == 4 + assert return_data.key == "test_data_to_TXTFormat" + return_data = test_data.write(fn, H5Format, "txt2h5") + assert return_data.key == "txt2h5" + # print(return_data) + # assert False + + +# Data collection section +def test_DataCollection_instance(): + """Test creating a DataCollection instance""" + collection = DataCollection() + assert isinstance(collection, DataCollection) + + +def test_DataCollection_one_data(txt_file): + """Test a DataCollection instance with one dataset""" + test_data = NumberData.from_file(txt_file, TXTFormat, "test_data") + collection = DataCollection(test_data) + data_in_collection = collection["test_data"] + assert collection.get_data() == data_in_collection.get_data() + + +def test_DataCollection_one_data_write(txt_file, tmpdir): + """Test a DataCollection instance with one dataset""" + test_data = NumberData.from_file(txt_file, TXTFormat, "test_data") + collection = DataCollection(test_data) + fn = str(tmpdir / "data.h5") + written_data = collection.write(fn, H5Format) + assert written_data.mapping_type == H5Format + assert written_data.get_data()["number"] == 4 + + +def test_DataCollection_two_data(txt_file): + """Test creating a DataCollection instance with two datasets""" + my_dict = {"number": 5} + test_data_txt = NumberData.from_file(txt_file, TXTFormat, "test_txt") + test_data_dict = NumberData.from_dict(my_dict, "test_dict") + collection = DataCollection(test_data_txt, test_data_dict) + assert collection["test_dict"].get_data()["number"] == 5 + assert collection["test_txt"].get_data()["number"] == 4 + value_collection = collection.get_data() + assert value_collection["test_dict"]["number"] == 5 + assert value_collection["test_txt"]["number"] == 4 + + +def test_DataCollection_two_data_write(txt_file, tmpdir): + """Test writing a DataCollection instance with two datasets""" + my_dict = {"number": 5} + test_data_txt = NumberData.from_file(txt_file, TXTFormat, "test_txt") + test_data_dict = NumberData.from_dict(my_dict, "test_dict") + collection = DataCollection(test_data_txt, test_data_dict) + fn_txt = str(tmpdir / "data_new.txt") + fn_h5 = str(tmpdir / "data_new.h5") + filenames = {"test_txt": fn_h5, "test_dict": fn_txt} + format_classes = {"test_txt": H5Format, "test_dict": TXTFormat} + keys = {"test_txt": None, "test_dict": None} + written_collection = collection.write(filenames, format_classes, keys) + # Create a new data collection from the collection dict + new_collection = DataCollection(*written_collection.values()) + assert new_collection["test_dict_to_TXTFormat"].get_data()["number"] == 5 + + +def test_DataCollection_add_data(txt_file): + """Test adding data to a DataCollection instance""" + my_dict = {"number": 5} + test_data_dict = NumberData.from_dict(my_dict, "test_dict") + test_data_txt = NumberData.from_file(txt_file, TXTFormat, "test_txt") + collection = DataCollection() + collection.add_data(test_data_dict, test_data_txt) + print(collection) + + +def test_DataCollection_add_wrong_data_type(): + """Test adding data in wrong type to a DataCollection instance""" + collection = DataCollection() + with pytest.raises(AssertionError): + collection.add_data(0) + + +def test_DataCollection_to_list(txt_file): + """Test returning a DataCollection as a list""" + my_dict = {"number": 5} + test_data_dict = NumberData.from_dict(my_dict, "test_dict") + test_data_txt = NumberData.from_file(txt_file, TXTFormat, "test_txt") + collection = DataCollection(test_data_dict, test_data_txt) + my_list = collection.to_list() + assert my_list[0].get_data()["number"] == 5 + assert my_list[1].get_data()["number"] == 4 diff --git a/tests/unit/test_Instrument.py b/tests/unit/test_Instrument.py new file mode 100644 index 0000000..0e6dc7f --- /dev/null +++ b/tests/unit/test_Instrument.py @@ -0,0 +1,119 @@ +import unittest +import os +import shutil + +from test_BaseCalculator import PlusCalculator, NumberData +from libpyvinyl.Instrument import Instrument + + +class InstrumentTest(unittest.TestCase): + """ + Test class for the Detector class. + """ + + @classmethod + def setUpClass(cls): + """Setting up the test class.""" + input1 = NumberData.from_dict({"number": 1}, "input1") + input2 = NumberData.from_dict({"number": 1}, "input2") + calculator1 = PlusCalculator("test1", [input1, input2]) + cls.calculator1 = calculator1 + calculator2 = PlusCalculator("test2", [input1, input2]) + calculator2.parameters["plus_times"] = 12 + cls.calculator2 = calculator2 + + @classmethod + def tearDownClass(cls): + """Tearing down the test class.""" + pass + + def setUp(self): + """Setting up a test.""" + self.__files_to_remove = [] + self.__dirs_to_remove = [] + + def tearDown(self): + """Tearing down a test.""" + + for f in self.__files_to_remove: + if os.path.isfile(f): + os.remove(f) + for d in self.__dirs_to_remove: + if os.path.isdir(d): + shutil.rmtree(d) + + def testInstrumentConstruction(self): + """Testing the default construction of the class.""" + + # Construct the object. + my_instrument = Instrument("myInstrument") + my_instrument.add_calculator(self.calculator1) + my_instrument.add_calculator(self.calculator2) + + def testListCalculator(self): + """Testing list calculators""" + + # Construct the object. + my_instrument = Instrument("myInstrument") + my_instrument.add_calculator(self.calculator1) + my_instrument.add_calculator(self.calculator2) + my_instrument.list_calculators() + + def testListParams(self): + """Testing listing parameters""" + + my_instrument = Instrument("myInstrument") + my_instrument.add_calculator(self.calculator1) + my_instrument.add_calculator(self.calculator2) + my_instrument.list_parameters() + + def testRemoveCalculator(self): + """Testing remove calculator""" + + my_instrument = Instrument("myInstrument") + my_instrument.add_calculator(self.calculator1) + my_instrument.add_calculator(self.calculator2) + self.assertEqual(len(my_instrument.calculators), 2) + my_instrument.remove_calculator(self.calculator1.name) + self.assertEqual(len(my_instrument.calculators), 1) + + def testEditCalculator(self): + """Testing edit calculator""" + my_instrument = Instrument("myInstrument") + my_instrument.add_calculator(self.calculator1) + my_instrument.parameters["test1"]["plus_times"] = 10 + my_instrument.parameters["test1"]["plus_times"] = 15 + energy1 = my_instrument.calculators["test1"].parameters["plus_times"].value + self.assertEqual(energy1, 15) + + def testAddMaster(self): + """Testing remove calculator""" + + my_instrument = Instrument("myInstrument") + my_instrument.add_calculator(self.calculator1) + my_instrument.add_calculator(self.calculator2) + links = {"test1": "plus_times", "test2": "plus_times"} + my_instrument.add_master_parameter("plus_times", links) + my_instrument.master["plus_times"] = 10 + tims1 = my_instrument.calculators["test1"].parameters["plus_times"].value + tims2 = my_instrument.calculators["test2"].parameters["plus_times"].value + self.assertEqual(tims1, 10) + self.assertEqual(tims2, 10) + + def testSetBasePath(self): + """Testing setup base path for calculators""" + + my_instrument = Instrument("myInstrument") + my_instrument.add_calculator(self.calculator1) + my_instrument.add_calculator(self.calculator2) + my_instrument.set_instrument_base_dir("test") + self.assertEqual( + my_instrument.calculators["test1"].base_dir, "test/PlusCalculator" + ) + self.assertEqual( + my_instrument.calculators["test2"].base_dir, "test/PlusCalculator" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_Parameters.py b/tests/unit/test_Parameters.py new file mode 100644 index 0000000..8ae6792 --- /dev/null +++ b/tests/unit/test_Parameters.py @@ -0,0 +1,627 @@ +import unittest +import numpy +import pytest +import os +import tempfile +from pint.quantity import Quantity +from pint.unit import Unit +from libpyvinyl.Parameters import Parameter +from libpyvinyl.Parameters import CalculatorParameters +from libpyvinyl.Parameters import InstrumentParameters + + +class Test_Parameter(unittest.TestCase): + def test_initialize_parameter_simple(self): + par = Parameter("test") + self.assertEqual(par.name, "test") + + def test_initialize_parameter_complex(self): + par = Parameter("test", unit="cm", comment="comment string") + self.assertEqual(par.name, "test") + assert par.unit == str(Unit("cm")) + self.assertEqual(par.comment, "comment string") + + def test_units_assignment(self): + par = Parameter("test", unit="kg") + assert par.unit == Unit("kg") + par.unit = "cm" + assert par.unit == Unit("cm") + par.unit = "meter" + assert par.unit == Unit("m") + assert par.unit == str(Unit("m")) + assert par.unit == "meter" + par.unit = "nounit" + assert par.unit == "nounit" + + def test_check_value_type(self): + par = Parameter("test") + v = 1 + par._Parameter__check_compatibility(v) + par._Parameter__set_value_type(v) + assert par._Parameter__value_type == int + + v = 1.0 + par._Parameter__check_compatibility(v) + par._Parameter__set_value_type(v) + assert par._Parameter__value_type == Quantity + + v = "string" + with pytest.raises(TypeError): + par._Parameter__check_compatibility(v) + par._Parameter__set_value_type(v) + assert par._Parameter__value_type == str + + v = True + par = Parameter("test") + par._Parameter__set_value_type("string") + assert par._Parameter__value_type == str + with pytest.raises(TypeError): + par._Parameter__check_compatibility(v) + + par = Parameter("test") + par._Parameter__set_value_type(True) + assert par._Parameter__value_type == bool + + v = ["ciao", True] + par = Parameter("test") + with pytest.raises(TypeError): + par._Parameter__check_compatibility(v) + par._Parameter__set_value_type([False, True]) + + v = {"ciao": True, "bye": "string"} + par = Parameter("test") + with pytest.raises(NotImplementedError): + par._Parameter__check_compatibility(v) + v = {"ciao": True, "bye": False} + with pytest.raises(NotImplementedError): + par._Parameter__check_compatibility(v) + + v = numpy.random.uniform(0, 1, 10) + par = Parameter("test") + par._Parameter__check_compatibility(v) + par._Parameter__set_value_type(v) + assert par._Parameter__value_type == Quantity + + # no conditions + def test_parameter_no_legal_conditions(self): + par = Parameter("test") + self.assertTrue(par.is_legal(None)) # FIXME how is this supposed to work? + self.assertTrue(par.is_legal(-999)) + self.assertTrue(par.is_legal(-1)) + self.assertTrue(par.is_legal(0)) + self.assertTrue(par.is_legal(1)) + self.assertTrue(par.is_legal("This is a string")) + self.assertTrue(par.is_legal(True)) + self.assertTrue(par.is_legal(False)) + self.assertTrue(par.is_legal([0, "A", True])) + + # case 1: only legal interval + def test_parameter_legal_interval(self): + par = Parameter("test") + par.add_interval(3, 4.5, True) + + self.assertTrue(par.is_legal(3.5)) + self.assertFalse(par.is_legal(1.0)) + + # case 2: only illegal interval + def test_parameter_illegal_interval(self): + par = Parameter("test") + par.add_interval(3, 4.5, False) + + self.assertFalse(par.is_legal(3.5)) + self.assertTrue(par.is_legal(1.0)) + + def test_parameter_multiple_intervals(self): + par = Parameter("test") + par.add_interval(None, 8.5, True) # minus infinite to 8.5 + + self.assertRaises(ValueError, par.add_interval, 3, 4.5, False) + self.assertTrue(par.is_legal(-831.0)) + self.assertTrue(par.is_legal(3.5)) + self.assertTrue(par.is_legal(5.0)) + self.assertFalse(par.is_legal(10.0)) + + def test_values_different_types(self): + par = Parameter("test") + par.add_option(9.8, True) + with pytest.raises(TypeError): + par.add_option(True, True) + + # case 1: only legal option + def test_parameter_legal_option_float(self): + par = Parameter("test") + par.add_option(9.8, True) + + self.assertFalse(par.is_legal(10)) + self.assertTrue(par.is_legal(9.8)) + self.assertFalse(par.is_legal(True)) + self.assertFalse(par.is_legal("A")) + self.assertFalse(par.is_legal(38)) + + # case 1: only legal option + def test_parameter_legal_option_bool(self): + par = Parameter("test") + par.add_option(True, True) + + self.assertFalse(par.is_legal(10)) + self.assertFalse(par.is_legal(9.8)) + self.assertTrue(par.is_legal(True)) + self.assertFalse(par.is_legal("A")) + self.assertFalse(par.is_legal(38)) + + # case 1: only legal option + def test_parameter_legal_option_float_and_int(self): + par = Parameter("test") + par.add_option(9.8, True) + par.add_option(38, True) + + self.assertFalse(par.is_legal(10)) + self.assertTrue(par.is_legal(9.8)) + self.assertFalse(par.is_legal(True)) + self.assertFalse(par.is_legal("A")) + self.assertTrue(par.is_legal(38)) + + # case 1: only legal option + def test_parameter_legal_option_int_and_float(self): + par = Parameter("test") + par.add_option(38, True) + par.add_option(9.8, True) + + self.assertFalse(par.is_legal(10)) + self.assertTrue(par.is_legal(9.8)) + self.assertFalse(par.is_legal(True)) + self.assertFalse(par.is_legal("A")) + self.assertTrue(par.is_legal(38)) + + # case 1: only legal option + def test_parameter_legal_option_fromlist(self): + par = Parameter("test") + par.add_option([9, 8, 38], True) + + self.assertFalse(par.is_legal(10)) + self.assertFalse(par.is_legal(9.8)) + self.assertFalse(par.is_legal(True)) + self.assertFalse(par.is_legal("A")) + self.assertTrue(par.is_legal(38)) + self.assertTrue(par.is_legal(38.0)) + self.assertTrue(par.is_legal(8)) + + # case 1: only legal option + + def test_parameter_legal_option_string(self): + par = Parameter("test") + par.add_option(["B", "A"], True) + + self.assertFalse(par.is_legal(10)) + self.assertFalse(par.is_legal(9.8)) + self.assertFalse(par.is_legal(True)) + self.assertTrue(par.is_legal("A")) + self.assertTrue(par.is_legal("B")) + self.assertFalse(par.is_legal("C")) + self.assertFalse(par.is_legal(38)) + + def test_parameter_multiple_options(self): + par = Parameter("test") + par.add_option(9.8, True) + + self.assertRaises(ValueError, par.add_option, 3, False) + self.assertFalse(par.is_legal(-831.0)) + self.assertTrue(par.is_legal(9.8)) + self.assertFalse(par.is_legal(3)) + + # case 1: legal interval + legal option + def test_parameter_legal_interval_plus_legal_option(self): + par = Parameter("test") + par.add_interval(None, 8.5, True) # minus infinite to 8.5 + par.add_option(5, True) # this is stupid, already accounted in the interval + par.add_option(11, True) + + self.assertTrue(par.is_legal(-831.0)) + self.assertTrue(par.is_legal(8.5)) + self.assertTrue(par.is_legal(5.0)) + self.assertFalse(par.is_legal(10.0)) + self.assertTrue(par.is_legal(11.0)) + + # case 2: illegal interval + illegal option + def test_parameter_illegal_interval_plus_illegal_option(self): + par = Parameter("test") + par.add_interval(None, 8.5, False) # minus infinite to 8.5 + par.add_option(5, False) # this is stupid, already accounted in the interval + par.add_option(11, False) + + self.assertFalse(par.is_legal(-831.0)) + self.assertFalse(par.is_legal(8.5)) # illegal because closed interval + self.assertFalse(par.is_legal(5.0)) + self.assertTrue(par.is_legal(10.0)) + self.assertFalse(par.is_legal(11.0)) + + # case 3: legal interval + illegal option + def test_parameter_legal_interval_plus_illegal_option(self): + par = Parameter("test") + par.add_interval(None, 8.5, True) # minus infinite to 8.5 + par.add_option(5, False) + + self.assertTrue(par.is_legal(-831.0)) + self.assertTrue(par.is_legal(8.5)) + self.assertFalse(par.is_legal(5.0)) + self.assertFalse(par.is_legal(10.0)) + self.assertFalse(par.is_legal(11.0)) + + # case 4: illegal interval + legal option + def test_parameter_illegal_interval_plus_legal_option(self): + par = Parameter("test") + par.add_interval(None, 8.5, False) # minus infinite to 8.5 + par.add_option(5, True) + + self.assertFalse(par.is_legal(-831.0)) + self.assertFalse(par.is_legal(8.5)) + self.assertTrue(par.is_legal(5.0)) + self.assertTrue(par.is_legal(10.0)) + self.assertTrue(par.is_legal(11.0)) + + # case 2: illegal interval + illegal option + def test_parameter_get_options(self): + """ + Ensure get_options returns the options as required + """ + par = Parameter("test") + par.add_interval(None, 8.5, False) # minus infinite to 8.5 + par.add_option(5, True) # this is stupid, already accounted in the interval + par.add_option(11, True) + + retrieved_options = par.get_options() + + self.assertEqual(len(retrieved_options), 2) + self.assertEqual(retrieved_options[0], 5.0) + self.assertEqual(retrieved_options[1], 11.0) + self.assertTrue(par.get_options_are_legal()) + + def test_parameter_value_type(self): + par = Parameter("test") + par.value = 4.0 + assert par._Parameter__value_type == Quantity + + par1 = Parameter("test") + par1.value = 4 + assert par1._Parameter__value_type == int + + par2 = Parameter("test", unit="meV") + par2.value = 4 + assert par2._Parameter__value_type == Quantity + + par3 = Parameter("test", unit="meV") + par3.add_interval(0, 1e6, True) + assert par3._Parameter__value_type == Quantity + + def test_parameter_set_value(self): + par = Parameter("test") + par.add_interval(3, 4.5, True) + + par.value = 4.0 + self.assertEqual(par.value, 4.0) + + with self.assertRaises(ValueError): + par.value = 5.0 # Should throw an error and be ignored + + self.assertEqual(par.value, 4.0) + + def test_add_interval_after_value(self): + par = Parameter("test") + par.value = 4.0 + par.add_interval(3, 4.5, True) + + par.clear_intervals() + par.value = 5.0 + with self.assertRaises(ValueError): + par.add_interval(3, 4.5, True) + + def test_parameter_from_dict(self): + par = Parameter("test") + + par.add_interval(3, 4.5, True) + par.value = 4.0 + par_from_dict = Parameter.from_dict(par.__dict__) + self.assertEqual(par_from_dict.value, 4.0) + + def test_print_legal_interval(self): + par = Parameter("test") + par.add_interval(3, 4.5, True) + par.add_option(9.8, True) + par.print_parameter_constraints() + + def test_clear_intervals(self): # FIXME + par = Parameter("test") + par.add_interval(3, 4.5, True) + # self.assertEqual(par.__intervals, [[3, 4.5]]) #FIXME + + par.clear_intervals() + par.add_option(9.7, True) + # self.assertEqual(par.__options, [9.7]) + par.clear_options() + + # self.assertEqual(par.__options, []) + + def test_print_line(self): + par = Parameter("test") + par.add_interval(3, 4.5, True) + par.add_option(9.8, True) + par.print_line() + + def test_print(self): + par = Parameter("test") + par.add_interval(3, 4.5, True) + par.add_option(9.8, True) + print(par) + + def test_parameter_iterable(self): + par = Parameter("test") + par.add_interval(3, 4.5, True) + par.add_option(7, True) + self.assertFalse(par.is_legal([0.5, 3.2, 5.0])) + self.assertTrue(par.is_legal([3.1, 4.2, 4.4])) + self.assertTrue(par.is_legal([3.1, 4.2, 4.4, 7])) + + def test_get_intervals(self): + par = Parameter("test") + par.add_interval(3, 4.5, True) + par.add_interval(8, 10, True) + + retrived_intervals = par.get_intervals() + self.assertEqual(len(retrived_intervals), 2) + self.assertEqual(retrived_intervals[0][0], 3) + self.assertEqual(retrived_intervals[0][1], 4.5) + self.assertEqual(retrived_intervals[1][0], 8) + self.assertEqual(retrived_intervals[1][1], 10) + + self.assertTrue(par.get_intervals_are_legal()) + + def test_parameters_with_quantity(self): + """Test if we can construct and use a Parameter instance passing pint.Quantity and pint.Unit objects to the constructor and interval setter.""" + + # Define the base unit of my parameter object. + meter = Unit("meter") + self.assertIsInstance(meter, Unit) + + minimum_undulator_length = 10.0 * meter + undulator_length = Parameter("undulator_length", meter) + + self.assertIsInstance(undulator_length, Parameter) + self.assertEqual(undulator_length.unit, Unit("meter")) + + undulator_length.add_interval( + min_value=minimum_undulator_length, + max_value=numpy.inf * meter, + intervals_are_legal=True, + ) + + self.assertTrue(undulator_length.is_legal(10.1 * meter)) + self.assertFalse(undulator_length.is_legal(9.0 * meter)) + self.assertTrue(undulator_length.is_legal(5.5e4 * Unit("centimeter"))) + + def test_parameter_set_numpy_value(self): + par = Parameter("test", unit="eV") + par.value = 1e-4 + par.value = numpy.log(10) + + def test_parameters_with_quantity_powers(self): + """Test if we can construct and use a Parameter instance passing pint.Quantity and pint.Unit objects to the constructor and interval setter. Use different powers of 10 in parameter initialization and value assignment.""" + + # Define the base unit of my parameter object. + meter = Unit("meter") + centimeter = Unit("centimeter") + self.assertIsInstance(meter, Unit) + + minimum_undulator_length = 10.0 * meter + undulator_length = Parameter("undulator_length", centimeter) + + self.assertIsInstance(undulator_length, Parameter) + self.assertEqual(undulator_length.unit, Unit("centimeter")) + + undulator_length.add_interval( + min_value=minimum_undulator_length, + max_value=numpy.inf * meter, + intervals_are_legal=True, + ) + + print(undulator_length) + + self.assertTrue(undulator_length.is_legal(10.1 * meter)) + self.assertFalse(undulator_length.is_legal(9.0 * centimeter)) + self.assertTrue(undulator_length.is_legal(5.5e4 * Unit("centimeter"))) + + +class Test_Parameters(unittest.TestCase): + def test_initialize_parameters_from_list(self): + par1 = Parameter("test") + par1.value = 8 + par2 = Parameter("test2", unit="meV") + + parameters = CalculatorParameters([par1, par2]) + + self.assertEqual(parameters["test"].value, 8) + + def test_initialize_parameters_from_add(self): + par1 = Parameter("test") + par1.value = 8 + par2 = Parameter("test2", unit="meV") + par2.value = 10 + + parameters = CalculatorParameters() + parameters.add(par1) + parameters.add(par2) + + self.assertEqual(parameters["test"].value, 8) + self.assertEqual(parameters["test2"].value, 10) + + def test_print_parameters(self): + par1 = Parameter("test") + par1.value = 8 + par2 = Parameter("test2", unit="meV") + par2.value = 10 + parameters = CalculatorParameters() + parameters.add(par1) + parameters.add(par2) + print(parameters) + + def test_json(self): + par1 = Parameter("test") + par1.value = 8.0 + par2 = Parameter("test2", unit="meV") + par2.value = 10 + + parameters = CalculatorParameters() + parameters.add(par1) + parameters.add(par2) + + with tempfile.TemporaryDirectory() as d: + tmp_file = os.path.join(d, "test.json") + parameters.to_json(tmp_file) + params_json = CalculatorParameters.from_json(tmp_file) + self.assertEqual(params_json["test2"].value, 10) + assert params_json["test2"].value_no_conversion == Quantity(10, "meV") + with pytest.raises(TypeError): + params_json["test2"].value = "A" + + def test_json_with_objects(self): + par1 = Parameter("test") + par1.value = 8 + par2 = Parameter("test2", unit="meV") + par2.value = 10 + par3 = Parameter("test3", unit="meV") + par3.value = 3.14 + + parameters = CalculatorParameters() + parameters.add(par1) + parameters.add(par2) + parameters.add(par3) + with tempfile.TemporaryDirectory() as d: + tmp_file = os.path.join(d, "test.json") + tmp_file = "/tmp/test.json" + parameters.to_json(tmp_file) + params_json = CalculatorParameters.from_json(tmp_file) + self.assertEqual(params_json["test2"].value, 10) + print(params_json["test3"]) + assert params_json["test3"].value == par3.value + assert params_json["test3"].value == 3.14 + assert params_json["test3"].value_no_conversion == Quantity(3.14, "meV") + + def test_get_item(self): + par1 = Parameter("test") + par1.value = 8 + par2 = Parameter("test2", unit="meV") + par2.value = 10 + + parameters = CalculatorParameters() + self.assertRaises(KeyError, parameters.__getitem__, "test3") + + +def source_calculator(): + """ + Little dummy calculator that sets up a parameters object for a source + """ + parameters = CalculatorParameters() + parameters.new_parameter("energy", unit="eV", comment="Source energy setting") + parameters["energy"].add_interval(0, 1e6, True) + parameters["energy"].value = 4000 + + parameters.new_parameter("delta_energy", unit="eV", comment="Energy spread fwhm") + parameters["delta_energy"].add_interval(0, 400, True) + + parameters.new_parameter("position", unit="cm", comment="Source center") + parameters["position"].add_interval(-1.5, 1.5, True) + + parameters.new_parameter("gaussian", comment="False for flat, True for gaussian") + parameters["gaussian"].add_option([False, True], True) + + return parameters + + +def sample_calculator(): + """ + Little dummy calculator that sets up a parameters object for a sample + """ + parameters = CalculatorParameters() + parameters.new_parameter("radius", unit="cm", comment="Sample radius") + parameters["radius"].add_interval(0, None, True) # To infinite + + parameters.new_parameter("height", unit="cm", comment="Sample height") + parameters["height"].add_interval(0, None, True) + + absporption = parameters.new_parameter( + "absorption", unit="barns", comment="absorption cross section" + ) + absporption.add_interval(0, None, True) + + return parameters + + +class Test_Instruments(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Setting up the test class.""" + cls.d = tempfile.TemporaryDirectory() + + def setUp(self): + # We start creating our instrument with a InstrumentParameters + self.instr_parameters = InstrumentParameters() + + # We insert a source and get some parameters out + source_pars = source_calculator() + # These are added to the instr_parameters so they can be controlled + self.instr_parameters.add("Source", source_pars) + + # We also add a few sample objects with their parameter objects + top_sample_pars = sample_calculator() + self.instr_parameters.add("Sample top", top_sample_pars) + + bottom_sample_pars = sample_calculator() + self.instr_parameters.add("Sample bottom", bottom_sample_pars) + + @classmethod + def tearDownClass(cls): + """Tearing down the test class.""" + cls.d.cleanup() + + def test_link(self): + description = "Absorption cross section for both samples" + links = {"Sample top": "absorption", "Sample bottom": "absorption"} + master_value = 3.4 + self.instr_parameters.add_master_parameter( + "absorption", links, unit="barns", comment=description + ) + self.instr_parameters.master["absorption"] = master_value + top_value = self.instr_parameters["Sample top"]["absorption"].value + bottom_value = self.instr_parameters["Sample bottom"]["absorption"].value + self.assertEqual(top_value, master_value) + self.assertEqual(bottom_value, master_value) + master_params = self.instr_parameters.master.parameters + self.assertIn("absorption", master_params.keys()) + self.assertEqual(master_value, master_params["absorption"].value) + self.assertEqual(self.instr_parameters.master["absorption"].links, links) + + def test_print(self): + print(self.instr_parameters) + + def test_json(self): + description = "Absorption cross section for both samples" + links = {"Sample top": "absorption", "Sample bottom": "absorption"} + master_value = 3.4 + self.instr_parameters.add_master_parameter( + "absorption", links, unit="barns", comment=description + ) + self.instr_parameters.master["absorption"] = master_value + temp_file = os.path.join(self.d.name, "test.json") + self.instr_parameters.to_json(temp_file) + print(self.instr_parameters) + + # From json + instr_json = InstrumentParameters.from_json(temp_file) + self.assertEqual(instr_json["Source"]["energy"].value, 4000) + master_params = instr_json.master.parameters + self.assertIn("absorption", master_params.keys()) + self.assertEqual(master_value, master_params["absorption"].value) + self.assertEqual(self.instr_parameters.master["absorption"].links, links) + + +if __name__ == "__main__": + unittest.main()