Skip to content

Commit

Permalink
fix: multiple fixes in source reference calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
pablojhl committed Jan 23, 2025
1 parent 5bbcb1d commit 434c79a
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 13 deletions.
17 changes: 8 additions & 9 deletions nada_dsl/nada_types/scalar_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ def equals_operation(
return Boolean(value=bool(f(left.value, right.value)))
case Mode.PUBLIC:
operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return PublicBoolean(child=operation)
case Mode.SECRET:
operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return SecretBoolean(child=operation)

Expand Down Expand Up @@ -214,7 +214,6 @@ def __radd__(self, other):
other
)
return self.__add__(other_type)

return self.__add__(other)


Expand All @@ -233,7 +232,7 @@ def binary_arithmetic_operation(
return new_scalar_type(mode, base_type)(f(left.value, right.value))
case Mode.PUBLIC | Mode.SECRET:
child = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, base_type)(child)

Expand All @@ -255,7 +254,7 @@ def shift_operation(
return new_scalar_type(mode, base_type)(f(left.value, right.value))
case Mode.PUBLIC | Mode.SECRET:
child = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, base_type)(child)

Expand All @@ -273,7 +272,7 @@ def binary_relational_operation(
return new_scalar_type(mode, BaseType.BOOLEAN)(f(left.value, right.value)) # type: ignore
case Mode.PUBLIC | Mode.SECRET:
child = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, BaseType.BOOLEAN)(child) # type: ignore

Expand All @@ -290,7 +289,7 @@ def public_equals_operation(

return PublicBoolean(
child=PublicOutputEquality(
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
) # type: ignore
)

Expand Down Expand Up @@ -346,12 +345,12 @@ def binary_logical_operation(
return Boolean(value=bool(f(left.value, right.value)))
if mode == Mode.PUBLIC:
operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return PublicBoolean(child=operation)

operation = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame().back_frame()
left=left, right=right, source_ref=SourceRef.back_frame()
)
return SecretBoolean(child=operation)

Expand Down
9 changes: 7 additions & 2 deletions nada_dsl/source_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class SourceRef:
@classmethod
def back_frame(cls) -> "SourceRef":
"""Get the source reference of the calling frame."""
backend_frame = inspect.currentframe().f_back.f_back
backend_frame = inspect.currentframe()
while "nada_dsl/" in backend_frame.f_code.co_filename:
backend_frame = backend_frame.f_back

lineno = backend_frame.f_lineno
(offset, length) = SourceRef.try_get_line_info(backend_frame, lineno)
return cls(
Expand Down Expand Up @@ -66,7 +69,9 @@ def try_get_line_info(backend_frame, lineno) -> Tuple[int, int]:
return 0, 0

lines = src.splitlines()
if lineno < len(lines):

# lineno starts counting from 1
if lineno <= len(lines):
offset = 0
for i in range(lineno - 1):
offset += len(lines[i]) + 1
Expand Down
5 changes: 5 additions & 0 deletions test-programs/lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from nada_dsl import PublicInteger


def function(a: PublicInteger, b: PublicInteger) -> PublicInteger:
return a + b
26 changes: 26 additions & 0 deletions test-programs/multiple_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from nada_dsl import *
from lib import function


def nada_main():
party1 = Party(name="Party1")
my_int1 = PublicInteger(Input(name="my_int1", party=party1))
my_int2 = PublicInteger(Input(name="my_int2", party=party1))

addition = my_int1 * my_int2
equals = my_int1 == my_int2
pow = my_int1**my_int2
sum_list = sum([my_int1, my_int2])
shift_l = my_int1 << UnsignedInteger(2)
shift_r = my_int1 >> UnsignedInteger(2)
function_result = function(my_int1, my_int2)

return [
Output(addition, "addition", party1),
Output(equals, "equals", party1),
Output(pow, "pow", party1),
Output(sum_list, "sum_list", party1),
Output(shift_l, "shift_l", party1),
Output(shift_r, "shift_r", party1),
Output(function_result, "function_result", party1),
]
92 changes: 92 additions & 0 deletions tests/source_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import sys
import os.path

from nada_dsl.compiler_frontend import nada_dsl_to_nada_mir, print_operations, print_mir
from nada_mir_proto.nillion.nada.mir.v1 import ProgramMir, SourceRef
from nada_dsl.errors import MissingEntryPointError, MissingProgramArgumentError
from tests.compile_test import get_test_programs_folder


def mir_model(script_path) -> ProgramMir:
script_dir = os.path.dirname(script_path)
sys.path.insert(0, script_dir)
script_name = os.path.basename(script_path)
if script_name.endswith(".py"):
script_name = script_name[:-3]
script = __import__(script_name)

try:
main = getattr(script, "nada_main")
except Exception as exc:
raise MissingEntryPointError(
"'nada_dsl' entrypoint function is missing in program " + script_name
) from exc

outputs = main()
return nada_dsl_to_nada_mir(outputs)


def assert_source_ref(
mir: ProgramMir, source_ref_index: int, file: str, lineno: int, offset: int, length: int
):
source_ref = mir.source_refs[source_ref_index]
assert source_ref.file == file
assert source_ref.lineno == lineno
assert source_ref.offset == offset
assert source_ref.length == length


def test_multiple_operations():
mir = mir_model(f"{get_test_programs_folder()}/multiple_operations.py")


# party1 = Party(name="Party1")
assert_source_ref(mir, mir.parties[0].source_ref_index, "multiple_operations.py", 6, 67, 33)

# my_int1 = PublicInteger(Input(name="my_int1", party=party1))
assert_source_ref(mir, mir.inputs[0].source_ref_index, "multiple_operations.py", 7, 101, 64)
assert_source_ref(mir, mir.operations[0].operation.source_ref_index, "multiple_operations.py", 7, 101, 64)
# my_int2 = PublicInteger(Input(name="my_int2", party=party1))
assert_source_ref(mir, mir.inputs[1].source_ref_index, "multiple_operations.py", 8, 166, 64)
assert_source_ref(mir, mir.operations[1].operation.source_ref_index, "multiple_operations.py", 8, 166, 64)

# addition = my_int1 * my_int2
assert_source_ref(mir, mir.operations[2].operation.source_ref_index, "multiple_operations.py", 10, 232, 32)
# equals = my_int1 == my_int2
assert_source_ref(mir, mir.operations[3].operation.source_ref_index, "multiple_operations.py", 11, 265, 31)
# pow = my_int1**my_int2
assert_source_ref(mir, mir.operations[4].operation.source_ref_index, "multiple_operations.py", 12, 297, 26)
# sum_list = sum([my_int1, my_int2]) = 0 + my_int1 + my_int2
# literal_ref
assert_source_ref(mir, mir.operations[5].operation.source_ref_index, "multiple_operations.py", 13, 324, 38)
# 0 + my_int1
assert_source_ref(mir, mir.operations[6].operation.source_ref_index, "multiple_operations.py", 13, 324, 38)
# 0 + my_int1 + my_int2
assert_source_ref(mir, mir.operations[7].operation.source_ref_index, "multiple_operations.py", 13, 324, 38)
# shift_l = my_int1 << UnsignedInteger(2)
# literal_ref
assert_source_ref(mir, mir.operations[8].operation.source_ref_index, "multiple_operations.py", 14, 363, 43)
# my_int1 << UnsignedInteger(2)
assert_source_ref(mir, mir.operations[9].operation.source_ref_index, "multiple_operations.py", 14, 363, 43)
# shift_l = my_int1 >> UnsignedInteger(2)
# literal_ref
assert_source_ref(mir, mir.operations[10].operation.source_ref_index, "multiple_operations.py", 15, 407, 43)
# my_int1 >> UnsignedInteger(2)
assert_source_ref(mir, mir.operations[11].operation.source_ref_index, "multiple_operations.py", 15, 407, 43)
# function_result = function(my_int1, my_int2) -> return a + b
assert_source_ref(mir, mir.operations[12].operation.source_ref_index, "lib.py", 5, 104, 16)

# Output(addition, "addition", party1),
assert_source_ref(mir, mir.outputs[0].source_ref_index, "multiple_operations.py", 19, 514, 45)
# Output(equals, "equals", party1),
assert_source_ref(mir, mir.outputs[1].source_ref_index, "multiple_operations.py", 20, 560, 41)
# Output(pow, "pow", party1),
assert_source_ref(mir, mir.outputs[2].source_ref_index, "multiple_operations.py", 21, 602, 35)
# Output(sum_list, "sum_list", party1),
assert_source_ref(mir, mir.outputs[3].source_ref_index, "multiple_operations.py", 22, 638, 45)
# Output(shift_l, "shift_l", party1),
assert_source_ref(mir, mir.outputs[4].source_ref_index, "multiple_operations.py", 23, 684, 43)
# Output(shift_r, "shift_r", party1),
assert_source_ref(mir, mir.outputs[5].source_ref_index, "multiple_operations.py", 24, 728, 43)
# Output(function_result, "function_result", party1),
assert_source_ref(mir, mir.outputs[6].source_ref_index, "multiple_operations.py", 25, 772, 59)
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 434c79a

Please sign in to comment.