diff --git a/mxfusion/inference/inference_alg.py b/mxfusion/inference/inference_alg.py index 573fbe9..6871175 100644 --- a/mxfusion/inference/inference_alg.py +++ b/mxfusion/inference/inference_alg.py @@ -106,6 +106,7 @@ def replicate_self(self, model, extra_graphs=None): replicant._observed = set(observed) replicant._observed_uuid = variables_to_UUID(observed) replicant._observed_names = [v.name for v in observed] + replicant._graphs = self._graphs return replicant def __init__(self, model, observed, extra_graphs=None): diff --git a/mxfusion/modules/gp_modules/gp_regression.py b/mxfusion/modules/gp_modules/gp_regression.py index 7ba80c6..b3c3fe6 100644 --- a/mxfusion/modules/gp_modules/gp_regression.py +++ b/mxfusion/modules/gp_modules/gp_regression.py @@ -75,6 +75,11 @@ def compute(self, F, variables): self.set_parameter(variables, self.posterior.LinvY, LinvY[0]) return logL + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.jitter = self.jitter + return rep + class GPRegressionSampling(SamplingAlgorithm): """ @@ -134,6 +139,11 @@ def compute(self, F, variables): else: return samples + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep._rand_gen = self._rand_gen + return rep + class GPRegressionMeanVariancePrediction(SamplingAlgorithm): def __init__(self, model, posterior, observed, noise_free=True, @@ -195,6 +205,12 @@ def compute(self, F, variables): else: return outcomes + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.diagonal_variance = self.diagonal_variance + rep.noise_free = self.noise_free + return rep + class GPRegressionSamplingPrediction(SamplingAlgorithm): """ @@ -274,6 +290,14 @@ def compute(self, F, variables): else: return outcomes + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep._rand_gen = self._rand_gen + rep.diagonal_variance = self.diagonal_variance + rep.jitter = self.jitter + rep.noise_free = self.noise_free + return rep + class GPRegression(Module): """ @@ -425,4 +449,5 @@ def replicate_self(self, attribute_map=None): rep.kernel = self.kernel.replicate_self(attribute_map) rep._has_mean = self._has_mean + rep._module_graph.kernel = rep.kernel # TODO: put this into factor graph clone method return rep diff --git a/mxfusion/modules/gp_modules/sparsegp_regression.py b/mxfusion/modules/gp_modules/sparsegp_regression.py index 57b7e48..44d8b37 100644 --- a/mxfusion/modules/gp_modules/sparsegp_regression.py +++ b/mxfusion/modules/gp_modules/sparsegp_regression.py @@ -107,6 +107,11 @@ def compute(self, F, variables): return logL + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.jitter = self.jitter + return rep + class SparseGPRegressionMeanVariancePrediction(SamplingAlgorithm): def __init__(self, model, posterior, observed, target_variables=None, @@ -173,6 +178,12 @@ def compute(self, F, variables): else: return outcomes + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.diagonal_variance = self.diagonal_variance + rep.noise_free = self.noise_free + return rep + class SparseGPRegressionSamplingPrediction(SamplingAlgorithm): def __init__(self, model, posterior, observed, rand_gen=None, @@ -254,6 +265,14 @@ def compute(self, F, variables): else: return outcomes + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.noise_free = self.noise_free + rep._rand_gen = self._rand_gen + rep.diagonal_variance = self.diagonal_variance + rep.jitter = self.jitter + return rep + class SparseGPRegression(Module): """ @@ -427,4 +446,5 @@ def replicate_self(self, attribute_map=None): rep.kernel = self.kernel.replicate_self(attribute_map) rep._has_mean = self._has_mean + rep._module_graph.kernel = rep.kernel return rep diff --git a/mxfusion/modules/gp_modules/svgp_regression.py b/mxfusion/modules/gp_modules/svgp_regression.py index bfbda08..efad825 100644 --- a/mxfusion/modules/gp_modules/svgp_regression.py +++ b/mxfusion/modules/gp_modules/svgp_regression.py @@ -108,6 +108,11 @@ def compute(self, F, variables): logL = self.log_pdf_scaling*logL + KL_u return logL + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.jitter = self.jitter + return rep + class SVGPRegressionMeanVariancePrediction(SamplingAlgorithm): def __init__(self, model, posterior, observed, noise_free=True, @@ -188,6 +193,13 @@ def compute(self, F, variables): else: return outcomes + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.jitter = self.jitter + rep.noise_free = self.noise_free + rep.diagonal_variance = self.diagonal_variance + return rep + class SVGPRegressionSamplingPrediction(SamplingAlgorithm): def __init__(self, model, posterior, observed, rand_gen=None, @@ -279,6 +291,14 @@ def compute(self, F, variables): else: return outcomes + def replicate_self(self, model, extra_graphs=None): + rep = super().replicate_self(model, extra_graphs) + rep.jitter = self.jitter + rep.noise_free = self.noise_free + rep.diagonal_variance = self.diagonal_variance + rep._rand_gen = self._rand_gen + return rep + class SVGPRegression(Module): """ @@ -454,4 +474,5 @@ def replicate_self(self, attribute_map=None): rep.kernel = self.kernel.replicate_self(attribute_map) rep._has_mean = self._has_mean + rep._module_graph.kernel = rep.kernel return rep diff --git a/mxfusion/modules/module.py b/mxfusion/modules/module.py index 75a247c..5d2c94b 100644 --- a/mxfusion/modules/module.py +++ b/mxfusion/modules/module.py @@ -424,6 +424,7 @@ def _clone_algorithms(self, algorithms, replicant): algs = {} for conditionals, algorithms in algorithms.items(): for targets, algorithm, alg_name in algorithms: + alg_name = replicant._set_algorithm_name(alg_name, algorithm) graphs_index = {g: i for i,g in enumerate(self._extra_graphs)} extra_graphs = [replicant._extra_graphs[graphs_index[graph]] for graph in algorithm.graphs if graph in graphs_index] diff --git a/testing/modules/gpregression_test.py b/testing/modules/gpregression_test.py index b982a1e..e649ed2 100644 --- a/testing/modules/gpregression_test.py +++ b/testing/modules/gpregression_test.py @@ -366,12 +366,126 @@ def test_prediction_print(self): print = infr.print_params() assert (len(print) > 1) - def test_module_clone(self): + def test_module_clone_prediction(self): D, X, Y, noise_var, lengthscale, variance = self.gen_data() dtype = 'float64' - m = Model() - m.N = Variable() - kernel = RBF(input_dim=3, ARD=True, variance=mx.nd.array(variance, dtype=dtype), lengthscale=mx.nd.array(lengthscale, dtype=dtype), dtype=dtype) - m.Y = GPRegression.define_variable(X=mx.nd.zeros((2, 3)), kernel=kernel, noise_var=mx.nd.ones((1,)), dtype=dtype) - m.clone() + # Predict from original model + m = self.gen_mxfusion_model(dtype, D, noise_var, lengthscale, variance) + + observed = [m.X, m.Y] + infr = Inference(MAP(model=m, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]), + infr_params=infr.params, dtype=np.float64) + infr2.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False + infr2.inference_algorithm.model.Y.factor.gp_predict.noise_free = False + res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf, var_mf = res[0].asnumpy()[0], res[1].asnumpy()[0] + + # Clone model + cloned_model = m.clone() + + # Predict from cloned model + observed = [cloned_model.X, cloned_model.Y] + infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X], + target_variables=[cloned_model.Y]), + infr_params=infr.params, dtype=np.float64) + + infr2_clone.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False + infr2_clone.inference_algorithm.model.Y.factor.gp_predict.noise_free = False + res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf_clone, var_mf_clone = res[0].asnumpy()[0], res[1].asnumpy()[0] + + assert np.allclose(mu_mf, mu_mf_clone) + assert np.allclose(var_mf, var_mf_clone) + + def test_module_clone_sampling(self): + D, X, Y, noise_var, lengthscale, variance = self.gen_data() + dtype = 'float64' + + # Predict from original model + m = self.gen_mxfusion_model(dtype, D, noise_var, lengthscale, variance) + + observed = [m.X, m.Y] + infr = Inference(MAP(model=m, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + gp = m.Y.factor + gp.attach_prediction_algorithms( + targets=gp.output_names, conditionals=gp.input_names, + algorithm=GPRegressionSamplingPrediction( + gp._module_graph, gp._extra_graphs[0], [gp._module_graph.X]), + alg_name='gp_predict') + mx.random.seed(123) + infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]), + infr_params=infr.params, dtype=np.float64) + infr2.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False + infr2.inference_algorithm.model.Y.factor.gp_predict.noise_free = False + res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0] + samples = res[0].asnumpy() + + # Clone model + cloned_model = m.clone() + + # Predict from cloned model + observed = [cloned_model.X, cloned_model.Y] + infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + mx.random.seed(123) + infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X], + target_variables=[cloned_model.Y]), + infr_params=infr.params, dtype=np.float64) + + res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0] + samples_cloned = res[0].asnumpy() + + assert np.allclose(samples, samples_cloned) + + def test_module_clone_prediction_w_mean(self): + D, X, Y, noise_var, lengthscale, variance = self.gen_data() + dtype = 'float64' + + # Predict from original model + m, net = self.gen_mxfusion_model_w_mean(dtype, D, noise_var, lengthscale, variance) + + observed = [m.X, m.Y] + infr = Inference(MAP(model=m, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]), + infr_params=infr.params, dtype=np.float64) + infr2.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False + infr2.inference_algorithm.model.Y.factor.gp_predict.noise_free = False + res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf, var_mf = res[0].asnumpy()[0], res[1].asnumpy()[0] + + # Clone model + cloned_model = m.clone() + + # Predict from cloned model + observed = [cloned_model.X, cloned_model.Y] + infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X], + target_variables=[cloned_model.Y]), + infr_params=infr.params, dtype=np.float64) + + infr2_clone.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False + infr2_clone.inference_algorithm.model.Y.factor.gp_predict.noise_free = False + res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf_clone, var_mf_clone = res[0].asnumpy()[0], res[1].asnumpy()[0] + + assert np.allclose(mu_mf, mu_mf_clone) + assert np.allclose(var_mf, var_mf_clone) diff --git a/testing/modules/sparsegpregression_test.py b/testing/modules/sparsegpregression_test.py index 274ec80..78e0785 100644 --- a/testing/modules/sparsegpregression_test.py +++ b/testing/modules/sparsegpregression_test.py @@ -338,3 +338,93 @@ def test_module_clone(self): kernel = RBF(input_dim=3, ARD=True, variance=mx.nd.array(variance, dtype=dtype), lengthscale=mx.nd.array(lengthscale, dtype=dtype), dtype=dtype) m.Y = SparseGPRegression.define_variable(X=mx.nd.zeros((2, 3)), kernel=kernel, noise_var=mx.nd.ones((1,)), dtype=dtype) m.clone() + + def test_module_clone_prediction(self): + D, X, Y, Z, noise_var, lengthscale, variance = self.gen_data() + dtype = 'float64' + + # Predict from original model + m = self.gen_mxfusion_model(dtype, D, Z, noise_var, lengthscale, variance) + + observed = [m.X, m.Y] + infr = Inference(MAP(model=m, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]), + infr_params=infr.params, dtype=np.float64) + infr2.inference_algorithm.model.Y.factor.sgp_predict.diagonal_variance = False + infr2.inference_algorithm.model.Y.factor.sgp_predict.noise_free = False + res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf, var_mf = res[0].asnumpy()[0], res[1].asnumpy()[0] + + # Clone model + cloned_model = m.clone() + + # Predict from cloned model + observed = [cloned_model.X, cloned_model.Y] + infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X], + target_variables=[cloned_model.Y]), + infr_params=infr.params, dtype=np.float64) + + infr2_clone.inference_algorithm.model.Y.factor.sgp_predict.diagonal_variance = False + infr2_clone.inference_algorithm.model.Y.factor.sgp_predict.noise_free = False + res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf_clone, var_mf_clone = res[0].asnumpy()[0], res[1].asnumpy()[0] + + assert np.allclose(mu_mf, mu_mf_clone) + assert np.allclose(var_mf, var_mf_clone) + + def test_module_clone_samples(self): + D, X, Y, Z, noise_var, lengthscale, variance = self.gen_data() + dtype = 'float64' + + # Predict from original model + m = self.gen_mxfusion_model(dtype, D, Z, noise_var, lengthscale, variance) + + observed = [m.X, m.Y] + infr = Inference(MAP(model=m, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]), + infr_params=infr.params, dtype=np.float64) + + gp = m.Y.factor + gp.attach_prediction_algorithms( + targets=gp.output_names, conditionals=gp.input_names, + algorithm=SparseGPRegressionSamplingPrediction( + gp._module_graph, gp._extra_graphs[0], [gp._module_graph.X]), + alg_name='sgp_predict') + + infr2.inference_algorithm.model.Y.factor.sgp_predict.diagonal_variance = False + infr2.inference_algorithm.model.Y.factor.sgp_predict.noise_free = False + + mx.random.seed(123) + res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0] + samples = res.asnumpy() + + # Clone model + cloned_model = m.clone() + + # Predict from cloned model + observed = [cloned_model.X, cloned_model.Y] + infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X], + target_variables=[cloned_model.Y]), + infr_params=infr.params, dtype=np.float64) + + infr2_clone.inference_algorithm.model.Y.factor.sgp_predict.diagonal_variance = False + infr2_clone.inference_algorithm.model.Y.factor.sgp_predict.noise_free = False + + mx.random.seed(123) + res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0] + samples_clone = res.asnumpy() + assert np.allclose(samples, samples_clone) diff --git a/testing/modules/svgpregression_test.py b/testing/modules/svgpregression_test.py index 06ce5b8..2c1276c 100644 --- a/testing/modules/svgpregression_test.py +++ b/testing/modules/svgpregression_test.py @@ -416,3 +416,116 @@ def test_module_clone(self): kernel = RBF(input_dim=3, ARD=True, variance=mx.nd.array(variance, dtype=dtype), lengthscale=mx.nd.array(lengthscale, dtype=dtype), dtype=dtype) m.Y = SVGPRegression.define_variable(X=mx.nd.zeros((2, 3)), kernel=kernel, noise_var=mx.nd.ones((1,)), dtype=dtype) m.clone() + + def test_module_clone_prediction(self): + D, X, Y, Z, noise_var, lengthscale, variance, qU_mean, \ + qU_cov_W, qU_cov_diag, qU_chol = self.gen_data() + dtype = 'float64' + + # Predict from original model + m, gp = self.gen_mxfusion_model(dtype, D, Z, noise_var, lengthscale, variance) + + observed = [m.X, m.Y] + infr = Inference(MAP(model=m, observed=observed), dtype=dtype) + + infr.initialize(X=X.shape, Y=Y.shape) + infr.params[gp._extra_graphs[0].qU_mean] = mx.nd.array(qU_mean, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_W] = mx.nd.array(qU_cov_W, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_diag] = mx.nd.array(qU_cov_diag, dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]), + infr_params=infr.params, dtype=np.float64) + infr2.inference_algorithm.model.Y.factor.svgp_predict.diagonal_variance = False + infr2.inference_algorithm.model.Y.factor.svgp_predict.noise_free = False + res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf, var_mf = res[0].asnumpy()[0], res[1].asnumpy()[0] + + # Clone model + cloned_model = m.clone() + + # Predict from cloned model + observed = [cloned_model.X, cloned_model.Y] + infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype) + + infr.initialize(X=X.shape, Y=Y.shape) + infr.params[gp._extra_graphs[0].qU_mean] = mx.nd.array(qU_mean, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_W] = mx.nd.array(qU_cov_W, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_diag] = mx.nd.array(qU_cov_diag, dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X], + target_variables=[cloned_model.Y]), + infr_params=infr.params, dtype=np.float64) + + infr2_clone.inference_algorithm.model.Y.factor.svgp_predict.diagonal_variance = False + infr2_clone.inference_algorithm.model.Y.factor.svgp_predict.noise_free = False + res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0] + mu_mf_clone, var_mf_clone = res[0].asnumpy()[0], res[1].asnumpy()[0] + + assert np.allclose(mu_mf, mu_mf_clone) + assert np.allclose(var_mf, var_mf_clone) + + def test_module_clone_samples(self): + D, X, Y, Z, noise_var, lengthscale, variance, qU_mean, \ + qU_cov_W, qU_cov_diag, qU_chol = self.gen_data() + dtype = 'float64' + + # Predict from original model + m, gp = self.gen_mxfusion_model(dtype, D, Z, noise_var, lengthscale, variance) + + observed = [m.X, m.Y] + infr = Inference(MAP(model=m, observed=observed), dtype=dtype) + + infr.initialize(X=X.shape, Y=Y.shape) + infr.params[gp._extra_graphs[0].qU_mean] = mx.nd.array(qU_mean, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_W] = mx.nd.array(qU_cov_W, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_diag] = mx.nd.array(qU_cov_diag, dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]), + infr_params=infr.params, dtype=np.float64) + + gp = m.Y.factor + gp.attach_prediction_algorithms( + targets=gp.output_names, conditionals=gp.input_names, + algorithm=SVGPRegressionSamplingPrediction( + gp._module_graph, gp._extra_graphs[0], [gp._module_graph.X]), + alg_name='svgp_predict') + + infr2.inference_algorithm.model.Y.factor.svgp_predict.diagonal_variance = False + infr2.inference_algorithm.model.Y.factor.svgp_predict.noise_free = False + + mx.random.seed(123) + res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0] + samples = res.asnumpy() + + # Clone model + cloned_model = m.clone() + + # Predict from cloned model + observed = [cloned_model.X, cloned_model.Y] + infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype) + + infr.initialize(X=X.shape, Y=Y.shape) + infr.params[gp._extra_graphs[0].qU_mean] = mx.nd.array(qU_mean, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_W] = mx.nd.array(qU_cov_W, dtype=dtype) + infr.params[gp._extra_graphs[0].qU_cov_diag] = mx.nd.array(qU_cov_diag, dtype=dtype) + + loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1) + + infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X], + target_variables=[cloned_model.Y]), + infr_params=infr.params, dtype=np.float64) + + infr2_clone.inference_algorithm.model.Y.factor.svgp_predict.diagonal_variance = False + infr2_clone.inference_algorithm.model.Y.factor.svgp_predict.noise_free = False + + mx.random.seed(123) + res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0] + samples_clone = res.asnumpy() + assert np.allclose(samples, samples_clone) +