From 03b952732c3c459fbb599c160e9e923d348ec4bb Mon Sep 17 00:00:00 2001 From: kaneko Date: Mon, 2 Mar 2020 16:23:08 +0900 Subject: [PATCH 1/3] remove relaxing option --- pixyz/distributions/distributions.py | 18 ++++--- .../exponential_distributions.py | 48 +++++++------------ 2 files changed, 24 insertions(+), 42 deletions(-) diff --git a/pixyz/distributions/distributions.py b/pixyz/distributions/distributions.py index a7f74481..930da68f 100644 --- a/pixyz/distributions/distributions.py +++ b/pixyz/distributions/distributions.py @@ -692,7 +692,7 @@ def dist(self): """Return the instance of PyTorch distribution.""" return self._dist - def set_dist(self, x_dict={}, relaxing=False, batch_n=None, **kwargs): + def set_dist(self, x_dict={}, batch_n=None, **kwargs): """Set :attr:`dist` as PyTorch distributions given parameters. This requires that :attr:`get_params_keys` and :attr:`get_distribution_torch_class` are set. @@ -701,8 +701,6 @@ def set_dist(self, x_dict={}, relaxing=False, batch_n=None, **kwargs): ---------- x_dict : :obj:`dict`, defaults to {}. Parameters of this distribution. - relaxing : :obj:`bool`, defaults to False. - Choose whether to use relaxed_* in PyTorch distribution. batch_n : :obj:`int`, defaults to None. Set batch size of parameters. **kwargs @@ -712,11 +710,11 @@ def set_dist(self, x_dict={}, relaxing=False, batch_n=None, **kwargs): ------- """ - params = self.get_params(x_dict, relaxing=relaxing, **kwargs) - if set(self.get_params_keys(relaxing=relaxing, **kwargs)) != set(params.keys()): + params = self.get_params(x_dict, **kwargs) + if set(self.get_params_keys(**kwargs)) != set(params.keys()): raise ValueError() - self._dist = self.get_distribution_torch_class(relaxing=relaxing, **kwargs)(**params) + self._dist = self.get_distribution_torch_class(**kwargs)(**params) # expand batch_n if batch_n: @@ -759,7 +757,7 @@ def has_reparam(self): def get_log_prob(self, x_dict, sum_features=True, feature_dims=None): _x_dict = get_dict_values(x_dict, self._cond_var, return_dict=True) - self.set_dist(_x_dict, relaxing=False) + self.set_dist(_x_dict) x_targets = get_dict_values(x_dict, self._var) log_prob = self.dist.log_prob(*x_targets) @@ -783,7 +781,7 @@ def get_params(self, params_dict={}, **kwargs): def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None): _x_dict = get_dict_values(x_dict, self._cond_var, return_dict=True) - self.set_dist(_x_dict, relaxing=False) + self.set_dist(_x_dict) entropy = self.dist.entropy() if sum_features: @@ -1074,9 +1072,9 @@ def get_params(self, params_dict={}): params_dict = replace_dict_keys(params_dict, self._replace_inv_cond_var_dict) return self.p.get_params(params_dict) - def set_dist(self, x_dict={}, sampling=False, batch_n=None, **kwargs): + def set_dist(self, x_dict={}, batch_n=None, **kwargs): x_dict = replace_dict_keys(x_dict, self._replace_inv_cond_var_dict) - return self.p.set_dist(x_dict=x_dict, relaxing=sampling, batch_n=batch_n, **kwargs) + return self.p.set_dist(x_dict=x_dict, batch_n=batch_n, **kwargs) def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True, reparam=False, **kwargs): input_dict = get_dict_values(x_dict, self.cond_var, return_dict=True) diff --git a/pixyz/distributions/exponential_distributions.py b/pixyz/distributions/exponential_distributions.py index ec61f767..e7d61ba0 100644 --- a/pixyz/distributions/exponential_distributions.py +++ b/pixyz/distributions/exponential_distributions.py @@ -65,33 +65,25 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super(Bernoulli, self).__init__(cond_var, var, name, features_shape, **_valid_param_dict({ 'probs': probs, 'temperature': temperature})) - def get_params_keys(self, relaxing=True, **kwargs): - if relaxing: - return ["probs", "temperature"] - else: - return ["probs"] + def get_params_keys(self, **kwargs): + return ["probs", "temperature"] - def get_distribution_torch_class(self, relaxing=True, **kwargs): + def get_distribution_torch_class(self, **kwargs): """Use relaxed version only when sampling""" - if relaxing: - return RelaxedBernoulliTorch - else: - return BernoulliTorch + return RelaxedBernoulliTorch @property def distribution_name(self): return "RelaxedBernoulli" - def set_dist(self, x_dict={}, relaxing=True, batch_n=None, **kwargs): - super().set_dist(x_dict, relaxing, batch_n, **kwargs) + def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None): + raise NotImplementedError() def sample_mean(self, x_dict={}): - self.set_dist(x_dict, relaxing=False) - return self.dist.mean + raise NotImplementedError() def sample_variance(self, x_dict={}): - self.set_dist(x_dict, relaxing=False) - return self.dist.variance + raise NotImplementedError() @property def has_reparam(self): @@ -153,33 +145,25 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super(Categorical, self).__init__(cond_var, var, name, features_shape, **_valid_param_dict({'probs': probs, 'temperature': temperature})) - def get_params_keys(self, relaxing=True, **kwargs): - if relaxing: - return ['probs', 'temperature'] - else: - return ['probs'] + def get_params_keys(self, **kwargs): + return ['probs', 'temperature'] - def get_distribution_torch_class(self, relaxing=True, **kwargs): + def get_distribution_torch_class(self, **kwargs): """Use relaxed version only when sampling""" - if relaxing: - return RelaxedOneHotCategoricalTorch - else: - return CategoricalTorch + return RelaxedOneHotCategoricalTorch @property def distribution_name(self): return "RelaxedCategorical" - def set_dist(self, x_dict={}, relaxing=True, batch_n=None, **kwargs): - super().set_dist(x_dict, relaxing, batch_n, **kwargs) + def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None): + raise NotImplementedError() def sample_mean(self, x_dict={}): - self.set_dist(x_dict, relaxing=False) - return self.dist.mean + raise NotImplementedError() def sample_variance(self, x_dict={}): - self.set_dist(x_dict, relaxing=False) - return self.dist.variance + raise NotImplementedError() @property def has_reparam(self): From 2146b98e983c842b4854152a79dab6bf4513bd4c Mon Sep 17 00:00:00 2001 From: kaneko Date: Tue, 3 Mar 2020 18:12:44 +0900 Subject: [PATCH 2/3] changed methods back to properties --- pixyz/distributions/distributions.py | 14 +++-- .../exponential_distributions.py | 60 ++++++++++++------- pixyz/losses/divergences.py | 2 +- pixyz/losses/entropy.py | 2 +- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/pixyz/distributions/distributions.py b/pixyz/distributions/distributions.py index 930da68f..b35c62c1 100644 --- a/pixyz/distributions/distributions.py +++ b/pixyz/distributions/distributions.py @@ -679,11 +679,13 @@ def _check_features_shape(self, features): raise ValueError("the shape of a given parameter {} and features_shape {} " "do not match.".format(features.size(), self.features_shape)) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): """list: Return the list of parameter names for this distribution.""" raise NotImplementedError() - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): """Return the class of PyTorch distribution.""" raise NotImplementedError() @@ -695,7 +697,7 @@ def dist(self): def set_dist(self, x_dict={}, batch_n=None, **kwargs): """Set :attr:`dist` as PyTorch distributions given parameters. - This requires that :attr:`get_params_keys` and :attr:`get_distribution_torch_class` are set. + This requires that :attr:`params_keys` and :attr:`distribution_torch_class` are set. Parameters ---------- @@ -711,10 +713,10 @@ def set_dist(self, x_dict={}, batch_n=None, **kwargs): """ params = self.get_params(x_dict, **kwargs) - if set(self.get_params_keys(**kwargs)) != set(params.keys()): + if set(self.params_keys) != set(params.keys()): raise ValueError() - self._dist = self.get_distribution_torch_class(**kwargs)(**params) + self._dist = self.distribution_torch_class(**params) # expand batch_n if batch_n: @@ -773,7 +775,7 @@ def get_params(self, params_dict={}, **kwargs): output_dict.update(params_dict) # append constant parameters to output_dict - constant_params_dict = get_dict_values(dict(self.named_buffers()), self.get_params_keys(**kwargs), + constant_params_dict = get_dict_values(dict(self.named_buffers()), self.params_keys, return_dict=True) output_dict.update(constant_params_dict) diff --git a/pixyz/distributions/exponential_distributions.py b/pixyz/distributions/exponential_distributions.py index e7d61ba0..71350d43 100644 --- a/pixyz/distributions/exponential_distributions.py +++ b/pixyz/distributions/exponential_distributions.py @@ -23,10 +23,12 @@ class Normal(DistributionBase): def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), loc=None, scale=None): super().__init__(cond_var, var, name, features_shape, **_valid_param_dict({'loc': loc, 'scale': scale})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["loc", "scale"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): return NormalTorch @property @@ -43,10 +45,12 @@ class Bernoulli(DistributionBase): def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), probs=None): super().__init__(cond_var, var, name, features_shape, **_valid_param_dict({'probs': probs})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["probs"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): return BernoulliTorch @property @@ -65,10 +69,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super(Bernoulli, self).__init__(cond_var, var, name, features_shape, **_valid_param_dict({ 'probs': probs, 'temperature': temperature})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["probs", "temperature"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): """Use relaxed version only when sampling""" return RelaxedBernoulliTorch @@ -120,10 +126,12 @@ def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size() super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **_valid_param_dict({'probs': probs})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["probs"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): return CategoricalTorch @property @@ -145,10 +153,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super(Categorical, self).__init__(cond_var, var, name, features_shape, **_valid_param_dict({'probs': probs, 'temperature': temperature})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ['probs', 'temperature'] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): """Use relaxed version only when sampling""" return RelaxedOneHotCategoricalTorch @@ -183,10 +193,12 @@ def __init__(self, total_count=1, cond_var=[], var=["x"], name="p", features_sha def total_count(self): return self._total_count - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["probs"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): return MultinomialTorch @property @@ -204,10 +216,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **_valid_param_dict({'concentration': concentration})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["concentration"] - def get_distribution_torch_class(self, kwargs): + @property + def distribution_torch_class(self): return DirichletTorch @property @@ -226,10 +240,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **_valid_param_dict({'concentration1': concentration1, 'concentration0': concentration0})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["concentration1", "concentration0"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): return BetaTorch @property @@ -249,10 +265,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **_valid_param_dict({'loc': loc, 'scale': scale})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["loc", "scale"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): return LaplaceTorch @property @@ -272,10 +290,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size() super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **_valid_param_dict({'concentration': concentration, 'rate': rate})) - def get_params_keys(self, **kwargs): + @property + def params_keys(self): return ["concentration", "rate"] - def get_distribution_torch_class(self, **kwargs): + @property + def distribution_torch_class(self): return GammaTorch @property diff --git a/pixyz/losses/divergences.py b/pixyz/losses/divergences.py index 2ea89168..4848bd1b 100644 --- a/pixyz/losses/divergences.py +++ b/pixyz/losses/divergences.py @@ -53,7 +53,7 @@ def _symbol(self): return sympy.Symbol("D_{{KL}} \\left[{}||{} \\right]".format(self.p.prob_text, self.q.prob_text)) def forward(self, x_dict, **kwargs): - if (not hasattr(self.p, 'get_distribution_torch_class')) or (not hasattr(self.q, 'get_distribution_torch_class')): + if (not hasattr(self.p, 'distribution_torch_class')) or (not hasattr(self.q, 'distribution_torch_class')): raise ValueError("Divergence between these two distributions cannot be evaluated, " "got %s and %s." % (self.p.distribution_name, self.q.distribution_name)) diff --git a/pixyz/losses/entropy.py b/pixyz/losses/entropy.py index 367544e2..90a0de2c 100644 --- a/pixyz/losses/entropy.py +++ b/pixyz/losses/entropy.py @@ -54,7 +54,7 @@ def _symbol(self): return sympy.Symbol(f"H \\left[ {p_text} \\right]") def forward(self, x_dict, **kwargs): - if not hasattr(self.p, 'get_distribution_torch_class'): + if not hasattr(self.p, 'distribution_torch_class'): raise ValueError("Entropy of this distribution cannot be evaluated, " "got %s." % self.p.distribution_name) From f2b6d32b42989e0b2c2e676aa02641012512a843 Mon Sep 17 00:00:00 2001 From: kaneko Date: Tue, 17 Mar 2020 18:18:52 +0900 Subject: [PATCH 3/3] bug fix --- pixyz/distributions/distributions.py | 5 ++--- pixyz/losses/losses.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pixyz/distributions/distributions.py b/pixyz/distributions/distributions.py index a12599d6..427d7a04 100644 --- a/pixyz/distributions/distributions.py +++ b/pixyz/distributions/distributions.py @@ -656,9 +656,8 @@ def _set_buffers(self, **params_dict): if params_dict[key] in self._cond_var: self.replace_params_dict[params_dict[key]] = key else: - raise ValueError("parameter setting {}:{} is not valid" - " because cond_var does not contains {}.".format( - key, params_dict[key], params_dict[key])) + raise ValueError("parameter setting {}:{} is not valid because cond_var does not contains {}." + .format(key, params_dict[key], params_dict[key])) elif isinstance(params_dict[key], torch.Tensor): features = params_dict[key] features_checked = self._check_features_shape(features) diff --git a/pixyz/losses/losses.py b/pixyz/losses/losses.py index 2f397aa3..3363abd3 100644 --- a/pixyz/losses/losses.py +++ b/pixyz/losses/losses.py @@ -821,7 +821,7 @@ class DataParalleledLoss(Loss): >>> from torch import optim >>> from torch.nn import functional as F >>> from pixyz.distributions import Bernoulli, Normal - >>> from pixyz.losses import StochasticReconstructionLoss, KullbackLeibler, DataParalleledLoss + >>> from pixyz.losses import KullbackLeibler, DataParalleledLoss >>> from pixyz.models import Model >>> used_gpu_i = set() >>> used_gpu_g = set() @@ -846,7 +846,7 @@ class DataParalleledLoss(Loss): >>> prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), ... var=["z"], features_shape=[64], name="p_{prior}") >>> # Define a loss function (Loss API) - >>> reconst = StochasticReconstructionLoss(q, p) + >>> reconst = -p.log_prob().expectation(q) >>> kl = KullbackLeibler(q, prior) >>> batch_loss_cls = (reconst - kl) >>> # device settings