Skip to content

Commit

Permalink
Merge pull request #146 from masa-su/fix/rename_datadist
Browse files Browse the repository at this point in the history
rename DataDistribution to EmpiricalDistribution
  • Loading branch information
masa-su authored Oct 27, 2020
2 parents 009be8a + a08ff61 commit d94f136
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ Deterministic
:members:
:undoc-members:

DataDistribution
EmpiricalDistribution
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DataDistribution
.. autoclass:: EmpiricalDistribution
:members:
:undoc-members:

Expand Down
6 changes: 3 additions & 3 deletions examples/mmd_vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pixyz.distributions import Normal, Bernoulli, DataDistribution\n",
"from pixyz.distributions import Normal, Bernoulli, EmpiricalDistribution\n",
"from pixyz.losses import CrossEntropy, MMD\n",
"from pixyz.models import Model\n",
"from pixyz.utils import print_latex"
Expand Down Expand Up @@ -112,7 +112,7 @@
"prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),\n",
" var=[\"z\"], features_shape=[z_dim], name=\"p_{prior}\").to(device)\n",
"\n",
"p_data = DataDistribution([\"x\"]).to(device)\n",
"p_data = EmpiricalDistribution([\"x\"]).to(device)\n",
"q_mg = (q*p_data).marginalize_var(\"x\")\n",
"q_mg.name = \"q\""
]
Expand Down Expand Up @@ -169,7 +169,7 @@
"Distribution:\n",
" q(z) = \\int q(z|x)p_{data}(x)dx\n",
"Network architecture:\n",
" DataDistribution(\n",
" EmpiricalDistribution(\n",
" name=p_{data}, distribution_name=Data distribution,\n",
" var=['x'], cond_var=[], input_var=['x'], features_shape=torch.Size([])\n",
" )\n",
Expand Down
4 changes: 2 additions & 2 deletions pixyz/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .special_distributions import (
Deterministic,
DataDistribution
EmpiricalDistribution
)

from .distributions import (
Expand All @@ -37,7 +37,7 @@
'Distribution',
'CustomProb',
'Deterministic',
'DataDistribution',
'EmpiricalDistribution',
'Normal',
'Bernoulli',
'RelaxedBernoulli',
Expand Down
8 changes: 4 additions & 4 deletions pixyz/distributions/special_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def has_reparam(self):
return True


class DataDistribution(Distribution):
class EmpiricalDistribution(Distribution):
"""
Data distribution.
Expand All @@ -74,12 +74,12 @@ class DataDistribution(Distribution):
Examples
--------
>>> import torch
>>> p = DataDistribution(var=["x"])
>>> p = EmpiricalDistribution(var=["x"])
>>> print(p)
Distribution:
p_{data}(x)
Network architecture:
DataDistribution(
EmpiricalDistribution(
name=p_{data}, distribution_name=Data distribution,
var=['x'], cond_var=[], input_var=['x'], features_shape=torch.Size([])
)
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
@property
def input_var(self):
"""
In DataDistribution, `input_var` is same as `var`.
In EmpiricalDistribution, `input_var` is same as `var`.
"""

return self.var
Expand Down
18 changes: 9 additions & 9 deletions pixyz/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class AdversarialJensenShannon(AdversarialLoss):
Examples
--------
>>> import torch
>>> from pixyz.distributions import Deterministic, DataDistribution, Normal
>>> from pixyz.distributions import Deterministic, EmpiricalDistribution, Normal
>>> # Generator
>>> class Generator(Deterministic):
... def __init__(self):
Expand Down Expand Up @@ -166,12 +166,12 @@ class AdversarialJensenShannon(AdversarialLoss):
(model): Linear(in_features=32, out_features=64, bias=True)
)
>>> # Data distribution (dummy distribution)
>>> p_data = DataDistribution(["x"])
>>> p_data = EmpiricalDistribution(["x"])
>>> print(p_data)
Distribution:
p_{data}(x)
Network architecture:
DataDistribution(
EmpiricalDistribution(
name=p_{data}, distribution_name=Data distribution,
var=['x'], cond_var=[], input_var=['x'], features_shape=torch.Size([])
)
Expand Down Expand Up @@ -303,7 +303,7 @@ class AdversarialKullbackLeibler(AdversarialLoss):
Examples
--------
>>> import torch
>>> from pixyz.distributions import Deterministic, DataDistribution, Normal
>>> from pixyz.distributions import Deterministic, EmpiricalDistribution, Normal
>>> # Generator
>>> class Generator(Deterministic):
... def __init__(self):
Expand Down Expand Up @@ -333,12 +333,12 @@ class AdversarialKullbackLeibler(AdversarialLoss):
(model): Linear(in_features=32, out_features=64, bias=True)
)
>>> # Data distribution (dummy distribution)
>>> p_data = DataDistribution(["x"])
>>> p_data = EmpiricalDistribution(["x"])
>>> print(p_data)
Distribution:
p_{data}(x)
Network architecture:
DataDistribution(
EmpiricalDistribution(
name=p_{data}, distribution_name=Data distribution,
var=['x'], cond_var=[], input_var=['x'], features_shape=torch.Size([])
)
Expand Down Expand Up @@ -461,7 +461,7 @@ class AdversarialWassersteinDistance(AdversarialJensenShannon):
Examples
--------
>>> import torch
>>> from pixyz.distributions import Deterministic, DataDistribution, Normal
>>> from pixyz.distributions import Deterministic, EmpiricalDistribution, Normal
>>> # Generator
>>> class Generator(Deterministic):
... def __init__(self):
Expand Down Expand Up @@ -491,12 +491,12 @@ class AdversarialWassersteinDistance(AdversarialJensenShannon):
(model): Linear(in_features=32, out_features=64, bias=True)
)
>>> # Data distribution (dummy distribution)
>>> p_data = DataDistribution(["x"])
>>> p_data = EmpiricalDistribution(["x"])
>>> print(p_data)
Distribution:
p_{data}(x)
Network architecture:
DataDistribution(
EmpiricalDistribution(
name=p_{data}, distribution_name=Data distribution,
var=['x'], cond_var=[], input_var=['x'], features_shape=torch.Size([])
)
Expand Down
4 changes: 2 additions & 2 deletions pixyz/models/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ..models.model import Model
from ..losses import AdversarialJensenShannon
from ..distributions import DataDistribution
from ..distributions import EmpiricalDistribution


class GAN(Model):
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, p, discriminator,

# set distributions (for training)
distributions = [p]
p_data = DataDistribution(p.var)
p_data = EmpiricalDistribution(p.var)

# set losses
loss = AdversarialJensenShannon(p_data, p, discriminator,
Expand Down

0 comments on commit d94f136

Please sign in to comment.