forked from lava-nc/lava
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
481 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import numpy as np | ||
import typing as ty | ||
|
||
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol | ||
from lava.magma.core.model.py.ports import PyRefPort | ||
from lava.magma.core.model.py.type import LavaPyType | ||
from lava.magma.core.resources import CPU | ||
from lava.magma.core.decorator import implements, requires | ||
from lava.magma.core.model.py.model import PyLoihiProcessModel | ||
|
||
from lava.proc.bit_check.process import BitCheck | ||
|
||
|
||
class AbstractPyBitCheckModel(PyLoihiProcessModel): | ||
"""Abstract implementation of BitCheckModel. | ||
Specific implementations inherit from here. | ||
""" | ||
|
||
state: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) | ||
|
||
bits: int = LavaPyType(int, int) | ||
layerid: int = LavaPyType(int, int) | ||
debug: int = LavaPyType(int, int) | ||
|
||
|
||
class AbstractBitCheckModel(AbstractPyBitCheckModel): | ||
"""Abstract implementation of BitCheck process. This | ||
short and simple ProcessModel can be used for quick | ||
checking of bit-accurate process runs as to whether | ||
bits will overflow when running on bit limited hardware. | ||
""" | ||
|
||
state: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) | ||
|
||
bits: int = LavaPyType(int, int) | ||
layerid: int = LavaPyType(int, int) | ||
debug: int = LavaPyType(int, int) | ||
_overflowed: int = LavaPyType(int, int) | ||
|
||
def post_guard(self): | ||
return True | ||
|
||
def run_post_mgmt(self): | ||
value = self.state.read() | ||
|
||
if self.debug == 1: | ||
print("Value is: {} at time step: {}" | ||
.format(value, self.time_step)) | ||
|
||
# If self.check_bit_overflow(value) is true, | ||
# the value overflowed the allowed bits from self.bits | ||
if self.check_bit_overflow(value): | ||
self._overflowed = 1 | ||
if self.debug == 1: | ||
if self.layerid: | ||
print("layer id number: {}".format(self.layerid)) | ||
print( | ||
"value.max: overflows {} bits {}".format( | ||
self.bits, value.max() | ||
) | ||
) | ||
print( | ||
"max signed value {}".format( | ||
self.max_signed_int_per_bits(self.bits) | ||
) | ||
) | ||
print( | ||
"value.min: overflows {} bits {}".format( | ||
self.bits, value.min() | ||
) | ||
) | ||
print( | ||
"min signed value {}".format( | ||
self.max_signed_int_per_bits(self.bits) | ||
) | ||
) | ||
|
||
def check_bit_overflow(self, value: ty.Type[np.ndarray]): | ||
value = value.astype(np.int32) | ||
shift_amt = 32 - self.bits | ||
# Shift value left by shift_amt and | ||
# then shift value right by shift_amt, | ||
# the result should equal unshifted value | ||
# if the value did not overflow bits in self.bits | ||
return not np.all( | ||
((value << shift_amt) >> shift_amt) == value | ||
) | ||
|
||
def max_unsigned_int_per_bits(self, bits: ty.Type[int]): | ||
return (1 << bits) - 1 | ||
|
||
def min_signed_int_per_bits(self, bits: ty.Type[int]): | ||
return -1 << (bits - 1) | ||
|
||
def max_signed_int_per_bits(self, bits: ty.Type[int]): | ||
return (1 << (bits - 1)) - 1 | ||
|
||
|
||
@implements(proc=BitCheck, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
class LoihiBitCheckModel(AbstractBitCheckModel): | ||
"""Implementation of Loihi BitCheck process. This | ||
short and simple ProcessModel can be used for quick | ||
checking of Loihi bit-accurate process run as to | ||
whether bits will overflow when running on Loihi Hardware. | ||
""" | ||
|
||
state: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) | ||
|
||
bits: int = LavaPyType(int, int) | ||
layerid: int = LavaPyType(int, int) | ||
debug: int = LavaPyType(int, int) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import typing as ty | ||
|
||
from lava.magma.core.process.process import LogConfig, AbstractProcess | ||
from lava.magma.core.process.ports.ports import RefPort | ||
from lava.magma.core.process.variable import Var | ||
|
||
|
||
class BitCheck(AbstractProcess): | ||
def __init__( | ||
self, | ||
*, | ||
layerid: ty.Optional[int] = None, | ||
debug: ty.Optional[int] = 0, | ||
bits: ty.Optional[int] = 24, | ||
name: ty.Optional[str] = None, | ||
log_config: ty.Optional[LogConfig] = None, | ||
**kwargs, | ||
) -> None: | ||
"""BitCheck process. | ||
This process is used for quick checking of | ||
bit-accurate process run as to whether bits will | ||
overflow when running on bit limited hardware. | ||
Parameters | ||
---------- | ||
shape: Tuple | ||
Shape of the sigma process. | ||
Default is (1,). | ||
layerid: int or float | ||
Layer number of network. | ||
Default is None. | ||
debug: 0 or 1 | ||
Enable (1) or disable (0) debug print. | ||
Default is 0. | ||
bits: int | ||
Bits to use when checking overflow, 1-32. | ||
Default is 24. | ||
""" | ||
super().__init__( | ||
name=name, | ||
log_config=log_config, | ||
**kwargs | ||
) | ||
|
||
initial_shape = (1,) | ||
self.state = RefPort(initial_shape) | ||
|
||
self.layerid: ty.Type(Var) = Var(shape=initial_shape, init=layerid) | ||
self.debug: ty.Type(Var) = Var(shape=initial_shape, init=debug) | ||
if bits <= 31 and bits >= 1: | ||
self.bits: ty.Type(Var) = Var(shape=initial_shape, init=bits) | ||
else: | ||
raise ValueError("bits value is \ | ||
{} but should be 1-31".format(bits)) | ||
self._overflowed: ty.Type(Var) = Var(shape=initial_shape, init=0) | ||
|
||
def connect_var(self, var: Var) -> None: | ||
self.state = RefPort(var.shape) | ||
self.state.connect_var(var) | ||
|
||
self.layerid = Var(name="layerid", | ||
shape=var.shape, | ||
init=self.layerid.init | ||
) | ||
self.debug = Var(name="debug", | ||
shape=var.shape, | ||
init=self.debug.init | ||
) | ||
self.bits = Var(name="bits", | ||
shape=var.shape, | ||
init=self.bits.init | ||
) | ||
self._overflowed = Var(name="overflowed", | ||
shape=var.shape, | ||
init=self._overflowed.init | ||
) | ||
|
||
self._post_init() | ||
|
||
@property | ||
def shape(self) -> ty.Tuple[int, ...]: | ||
"""Return shape of the Process.""" | ||
return self.state.shape | ||
|
||
@property | ||
def overflowed(self) -> ty.Type[int]: | ||
"""Return overflow Var of Process. | ||
1 is overflowed, 0 is not overflowed.""" | ||
return self._overflowed.get() |
Empty file.
Oops, something went wrong.