Skip to content

Commit

Permalink
precision is now property
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada committed Jan 20, 2025
1 parent b80c64b commit 117947c
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
5 changes: 3 additions & 2 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .. import core
from ..core import DataType
from .parallel import Parallel
from .utils import DtypeBits

__all__ = ["Backend"]

Expand All @@ -36,7 +37,7 @@ class Backend(ABC, Generic[DataType]):
device_type = None
is_installed = True
_device: Any
_precision: int
_dtype: core.Dtype
supported_dtypes = [
core.Dtype.float16,
core.Dtype.bfloat16,
Expand All @@ -58,7 +59,7 @@ def __init__(self, dtype: core.Dtype = core.float32, device: str = "cpu") -> Non

@property
def precision(self) -> int:
return self._precision
return DtypeBits[self._dtype.name].value

#!!
@property
Expand Down
5 changes: 2 additions & 3 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ....core import Dtype
from ...backend import PadWidthType, ParallelBackend
from ...utils import DtypeBits, DtypeSubTypes, process_shape
from ...utils import DtypeSubTypes, process_shape
from . import ops, utils
from .parallel import JaxParallel

Expand Down Expand Up @@ -57,7 +57,6 @@ def __init__(
self._device = device
utils.get_device(device) # Check device is available
self._dtype = dtype
self._precision = DtypeBits[dtype.name].value
self._parallel_manager: JaxParallel | None = None

super().__init__(dtype=dtype, device_mesh=device_mesh)
Expand Down Expand Up @@ -172,7 +171,7 @@ def array(
dtype: Dtype | None = None,
device_mesh: tuple[int, ...] | None = None,
) -> jax.Array:
_dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision)

with jax.default_device(self.device):
array = jax.numpy.array(input, dtype=utils.dtype_map[_dtype])
Expand Down
5 changes: 2 additions & 3 deletions mithril/backends/with_autograd/mlx_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ....core import Dtype
from ...backend import Backend, PadWidthType
from ...utils import DtypeBits, DtypeSubTypes, process_shape
from ...utils import DtypeSubTypes, process_shape
from . import ops, utils

__all__ = ["MlxBackend"]
Expand All @@ -44,7 +44,6 @@ def __init__(
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

self._dtype = dtype
self._precision = DtypeBits[dtype.name].value
self._device = device
super().__init__(dtype=dtype)

Expand Down Expand Up @@ -177,7 +176,7 @@ def _handle_sequence_type_fun(
return [output]

def array(self, input: Any, *, dtype: Dtype | None = None) -> mx.array:
_dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision)
return mx.array(input, dtype=utils.dtype_map[_dtype])

def zeros(
Expand Down
5 changes: 2 additions & 3 deletions mithril/backends/with_autograd/torch_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ....core import Dtype
from ...backend import PadWidthType, ParallelBackend
from ...utils import DtypeBits, DtypeSubTypes, process_shape
from ...utils import DtypeSubTypes, process_shape
from . import ops, utils
from .parallel import TorchParallel

Expand Down Expand Up @@ -58,7 +58,6 @@ def __init__(
) -> None:
self._device = device
self._dtype = dtype
self._precision = DtypeBits[dtype.name].value
self._parallel_manager: TorchParallel | None = None

utils.get_device(device) # Check if device is valid
Expand Down Expand Up @@ -207,7 +206,7 @@ def array(
dtype: Dtype | None = None,
device_mesh: tuple[int, ...] | None = None,
) -> torch.Tensor:
_dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision)

array = torch.tensor(input, dtype=utils.dtype_map[_dtype], device=self._device)
if self._parallel_manager is not None:
Expand Down
5 changes: 4 additions & 1 deletion mithril/backends/with_manualgrad/c_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ class CBackend(Backend[PyArray]):
SRC_PATH = "mithril/backends/with_manualgrad/c_backend/src"

def __init__(self) -> None:
self._precision = 32
self._device = "cpu"
self.primitive_function_dict = {}

@property
def is_manualgrad(self) -> bool:
return True

@property
def precision(self) -> int:
return 32

def set_seed(self, seed: int) -> None:
pass

Expand Down
5 changes: 2 additions & 3 deletions mithril/backends/with_manualgrad/numpy_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ....core import Dtype
from ...backend import Backend, PadWidthType
from ...utils import DtypeBits, process_shape
from ...utils import process_shape
from ..common_primitives import CacheType
from . import ops, ops_grad, utils

Expand All @@ -46,7 +46,6 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]):

def __init__(self, device: str = "cpu", dtype: Dtype = Dtype.float32) -> None:
self._dtype = dtype
self._precision = DtypeBits[dtype.name].value

if device != "cpu":
raise RuntimeError(
Expand Down Expand Up @@ -118,7 +117,7 @@ def accumulate_grads(
return utils.accumulate_grads(gradient, input, cache, idx)

def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]:
_dtype = utils.determine_dtype(data, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(data, dtype, self._dtype, self.precision)

return np.array(data, dtype=utils.dtype_map[_dtype])

Expand Down

0 comments on commit 117947c

Please sign in to comment.