Skip to content

Commit

Permalink
Add BitCheck Process (lava-nc#802)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgkwill authored Oct 25, 2023
1 parent fff3b10 commit 429c3bc
Show file tree
Hide file tree
Showing 7 changed files with 481 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/lava/magma/core/process/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
shape: ty.Tuple[int, ...],
init: ty.Union[bool, float, list, tuple, np.ndarray] = 0,
shareable: bool = True,
name: str = "Unnamed variable",
):
"""Initializes a new Lava variable.
Expand All @@ -61,8 +62,8 @@ def __init__(
super().__init__(shape)
self.init = init
self.shareable = shareable
self.name = name
self.id: int = VarServer().register(self)
self.name: str = "Unnamed variable"
self.aliased_var: ty.Optional[Var] = None
# VarModel generated during compilation
self._model: ty.Optional["AbstractVarModel"] = None
Expand Down
117 changes: 117 additions & 0 deletions src/lava/proc/bit_check/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np
import typing as ty

from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyRefPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel

from lava.proc.bit_check.process import BitCheck


class AbstractPyBitCheckModel(PyLoihiProcessModel):
"""Abstract implementation of BitCheckModel.
Specific implementations inherit from here.
"""

state: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

bits: int = LavaPyType(int, int)
layerid: int = LavaPyType(int, int)
debug: int = LavaPyType(int, int)


class AbstractBitCheckModel(AbstractPyBitCheckModel):
"""Abstract implementation of BitCheck process. This
short and simple ProcessModel can be used for quick
checking of bit-accurate process runs as to whether
bits will overflow when running on bit limited hardware.
"""

state: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

bits: int = LavaPyType(int, int)
layerid: int = LavaPyType(int, int)
debug: int = LavaPyType(int, int)
_overflowed: int = LavaPyType(int, int)

def post_guard(self):
return True

def run_post_mgmt(self):
value = self.state.read()

if self.debug == 1:
print("Value is: {} at time step: {}"
.format(value, self.time_step))

# If self.check_bit_overflow(value) is true,
# the value overflowed the allowed bits from self.bits
if self.check_bit_overflow(value):
self._overflowed = 1
if self.debug == 1:
if self.layerid:
print("layer id number: {}".format(self.layerid))
print(
"value.max: overflows {} bits {}".format(
self.bits, value.max()
)
)
print(
"max signed value {}".format(
self.max_signed_int_per_bits(self.bits)
)
)
print(
"value.min: overflows {} bits {}".format(
self.bits, value.min()
)
)
print(
"min signed value {}".format(
self.max_signed_int_per_bits(self.bits)
)
)

def check_bit_overflow(self, value: ty.Type[np.ndarray]):
value = value.astype(np.int32)
shift_amt = 32 - self.bits
# Shift value left by shift_amt and
# then shift value right by shift_amt,
# the result should equal unshifted value
# if the value did not overflow bits in self.bits
return not np.all(
((value << shift_amt) >> shift_amt) == value
)

def max_unsigned_int_per_bits(self, bits: ty.Type[int]):
return (1 << bits) - 1

def min_signed_int_per_bits(self, bits: ty.Type[int]):
return -1 << (bits - 1)

def max_signed_int_per_bits(self, bits: ty.Type[int]):
return (1 << (bits - 1)) - 1


@implements(proc=BitCheck, protocol=LoihiProtocol)
@requires(CPU)
class LoihiBitCheckModel(AbstractBitCheckModel):
"""Implementation of Loihi BitCheck process. This
short and simple ProcessModel can be used for quick
checking of Loihi bit-accurate process run as to
whether bits will overflow when running on Loihi Hardware.
"""

state: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

bits: int = LavaPyType(int, int)
layerid: int = LavaPyType(int, int)
debug: int = LavaPyType(int, int)
93 changes: 93 additions & 0 deletions src/lava/proc/bit_check/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import typing as ty

from lava.magma.core.process.process import LogConfig, AbstractProcess
from lava.magma.core.process.ports.ports import RefPort
from lava.magma.core.process.variable import Var


class BitCheck(AbstractProcess):
def __init__(
self,
*,
layerid: ty.Optional[int] = None,
debug: ty.Optional[int] = 0,
bits: ty.Optional[int] = 24,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
**kwargs,
) -> None:
"""BitCheck process.
This process is used for quick checking of
bit-accurate process run as to whether bits will
overflow when running on bit limited hardware.
Parameters
----------
shape: Tuple
Shape of the sigma process.
Default is (1,).
layerid: int or float
Layer number of network.
Default is None.
debug: 0 or 1
Enable (1) or disable (0) debug print.
Default is 0.
bits: int
Bits to use when checking overflow, 1-32.
Default is 24.
"""
super().__init__(
name=name,
log_config=log_config,
**kwargs
)

initial_shape = (1,)
self.state = RefPort(initial_shape)

self.layerid: ty.Type(Var) = Var(shape=initial_shape, init=layerid)
self.debug: ty.Type(Var) = Var(shape=initial_shape, init=debug)
if bits <= 31 and bits >= 1:
self.bits: ty.Type(Var) = Var(shape=initial_shape, init=bits)
else:
raise ValueError("bits value is \
{} but should be 1-31".format(bits))
self._overflowed: ty.Type(Var) = Var(shape=initial_shape, init=0)

def connect_var(self, var: Var) -> None:
self.state = RefPort(var.shape)
self.state.connect_var(var)

self.layerid = Var(name="layerid",
shape=var.shape,
init=self.layerid.init
)
self.debug = Var(name="debug",
shape=var.shape,
init=self.debug.init
)
self.bits = Var(name="bits",
shape=var.shape,
init=self.bits.init
)
self._overflowed = Var(name="overflowed",
shape=var.shape,
init=self._overflowed.init
)

self._post_init()

@property
def shape(self) -> ty.Tuple[int, ...]:
"""Return shape of the Process."""
return self.state.shape

@property
def overflowed(self) -> ty.Type[int]:
"""Return overflow Var of Process.
1 is overflowed, 0 is not overflowed."""
return self._overflowed.get()
Empty file.
Loading

0 comments on commit 429c3bc

Please sign in to comment.