Skip to content

Commit

Permalink
Merge pull request #163 from masa-su/fix/feature_dims
Browse files Browse the repository at this point in the history
Fix/feature dims
  • Loading branch information
masa-su authored Mar 2, 2021
2 parents ec5135a + 585e51b commit 7ad8ff5
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 54 deletions.
4 changes: 2 additions & 2 deletions pixyz/distributions/custom_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def input_var(self):
def distribution_name(self):
return self._distribution_name

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
x_dict = get_dict_values(x_dict, self._var, return_dict=True)
log_prob = self.log_prob_function(**x_dict)
if sum_features:
log_prob = sum_samples(log_prob)
log_prob = sum_samples(log_prob, feature_dims)

return log_prob

Expand Down
30 changes: 18 additions & 12 deletions pixyz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,15 @@ def forward(self, mode, kwargs):
else:
raise ValueError()

def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True, reparam=False, sample_mean=False):
return self('sample', kwargs={'x_dict': x_dict, 'batch_n': batch_n, 'sample_shape': sample_shape,
'return_all': return_all, 'reparam': reparam, 'sample_mean': sample_mean})
def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True, reparam=False,
sample_mean=False, **kwargs):
_kwargs = dict(x_dict=x_dict, batch_n=batch_n, sample_shape=sample_shape,
return_all=return_all, reparam=reparam, sample_mean=sample_mean)
_kwargs.update(kwargs)
return self('sample', kwargs=_kwargs)

def _sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True, reparam=False, sample_mean=False):
def _sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True, reparam=False,
sample_mean=False, **kwargs):
"""
Sample variables of this distribution.
If :attr:`cond_var` is not empty, you should set inputs as :obj:`dict`.
Expand Down Expand Up @@ -537,6 +541,7 @@ def _sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all
sample_option = dict(self.global_option)
sample_option.update(dict(batch_n=batch_n, sample_shape=sample_shape,
return_all=False, reparam=reparam, sample_mean=sample_mean))
sample_option.update(kwargs)
# ignore return_all because overriding is now under control.
if not(set(x_dict) >= set(self.input_var)):
raise ValueError(f"Input keys are not valid, expected {set(self.input_var)} but got {set(x_dict)}.")
Expand All @@ -554,11 +559,11 @@ def _sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all
else:
return delete_dict_values(result_dict, self.input_var)

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
return self(mode='get_log_prob', kwargs={'x_dict': x_dict, 'sum_features': sum_features,
'feature_dims': feature_dims})

def _get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def _get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
""" Giving variables, this method returns values of log-pdf.
Parameters
Expand Down Expand Up @@ -626,6 +631,7 @@ def _get_log_prob(self, x_dict, sum_features=True, feature_dims=None):

log_prob_option = dict(self.global_option)
log_prob_option.update(dict(sum_features=sum_features, feature_dims=feature_dims))
log_prob_option.update(kwargs)

require_var = self.var + self.cond_var
if not(set(x_dict) >= set(require_var)):
Expand Down Expand Up @@ -999,7 +1005,7 @@ def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=
"""
if self.graph:
return self.graph.sample(x_dict, batch_n, sample_shape, return_all, reparam, sample_mean)
return self.graph.sample(x_dict, batch_n, sample_shape, return_all, reparam, sample_mean, **kwargs)
raise NotImplementedError()

@property
Expand Down Expand Up @@ -1069,7 +1075,7 @@ def sample_variance(self, x_dict={}):
"""
raise NotImplementedError()

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
"""Giving variables, this method returns values of log-pdf.
Parameters
Expand Down Expand Up @@ -1108,7 +1114,7 @@ def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
"""
if self.graph:
return self.graph.get_log_prob(x_dict, sum_features, feature_dims)
return self.graph.get_log_prob(x_dict, sum_features, feature_dims, **kwargs)
raise NotImplementedError()

def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
Expand Down Expand Up @@ -1430,7 +1436,7 @@ def get_sample(self, reparam=False, sample_shape=torch.Size()):
def has_reparam(self):
raise NotImplementedError()

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
_x_dict = get_dict_values(x_dict, self._cond_var, return_dict=True)
self.set_dist(_x_dict)

Expand All @@ -1439,7 +1445,7 @@ def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
raise ValueError(f"x_dict has no value of the stochastic variable. x_dict: {x_dict}")
log_prob = self.dist.log_prob(*x_targets)
if sum_features:
log_prob = sum_samples(log_prob)
log_prob = sum_samples(log_prob, feature_dims)

return log_prob

Expand Down Expand Up @@ -1514,7 +1520,7 @@ def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):

entropy = self.dist.entropy()
if sum_features:
entropy = sum_samples(entropy)
entropy = sum_samples(entropy, feature_dims)

return entropy

Expand Down
4 changes: 2 additions & 2 deletions pixyz/distributions/exponential_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def __init__(self, var=['x'], cond_var=[], name='p', features_shape=torch.Size()
def distribution_name(self):
return "FactorizedBernoulli"

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
log_prob = super().get_log_prob(x_dict, sum_features=False)
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
log_prob = super().get_log_prob(x_dict, sum_features=False, **kwargs)
[_x] = get_dict_values(x_dict, self._var)
log_prob[_x == 0] = 0
if sum_features:
Expand Down
14 changes: 8 additions & 6 deletions pixyz/distributions/flow_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def logdet_jacobian(self):
def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True, reparam=False,
compute_jacobian=True, **kwargs):
# sample from the prior
sample_dict = self.prior.sample(x_dict, batch_n=batch_n, sample_shape=sample_shape, return_all=False)
sample_dict = self.prior.sample(x_dict, batch_n=batch_n, sample_shape=sample_shape, return_all=False, **kwargs)

# flow transformation
_x = get_dict_values(sample_dict, self.flow_input_var)[0]
Expand All @@ -83,14 +83,15 @@ def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=
def has_reparam(self):
return self.prior.has_reparam

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, compute_jacobian=False):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, compute_jacobian=False, **kwargs):
"""
It calculates the log-likelihood for a given z.
If a flow module has no inverse method, it only supports the previously sampled z-values.
"""
inf_dict = self._inference(x_dict, compute_jacobian=compute_jacobian)
# prior
log_prob_prior = self.prior.get_log_prob(inf_dict, sum_features=sum_features, feature_dims=feature_dims)
log_prob_prior = self.prior.get_log_prob(inf_dict, sum_features=sum_features, feature_dims=feature_dims,
**kwargs)

return log_prob_prior - self.logdet_jacobian

Expand Down Expand Up @@ -232,7 +233,7 @@ def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=
return_hidden=True, sample_mean=False, **kwargs):
# sample from the prior
sample_dict = self.prior.sample(x_dict, batch_n=batch_n, sample_shape=sample_shape, return_all=False,
reparam=reparam, sample_mean=sample_mean)
reparam=reparam, sample_mean=sample_mean, **kwargs)

# inverse flow transformation
_z = get_dict_values(sample_dict, self.flow_output_var)
Expand Down Expand Up @@ -276,12 +277,13 @@ def inference(self, x_dict, return_all=True, compute_jacobian=False):

return output_dict

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
# flow
output_dict = self.inference(x_dict, return_all=True, compute_jacobian=True)

# prior
log_prob_prior = self.prior.get_log_prob(output_dict, sum_features=sum_features, feature_dims=feature_dims)
log_prob_prior = self.prior.get_log_prob(output_dict, sum_features=sum_features, feature_dims=feature_dims,
**kwargs)

return log_prob_prior + self.logdet_jacobian

Expand Down
6 changes: 3 additions & 3 deletions pixyz/distributions/mixture_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=

# sample from prior
hidden_output = self.prior.sample(input_dict, batch_n=batch_n,
sample_mean=sample_mean, return_all=False)[self._hidden_var[0]]
sample_mean=sample_mean, return_all=False, **kwargs)[self._hidden_var[0]]

var_output = []
for _hidden_output in hidden_output:
var_output.append(self.distributions[_hidden_output.argmax(dim=-1)].sample(
input_dict, sample_mean=sample_mean, return_all=False)[self._var[0]])
input_dict, sample_mean=sample_mean, return_all=False, **kwargs)[self._var[0]])

var_output = torch.cat(var_output, dim=0)
output_dict = {self._var[0]: var_output}
Expand Down Expand Up @@ -250,5 +250,5 @@ def has_reparam(self):

def get_log_prob(self, x_dict, **kwargs):
# log p(z|x) = log p(x, z) - log p(x)
log_prob = self.p.get_log_prob(x_dict, return_hidden=True) - self.p.get_log_prob(x_dict)
log_prob = self.p.get_log_prob(x_dict, return_hidden=True, **kwargs) - self.p.get_log_prob(x_dict, **kwargs)
return log_prob # (num_mix, batch_size)
2 changes: 1 addition & 1 deletion pixyz/distributions/poe.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def log_prob(self, sum_features=True, feature_dims=None):
def prob(self, sum_features=True, feature_dims=None):
raise NotImplementedError()

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
raise NotImplementedError()


Expand Down
4 changes: 2 additions & 2 deletions pixyz/distributions/special_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def sample(self, x_dict={}, return_all=True, **kwargs):
def sample_mean(self, x_dict):
return self.sample(x_dict, return_all=False)[self._var[0]]

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
raise NotImplementedError("Log probability of deterministic distribution is not defined.")

@property
Expand Down Expand Up @@ -105,7 +105,7 @@ def sample(self, x_dict={}, return_all=True, **kwargs):
def sample_mean(self, x_dict):
return self.sample(x_dict, return_all=False)[self._var[0]]

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
raise NotImplementedError()

@property
Expand Down
22 changes: 11 additions & 11 deletions pixyz/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,21 +228,21 @@ def forward(self, x_dict, discriminator=False, **kwargs):
batch_n = self._get_batch_n(x_dict)

# sample x_p from p
x_p_dict = get_dict_values(self.p.sample(x_dict, batch_n=batch_n), self.d.input_var, True)
x_p_dict = get_dict_values(self.p.sample(x_dict, batch_n=batch_n, **kwargs), self.d.input_var, True)
# sample x_q from q
x_q_dict = get_dict_values(self.q.sample(x_dict, batch_n=batch_n), self.d.input_var, True)
x_q_dict = get_dict_values(self.q.sample(x_dict, batch_n=batch_n, **kwargs), self.d.input_var, True)
if discriminator:
# sample y_p from d
y_p = get_dict_values(self.d.sample(detach_dict(x_p_dict)), self.d.var)[0]
y_p = get_dict_values(self.d.sample(detach_dict(x_p_dict), **kwargs), self.d.var)[0]
# sample y_q from d
y_q = get_dict_values(self.d.sample(detach_dict(x_q_dict)), self.d.var)[0]
y_q = get_dict_values(self.d.sample(detach_dict(x_q_dict), **kwargs), self.d.var)[0]

return self.d_loss(y_p, y_q, batch_n), x_dict

# sample y_p from d
y_p_dict = self.d.sample(x_p_dict)
y_p_dict = self.d.sample(x_p_dict, **kwargs)
# sample y_q from d
y_q_dict = self.d.sample(x_q_dict)
y_q_dict = self.d.sample(x_q_dict, **kwargs)

y_p = get_dict_values(y_p_dict, self.d.var)[0]
y_q = get_dict_values(y_q_dict, self.d.var)[0]
Expand Down Expand Up @@ -391,21 +391,21 @@ def forward(self, x_dict, discriminator=False, **kwargs):
batch_n = self._get_batch_n(x_dict)

# sample x_p from p
x_p_dict = get_dict_values(self.p.sample(x_dict, batch_n=batch_n), self.d.input_var, True)
x_p_dict = get_dict_values(self.p.sample(x_dict, batch_n=batch_n, **kwargs), self.d.input_var, True)

if discriminator:
# sample x_q from q
x_q_dict = get_dict_values(self.q.sample(x_dict, batch_n=batch_n), self.d.input_var, True)
x_q_dict = get_dict_values(self.q.sample(x_dict, batch_n=batch_n, **kwargs), self.d.input_var, True)

# sample y_p from d
y_p = get_dict_values(self.d.sample(detach_dict(x_p_dict)), self.d.var)[0]
y_p = get_dict_values(self.d.sample(detach_dict(x_p_dict), **kwargs), self.d.var)[0]
# sample y_q from d
y_q = get_dict_values(self.d.sample(detach_dict(x_q_dict)), self.d.var)[0]
y_q = get_dict_values(self.d.sample(detach_dict(x_q_dict), **kwargs), self.d.var)[0]

return self.d_loss(y_p, y_q, batch_n), {}

# sample y from d
y_p = get_dict_values(self.d.sample(x_p_dict), self.d.var)[0]
y_p = get_dict_values(self.d.sample(x_p_dict, **kwargs), self.d.var)[0]

return self.g_loss(y_p, batch_n), {}

Expand Down
4 changes: 2 additions & 2 deletions pixyz/losses/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def forward(self, x_dict={}, **kwargs):
batch_n = self._get_batch_n(x_dict)

# sample from distributions
p_x = get_dict_values(self.p.sample(x_dict, batch_n=batch_n), self.p.var)[0]
q_x = get_dict_values(self.q.sample(x_dict, batch_n=batch_n), self.q.var)[0]
p_x = get_dict_values(self.p.sample(x_dict, batch_n=batch_n, **kwargs), self.p.var)[0]
q_x = get_dict_values(self.q.sample(x_dict, batch_n=batch_n, **kwargs), self.q.var)[0]

if p_x.shape != q_x.shape:
raise ValueError("The two distribution variables must have the same shape.")
Expand Down
2 changes: 1 addition & 1 deletion pixyz/losses/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _symbol(self):
return sympy.Symbol("\\log {}".format(self.p.prob_text))

def forward(self, x={}, **kwargs):
log_prob = self.p.get_log_prob(x, sum_features=self.sum_features, feature_dims=self.feature_dims)
log_prob = self.p.get_log_prob(x, sum_features=self.sum_features, feature_dims=self.feature_dims, **kwargs)
return log_prob, {}


Expand Down
4 changes: 2 additions & 2 deletions pixyz/losses/wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def forward(self, x_dict, **kwargs):
batch_n = self._get_batch_n(x_dict)

# sample from distributions
p_x = get_dict_values(self.p.sample(x_dict, batch_n=batch_n), self.p.var)[0]
q_x = get_dict_values(self.q.sample(x_dict, batch_n=batch_n), self.q.var)[0]
p_x = get_dict_values(self.p.sample(x_dict, batch_n=batch_n, **kwargs), self.p.var)[0]
q_x = get_dict_values(self.q.sample(x_dict, batch_n=batch_n, **kwargs), self.q.var)[0]

if p_x.shape != q_x.shape:
raise ValueError("The two distribution variables must have the same shape.")
Expand Down
Loading

0 comments on commit 7ad8ff5

Please sign in to comment.