Skip to content

Commit

Permalink
multivariate bugfixes (0.3.0 release) (#7)
Browse files Browse the repository at this point in the history
* MVSamples: fix to pass samples.T to stats.kde_gaussian
* MVSamples.to_mvhistogram: use samples directly if N>=nsamples
* temporary fix: accept unit/as_quantity kwargs to MV*.sample
  • Loading branch information
kecnry authored Mar 25, 2021
1 parent 2d885e8 commit 5ad2aa2
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
81 changes: 69 additions & 12 deletions distl/distl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
else:
_has_dill = True

__version__ = '0.2.0'
__version__ = '0.3.0'
version = __version__

_math_symbols = {'__mul__': '*', '__add__': '+', '__sub__': '-',
Expand Down Expand Up @@ -3255,7 +3255,8 @@ def uncertainties(self, sigma=1, tex=False, dimension=None, samples=None):
else:
return qs_per_dim

def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
def sample(self, size=None, dimension=None, seed=None, cache_sample=True,
unit=None, as_quantity=False):
"""
Sample from the distribution.
Expand All @@ -3270,13 +3271,18 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
prior to sampling.
* `cache_sample` (bool, optional, default=True): whether to override the
existing <<class>.cached_sample>.
* `unit` (None): NOT YET IMPLEMENTED will raise error if not None
* `as_quantity` (False): NOT YET IMPLEMENTED will raise error if not False
Returns
---------
* float or array: float if `size=None`, otherwise a numpy array with
shape defined by `size`.
"""

if unit is not None or as_quantity:
raise NotImplementedError("unit and quantities not yet supported for multivariate distributions")

# TODO: add support for per-dimension unit, wrap_at, as_quantity (and pass in to_mvhistogram)
# TODO: add support for seed
if isinstance(seed, dict):
Expand Down Expand Up @@ -3714,14 +3720,16 @@ def uncertainties(self, sigma=1, tex=False):

### SAMPLING & PLOTTING

def sample(self, size=None, wrap_at=None, seed=None, cache_sample=True):
def sample(self, size=None, wrap_at=None, seed=None, cache_sample=True,
unit=None, as_quantity=False):
"""
Sample the underlying <<class>.multivariate> distribution in the dimension
defined in <<class>.dimension>.
"""

# TODO: support unit, wrap_at, as_quantity
return self.multivariate.sample(size=size, seed=seed, dimension=self.dimension, cache_sample=cache_sample)
return self.multivariate.sample(size=size, seed=seed, dimension=self.dimension, cache_sample=cache_sample,
unit=unit, as_quantity=as_quantity)

def plot_sample(self, *args, **kwargs):
if hasattr(self, 'bins'):
Expand Down Expand Up @@ -3958,7 +3966,7 @@ def get_distributions_with_values(self, values=None, as_univariates=False):
if not as_univariates and isinstance(dist_orig, BaseMultivariateSliceDistribution):
d = dist_orig.multivariate
else:
d = dist_orig
d = dist_orig #.to_univariate()?

# if as_univariates then we want MVSlices with the same parent MV to be treated separately
take_dimensions = not as_univariates and isinstance(dist_orig, BaseMultivariateSliceDistribution)
Expand Down Expand Up @@ -4044,7 +4052,7 @@ def logpdf(self, values=None, as_univariates=False):
samples are available, a ValueError will be raised.
* `as_univariates` (bool, optional, default=False): whether `values` corresponds
to the passed distributions (<DistributionCollection.distributions>)
or the underlying unpacked distributions (<DistributionCollection.distributions_unpacked>).
or the underlying unpacked distributions (<DistributionCollection.dists_unpacked>).
If the former (`as_univariates=False`), covariances will be respected
from any underlying multivariate distributions. If the latter
(`as_univariates=True`) covariances will be ignored.
Expand Down Expand Up @@ -7245,7 +7253,8 @@ def take_dimensions(self, dimensions):
labels=[self.labels[d] for d in dimensions] if self.labels is not None else None,
wrap_ats=[self.wrap_ats[d] for d in dimensions] if self.wrap_ats is not None else None)

def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
def sample(self, size=None, dimension=None, seed=None, cache_sample=True,
unit=None, as_quantity=False):
"""
Arguments
Expand All @@ -7256,6 +7265,8 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
prior to sampling.
* `cache_sample` (bool, optional, default=True): whether to override the
existing <<class>.cached_sample>.
* `unit` (None): NOT YET IMPLEMENTED will raise error if not None
* `as_quantity` (False): NOT YET IMPLEMENTED will raise error if not False
"""
# if dimension is not None:
Expand All @@ -7266,6 +7277,10 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
# bins = self.bins
# density = self.density

if unit is not None or as_quantity:
raise NotImplementedError("unit and quantities not yet supported for multivariate distributions")


if isinstance(seed, dict):
seed = seed.get(self.uniqueid, None)

Expand Down Expand Up @@ -7535,7 +7550,7 @@ def __init__(self, samples, weights=None, bw_method=None, units=None,
Arguments
--------------
* `samples` (np.array object with shape (nsamples, <MVSamples.ndimensions>)):
* `samples` (np.array object with shape (<MVSamples.nsamples>, <MVSamples.ndimensions>)):
the samples.
* `weights` (np.array object with shape (nsamples) or None, optional, default=None):
weights for each entry in `samples`. NOTE: only supported with scipy
Expand All @@ -7560,6 +7575,8 @@ def __init__(self, samples, weights=None, bw_method=None, units=None,
--------
* an <MVSamples> object
"""
# NOTE: the passed samples need to be transposed, so see the override
# in dist_constructor_args
super(MVSamples, self).__init__(units, labels, labels_latex, wrap_ats,
_stats.gaussian_kde, ('samples', 'bw_method') if StrictVersion(_scipy_version) < StrictVersion("1.2.0") else ('samples', 'bw_method', 'weights'),
samples=samples, weights=weights, bw_method=bw_method,
Expand All @@ -7580,7 +7597,7 @@ def samples(self, value):
@property
def weights(self):
"""
weights for each entry in <Samples.samples>
weights for each sample in <Samples.samples> (nsamples)
"""
return self._weights

Expand Down Expand Up @@ -7608,6 +7625,25 @@ def bw_method(self, value):
self._bw_method = is_float(value)
self._dist_constructor_object_clear_cache()

@property
def dist_constructor_args(self):
"""
Return the arguments to pass to the the underlying distribution
constructor (often the scipy.stats random variable generator function)
<MVSamples.samples> is transposed before passing on to gaussian_kde
See also:
* <<class>.dist_constructor_func>
* <<class>.dist_constructor_object>
Returns
-------
* tuple
"""
return [getattr(self, a).T if a=='samples' else getattr(self,a) for a in self.dist_constructor_argnames]

@property
def ndimensions(self):
"""
Expand Down Expand Up @@ -7733,7 +7769,8 @@ def interval(self, *args, **kwargs):
# TODO: manual implementation
raise NotImplementedError()

def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
def sample(self, size=None, dimension=None, seed=None, cache_sample=True,
unit=None, as_quantity=False):
"""
Sample from the samples (<MVSamples.samples> if <MVSamples.weights>
is not provided, otherwise <MVSamples.samples_weighted>)
Expand All @@ -7746,9 +7783,16 @@ def sample(self, size=None, dimension=None, seed=None, cache_sample=True):
prior to sampling.
* `cache_sample` (bool, optional, default=True): whether to override the
existing <<class>.cached_sample>.
* `unit` (None): NOT YET IMPLEMENTED will raise error if not None
* `as_quantity` (False): NOT YET IMPLEMENTED will raise error if not False
"""

if unit is not None or as_quantity:
raise NotImplementedError("unit and quantities not yet supported for multivariate distributions")


if isinstance(seed, dict):
seed = seed.get(self.uniqueid, None)

Expand Down Expand Up @@ -7928,7 +7972,8 @@ def to_mvhistogram(self, N=1e6, bins=15, range=None):
Arguments
-----------
* `N` (int, optional, default=1e6): number of samples to use for
the histogram.
the histogram. If N>=<MVSamples.nsamples>, <MVSamples.samples>
will be passed directly.
* `bins` (int, optional, default=15): number of bins to use for the
histogram.
* `range` (tuple or None): range to use for the histogram.
Expand All @@ -7938,7 +7983,7 @@ def to_mvhistogram(self, N=1e6, bins=15, range=None):
* an <MVHistogram> object
"""
# TODO: if sample is updated to take wrap_at/wrap_ats... pass wrap_at=False here
return MVHistogram.from_data(self.sample(size=int(N), cache_sample=False),
return MVHistogram.from_data(self.samples if N >= self.nsamples else self.sample(size=int(N), cache_sample=False),
bins=bins, range=range,
units=self.units,
labels=self.labels, labels_latex=self._labels_latex,
Expand Down Expand Up @@ -8015,6 +8060,18 @@ def ppf(self, q, unit=None, as_quantity=False, wrap_at=None):
"""
return Samples(samples=self.samples, weights=self.weights, bw_method=self.bw_method, unit=self.unit).ppf(q, unit=unit, as_quantity=as_quantity, wrap_at=wrap_at)

# def pdf(self, x, unit=None, as_quantity=False, wrap_at=None):
# """
# See <Samples.pdf>
# """
# return Samples(samples=self.samples, weights=self.weights, bw_method=self.bw_method, unit=self.unit).pdf(x, unit=unit, as_quantity=as_quantity, wrap_at=wrap_at)
#
# def logpdf(self, x, unit=None, as_quantity=False, wrap_at=None):
# """
# See <Samples.logpdf>
# """
# return Samples(samples=self.samples, weights=self.weights, bw_method=self.bw_method, unit=self.unit).logpdf(x, unit=unit, as_quantity=as_quantity, wrap_at=wrap_at)

def interval(self, alpha, unit=None, as_quantity=False, wrap_at=None):
"""
See <Samples.interval>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
long_description = fh.read()

setup(name='distl',
version='0.2.0',
version='0.3.0',
description='Simple Distributions: math operations, serializing, covariances',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 5ad2aa2

Please sign in to comment.