Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate messengers #3308

Merged
merged 2 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyro/poutine/escape_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _pyro_sample(self, msg: Message) -> None:
msg["done"] = True
msg["stop"] = True

def cont(m):
def cont(m: Message) -> None:
raise NonlocalExit(m)

msg["continuation"] = cont
2 changes: 1 addition & 1 deletion pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Message(TypedDict, total=False):
args: Tuple
kwargs: Dict
value: Optional[torch.Tensor]
scale: float
scale: Union[torch.Tensor, float]
mask: Union[bool, torch.Tensor, None]
cond_indep_stack: Tuple[CondIndepStackFrame, ...]
done: bool
Expand Down
11 changes: 6 additions & 5 deletions pyro/poutine/scale_messenger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import torch

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.util import is_validation_enabled

from .messenger import Messenger


class ScaleMessenger(Messenger):
"""
Expand All @@ -33,7 +35,7 @@ class ScaleMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.scale_messenger.ScaleMessenger`
"""

def __init__(self, scale):
def __init__(self, scale: Union[float, torch.Tensor]) -> None:
if isinstance(scale, torch.Tensor):
if is_validation_enabled() and not (scale > 0).all():
raise ValueError(
Expand All @@ -45,6 +47,5 @@ def __init__(self, scale):
super().__init__()
self.scale = scale

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
msg["scale"] = self.scale * msg["scale"]
return None
17 changes: 12 additions & 5 deletions pyro/poutine/seed_messenger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from pyro.util import get_rng_state, set_rng_seed, set_rng_state
from types import TracebackType
from typing import Optional, Type

from .messenger import Messenger
from pyro.poutine.messenger import Messenger
from pyro.util import get_rng_state, set_rng_seed, set_rng_state


class SeedMessenger(Messenger):
Expand All @@ -18,14 +20,19 @@ class SeedMessenger(Messenger):
:param int rng_seed: rng seed.
"""

def __init__(self, rng_seed):
def __init__(self, rng_seed: int) -> None:
assert isinstance(rng_seed, int)
self.rng_seed = rng_seed
super().__init__()

def __enter__(self):
def __enter__(self) -> None: # type: ignore[override]
self.old_state = get_rng_state()
set_rng_seed(self.rng_seed)

def __exit__(self, type, value, traceback):
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
traceback: Optional[TracebackType],
) -> None:
set_rng_state(self.old_state)
115 changes: 63 additions & 52 deletions pyro/poutine/subsample_messenger.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch

from pyro.distributions.distribution import Distribution
from pyro.poutine.indep_messenger import CondIndepStackFrame, IndepMessenger
from pyro.poutine.runtime import Message, apply_stack
from pyro.poutine.util import is_validation_enabled
from pyro.util import ignore_jit_warnings

from .indep_messenger import CondIndepStackFrame, IndepMessenger
from .runtime import apply_stack


class _Subsample(Distribution):
"""
Expand All @@ -18,7 +19,13 @@ class _Subsample(Distribution):
Internal use only. This should only be used by `plate`.
"""

def __init__(self, size, subsample_size, use_cuda=None, device=None):
def __init__(
self,
size: int,
subsample_size: Optional[int],
use_cuda: Optional[bool] = None,
device: Optional[str] = None,
) -> None:
"""
:param int size: the size of the range to subsample from
:param int subsample_size: the size of the returned subsample
Expand All @@ -38,10 +45,10 @@ def __init__(self, size, subsample_size, use_cuda=None, device=None):
)
)
with ignore_jit_warnings(["torch.Tensor results are registered as constants"]):
self.device = torch.Tensor().device if not device else device
self.device = device or torch.Tensor().device

@ignore_jit_warnings(["Converting a tensor to a Python boolean"])
def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
"""
:returns: a random subsample of `range(size)`
:rtype: torch.LongTensor
Expand All @@ -57,7 +64,7 @@ def sample(self, sample_shape=torch.Size()):
].clone()
return result.cuda() if self.use_cuda else result

def log_prob(self, x):
def log_prob(self, x: torch.Tensor) -> torch.Tensor:
# This is zero so that plate can provide an unbiased estimate of
# the non-subsampled log_prob.
result = torch.tensor(0.0, device=self.device)
Expand All @@ -71,33 +78,34 @@ class SubsampleMessenger(IndepMessenger):

def __init__(
self,
name,
size=None,
subsample_size=None,
subsample=None,
dim=None,
use_cuda=None,
device=None,
):
super().__init__(name, size, dim, device)
self.subsample_size = subsample_size
self._indices = subsample
self.use_cuda = use_cuda
self.device = device

self.size, self.subsample_size, self._indices = self._subsample(
self.name,
self.size,
self.subsample_size,
self._indices,
self.use_cuda,
self.device,
name: str,
size: Optional[int] = None,
subsample_size: Optional[int] = None,
subsample: Optional[torch.Tensor] = None,
dim: Optional[int] = None,
use_cuda: Optional[bool] = None,
device: Optional[str] = None,
) -> None:
full_size, self.subsample_size, subsample = self._subsample(
name,
size,
subsample_size,
subsample,
use_cuda,
device,
)
super().__init__(name, full_size, dim, device)
self._indices = subsample

@staticmethod
def _subsample(
name, size=None, subsample_size=None, subsample=None, use_cuda=None, device=None
):
name: str,
size: Optional[int] = None,
subsample_size: Optional[int] = None,
subsample: Optional[torch.Tensor] = None,
use_cuda: Optional[bool] = None,
device: Optional[str] = None,
) -> Tuple[int, int, Optional[torch.Tensor]]:
"""
Helper function for plate. See its docstrings for details.
"""
Expand All @@ -107,27 +115,28 @@ def _subsample(
size = -1 # This is PyTorch convention for "arbitrary size"
subsample_size = -1
else:
msg = {
"type": "sample",
"name": name,
"fn": _Subsample(size, subsample_size, use_cuda, device),
"is_observed": False,
"args": (),
"kwargs": {},
"value": subsample,
"infer": {},
"scale": 1.0,
"mask": None,
"cond_indep_stack": (),
"done": False,
"stop": False,
"continuation": None,
}
msg = Message(
type="sample",
name=name,
fn=_Subsample(size, subsample_size, use_cuda, device),
is_observed=False,
args=(),
kwargs={},
value=subsample,
infer={},
scale=1.0,
mask=None,
cond_indep_stack=(),
done=False,
stop=False,
continuation=None,
)
apply_stack(msg)
subsample = msg["value"]

with ignore_jit_warnings():
if subsample_size is None:
assert subsample is not None
subsample_size = (
subsample.size(0)
if isinstance(subsample, torch.Tensor)
Expand All @@ -143,11 +152,11 @@ def _subsample(

return size, subsample_size, subsample

def _reset(self):
def _reset(self) -> None:
self._indices = None
super()._reset()

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
frame = CondIndepStackFrame(
name=self.name,
dim=self.dim,
Expand All @@ -164,12 +173,13 @@ def _process_message(self, msg):
msg["scale"] = torch.tensor(msg["scale"])
msg["scale"] = msg["scale"] * self.size / self.subsample_size

def _postprocess_message(self, msg):
def _postprocess_message(self, msg: Message) -> None:
if msg["type"] in ("param", "subsample") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
assert msg["value"] is not None
shape = msg["value"].shape
if len(shape) >= -dim and shape[dim] != 1:
if is_validation_enabled() and shape[dim] != self.size:
Expand All @@ -189,18 +199,19 @@ def _postprocess_message(self, msg):
# Subsample parameters with known batch semantics.
if self.subsample_size < self.size:
value = msg["value"]
assert self._indices is not None
new_value = value.index_select(
dim, self._indices.to(value.device)
)
if msg["type"] == "param":
if hasattr(value, "_pyro_unconstrained_param"):
param = value._pyro_unconstrained_param
param = value._pyro_unconstrained_param # type: ignore[attr-defined]
else:
param = value.unconstrained()
param = value.unconstrained() # type: ignore[attr-defined]

if not hasattr(param, "_pyro_subsample"):
param._pyro_subsample = {}

param._pyro_subsample[dim] = self._indices
new_value._pyro_unconstrained_param = param
new_value._pyro_unconstrained_param = param # type: ignore[attr-defined]
msg["value"] = new_value
26 changes: 16 additions & 10 deletions pyro/poutine/substitute_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Dict, Set

import torch
from typing_extensions import Self

from pyro import params
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.util import is_validation_enabled


Expand All @@ -20,30 +25,30 @@ class SubstituteMessenger(Messenger):
... a = pyro.param("a", torch.tensor(0.5))
... x = pyro.sample("x", dist.Bernoulli(probs=a))
... return x
>>> substituted_model = pyro.poutine.substitute(model, data={"a": 0.3})
>>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})

In this example, site `a` will now have value `0.3`.
In this example, site `a` will now have value `torch.tensor(0.3)`.
:param data: dictionary of values keyed by site names.
:returns: ``fn`` decorated with a :class:`~pyro.poutine.substitute_messenger.SubstituteMessenger`
"""

def __init__(self, data):
def __init__(self, data: Dict[str, torch.Tensor]) -> None:
"""
:param data: values for the parameters.
Constructor
"""
super().__init__()
self.data = data
self._data_cache = {}
self._data_cache: Dict[str, Message] = {}

def __enter__(self):
def __enter__(self) -> Self:
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
self._param_hits = set()
self._param_misses = set()
self._param_hits: Set[str] = set()
self._param_misses: Set[str] = set()
return super().__enter__()

def __exit__(self, *args, **kwargs):
def __exit__(self, *args, **kwargs) -> None:
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
extra = set(self.data) - self._param_hits
Expand All @@ -56,15 +61,16 @@ def __exit__(self, *args, **kwargs):
)
return super().__exit__(*args, **kwargs)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
return None

def _pyro_param(self, msg):
def _pyro_param(self, msg: Message) -> None:
"""
Overrides the `pyro.param` with substituted values.
If the param name does not match the name the keys in `data`,
that param value is unchanged.
"""
assert msg["name"] is not None
name = msg["name"]
param_name = params.user_param_name(name)

Expand Down
Loading