Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove relaxing option #108

Merged
merged 5 commits into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions pixyz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,11 +680,13 @@ def _check_features_shape(self, features):
raise ValueError("the shape of a given parameter {} and features_shape {} "
"do not match.".format(features.size(), self.features_shape))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
"""list: Return the list of parameter names for this distribution."""
raise NotImplementedError()

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
"""Return the class of PyTorch distribution."""
raise NotImplementedError()

Expand All @@ -693,17 +695,15 @@ def dist(self):
"""Return the instance of PyTorch distribution."""
return self._dist

def set_dist(self, x_dict={}, relaxing=False, batch_n=None, **kwargs):
def set_dist(self, x_dict={}, batch_n=None, **kwargs):
"""Set :attr:`dist` as PyTorch distributions given parameters.

This requires that :attr:`get_params_keys` and :attr:`get_distribution_torch_class` are set.
This requires that :attr:`params_keys` and :attr:`distribution_torch_class` are set.

Parameters
----------
x_dict : :obj:`dict`, defaults to {}.
Parameters of this distribution.
relaxing : :obj:`bool`, defaults to False.
Choose whether to use relaxed_* in PyTorch distribution.
batch_n : :obj:`int`, defaults to None.
Set batch size of parameters.
**kwargs
Expand All @@ -713,12 +713,12 @@ def set_dist(self, x_dict={}, relaxing=False, batch_n=None, **kwargs):
-------

"""
params = self.get_params(x_dict, relaxing=relaxing, **kwargs)
if set(self.get_params_keys(relaxing=relaxing, **kwargs)) != set(params.keys()):
raise ValueError("{} class requires following parameters:"
" {}\nbut got {}".format(type(self), set(self.params_keys)), set(params.keys()))
params = self.get_params(x_dict, **kwargs)
if set(self.params_keys) != set(params.keys()):
raise ValueError("{} class requires following parameters: {}\n"
"but got {}".format(type(self), set(self.params_keys), set(params.keys())))

self._dist = self.get_distribution_torch_class(relaxing=relaxing, **kwargs)(**params)
self._dist = self.distribution_torch_class(**params)

# expand batch_n
if batch_n:
Expand Down Expand Up @@ -761,7 +761,7 @@ def has_reparam(self):

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
_x_dict = get_dict_values(x_dict, self._cond_var, return_dict=True)
self.set_dist(_x_dict, relaxing=False)
self.set_dist(_x_dict)

x_targets = get_dict_values(x_dict, self._var)
log_prob = self.dist.log_prob(*x_targets)
Expand All @@ -777,15 +777,15 @@ def get_params(self, params_dict={}, **kwargs):
output_dict.update(params_dict)

# append constant parameters to output_dict
constant_params_dict = get_dict_values(dict(self.named_buffers()), self.get_params_keys(**kwargs),
constant_params_dict = get_dict_values(dict(self.named_buffers()), self.params_keys,
return_dict=True)
output_dict.update(constant_params_dict)

return output_dict

def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
_x_dict = get_dict_values(x_dict, self._cond_var, return_dict=True)
self.set_dist(_x_dict, relaxing=False)
self.set_dist(_x_dict)

entropy = self.dist.entropy()
if sum_features:
Expand Down Expand Up @@ -1076,9 +1076,9 @@ def get_params(self, params_dict={}):
params_dict = replace_dict_keys(params_dict, self._replace_inv_cond_var_dict)
return self.p.get_params(params_dict)

def set_dist(self, x_dict={}, sampling=False, batch_n=None, **kwargs):
def set_dist(self, x_dict={}, batch_n=None, **kwargs):
x_dict = replace_dict_keys(x_dict, self._replace_inv_cond_var_dict)
return self.p.set_dist(x_dict=x_dict, relaxing=sampling, batch_n=batch_n, **kwargs)
return self.p.set_dist(x_dict=x_dict, batch_n=batch_n, **kwargs)

def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True, reparam=False, **kwargs):
input_dict = get_dict_values(x_dict, self.cond_var, return_dict=True)
Expand Down
100 changes: 52 additions & 48 deletions pixyz/distributions/exponential_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ class Normal(DistributionBase):
def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), loc=None, scale=None):
super().__init__(cond_var, var, name, features_shape, **_valid_param_dict({'loc': loc, 'scale': scale}))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["loc", "scale"]

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
return NormalTorch

@property
Expand All @@ -43,10 +45,12 @@ class Bernoulli(DistributionBase):
def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), probs=None):
super().__init__(cond_var, var, name, features_shape, **_valid_param_dict({'probs': probs}))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["probs"]

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
return BernoulliTorch

@property
Expand All @@ -65,33 +69,27 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size()
super(Bernoulli, self).__init__(cond_var, var, name, features_shape, **_valid_param_dict({
'probs': probs, 'temperature': temperature}))

def get_params_keys(self, relaxing=True, **kwargs):
if relaxing:
return ["probs", "temperature"]
else:
return ["probs"]
@property
def params_keys(self):
return ["probs", "temperature"]

def get_distribution_torch_class(self, relaxing=True, **kwargs):
@property
def distribution_torch_class(self):
"""Use relaxed version only when sampling"""
if relaxing:
return RelaxedBernoulliTorch
else:
return BernoulliTorch
return RelaxedBernoulliTorch

@property
def distribution_name(self):
return "RelaxedBernoulli"

def set_dist(self, x_dict={}, relaxing=True, batch_n=None, **kwargs):
super().set_dist(x_dict, relaxing, batch_n, **kwargs)
def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
raise NotImplementedError()

def sample_mean(self, x_dict={}):
self.set_dist(x_dict, relaxing=False)
return self.dist.mean
raise NotImplementedError()

def sample_variance(self, x_dict={}):
self.set_dist(x_dict, relaxing=False)
return self.dist.variance
raise NotImplementedError()

@property
def has_reparam(self):
Expand Down Expand Up @@ -128,10 +126,12 @@ def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size()
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'probs': probs}))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["probs"]

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
return CategoricalTorch

@property
Expand All @@ -153,33 +153,27 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size()
super(Categorical, self).__init__(cond_var, var, name, features_shape,
**_valid_param_dict({'probs': probs, 'temperature': temperature}))

def get_params_keys(self, relaxing=True, **kwargs):
if relaxing:
return ['probs', 'temperature']
else:
return ['probs']
@property
def params_keys(self):
return ['probs', 'temperature']

def get_distribution_torch_class(self, relaxing=True, **kwargs):
@property
def distribution_torch_class(self):
"""Use relaxed version only when sampling"""
if relaxing:
return RelaxedOneHotCategoricalTorch
else:
return CategoricalTorch
return RelaxedOneHotCategoricalTorch

@property
def distribution_name(self):
return "RelaxedCategorical"

def set_dist(self, x_dict={}, relaxing=True, batch_n=None, **kwargs):
super().set_dist(x_dict, relaxing, batch_n, **kwargs)
def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
raise NotImplementedError()

def sample_mean(self, x_dict={}):
self.set_dist(x_dict, relaxing=False)
return self.dist.mean
raise NotImplementedError()

def sample_variance(self, x_dict={}):
self.set_dist(x_dict, relaxing=False)
return self.dist.variance
raise NotImplementedError()

@property
def has_reparam(self):
Expand All @@ -199,10 +193,12 @@ def __init__(self, total_count=1, cond_var=[], var=["x"], name="p", features_sha
def total_count(self):
return self._total_count

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["probs"]

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
return MultinomialTorch

@property
Expand All @@ -220,10 +216,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size()
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'concentration': concentration}))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["concentration"]

def get_distribution_torch_class(self, kwargs):
@property
def distribution_torch_class(self):
return DirichletTorch

@property
Expand All @@ -242,10 +240,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size()
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'concentration1': concentration1, 'concentration0': concentration0}))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["concentration1", "concentration0"]

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
return BetaTorch

@property
Expand All @@ -265,10 +265,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size()
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'loc': loc, 'scale': scale}))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["loc", "scale"]

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
return LaplaceTorch

@property
Expand All @@ -288,10 +290,12 @@ def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size()
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'concentration': concentration, 'rate': rate}))

def get_params_keys(self, **kwargs):
@property
def params_keys(self):
return ["concentration", "rate"]

def get_distribution_torch_class(self, **kwargs):
@property
def distribution_torch_class(self):
return GammaTorch

@property
Expand Down
2 changes: 1 addition & 1 deletion pixyz/losses/divergences.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _symbol(self):
return sympy.Symbol("D_{{KL}} \\left[{}||{} \\right]".format(self.p.prob_text, self.q.prob_text))

def forward(self, x_dict, **kwargs):
if (not hasattr(self.p, 'get_distribution_torch_class')) or (not hasattr(self.q, 'get_distribution_torch_class')):
if (not hasattr(self.p, 'distribution_torch_class')) or (not hasattr(self.q, 'distribution_torch_class')):
raise ValueError("Divergence between these two distributions cannot be evaluated, "
"got %s and %s." % (self.p.distribution_name, self.q.distribution_name))

Expand Down
2 changes: 1 addition & 1 deletion pixyz/losses/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _symbol(self):
return sympy.Symbol("H \\left[ {} \\right]".format(p_text))

def forward(self, x_dict, **kwargs):
if not hasattr(self.p, 'get_distribution_torch_class'):
if not hasattr(self.p, 'distribution_torch_class'):
raise ValueError("Entropy of this distribution cannot be evaluated, "
"got %s." % self.p.distribution_name)

Expand Down
4 changes: 2 additions & 2 deletions pixyz/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ class DataParalleledLoss(Loss):
>>> from torch import optim
>>> from torch.nn import functional as F
>>> from pixyz.distributions import Bernoulli, Normal
>>> from pixyz.losses import StochasticReconstructionLoss, KullbackLeibler, DataParalleledLoss
>>> from pixyz.losses import KullbackLeibler, DataParalleledLoss
>>> from pixyz.models import Model
>>> used_gpu_i = set()
>>> used_gpu_g = set()
Expand All @@ -846,7 +846,7 @@ class DataParalleledLoss(Loss):
>>> prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
... var=["z"], features_shape=[64], name="p_{prior}")
>>> # Define a loss function (Loss API)
>>> reconst = StochasticReconstructionLoss(q, p)
>>> reconst = -p.log_prob().expectation(q)
>>> kl = KullbackLeibler(q, prior)
>>> batch_loss_cls = (reconst - kl)
>>> # device settings
Expand Down