diff --git a/mxfusion/inference/inference.py b/mxfusion/inference/inference.py index 4feb2ab..fe6e686 100644 --- a/mxfusion/inference/inference.py +++ b/mxfusion/inference/inference.py @@ -14,13 +14,18 @@ import warnings +import io +import json import numpy as np import mxnet as mx +import zipfile from .inference_parameters import InferenceParameters from ..common.config import get_default_device, get_default_dtype from ..common.exceptions import InferenceError from ..util.inference import discover_shape_constants, init_outcomes from ..models import FactorGraph, Model, Posterior +from ..util.serialization import ModelComponentEncoder, make_numpy, load_json_from_zip, load_parameters, \ + FILENAMES, DEFAULT_ZIP, ENCODINGS, SERIALIZATION_VERSION class Inference(object): @@ -44,7 +49,6 @@ class Inference(object): def __init__(self, inference_algorithm, constants=None, hybridize=False, dtype=None, context=None): - self.dtype = dtype if dtype is not None else get_default_dtype() self.mxnet_context = context if context is not None else get_default_device() self._hybridize = hybridize @@ -58,7 +62,7 @@ def __init__(self, inference_algorithm, constants=None, def print_params(self): """ Returns a string with the inference parameters nicely formatted for display, showing which model they came from and their name + uuid. - + Format: > infr.print_params() Variable(1ab23)(name=y) - (Model/Posterior(123ge2)) - (first mxnet values/shape) @@ -171,34 +175,36 @@ def set_initializer(self): pass def load(self, - graphs_file=None, - inference_configuration_file=None, - parameters_file=None, - mxnet_constants_file=None, - variable_constants_file=None): + zip_filename=DEFAULT_ZIP): """ - Loads back everything needed to rerun an inference algorithm. - The major pieces of this are the InferenceParameters, FactorGraphs, and - InferenceConfiguration. - - :param graphs_file: The file containing the graphs to load back for this inference algorithm. The first of these is the primary graph. - :type graphs_file: str of filename - :param inference_configuration_file: The file containing any inference specific configuration needed to - reload this inference algorithm. e.g. observation patterns used to train it. - :type inference_configuration_file: str of filename - :param parameters_file: These are the parameters of the previous inference algorithm. - These are in a {uuid: mx.nd.array} mapping. - :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved - in a binary format. - :param mxnet_constants_file: These are the constants in mxnet format from the previous inference algorithm. - These are in a {uuid: mx.nd.array} mapping. - :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved - in a binary format. - :param variable_constants_file: These are the constants in primitive format from the previous - inference algorithm. - :type variable_constants_file: json dict of {uuid: constant_primitive} + Loads back everything needed to rerun an inference algorithm from a zip file. + See the save function for details on the structure of the zip file. + + :param zip_filename: Path to the zip file of the inference method to load back in. + Defaults to the default name of inference.zip + :type zip_filename: str of zip filename """ - graphs = FactorGraph.load_graphs(graphs_file) + + # Check version is correct + zip_version = load_json_from_zip(zip_filename, FILENAMES['version_file']) + if zip_version['serialization_version'] != SERIALIZATION_VERSION: + raise SerializationError("Serialization version of saved inference \ + and running code are note the same.") + + # Load parameters back in + + with zipfile.ZipFile(zip_filename, 'r') as zip_file: + mxnet_parameters = load_parameters(FILENAMES['mxnet_params'], zip_file, context=self.mxnet_context) + mxnet_constants = load_parameters(FILENAMES['mxnet_constants'], zip_file, context=self.mxnet_context) + + variable_constants = load_json_from_zip(zip_filename, + FILENAMES['variable_constants']) + + # Load graphs + from ..util.serialization import ModelComponentDecoder + graphs_list = load_json_from_zip(zip_filename, FILENAMES['graphs'], + decoder=ModelComponentDecoder) + graphs = FactorGraph.load_graphs(graphs_list) primary_model = graphs[0] secondary_graphs = graphs[1:] @@ -207,56 +213,100 @@ def load(self, current_graphs=self.graphs, primary_previous_graph=primary_model, secondary_previous_graphs=secondary_graphs) + new_parameters = InferenceParameters.load_parameters( uuid_map=self._uuid_map, - parameters_file=parameters_file, - variable_constants_file=variable_constants_file, - mxnet_constants_file=mxnet_constants_file, + mxnet_parameters=mxnet_parameters, + variable_constants=variable_constants, + mxnet_constants=mxnet_constants, current_params=self.params._params) self.params = new_parameters - self.load_configuration(inference_configuration_file, self._uuid_map) - def load_configuration(self, config_file, uuid_map): + configuration = load_json_from_zip(zip_filename, FILENAMES['configuration']) + self.load_configuration(configuration, self._uuid_map) + + def load_configuration(self, configuration, uuid_map): """ Loads relevant inference configuration back from a file. - Currently only loads the observed variables UUIDs back in, using the uuid_map - parameter to store the correct current observed variables. + Currently only loads the observed variables UUIDs back in, + using the uuid_map parameter to store the + correct current observed variables. - :param config_file: The file to save the configuration down into. + :param config_file: The loaded configuration dictionary :type config_file: str - :param uuid_map: A map of previous/loaded model component uuids to their current variable in the loaded graph. + :param uuid_map: A map of previous/loaded model component + uuids to their current variable in the loaded graph. :type uuid_map: { current_model_uuid : loaded_previous_uuid} """ - import json - with open(config_file) as f: - configuration = json.load(f) + pass # loaded_uuid = [uuid_map[uuid] for uuid in configuration['observed']] - def save_configuration(self, config_file): + def get_serializable(self): """ - Saves relevant inference configuration down into a file. - Currently only saves the observed variables UUIDs as {'observed': [observed_uuids]}. - - :param config_file: The file to save the configuration down into. - :type config_file: str + Returns the mimimum set of properties that the object needs to save in order to be + serialized down and loaded back in properly. + :returns: A dictionary of configuration properties needed to serialize and reload this inference method. + :rtypes: Dictionary that is JSON serializable. """ - import json - with open(config_file, 'w') as f: - json.dump({'observed': self.observed_variable_UUIDs}, f, ensure_ascii=False) + return {'observed': self.observed_variable_UUIDs} - def save(self, prefix=None): + def save(self, zip_filename=DEFAULT_ZIP): """ Saves down everything needed to reload an inference algorithm. - The two primary pieces of this are the InferenceParameters and FactorGraphs. - - :param prefix: The directory and any appending tag for the files to save this Inference as. - :type prefix: str , ex. "../saved_inferences/experiment_1" + This method writes everything into a single zip archive, with 6 internal files. + 1. version.json - This has the version of serialization used to create the zip file. + 2. graphs.json - This is a networkx representation of all FactorGraphs used during Inference. + See mxfusion.models.FactorGraph.save for more information. + 3. mxnet_parameters.npz - This is a numpy zip file saved using numpy.savez(), containing one file for each + mxnet parameter in the InferenceParameters object. Each parameter is saved in a binary file named by the + parameter's UUID. + 4. mxnet_constants.npz - The same as mxnet_parameters, except only for constant mxnet parameters. + 5. variable_constants.json - Parameters file of primitive data type constants, such as ints or floats. + I.E. { UUID : int/float} + 6. configuration.json - This has other configuration related to inference such as the observation pattern. + :param zip_filename: Path to and name of the zip archive to save the inference method as. + :type zip_filename: str """ - prefix = prefix if prefix is not None else "inference" - self.params.save(prefix=prefix) - self.save_configuration(prefix + '_configuration.json') + # Retrieve dictionary representations of things to save + mxnet_parameters, mxnet_constants, variable_constants = self.params.get_serializable() + configuration = self.get_serializable() graphs = [g.as_json()for g in self._graphs] - FactorGraph.save(prefix + "_graphs.json", graphs) + version_dict = {"serialization_version": + SERIALIZATION_VERSION} + + files_to_save = [] + objects = [graphs, mxnet_parameters, mxnet_constants, + variable_constants, configuration, version_dict] + ordered_filenames = [FILENAMES['graphs'], FILENAMES['mxnet_params'], FILENAMES['mxnet_constants'], + FILENAMES['variable_constants'], FILENAMES['configuration'], FILENAMES['version_file']] + encodings = [ENCODINGS['json'], ENCODINGS['numpy'], ENCODINGS['numpy'], + ENCODINGS['json'], ENCODINGS['json'], ENCODINGS['json']] + + # Form each individual file buffer. + for filename, obj, encoding in zip(ordered_filenames, objects, encodings): + # For the FactorGraphs, configuration, and variable constants just write them as regular json. + if encoding == ENCODINGS['json']: + buffer = io.StringIO() + json.dump(obj, buffer, ensure_ascii=False, + cls=ModelComponentEncoder) + # For MXNet parameters, save them as numpy compressed zip files of arrays. + # So a numpy-zip within the bigger zip. + elif encoding == ENCODINGS['numpy']: + buffer = io.BytesIO() + np_obj = make_numpy(obj) + np.savez(buffer, **np_obj) + files_to_save.append((filename, buffer)) + + # Form the overall zipfile stream + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "a", + zipfile.ZIP_DEFLATED, False) as zip_file: + for base_name, data in files_to_save: + zip_file.writestr(base_name, data.getvalue()) + + # Finally save the actual zipfile stream to disk + with open(zip_filename, 'wb') as f: + f.write(zip_buffer.getvalue()) class TransferInference(Inference): @@ -268,7 +318,8 @@ class TransferInference(Inference): :type inference_algorithm: InferenceAlgorithm :param constants: Specify a list of model variables as constants :type constants: {Variable: mxnet.ndarray} - :param hybridize: Whether to hybridize the MXNet Gluon block of the inference method. + :param hybridize: Whether to hybridize + the MXNet Gluon block of the inference method. :type hybridize: boolean :param dtype: data type for internal numerical representation :type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'} diff --git a/mxfusion/inference/inference_parameters.py b/mxfusion/inference/inference_parameters.py index 1ef8627..398c309 100644 --- a/mxfusion/inference/inference_parameters.py +++ b/mxfusion/inference/inference_parameters.py @@ -184,24 +184,25 @@ def __contains__(self, k): @staticmethod def load_parameters(uuid_map=None, - parameters_file=None, - variable_constants_file=None, - mxnet_constants_file=None, + mxnet_parameters=None, + variable_constants=None, + mxnet_constants=None, context=None, dtype=None, current_params=None): """ - Loads back a sest of InferenceParameters from files. - :param parameters_file: These are the parameters of the previous inference algorithm. + Loads back a set of InferenceParameters from files. + :param mxnet_parameters: These are the parameters of + the previous inference algorithm. These are in a {uuid: mx.nd.array} mapping. - :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved - in a binary format. - :param mxnet_constants_file: These are the constants in mxnet format from the previous inference algorithm. + :type mxnet_parameters: Dict of {uuid: mx.nd.array} + :param mxnet_constants: These are the constants in mxnet format + from the previous inference algorithm. These are in a {uuid: mx.nd.array} mapping. - :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved - in a binary format. - :param variable_constants_file: These are the constants in primitive format from the previous + :type mxnet_constants: Dict of {uuid: mx.nd.array} + :param variable_constants: These are the constants in + primitive format from the previous inference algorithm. - :type variable_constants_file: json dict of {uuid: constant_primitive} + :type variable_constants: dict of {uuid: constant primitive} """ def with_uuid_map(item, uuid_map): if uuid_map is not None: @@ -210,62 +211,49 @@ def with_uuid_map(item, uuid_map): return item ip = InferenceParameters(context=context, dtype=dtype) - if parameters_file is not None: - old_params = ndarray.load(parameters_file) - mapped_params = {with_uuid_map(k, uuid_map): v - for k, v in old_params.items()} + mapped_params = {with_uuid_map(k, uuid_map): v + for k, v in mxnet_parameters.items()} - new_paramdict = ParameterDict() - if current_params is not None: - new_paramdict.update(current_params) + new_paramdict = ParameterDict() + if current_params is not None: + new_paramdict.update(current_params) - # Do this because we need to map the uuids to the new Model - # before loading them into the ParamDict - for name, mapped_param in mapped_params.items(): - new_paramdict[name]._load_init(mapped_param, ip.mxnet_context) - ip._params = new_paramdict + # Do this because we need to map the uuids to the new Model + # before loading them into the ParamDict + for name, mapped_param in mapped_params.items(): + new_paramdict[name]._load_init(mapped_param, ip.mxnet_context) + ip._params = new_paramdict new_mxnet_constants = {} new_variable_constants = {} - if variable_constants_file is not None: - import json - with open(variable_constants_file) as f: - old_constants = json.load(f) - new_variable_constants = {with_uuid_map(k, uuid_map): v for k, v in old_constants.items()} - if mxnet_constants_file is not None: - mxnet_constants = ndarray.load(mxnet_constants_file) - if isinstance(mxnet_constants, dict): - new_mxnet_constants = {with_uuid_map(k, uuid_map): v for k, v in mxnet_constants.items()} - else: - new_mxnet_constants = {} + new_variable_constants = {with_uuid_map(k, uuid_map): v + for k, v in variable_constants.items()} + new_mxnet_constants = {with_uuid_map(k, uuid_map): v + for k, v in mxnet_constants.items()} + ip._constants = {} ip._constants.update(new_variable_constants) ip._constants.update(new_mxnet_constants) return ip - def save(self, prefix): + def get_serializable(self): """ - Saves the parameters and constants down to json files as maps from {uuid : value}, - where value is an mx.ndarray for parameters and either primitive number types or mx.ndarray for constants. - Saves up to 3 files: prefix+["_params.json", "_variable_constants.json", "_mxnet_constants.json"] - - :param prefix: The directory and any appending tag for the files to save this Inference as. - :type prefix: str , ex. "../saved_inferences/experiment_1" + Returns three dicts: + 1. MXNet parameters {uuid: mxnet parameters, mx.nd.array}. + 2. MXNet constants {uuid: mxnet parameter (only constant types), mx.nd.array} + 3. Other constants {uuid: primitive numeric types (int, float)} + :returns: Three dictionaries: MXNet parameters, MXNet constants, and other constants (in that order) + :rtypes: {uuid: mx.nd.array}, {uuid: mx.nd.array}, {uuid: primitive (int/float)} """ - param_file = prefix + "_params.json" - variable_constants_file = prefix + "_variable_constants.json" - mxnet_constants_file = prefix + "_mxnet_constants.json" - to_save = {key: value._reduce() for key, value in self._params.items()} - ndarray.save(param_file, to_save) + + mxnet_parameters = {key: value._reduce() for key, value in self._params.items()} mxnet_constants = {uuid: value for uuid, value in self._constants.items() if isinstance(value, mx.ndarray.ndarray.NDArray)} - ndarray.save(mxnet_constants_file, mxnet_constants) variable_constants = {uuid: value for uuid, value in self._constants.items() if uuid not in mxnet_constants} - import json - with open(variable_constants_file, 'w') as f: - json.dump(variable_constants, f, ensure_ascii=False) + + return mxnet_parameters, mxnet_constants, variable_constants diff --git a/mxfusion/models/factor_graph.py b/mxfusion/models/factor_graph.py index 5f27fde..c6fc0c3 100644 --- a/mxfusion/models/factor_graph.py +++ b/mxfusion/models/factor_graph.py @@ -573,18 +573,15 @@ def load_from_json(self, json_graph): return self @staticmethod - def load_graphs(graphs_file, existing_graphs=None): + def load_graphs(graphs_list, existing_graphs=None): """ - Method to load back in a graph. The graph file should be saved down using the save method, and is a JSON representation of the graph + Method to load back in a graph. The graphs should have been saved down using the save method, and be a JSON representation of the graph generated by the [networkx](https://networkx.github.io) library. - :param graph_file: The file containing the primary model to load back for this inference algorithm. - :type graph_file: str of filename + :param graphs_list: A list of raw json dicts loaded in from memory representing the FactorGraphs to create. + :type graphs_list: list of dicts loaded in using the ModelComponentDecoder class. """ import json - from ..util.graph_serialization import ModelComponentDecoder - with open(graphs_file) as f: - graphs_list = json.load(f, cls=ModelComponentDecoder) existing_graphs = existing_graphs if existing_graphs is not None else [FactorGraph(graph['name']) for graph in graphs_list] return [existing_graph.load_from_json(graph) for existing_graph, graph in zip(existing_graphs, graphs_list)] @@ -592,7 +589,7 @@ def as_json(self): """ Returns the FactorGraph in a form suitable for JSON serialization. This is assuming a JSON serializer that knows how to handle ModelComponents - such as the one defined in mxfusion.util.graph_serialization. + such as the one defined in mxfusion.util.serialization. """ json_graph = nx.readwrite.json_graph.node_link_data(self._components_graph) json_graph['name'] = self.name @@ -609,7 +606,7 @@ def save(graph_file, json_graphs): """ json_graphs = [json_graphs] if not isinstance(json_graphs, type([])) else json_graphs import json - from ..util.graph_serialization import ModelComponentEncoder + from ..util.serialization import ModelComponentEncoder if graph_file is not None: with open(graph_file, 'w') as f: json.dump(json_graphs, f, ensure_ascii=False, cls=ModelComponentEncoder) diff --git a/mxfusion/util/__init__.py b/mxfusion/util/__init__.py index 37eca96..7f58ce7 100644 --- a/mxfusion/util/__init__.py +++ b/mxfusion/util/__init__.py @@ -22,7 +22,7 @@ :toctree: _autosummary customop - graph_serialization + serialization inference testutils util diff --git a/mxfusion/util/graph_serialization.py b/mxfusion/util/graph_serialization.py deleted file mode 100644 index 8242fbd..0000000 --- a/mxfusion/util/graph_serialization.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. -# ============================================================================== - - -import json -import mxfusion as mf -from ..common.exceptions import SerializationError - - -__GRAPH_JSON_VERSION__ = '1.0' - - -class ModelComponentEncoder(json.JSONEncoder): - - def default(self, obj): - """ - Serializes a ModelComponent object. Note: does not serialize the successor attribute as it isn't necessary for serialization. - """ - if isinstance(obj, mf.components.ModelComponent): - object_dict = obj.as_json() - object_dict["version"] = __GRAPH_JSON_VERSION__ - object_dict["type"] = obj.__class__.__name__ - return object_dict - return super(ModelComponentEncoder, self).default(obj) - - -class ModelComponentDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__( - self, object_hook=self.object_hook, *args, **kwargs) - - def object_hook(self, obj): - """ - Reloads a ModelComponent object. Note: does not reload the successor attribute as it isn't necessary for serialization. - """ - if not isinstance(obj, type({})) or 'uuid' not in obj: - return obj - if obj['version'] != __GRAPH_JSON_VERSION__: - raise SerializationError('The format of the stored model component '+str(obj['name'])+' is from an old version '+str(obj['version'])+'. The current version is '+__GRAPH_JSON_VERSION__+'. Backward compatibility is not supported yet.') - if 'graphs' in obj: - v = mf.modules.Module(None, None, None, None) - v.load_module(obj) - else: - v = mf.components.ModelComponent() - v.inherited_name = obj['inherited_name'] if 'inherited_name' in obj else None - v.name = obj['name'] - v._uuid = obj['uuid'] - v.attributes = obj['attributes'] - v.type = obj['type'] - return v diff --git a/mxfusion/util/serialization.py b/mxfusion/util/serialization.py new file mode 100644 index 0000000..b059f06 --- /dev/null +++ b/mxfusion/util/serialization.py @@ -0,0 +1,135 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + + +import io +import json +import mxfusion as mf +import mxnet as mx +import numpy as np +import zipfile +from ..common.exceptions import SerializationError +from ..common.config import get_default_device + + +__GRAPH_JSON_VERSION__ = '1.0' +SERIALIZATION_VERSION = '2.0' +DEFAULT_ZIP = 'inference.zip' +FILENAMES = { + 'graphs' : 'graphs.json', + 'mxnet_params' : 'mxnet_parameters.npz', + 'mxnet_constants' : 'mxnet_constants.npz', + 'variable_constants' : 'variable_constants.json', + 'configuration' : 'configuration.json', + 'version_file' : 'version.json' +} +ENCODINGS = { + 'json' : 'json', + 'numpy' : 'numpy' +} + +class ModelComponentEncoder(json.JSONEncoder): + + def default(self, obj): + """ + Serializes a ModelComponent object. Note: does not serialize the successor attribute as it isn't necessary for serialization. + """ + if isinstance(obj, mf.components.ModelComponent): + object_dict = obj.as_json() + object_dict["version"] = __GRAPH_JSON_VERSION__ + object_dict["type"] = obj.__class__.__name__ + return object_dict + return super(ModelComponentEncoder, self).default(obj) + + +class ModelComponentDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__( + self, object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, obj): + """ + Reloads a ModelComponent object. Note: does not reload the successor attribute as it isn't necessary for serialization. + """ + if not isinstance(obj, type({})) or 'uuid' not in obj: + return obj + if obj['version'] != __GRAPH_JSON_VERSION__: + raise SerializationError('The format of the stored model component '+str(obj['name'])+' is from an old version '+str(obj['version'])+'. The current version is '+__GRAPH_JSON_VERSION__+'. Backward compatibility is not supported yet.') + if 'graphs' in obj: + v = mf.modules.Module(None, None, None, None) + v.load_module(obj) + else: + v = mf.components.ModelComponent() + v.inherited_name = obj['inherited_name'] if 'inherited_name' in obj else None + v.name = obj['name'] + v._uuid = obj['uuid'] + v.attributes = obj['attributes'] + v.type = obj['type'] + return v + +def load_json_file(target_file, decoder=None): + with open(target_file) as f: + return json.load(f, cls=decoder) + +def load_json_from_zip(zip_filename, target_file, decoder=None): + """ + Utility function that loads a json file from inside a zip file without unzipping the zip file + and returns the loaded json as a dictionary. + :param encoder: optional. a JSONDecoder class to pass to the json.load function for loading back in the dict. + """ + with zipfile.ZipFile(zip_filename, 'r') as zip_file: + json_file = zip_file.open(target_file) + # json.load only takes str in 3.4/3.5 so we read, decode to UTF-8, and convert to a StringIO + loaded = json.load(io.StringIO(json_file.read().decode()), cls=decoder) + return loaded + +def make_numpy(obj): + """ + Utility function that takes a dictionary of numpy or MXNet arrays and + returns a dictionary of numpy arrays. Used to standardize serialization. + """ + ERR_MSG = "This function shouldn't be called on anything except " + \ + " dictionaries of numpy and MXNet arrays." + if not isinstance(obj, type({})): + raise SerializationError(ERR_MSG) + + np_obj = {} + for k,v in obj.items(): + if isinstance(v, np.ndarray): + np_obj[k] = v + elif isinstance(v, mx.ndarray.ndarray.NDArray): + np_obj[k] = v.asnumpy() + else: + raise SerializationError(ERR_MSG) + return np_obj + +def load_parameters(npz_filename, zip_file, context=None): + """ + Helper function to load the parameters from a npz file directly into a dictionary as mxnet arrays. + """ + context = context if context is not None else get_default_device() + params_file = zip_file.read(npz_filename) + try: + loaded = np.load(io.BytesIO(params_file)) + except OSError as e: + """ + Numpy load doesn't handle reloading an empty .npz directory after savez so just continue with an empty + dict if it throws an OSError here when loading back. + See https://github.com/chainer/chainer/issues/4542 + """ + return {} + parameters = {} + for k,v in loaded.items(): + parameters[k] = mx.nd.array(v, dtype=v.dtype, ctx=context) + return parameters diff --git a/testing/inference/inference_parameters_test.py b/testing/inference/inference_parameters_test.py deleted file mode 100644 index 16141b8..0000000 --- a/testing/inference/inference_parameters_test.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. -# ============================================================================== - - -import unittest -import mxnet as mx -from mxfusion.components import Variable -from mxfusion.inference import InferenceParameters - - -class InferenceParametersTests(unittest.TestCase): - """ - Test class that tests the MXFusion.inference.InferenceParameters methods. - """ - - def remove_saved_files(self, prefix): - import os, glob - for filename in glob.glob(prefix+"*"): - os.remove(filename) - - def test_save_reload_constants(self): - constants = {Variable(): 5, 'uuid': mx.nd.array([1])} - ip = InferenceParameters(constants=constants) - ip.save(prefix="constants_test") - # assert the file is there - - ip2 = InferenceParameters.load_parameters( - mxnet_constants_file='constants_test_mxnet_constants.json', - variable_constants_file='constants_test_variable_constants.json') - print(ip.constants) - print(ip2.constants) - assert ip.constants == ip2.constants - - self.remove_saved_files("constants_test") diff --git a/testing/inference/inference_serialization_test.py b/testing/inference/inference_serialization_test.py index 712f161..64c5a3a 100644 --- a/testing/inference/inference_serialization_test.py +++ b/testing/inference/inference_serialization_test.py @@ -19,6 +19,7 @@ import mxnet as mx import mxnet.gluon.nn as nn import mxfusion as mf +import os from mxfusion.components.variables.var_trans import PositiveTransformation from mxfusion.components.functions import MXFusionGluonFunction from mxfusion.common.config import get_default_dtype @@ -31,13 +32,9 @@ class InferenceSerializationTests(unittest.TestCase): Test class that tests the MXFusion.utils methods. """ - def remove_saved_files(self, prefix): - import os, glob - for filename in glob.glob(prefix+"*"): - os.remove(filename) - def setUp(self): self.PREFIX = 'test_' + str(uuid.uuid4()) + self.ZIPNAME = self.PREFIX + '_inference.zip' def make_model(self, net): dtype = get_default_dtype() @@ -105,8 +102,8 @@ def test_meanfield_saving(self): infr.initialize(y=y_nd, x=x_nd) infr.run(max_iter=1, learning_rate=1e-2, y=y_nd, x=x_nd) - infr.save(prefix=self.PREFIX) - self.remove_saved_files(self.PREFIX) + infr.save(self.ZIPNAME) + os.remove(self.ZIPNAME) def test_meanfield_save_and_load(self): dtype = get_default_dtype() @@ -131,7 +128,7 @@ def test_meanfield_save_and_load(self): infr.initialize(y=y_nd, x=x_nd) infr.run(max_iter=1, learning_rate=1e-2, y=y_nd, x=x_nd) - infr.save(prefix=self.PREFIX) + infr.save(self.ZIPNAME) net2 = self.make_net() net2(x_nd) @@ -145,11 +142,7 @@ def test_meanfield_save_and_load(self): infr2.initialize(y=y_nd, x=x_nd) # Load previous parameters - infr2.load(graphs_file=self.PREFIX+'_graphs.json', - parameters_file=self.PREFIX+'_params.json', - inference_configuration_file=self.PREFIX+'_configuration.json', - mxnet_constants_file=self.PREFIX+'_mxnet_constants.json', - variable_constants_file=self.PREFIX+'_variable_constants.json') + infr2.load(self.ZIPNAME) for original_uuid, original_param in infr.params.param_dict.items(): original_data = original_param.data().asnumpy() @@ -167,7 +160,7 @@ def test_meanfield_save_and_load(self): assert np.all(np.isclose(original_data, reloaded_data)) infr2.run(max_iter=1, learning_rate=1e-2, y=y_nd, x=x_nd) - self.remove_saved_files(self.PREFIX) + os.remove(self.ZIPNAME) def test_gp_module_save_and_load(self): @@ -187,7 +180,7 @@ def test_gp_module_save_and_load(self): loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype)) - infr.save(prefix=self.PREFIX) + infr.save(self.ZIPNAME) m2 = self.make_gpregr_model(lengthscale, variance, noise_var) @@ -197,11 +190,7 @@ def test_gp_module_save_and_load(self): infr2.initialize(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype)) # Load previous parameters - infr2.load(graphs_file=self.PREFIX+'_graphs.json', - parameters_file=self.PREFIX+'_params.json', - inference_configuration_file=self.PREFIX+'_configuration.json', - mxnet_constants_file=self.PREFIX+'_mxnet_constants.json', - variable_constants_file=self.PREFIX+'_variable_constants.json') + infr2.load(self.ZIPNAME) for original_uuid, original_param in infr.params.param_dict.items(): original_data = original_param.data().asnumpy() @@ -220,4 +209,4 @@ def test_gp_module_save_and_load(self): loss2, _ = infr2.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype)) - self.remove_saved_files(self.PREFIX) + os.remove(self.ZIPNAME) diff --git a/testing/models/factor_graph_test.py b/testing/models/factor_graph_test.py index 638a243..641a9a9 100644 --- a/testing/models/factor_graph_test.py +++ b/testing/models/factor_graph_test.py @@ -322,7 +322,8 @@ def test_save_reload_bnn_graph(self): m1, _ = self.make_bnn_model(self.make_net()) FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() - FactorGraph.load_graphs(self.TESTFILE, [m1_loaded]) + from mxfusion.util.serialization import ModelComponentDecoder, load_json_file + FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded]) m1_loaded_edges = set(m1_loaded.components_graph.edges()) m1_edges = set(m1.components_graph.edges()) @@ -336,7 +337,8 @@ def test_save_reload_then_reconcile_simple_graph(self): m1 = self.make_simple_model() FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() - FactorGraph.load_graphs(self.TESTFILE, [m1_loaded]) + from mxfusion.util.serialization import ModelComponentDecoder, load_json_file + FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded]) self.assertTrue(set(m1.components) == set(m1_loaded.components)) m2 = self.make_simple_model() @@ -366,7 +368,8 @@ def test_save_reload_then_reconcile_gp_module(self): m1 = self.make_gpregr_model() FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() - FactorGraph.load_graphs(self.TESTFILE, [m1_loaded]) + from mxfusion.util.serialization import ModelComponentDecoder, load_json_file + FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded]) self.assertTrue(set(m1.components) == set(m1_loaded.components)) self.assertTrue(len(set(m1.Y.factor._module_graph.components)) == len(set(m1_loaded[m1.Y.factor.uuid]._module_graph.components))) self.assertTrue(len(set(m1.Y.factor._extra_graphs[0].components)) == len(set(m1_loaded[m1.Y.factor.uuid]._extra_graphs[0].components))) @@ -397,7 +400,8 @@ def test_save_reload_then_reconcile_bnn_graph(self): m1, _ = self.make_bnn_model(self.make_net()) FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() - FactorGraph.load_graphs(self.TESTFILE, [m1_loaded]) + from mxfusion.util.serialization import ModelComponentDecoder, load_json_file + FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded]) self.assertTrue(set(m1.components) == set(m1_loaded.components)) m2, _ = self.make_bnn_model(self.make_net()) diff --git a/testing/util/graph_serialization_test.py b/testing/util/graph_serialization_test.py index 6a3c0ec..7245df5 100644 --- a/testing/util/graph_serialization_test.py +++ b/testing/util/graph_serialization_test.py @@ -16,12 +16,12 @@ import unittest import json from mxfusion.components import Variable -from mxfusion.util.graph_serialization import ModelComponentDecoder, ModelComponentEncoder +from mxfusion.util.serialization import ModelComponentDecoder, ModelComponentEncoder class GraphSerializationTests(unittest.TestCase): """ - Tests the mxfusion.util.graph_serialization classes for ModelComponent encoding and decoding. + Tests the mxfusion.util.serialization classes for ModelComponent encoding and decoding. """ def test_encode_component(self):