Skip to content

Commit

Permalink
fix: multiple fixes in source reference calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
pablojhl committed Jan 23, 2025
1 parent 5bbcb1d commit 205fc81
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 13 deletions.
25 changes: 14 additions & 11 deletions nada_dsl/nada_types/scalar_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ def equals_operation(
return Boolean(value=bool(f(left.value, right.value)))
case Mode.PUBLIC:
operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return PublicBoolean(child=operation)
case Mode.SECRET:
operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return SecretBoolean(child=operation)

Expand Down Expand Up @@ -213,9 +213,12 @@ def __radd__(self, other):
other_type = new_scalar_type(mode=Mode.CONSTANT, base_type=self.base_type)(
other
)
return self.__add__(other_type)

return self.__add__(other)
return binary_arithmetic_operation(
"Addition", "+", self, other_type, lambda lhs, rhs: lhs + rhs
)
return binary_arithmetic_operation(
"Addition", "+", self, other, lambda lhs, rhs: lhs + rhs
)


def binary_arithmetic_operation(
Expand All @@ -233,7 +236,7 @@ def binary_arithmetic_operation(
return new_scalar_type(mode, base_type)(f(left.value, right.value))
case Mode.PUBLIC | Mode.SECRET:
child = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, base_type)(child)

Expand All @@ -255,7 +258,7 @@ def shift_operation(
return new_scalar_type(mode, base_type)(f(left.value, right.value))
case Mode.PUBLIC | Mode.SECRET:
child = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, base_type)(child)

Expand All @@ -273,7 +276,7 @@ def binary_relational_operation(
return new_scalar_type(mode, BaseType.BOOLEAN)(f(left.value, right.value)) # type: ignore
case Mode.PUBLIC | Mode.SECRET:
child = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, BaseType.BOOLEAN)(child) # type: ignore

Expand All @@ -290,7 +293,7 @@ def public_equals_operation(

return PublicBoolean(
child=PublicOutputEquality(
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
) # type: ignore
)

Expand Down Expand Up @@ -346,12 +349,12 @@ def binary_logical_operation(
return Boolean(value=bool(f(left.value, right.value)))
if mode == Mode.PUBLIC:
operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return PublicBoolean(child=operation)

operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return SecretBoolean(child=operation)

Expand Down
9 changes: 7 additions & 2 deletions nada_dsl/source_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class SourceRef:
@classmethod
def back_frame(cls) -> "SourceRef":
"""Get the source reference of the calling frame."""
backend_frame = inspect.currentframe().f_back.f_back
backend_frame = inspect.currentframe()
while "nada_dsl/" in backend_frame.f_code.co_filename:
backend_frame = backend_frame.f_back

lineno = backend_frame.f_lineno
(offset, length) = SourceRef.try_get_line_info(backend_frame, lineno)
return cls(
Expand Down Expand Up @@ -66,7 +69,9 @@ def try_get_line_info(backend_frame, lineno) -> Tuple[int, int]:
return 0, 0

lines = src.splitlines()
if lineno < len(lines):

# lineno starts counting from 1
if lineno <= len(lines):
offset = 0
for i in range(lineno - 1):
offset += len(lines[i]) + 1
Expand Down
5 changes: 5 additions & 0 deletions test-programs/lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from build.lib.nada_dsl.audit.abstract import PublicInteger


def function(a: PublicInteger, b: PublicInteger) -> PublicInteger:
return a + b
26 changes: 26 additions & 0 deletions test-programs/multiple_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from nada_dsl import *
from lib import function


def nada_main():
party1 = Party(name="Party1")
my_int1 = PublicInteger(Input(name="my_int1", party=party1))
my_int2 = PublicInteger(Input(name="my_int2", party=party1))

addition = my_int1 * my_int2
equals = my_int1 == my_int2
pow = my_int1 ** my_int2
sum_list = sum([my_int1, my_int2])
shift_l = my_int1 << UnsignedInteger(2)
shift_r = my_int1 >> UnsignedInteger(2)
function_result = function(my_int1, my_int2)

return [
Output(addition, "addition", party1),
Output(equals, "equals", party1),
Output(pow, "pow", party1),
Output(sum_list, "sum_list", party1),
Output(shift_l, "shift_l", party1),
Output(shift_r, "shift_r", party1),
Output(function_result, "function_result", party1),
]
53 changes: 53 additions & 0 deletions tests/source_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

import sys
import os.path

from nada_dsl.compiler_frontend import nada_dsl_to_nada_mir, print_operations, print_mir
from nada_mir_proto.nillion.nada.mir.v1 import ProgramMir, SourceRef
from nada_dsl.errors import MissingEntryPointError, MissingProgramArgumentError
from tests.compile_test import get_test_programs_folder


def mir_model(script_path) -> ProgramMir:
script_dir = os.path.dirname(script_path)
sys.path.insert(0, script_dir)
script_name = os.path.basename(script_path)
if script_name.endswith(".py"):
script_name = script_name[:-3]
script = __import__(script_name)

try:
main = getattr(script, "nada_main")
except Exception as exc:
raise MissingEntryPointError(
"'nada_dsl' entrypoint function is missing in program " + script_name
) from exc

outputs = main()
return nada_dsl_to_nada_mir(outputs)

def assert_source_ref(source_ref: SourceRef, file: str, lineno: int, offset: int, length: int):
assert source_ref.file == file
assert source_ref.lineno == lineno
assert source_ref.offset == offset
assert source_ref.length == length

def test_multiple_operations():
mir = mir_model(f"{get_test_programs_folder()}/multiple_operations.py")
assert_source_ref(mir.source_refs[0], 'multiple_operations.py', 10, 232, 32)
assert_source_ref(mir.source_refs[1], 'multiple_operations.py', 8, 166, 64)
assert_source_ref(mir.source_refs[2], 'multiple_operations.py', 7, 101, 64)
assert_source_ref(mir.source_refs[3], 'multiple_operations.py', 19, 516, 45)
assert_source_ref(mir.source_refs[4], 'multiple_operations.py', 11, 265, 31)
assert_source_ref(mir.source_refs[5], 'multiple_operations.py', 20, 562, 41)
assert_source_ref(mir.source_refs[6], 'multiple_operations.py', 12, 297, 28)
assert_source_ref(mir.source_refs[7], 'multiple_operations.py', 21, 604, 35)
assert_source_ref(mir.source_refs[8], 'multiple_operations.py', 13, 326, 38)
assert_source_ref(mir.source_refs[9], 'multiple_operations.py', 22, 640, 45)
assert_source_ref(mir.source_refs[10], 'multiple_operations.py', 14, 365, 43)
assert_source_ref(mir.source_refs[11], 'multiple_operations.py', 23, 686, 43)
assert_source_ref(mir.source_refs[12], 'multiple_operations.py', 15, 409, 43)
assert_source_ref(mir.source_refs[13], 'multiple_operations.py', 24, 730, 43)
assert_source_ref(mir.source_refs[14], 'lib.py', 5, 129, 16)
assert_source_ref(mir.source_refs[15], 'multiple_operations.py', 25, 774, 59)
assert_source_ref(mir.source_refs[16], 'multiple_operations.py', 6, 67, 33)

0 comments on commit 205fc81

Please sign in to comment.