Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Python 3.9 #1923

Merged
merged 5 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.9","3.10","3.12"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -50,7 +50,7 @@ jobs:
needs: lint
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
needs: lint
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections import OrderedDict
from functools import partial
from typing import Callable
from typing import Callable, Optional

import jax
from jax import device_put, lax, random
Expand Down Expand Up @@ -348,7 +348,7 @@ def scan(
f: Callable,
init,
xs,
length: int | None = None,
length: Optional[int] = None,
reverse: bool = False,
history: int = 1,
):
Expand Down
12 changes: 6 additions & 6 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType
from typing import Any, Callable, OrderedDict as OrderedDictType, Union

import jax
from jax import random
Expand All @@ -21,9 +21,9 @@

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])

RunInferenceResult = (
dict[str, Any] | tuple[AutoGuide, dict[str, Any]]
) # for mcmc or sdvi
RunInferenceResult = Union[
dict[str, Any], tuple[AutoGuide, dict[str, Any]]
] # for mcmc or sdvi


class StochasticSupportInference(ABC):
Expand Down Expand Up @@ -124,12 +124,12 @@ def _combine_inferences(
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> DCCResult | SDVIResult:
) -> Union[DCCResult, SDVIResult]:
raise NotImplementedError

def run(
self, rng_key: ArrayLike, *args: Any, **kwargs: Any
) -> DCCResult | SDVIResult:
) -> Union[DCCResult, SDVIResult]:
"""
Run inference on each SLP separately and combine the results.

Expand Down
5 changes: 3 additions & 2 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from collections import OrderedDict
from itertools import product
from typing import Union

import numpy as np

Expand Down Expand Up @@ -230,7 +231,7 @@ def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:


def summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
) -> dict:
"""
Returns a summary table displaying diagnostics of ``samples`` from the
Expand Down Expand Up @@ -284,7 +285,7 @@ def summary(


def print_summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
Expand Down
10 changes: 5 additions & 5 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import random
import re
from threading import Lock
from typing import Any, Callable, Generator
from typing import Any, Callable, Generator, Optional
import warnings

import numpy as np
Expand All @@ -27,7 +27,7 @@
_CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3'


def set_rng_seed(rng_seed: int | None = None) -> None:
def set_rng_seed(rng_seed: Optional[int] = None) -> None:
"""
Initializes internal state for the Python and NumPy random number generators.

Expand All @@ -49,7 +49,7 @@ def enable_x64(use_x64: bool = True) -> None:
jax.config.update("jax_enable_x64", use_x64)


def set_platform(platform: str | None = None) -> None:
def set_platform(platform: Optional[str] = None) -> None:
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
Expand Down Expand Up @@ -408,7 +408,7 @@ def loop_fn(collection):


def soft_vmap(
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: int | None = None
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: Optional[int] = None
) -> Any:
"""
Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
Expand Down Expand Up @@ -466,7 +466,7 @@ def format_shapes(
*,
compute_log_prob: bool = False,
title: str = "Trace Shapes:",
last_site: str | None = None,
last_site: Optional[str] = None,
):
"""
Given the trace of a function, returns a string showing a table of the shapes of
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,7 @@
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
)