Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add bfloat16 support #136

Merged
merged 13 commits into from
Jan 20, 2025
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ from mithril.models import Model, Linear
# Build a simple linear model
model = Linear(16)

# Create backends, specify the precision
backend_jax = ml.JaxBackend(precision=64)
backend_numpy = ml.NumpyBackend(precision=32)
# Create backends, specify the default dtype
backend_jax = ml.JaxBackend(dtype=ml.float64)
backend_numpy = ml.NumpyBackend(dtype=ml.float32)

# Compile the model with different backends, optionally specify
# the file to write the generated code into and whether to use jit
Expand Down
32 changes: 17 additions & 15 deletions benchmarks/speed_benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import mithril as ml
from benchmarks.speed_benchmarks.jax_fns import mlp_v_jax
from benchmarks.speed_benchmarks.speed_helper import colorize_str
from benchmarks.speed_benchmarks.torch_fns import conv_v_torch, mlp_v_torch
from mithril.backends.utils import DtypeBits
from mithril.framework.common import Table
from mithril.models import Relu, Sigmoid, Tanh

# MLX is not included due to Ubuntu OS in Github
backends = ["Torch", "Jax"]
precisions = [64, 32, 16]
dtypes = [ml.float64, ml.float32, ml.float16]

iterations = 100
table = Table()
Expand Down Expand Up @@ -55,20 +57,20 @@

for backend in backends:
fn = mlp_v_jax if backend == "Jax" else mlp_v_torch
for precision in precisions:
if not (precision == 16 and backend == "Torch"):
for dtype in dtypes:
if not (DtypeBits[dtype.name].value == 16 and backend == "Torch"):
num_params, time_backend, time_mithril = fn(
activations=activations,
dimensions=dimensions,
input_shape=input_shape,
iterations=iterations,
precision=precision,
dtype=dtype,
)
table.add_row(
[
"MLP Large",
backend,
str(precision),
str(dtype),
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand All @@ -82,20 +84,20 @@

for backend in backends:
fn = mlp_v_jax if backend == "Jax" else mlp_v_torch
for precision in precisions:
if not (precision == 16 and backend == "Torch"):
for dtype in dtypes:
if not (DtypeBits[dtype.name].value == 16 and backend == "Torch"):
num_params, time_backend, time_mithril = fn(
activations=activations,
dimensions=dimensions,
input_shape=(128, 128),
iterations=iterations,
precision=precision,
dtype=dtype,
)
table.add_row(
[
"MLP Small",
backend,
str(precision),
dtype.name,
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand All @@ -107,21 +109,21 @@
dimensions = [12, 16, 32]
stride = (2, 2)
padding = 1
for precision in [32, 64]:
for dtype in [ml.float32, ml.float64]:
num_params, time_backend, time_mithril = conv_v_torch(
activations=activations,
dimensions=dimensions,
input_shape=(4, 4, 128, 128),
iterations=iterations,
precision=precision,
dtype=dtype,
stride=stride,
padding=padding,
)
table.add_row(
[
"Conv Small",
"Torch",
str(precision),
dtype.name,
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand All @@ -134,21 +136,21 @@
dimensions = [1024, 1024, 1024, 256]
stride = (2, 2)
padding = 2
for precision in [32, 64]:
for dtype in [ml.float32, ml.float64]:
num_params, time_backend, time_mithril = conv_v_torch(
activations=activations,
dimensions=dimensions,
input_shape=(2, 1, 128, 128),
iterations=iterations,
precision=precision,
dtype=dtype,
stride=stride,
padding=padding,
)
table.add_row(
[
"Conv Large",
"Torch",
str(precision),
dtype.name,
str(num_params),
f"{time_backend:.4f}",
f"{time_mithril:.4f}",
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/speed_benchmarks/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
create_compl_mlp,
measure_time_and_grads_mithril,
)
from mithril import JaxBackend
from mithril import JaxBackend, core
from mithril.backends.utils import DtypeBits
from mithril.models import (
AbsoluteError,
Gelu,
Expand Down Expand Up @@ -200,15 +201,15 @@ def mlp_v_jax(
activations: list,
dimensions: list[int],
input_shape: tuple[int, int],
precision: int,
dtype: core.Dtype,
iterations: int,
):
lr = 0.001
_input_shape, batch_size = input_shape
# batch_size, input_shape = input_shape[-1], input_shape[0]
output_shape = [_input_shape] + [dimensions[-1]]
device = "cpu"
dtype_jax = getattr(jnp, f"float{precision}")
dtype_jax = getattr(jnp, f"float{DtypeBits[dtype.name]}")
device = "cpu"
inputs = {
"input": jnp.array(np.random.randn(batch_size, *input_shape), dtype=dtype_jax),
Expand All @@ -231,7 +232,7 @@ def mlp_v_jax(
)
comp_ctx = mithril.compile(
model=ctx,
backend=JaxBackend(device=device, precision=precision),
backend=JaxBackend(device=device, dtype=dtype),
constant_keys=inputs,
)

Expand Down
15 changes: 8 additions & 7 deletions benchmarks/speed_benchmarks/torch_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
create_compl_mlp,
measure_time_and_grads_mithril,
)
from mithril import TorchBackend
from mithril import TorchBackend, core
from mithril.backends.utils import DtypeBits
from mithril.models import (
AbsoluteError,
Gelu,
Expand Down Expand Up @@ -138,14 +139,14 @@ def mlp_v_torch(
activations: list,
dimensions: list[int],
input_shape: tuple[int, int],
precision: int,
dtype: core.Dtype,
iterations: int,
):
lr = 0.001
batch_size, _input_shape = input_shape[-1], input_shape[0]
output_shape = [_input_shape] + [dimensions[-1]]
device = "cpu"
dtype_torch = getattr(torch, f"float{precision}")
dtype_torch = getattr(torch, f"float{DtypeBits[dtype.name]}")
torch.set_default_dtype(dtype_torch)

inputs = {
Expand All @@ -170,7 +171,7 @@ def mlp_v_torch(
)
comp_ctx = mithril.compile(
model=ctx,
backend=TorchBackend(device=device, precision=precision),
backend=TorchBackend(device=device, dtype=dtype),
constant_keys=inputs,
)
randomized_inputs = comp_ctx.randomize_params()
Expand Down Expand Up @@ -207,15 +208,15 @@ def conv_v_torch(
activations: list,
dimensions: list[int],
input_shape: tuple[int, int, int, int],
precision: int,
dtype: core.Dtype,
iterations: int,
stride: tuple[int, int] | int,
padding: int,
):
lr = 0.001
batch_size, in_shape, tensor_shape = input_shape[0], input_shape[1], input_shape[2:]
device = "cpu"
dtype_torch = getattr(torch, f"float{precision}")
dtype_torch = getattr(torch, f"float{DtypeBits[dtype.name]}")
torch.set_default_dtype(dtype_torch)
inputs = {
"input": torch.randn(*input_shape, device=device),
Expand Down Expand Up @@ -252,7 +253,7 @@ def conv_v_torch(
)
comp_ctx = mithril.compile(
model=ctx,
backend=TorchBackend(device=device, precision=precision),
backend=TorchBackend(device=device, dtype=dtype),
constant_keys=inputs,
)
randomized_inputs = comp_ctx.randomize_params()
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def run_sample(
)

# Create backend.
backend_obj = backend_map[backend](precision=32, device="cpu")
backend_obj = backend_map[backend](device="cpu")
# Set seed.
backend_obj.set_seed(seed)
# Compile gpt model.
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/cnn_forcast_sine_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# TODO: Remove numpy dependencies from the code.

# Define backend. It would also work with any available backend you prefer.
backend = ml.TorchBackend(precision=32)
backend = ml.TorchBackend()


# Generate synthetic data: a sine wave
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/convolution_with_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
)

# Set up device and precision of our backend of choice
backend = ml.TorchBackend(precision=32, device="mps")
backend = ml.TorchBackend(device="mps")

# Compile the model with given non-trainable keys
compiled_model = ml.compile(
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/many_to_one_any_backend_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mithril.models import ManyToOne, Mean, RNNCell, SquaredError, TrainModel

# Define backend. It would also work with any available backend you prefer.
backend = ml.JaxBackend(precision=64)
backend = ml.JaxBackend(dtype=ml.float64)

batch_size = 20
input_features = 10
Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/variable_length_many_to_one_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# given an ordered sequence of numbers.

# Define the backend
backend = ml.JaxBackend(precision=64)
backend = ml.JaxBackend(dtype=ml.float64)
backend.set_seed(42)

# Prepare training data. We will test the case for which the input data
Expand Down
2 changes: 2 additions & 0 deletions mithril/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .core import (
Constant,
DataType,
bfloat16,
bool,
double,
epsilon_table,
Expand Down Expand Up @@ -50,6 +51,7 @@
"bool",
"float",
"float16",
"bfloat16",
"float32",
"float64",
"int",
Expand Down
26 changes: 13 additions & 13 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,28 @@ class Backend(ABC, Generic[DataType]):

backend_type = ""
device_type = None
supported_precisions = [16, 32, 64]
is_installed = True
_device: Any
_precision: int
supported_dtypes = [
core.Dtype.float16,
core.Dtype.bfloat16,
core.Dtype.float32,
core.Dtype.float64,
]
primitive_function_dict: dict[str, Callable[..., DataType | Any]]
registered_primitives: dict[str, Callable[..., DataType]]
array_creation_funcs: list[str]
primitive_fn_path: str

def __init__(self, precision: int = 32, device: str = "cpu") -> None:
# Check if given precision is a valid one.
if self.precision not in self.supported_precisions:
raise Exception(
f"'{self.precision}' bits precision is not available!"
" Available precisions: '{self.supported_precisions}'"
def __init__(self, dtype: core.Dtype = core.float32, device: str = "cpu") -> None:
# Check if given dtype is a valid one.
if dtype not in self.supported_dtypes:
raise ValueError(
f"Invalid dtype {dtype}. Supported dtypes are {self.supported_dtypes}."
)
self.seed = 10 # Can be set any integer.

# Initialize epsilon constants according to given precision.
# for key, value in core.epsilon_table[f"float{self.precision}"].items():
# setattr(self, key, value)

@property
def precision(self) -> int:
return self._precision
Expand Down Expand Up @@ -1076,11 +1076,11 @@ def __repr__(self) -> str:


class ParallelBackend(Backend[DataType]):
def __init__(self, device_mesh: tuple[int, ...] | None) -> None:
def __init__(self, dtype: core.Dtype, device_mesh: tuple[int, ...] | None) -> None:
assert (
isinstance(device_mesh, tuple) or device_mesh is None
), "device_mesh must be a tuple or None."
super().__init__()
super().__init__(dtype=dtype)

self._raw_device_mesh = device_mesh
self.n_devices = math.prod(device_mesh) if device_mesh is not None else 1
Expand Down
25 changes: 25 additions & 0 deletions mithril/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from collections.abc import Sequence

from ..utils.type_utils import is_tuple_int
Expand All @@ -36,3 +37,27 @@ def process_shape(
)

return _shape


class DtypeBits(enum.IntEnum):
bool = 8
int8 = 8
int16 = 16
int32 = 32
int64 = 64
float16 = 16
bfloat16 = 16
float32 = 32
float64 = 64


class DtypeSubTypes(enum.Enum):
bool = "bool"
int8 = "int"
int16 = "int"
int32 = "int"
int64 = "int"
float16 = "float"
bfloat16 = "bfloat"
float32 = "float"
float64 = "float"
Loading
Loading