diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index 67ed015..318e226 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -100,12 +100,12 @@ def equals_operation( return Boolean(value=bool(f(left.value, right.value))) case Mode.PUBLIC: operation = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) return PublicBoolean(child=operation) case Mode.SECRET: operation = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) return SecretBoolean(child=operation) @@ -214,7 +214,6 @@ def __radd__(self, other): other ) return self.__add__(other_type) - return self.__add__(other) @@ -233,7 +232,7 @@ def binary_arithmetic_operation( return new_scalar_type(mode, base_type)(f(left.value, right.value)) case Mode.PUBLIC | Mode.SECRET: child = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) return new_scalar_type(mode, base_type)(child) @@ -255,7 +254,7 @@ def shift_operation( return new_scalar_type(mode, base_type)(f(left.value, right.value)) case Mode.PUBLIC | Mode.SECRET: child = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) return new_scalar_type(mode, base_type)(child) @@ -273,7 +272,7 @@ def binary_relational_operation( return new_scalar_type(mode, BaseType.BOOLEAN)(f(left.value, right.value)) # type: ignore case Mode.PUBLIC | Mode.SECRET: child = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) return new_scalar_type(mode, BaseType.BOOLEAN)(child) # type: ignore @@ -290,7 +289,7 @@ def public_equals_operation( return PublicBoolean( child=PublicOutputEquality( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) # type: ignore ) @@ -346,12 +345,12 @@ def binary_logical_operation( return Boolean(value=bool(f(left.value, right.value))) if mode == Mode.PUBLIC: operation = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) return PublicBoolean(child=operation) operation = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame().back_frame() + left=left, right=right, source_ref=SourceRef.back_frame() ) return SecretBoolean(child=operation) diff --git a/nada_dsl/source_ref.py b/nada_dsl/source_ref.py index 77c3c4f..a40552f 100644 --- a/nada_dsl/source_ref.py +++ b/nada_dsl/source_ref.py @@ -33,7 +33,10 @@ class SourceRef: @classmethod def back_frame(cls) -> "SourceRef": """Get the source reference of the calling frame.""" - backend_frame = inspect.currentframe().f_back.f_back + backend_frame = inspect.currentframe() + while "nada_dsl/" in backend_frame.f_code.co_filename: + backend_frame = backend_frame.f_back + lineno = backend_frame.f_lineno (offset, length) = SourceRef.try_get_line_info(backend_frame, lineno) return cls( @@ -66,7 +69,9 @@ def try_get_line_info(backend_frame, lineno) -> Tuple[int, int]: return 0, 0 lines = src.splitlines() - if lineno < len(lines): + + # lineno starts counting from 1 + if lineno <= len(lines): offset = 0 for i in range(lineno - 1): offset += len(lines[i]) + 1 diff --git a/test-programs/lib.py b/test-programs/lib.py new file mode 100644 index 0000000..d0721de --- /dev/null +++ b/test-programs/lib.py @@ -0,0 +1,5 @@ +from nada_dsl import PublicInteger + + +def function(a: PublicInteger, b: PublicInteger) -> PublicInteger: + return a + b diff --git a/test-programs/multiple_operations.py b/test-programs/multiple_operations.py new file mode 100644 index 0000000..f8a4250 --- /dev/null +++ b/test-programs/multiple_operations.py @@ -0,0 +1,26 @@ +from nada_dsl import * +from lib import function + + +def nada_main(): + party1 = Party(name="Party1") + my_int1 = PublicInteger(Input(name="my_int1", party=party1)) + my_int2 = PublicInteger(Input(name="my_int2", party=party1)) + + addition = my_int1 * my_int2 + equals = my_int1 == my_int2 + pow = my_int1**my_int2 + sum_list = sum([my_int1, my_int2]) + shift_l = my_int1 << UnsignedInteger(2) + shift_r = my_int1 >> UnsignedInteger(2) + function_result = function(my_int1, my_int2) + + return [ + Output(addition, "addition", party1), + Output(equals, "equals", party1), + Output(pow, "pow", party1), + Output(sum_list, "sum_list", party1), + Output(shift_l, "shift_l", party1), + Output(shift_r, "shift_r", party1), + Output(function_result, "function_result", party1), + ] diff --git a/tests/source_ref.py b/tests/source_ref.py new file mode 100644 index 0000000..763c8e7 --- /dev/null +++ b/tests/source_ref.py @@ -0,0 +1,92 @@ +import sys +import os.path + +from nada_dsl.compiler_frontend import nada_dsl_to_nada_mir, print_operations, print_mir +from nada_mir_proto.nillion.nada.mir.v1 import ProgramMir, SourceRef +from nada_dsl.errors import MissingEntryPointError, MissingProgramArgumentError +from tests.compile_test import get_test_programs_folder + + +def mir_model(script_path) -> ProgramMir: + script_dir = os.path.dirname(script_path) + sys.path.insert(0, script_dir) + script_name = os.path.basename(script_path) + if script_name.endswith(".py"): + script_name = script_name[:-3] + script = __import__(script_name) + + try: + main = getattr(script, "nada_main") + except Exception as exc: + raise MissingEntryPointError( + "'nada_dsl' entrypoint function is missing in program " + script_name + ) from exc + + outputs = main() + return nada_dsl_to_nada_mir(outputs) + + +def assert_source_ref( + mir: ProgramMir, source_ref_index: int, file: str, lineno: int, offset: int, length: int +): + source_ref = mir.source_refs[source_ref_index] + assert source_ref.file == file + assert source_ref.lineno == lineno + assert source_ref.offset == offset + assert source_ref.length == length + + +def test_multiple_operations(): + mir = mir_model(f"{get_test_programs_folder()}/multiple_operations.py") + + + # party1 = Party(name="Party1") + assert_source_ref(mir, mir.parties[0].source_ref_index, "multiple_operations.py", 6, 67, 33) + + # my_int1 = PublicInteger(Input(name="my_int1", party=party1)) + assert_source_ref(mir, mir.inputs[0].source_ref_index, "multiple_operations.py", 7, 101, 64) + assert_source_ref(mir, mir.operations[0].operation.source_ref_index, "multiple_operations.py", 7, 101, 64) + # my_int2 = PublicInteger(Input(name="my_int2", party=party1)) + assert_source_ref(mir, mir.inputs[1].source_ref_index, "multiple_operations.py", 8, 166, 64) + assert_source_ref(mir, mir.operations[1].operation.source_ref_index, "multiple_operations.py", 8, 166, 64) + + # addition = my_int1 * my_int2 + assert_source_ref(mir, mir.operations[2].operation.source_ref_index, "multiple_operations.py", 10, 232, 32) + # equals = my_int1 == my_int2 + assert_source_ref(mir, mir.operations[3].operation.source_ref_index, "multiple_operations.py", 11, 265, 31) + # pow = my_int1**my_int2 + assert_source_ref(mir, mir.operations[4].operation.source_ref_index, "multiple_operations.py", 12, 297, 26) + # sum_list = sum([my_int1, my_int2]) = 0 + my_int1 + my_int2 + # literal_ref + assert_source_ref(mir, mir.operations[5].operation.source_ref_index, "multiple_operations.py", 13, 324, 38) + # 0 + my_int1 + assert_source_ref(mir, mir.operations[6].operation.source_ref_index, "multiple_operations.py", 13, 324, 38) + # 0 + my_int1 + my_int2 + assert_source_ref(mir, mir.operations[7].operation.source_ref_index, "multiple_operations.py", 13, 324, 38) + # shift_l = my_int1 << UnsignedInteger(2) + # literal_ref + assert_source_ref(mir, mir.operations[8].operation.source_ref_index, "multiple_operations.py", 14, 363, 43) + # my_int1 << UnsignedInteger(2) + assert_source_ref(mir, mir.operations[9].operation.source_ref_index, "multiple_operations.py", 14, 363, 43) + # shift_l = my_int1 >> UnsignedInteger(2) + # literal_ref + assert_source_ref(mir, mir.operations[10].operation.source_ref_index, "multiple_operations.py", 15, 407, 43) + # my_int1 >> UnsignedInteger(2) + assert_source_ref(mir, mir.operations[11].operation.source_ref_index, "multiple_operations.py", 15, 407, 43) + # function_result = function(my_int1, my_int2) -> return a + b + assert_source_ref(mir, mir.operations[12].operation.source_ref_index, "lib.py", 5, 104, 16) + + # Output(addition, "addition", party1), + assert_source_ref(mir, mir.outputs[0].source_ref_index, "multiple_operations.py", 19, 514, 45) + # Output(equals, "equals", party1), + assert_source_ref(mir, mir.outputs[1].source_ref_index, "multiple_operations.py", 20, 560, 41) + # Output(pow, "pow", party1), + assert_source_ref(mir, mir.outputs[2].source_ref_index, "multiple_operations.py", 21, 602, 35) + # Output(sum_list, "sum_list", party1), + assert_source_ref(mir, mir.outputs[3].source_ref_index, "multiple_operations.py", 22, 638, 45) + # Output(shift_l, "shift_l", party1), + assert_source_ref(mir, mir.outputs[4].source_ref_index, "multiple_operations.py", 23, 684, 43) + # Output(shift_r, "shift_r", party1), + assert_source_ref(mir, mir.outputs[5].source_ref_index, "multiple_operations.py", 24, 728, 43) + # Output(function_result, "function_result", party1), + assert_source_ref(mir, mir.outputs[6].source_ref_index, "multiple_operations.py", 25, 772, 59) diff --git a/uv.lock b/uv.lock index 2e2c2e9..91a9491 100644 --- a/uv.lock +++ b/uv.lock @@ -628,7 +628,7 @@ wheels = [ [[package]] name = "nada-dsl" -version = "0.8.0rc1" +version = "0.8.0rc2" source = { editable = "." } dependencies = [ { name = "asttokens" }, @@ -707,7 +707,7 @@ dev = [ [[package]] name = "nada-mir-proto" -version = "0.1.0" +version = "0.2.0rc1" source = { editable = "nada_mir" } dependencies = [ { name = "betterproto" },