Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

GP replicate_self fixes #185

Merged
merged 10 commits into from
Jul 5, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mxfusion/inference/inference_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def replicate_self(self, model, extra_graphs=None):
replicant._observed = set(observed)
replicant._observed_uuid = variables_to_UUID(observed)
replicant._observed_names = [v.name for v in observed]
replicant._graphs = self._graphs
return replicant

def __init__(self, model, observed, extra_graphs=None):
Expand Down
25 changes: 25 additions & 0 deletions mxfusion/modules/gp_modules/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def compute(self, F, variables):
self.set_parameter(variables, self.posterior.LinvY, LinvY[0])
return logL

def replicate_self(self, model, extra_graphs=None):
rep = super().replicate_self(model, extra_graphs)
rep.jitter = self.jitter
return rep


class GPRegressionSampling(SamplingAlgorithm):
"""
Expand Down Expand Up @@ -134,6 +139,11 @@ def compute(self, F, variables):
else:
return samples

def replicate_self(self, model, extra_graphs=None):
rep = super().replicate_self(model, extra_graphs)
rep._rand_gen = self._rand_gen
return rep


class GPRegressionMeanVariancePrediction(SamplingAlgorithm):
def __init__(self, model, posterior, observed, noise_free=True,
Expand Down Expand Up @@ -195,6 +205,12 @@ def compute(self, F, variables):
else:
return outcomes

def replicate_self(self, model, extra_graphs=None):
rep = super().replicate_self(model, extra_graphs)
rep.diagonal_variance = self.diagonal_variance
rep.noise_free = self.noise_free
return rep


class GPRegressionSamplingPrediction(SamplingAlgorithm):
"""
Expand Down Expand Up @@ -274,6 +290,14 @@ def compute(self, F, variables):
else:
return outcomes

def replicate_self(self, model, extra_graphs=None):
rep = super().replicate_self(model, extra_graphs)
rep._rand_gen = self._rand_gen
rep.diagonal_variance = self.diagonal_variance
rep.jitter = self.jitter
rep.noise_free = self.noise_free
return rep


class GPRegression(Module):
"""
Expand Down Expand Up @@ -425,4 +449,5 @@ def replicate_self(self, attribute_map=None):

rep.kernel = self.kernel.replicate_self(attribute_map)
rep._has_mean = self._has_mean
rep._module_graph.kernel = rep.kernel
return rep
1 change: 1 addition & 0 deletions mxfusion/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def _clone_algorithms(self, algorithms, replicant):
algs = {}
for conditionals, algorithms in algorithms.items():
for targets, algorithm, alg_name in algorithms:
alg_name = replicant._set_algorithm_name(alg_name, algorithm)
graphs_index = {g: i for i,g in enumerate(self._extra_graphs)}
extra_graphs = [replicant._extra_graphs[graphs_index[graph]] for graph in algorithm.graphs
if graph in graphs_index]
Expand Down
86 changes: 80 additions & 6 deletions testing/modules/gpregression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,86 @@ def test_prediction_print(self):
print = infr.print_params()
assert (len(print) > 1)

def test_module_clone(self):
def test_module_clone_prediction(self):
D, X, Y, noise_var, lengthscale, variance = self.gen_data()
dtype = 'float64'

m = Model()
m.N = Variable()
kernel = RBF(input_dim=3, ARD=True, variance=mx.nd.array(variance, dtype=dtype), lengthscale=mx.nd.array(lengthscale, dtype=dtype), dtype=dtype)
m.Y = GPRegression.define_variable(X=mx.nd.zeros((2, 3)), kernel=kernel, noise_var=mx.nd.ones((1,)), dtype=dtype)
m.clone()
# Predict from original model
m = self.gen_mxfusion_model(dtype, D, noise_var, lengthscale, variance)

observed = [m.X, m.Y]
infr = Inference(MAP(model=m, observed=observed), dtype=dtype)

loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1)

infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]),
infr_params=infr.params, dtype=np.float64)
infr2.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False
infr2.inference_algorithm.model.Y.factor.gp_predict.noise_free = False
res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0]
mu_mf, var_mf = res[0].asnumpy()[0], res[1].asnumpy()[0]

# Clone model
cloned_model = m.clone()

# Predict from cloned model
observed = [cloned_model.X, cloned_model.Y]
infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype)

loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1)

infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X],
target_variables=[cloned_model.Y]),
infr_params=infr.params, dtype=np.float64)

infr2_clone.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False
infr2_clone.inference_algorithm.model.Y.factor.gp_predict.noise_free = False
res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0]
mu_mf_clone, var_mf_clone = res[0].asnumpy()[0], res[1].asnumpy()[0]

assert np.allclose(mu_mf, mu_mf_clone)
assert np.allclose(var_mf, var_mf_clone)

def test_module_clone_sampling(self):
D, X, Y, noise_var, lengthscale, variance = self.gen_data()
dtype = 'float64'

# Predict from original model
m = self.gen_mxfusion_model(dtype, D, noise_var, lengthscale, variance)

observed = [m.X, m.Y]
infr = Inference(MAP(model=m, observed=observed), dtype=dtype)

loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1)

gp = m.Y.factor
gp.attach_prediction_algorithms(
targets=gp.output_names, conditionals=gp.input_names,
algorithm=GPRegressionSamplingPrediction(
gp._module_graph, gp._extra_graphs[0], [gp._module_graph.X]),
alg_name='gp_predict')
mx.random.seed(123)
infr2 = TransferInference(ModulePredictionAlgorithm(m, observed=[m.X], target_variables=[m.Y]),
infr_params=infr.params, dtype=np.float64)
infr2.inference_algorithm.model.Y.factor.gp_predict.diagonal_variance = False
infr2.inference_algorithm.model.Y.factor.gp_predict.noise_free = False
res = infr2.run(X=mx.nd.array(X, dtype=dtype))[0]
samples = res[0].asnumpy()

# Clone model
cloned_model = m.clone()

# Predict from cloned model
observed = [cloned_model.X, cloned_model.Y]
infr = Inference(MAP(model=cloned_model, observed=observed), dtype=dtype)

loss, _ = infr.run(X=mx.nd.array(X, dtype=dtype), Y=mx.nd.array(Y, dtype=dtype), max_iter=1)
mx.random.seed(123)
infr2_clone = TransferInference(ModulePredictionAlgorithm(cloned_model, observed=[cloned_model.X],
target_variables=[cloned_model.Y]),
infr_params=infr.params, dtype=np.float64)

res = infr2_clone.run(X=mx.nd.array(X, dtype=dtype))[0]
samples_cloned = res[0].asnumpy()

assert np.allclose(samples, samples_cloned)