diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2099a8d --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +Prover.toml +Verifier.toml \ No newline at end of file diff --git a/example.py b/example.py index 6f5a1d7..760e25c 100644 --- a/example.py +++ b/example.py @@ -22,7 +22,7 @@ mlp, 'src/main.nr', 'Prover.toml', - '../zkfloat/lib.nr', + '../zkfloat/zkfloat.nr', 7 ) diff --git a/proofs/mlp-proof.proof b/proofs/mlp-proof.proof new file mode 100644 index 0000000..b6b5f4b --- /dev/null +++ b/proofs/mlp-proof.proof @@ -0,0 +1 @@ +19342440604102b52f14b0f87de5843660fcd75d80bf02b1e3d1a19ec2041ec30e73aad005e70aa5960820ea8d892073e626c9d1f784da9b450568bbd96ed84f094ffdc414c4bff828af65f8f9337ff56c30b525d2230b0ad49694b93ff7edb10421af5f3f654c811b7b50749dc594ed5c34928a7bbd5de159d8c294c5f6f0220074d2610d5800c865fd2dd91ea3b0bb0eb8c9696c773028e46585976ed6970c295455b7dc3d4cbcb66b046aa90f7dc6c131a48487b54b14b8bf1a8fac5dea601b18bee62f9ada54217593631ed8a7c0b0e3212e3bc5562cfab7f8c565468e0d0f05a84d749d41fc144140c31a0bd4824931f51e9843987542057e403ad2ac720b5ebaf074955b88a1f491a3043f46ffc817606f901c80ec1def23a6da97e6bf22af4ded3cc2979fcfbae78df92ff8aba420f29ccd24ffb33b1cd8125c621cb528275129c59191e7c0eeedee52810351c5394c53e0326561d10169be791f5bdb211a905b57e902c57ab3d0be0efbd287451577366f7137050de3fb6bef42ddd9155b3ffea3d919bb1367b559b6214d9ab8e0c9d8619e3750f7ef2cbff90538860af19e54a647c4353260aef5f7ea85843c7396f88495ce3b198b7fe90f64ced822fea70ece6f08e8623058e4e614dcaaf71d9fece6e8e599f4b5eb2f6d51042e0b0b97da63f59b282de5aac4cab8e8de907f288bea19c3907d2dff3c0589233c22c095efd68f36d1fc9270bfbafabbd2999d91610354bfd21358ad84d1a280d519f3cb041c3f9286f3ebab5f4609906943290e84f0224b71dc35cff6483862102b277500bcc11b5802b695818de50bb589de6ceba53d904478ccb41134976b2703d881d47496311c8ada54a5ab54c951045c5b7e3bf5fab798097f647712f62e159a72479f9d67df1787658ccd40bb9ba3ed718aea1ba3aa301b5afc818062c3096a41efa491e1b1e2924f77c27f65f10ec7144aea2e88b02c6c679b29ef917d293ce8d3b257bfdbdd3d80660c3ce6e43cd2da65a35734924bc2922d0e3c152021f7f08e7368354ae802b6c4c02c1f22246dd22beb4e55b7aeb06a7ec940a12e1a1e647e3489aa3884c6a5cf3a65de32e624fc91c2b9733ef6c3fca7fe49229c26c6133fbb593de8a5e763c348e519871ff42b7999fd4c6a37ab9ca329dbb0330946f258e1f71daf9c4cff683c4b11cd782505ca836696c34dde0c8ca425673924a2f57e041b03cb5316518fda6fb56fc63c069b34b65b5ce7aeb82516217c3021338e565236ddb0d88f7b8c970b73b8f63eb43c849c480f1946840589eae96f0d31ced6b34c7dd277da87b573b760c91b0764a4c7bbcacc732eaf379a6a5ff82885d76de6d4f001e943f2fef2f06a17730cfc9b9c78342e2abbb9faeb60d5cc06570fb363564d8c25005c86fc062679811dc7049ef661dfe7e08a07161a978f20c35b807869f8eec45a89ba2021ee6df276d668fca7472b5031196389bad3cd2a94ce243807bf60debcf9dcb2ab8a5ad64fd71ada216b6718c7f321c22ef2db121116f10e8b01ea5282bce4eeff281fba8b74d8430a2e4770817c34b6ebcfea2e28be5adda0dbf91a8002eeba205f9bb012802c2cfbb5d95e63dcd91f41d3cd1396f77605708cd8c95ff3f3fdfc2271c6c347e588f29cab104fda178176d534 \ No newline at end of file diff --git a/skproof-package/build/lib/skproof/float_num/FloatNum.py b/skproof-package/build/lib/skproof/float_num/FloatNum.py new file mode 100644 index 0000000..30a8c63 --- /dev/null +++ b/skproof-package/build/lib/skproof/float_num/FloatNum.py @@ -0,0 +1,128 @@ +class FloatNum: + """ + Class for representing base10 floating point numbers + in format of base10 mantissa and base10 exponent. + The exponent is shifted by a fixed value to prevent usage of negative values + for negative exponents. + + Attributes + ---------- + mantissa : int + unsigned integer value representing the significant digits of the number (max "precision" digits) + exponent: int + signed integer value of the exponent + precision: int + unsigned integer value representing the number of the significant digits in mantissa + exp_pad: int + unsigned integer value for shifting the original exponent (default is 0) + + Methods + ------- + truncate() + Truncating the value of the mantissa to "precission" number of digits, + while updating the exponent + __add__(num_2) -> FloatNum + Overrides + operator, adds the value of the FloatNum to the other and returns the resulting FloatNum + __mul__(num_2) -> FloatNum + Overrides * operator, multiplies the value of the FloatNum to the other and returns the resulting FloatNum + __gt__(num_2) -> boolean + Overrides > operator, returns true if the value of the FloatNum is greater than the other one + __lt__(num_2) -> boolean + Overrides > operator, returns true if the value of the FloatNum is less than the other one + get_noir_input() -> string + Returns the string representation of the FloatNum object using Noir language struct syntax + get_prover_input() -> string + Returns the string representation of the FloatNum object using and array syntax for the Noir language Prover.toml input file + truncated() -> FloatNum + Returns truncated value of the FloatNum + __str__() -> string + Returns string representation of the FloatNum in scientific notation (e) + """ + + def __init__(self, mantissa, exponent, precision, exp_pad=0): + self.precision = precision + self.mantissa = mantissa + + # Example: + # Original exponent = -1 + # exp_pad = 100 + # shifted exponent = 99 + self.exponent = exponent + exp_pad + + def truncate(self): + tr = self.truncated(self.mantissa, self.exponent) + self.mantissa = tr.mantissa + self.exponent = tr.exponent + return self + + def __add__(self, num_2): + mant_1 = self.mantissa + mant_2 = num_2.mantissa + + exp_1 = self.exponent + exp_2 = num_2.exponent + exp = exp_1 + diff = abs(exp_1 - exp_2) + if self.exponent < num_2.exponent: + mant_2 *= 10 ** diff + exp = exp_2 - diff + else: + mant_1 *= 10 ** diff + exp = exp_1 - diff + + sum_mant = mant_1 + mant_2 + + return self.truncated(sum_mant, exp) + + def __mul__(self, num_2): + mant_1 = self.mantissa + mant_2 = num_2.mantissa + + exp_1 = self.exponent + exp_2 = num_2.exponent + + return self.truncated(mant_1 * mant_2, (exp_1 + exp_2 - 100)) + + def __gt__(self, num_2): + return self.mantissa * (10 ** (self.exponent - 100)) > num_2.mantissa * (10 ** (num_2.exponent - 100)) + + def __lt__(self, num_2): + return self.mantissa * (10 ** (self.exponent - 100)) < num_2.mantissa * (10 ** (num_2.exponent - 100)) + + def get_noir_input(self): + sign = 0 + mant = self.mantissa + exp = self.exponent + if mant < 0: + sign = 1 + mant = -mant + + return 'Float {' f'sign: {sign}, mantissa: {mant}, exponent: {exp}' + ' }' + + def get_prover_input(self): + sign = 0 + mant = self.mantissa + exp = self.exponent + if mant < 0: + sign = 1 + mant = -mant + + return f'["{sign}", "{mant}", "{exp}"]' + + def truncated(self, mant, exp): + if len(str(abs(mant))) > self.precision: + l = len(str(abs(mant))) + sign_comp = 0 + if mant < 0: + sign_comp = 1 + prec_diff = abs(l - self.precision) + mant = int(str(mant)[:self.precision + sign_comp]) + exp += prec_diff + + if mant == 0: + exp = 100 + + return FloatNum(mant, exp, self.precision) + + def __str__(self): + return f'{self.mantissa}e{self.exponent}' \ No newline at end of file diff --git a/skproof-package/build/lib/skproof/float_num/__init__.py b/skproof-package/build/lib/skproof/float_num/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skproof-package/build/lib/skproof/mlp/MLPClassifierProver.py b/skproof-package/build/lib/skproof/mlp/MLPClassifierProver.py new file mode 100644 index 0000000..1907f2d --- /dev/null +++ b/skproof-package/build/lib/skproof/mlp/MLPClassifierProver.py @@ -0,0 +1,301 @@ +import numpy as np +from skproof.float_num.FloatNum import FloatNum +import math +import os + + +class MLPClassifierProver: + """ + Class for generating noir prover for pre-trained MLPClassifier circuit. + + Attributes + ---------- + clf : object + Trained Sci-kit learn MLPClassifier model + + Methods + ------- + quant(q) -> FloatNum + Generates FloatNum object with given precision from decimal numbers + float_from_string(str) -> FloatNum + Creates FloatNum object with given precision from the number string representation in scientific notation + generate_statements(expressions, inputs, outputs) + Generates and outputs Noir language addition, multiplication and relu function call statements + based on generated expressions and input node labels. Generates "constrain" statements based on the outputs + generate_expressions() -> (expressions, inputs, outputs) + Generates expressions from the trained MLPClassifier model with the given precision + generate_ar_statement(arg_1, op, arg_2) -> string + Generates FloatNum addition and multiplication statements using Noir language syntax, based on the given argumens and operation + generate_ac_statement(value, act_func) -> string + Generates activation function call using Noir language syntax, based on the given argument + simulate_ann(expressions, inputs, outputs, data)->(y_pred, input_values, output_values) + Simulates neural network execution based on the given model and input data and outputs the results + import_lib() + Outputs Noir language struct Float and methods for computations using Float numbers + generate_circuit(X) + Outputs Noir language circuit for the given input data + prove(X) + Generates Noir language circuit and generates proof for the given input data + """ + + def __init__( + self, + mlp_classifier_model, + circuit_output_path, + prover_output_path, + float_num_lib_path, + precision=7, + exp_pad=100, + verbose=True + ): + self.clf = mlp_classifier_model + + self.circuit_output_path = circuit_output_path + self.circuit_output = open(circuit_output_path, 'w') + + self.prover_output_path = open(prover_output_path, 'w') + self.float_num_lib = open(float_num_lib_path, 'r') + self.precision = precision + self.exp_pad = exp_pad + self.verbose = verbose + + def quant(self, q): + return FloatNum(round(q * 10**((self.precision * 2))), -(2 * self.precision), self.precision, self.exp_pad).truncate() + + def float_from_string(self, str_num): + [mant, exp] = str_num.split('e') + return FloatNum(int(mant), int(exp), self.precision) + + def generate_ar_statement(self, res, arg_1, op, arg_2): + if op == '+': + return f'let {res} = addFloats({arg_1}, {arg_2});' + elif op == '*': + return f'let {res} = mulFloats({arg_1}, {arg_2});' + + def generate_ac_statement(self, res, value, act_func): + return f'let {res} = {act_func}({value});' + + def generate_statements(self, expressions, inputs, outputs): + statements = [] + + for e in expressions: + operation = e[0] + variable = e[1] + arg_1 = e[2] + arg_2 = None + + if len(e) > 3: + arg_2 = e[3] + + if isinstance(arg_1, FloatNum): + arg_1 = arg_1.get_noir_input() + + if isinstance(arg_2, FloatNum): + arg_2 = arg_2.get_noir_input() + + if operation == 'SUM': + sum_statement = self.generate_ar_statement( + variable, arg_1, '+', arg_2) + statements.append(sum_statement) + elif operation == 'MUL': + mul_statement = self.generate_ar_statement( + variable, arg_1, '*', arg_2) + statements.append(mul_statement) + elif operation == 'RELU': + relu_statement = self.generate_ac_statement( + variable, arg_1, 'relu') + statements.append(relu_statement) + else: + raise f'Invalid operation: {operation}' + + declarations = '' + + x_num = 1 + for input in inputs: + declarations += f'\tlet mut {input} = Float' + '{ ' + \ + f'sign: x_{x_num}[0], mantissa: x_{x_num}[1], exponent: x_{x_num}[2]' + ' };\n' + x_num += 1 + + expression_statements = '' + for st in statements: + expression_statements += f'\t{st}\n' + + constrains = '' + y_num = 1 + + main_args_arr = [] + main_fn = 'fn main(\n' + + x_num = 1 + for input in inputs: + main_args_arr.append(f'\tx_{x_num} : pub [Field; 3]') + x_num += 1 + + + for out in outputs: + main_args_arr.append(f'\ty_{y_num} : pub [Field; 3]') + constrains += f'\tconstrain {out}.sign == y_{y_num}[0];\n' + constrains += f'\tconstrain {out}.mantissa == y_{y_num}[1];\n' + constrains += f'\tconstrain {out}.exponent == y_{y_num}[2];\n' + y_num += 1 + + main_fn += ',\n'.join(main_args_arr) + main_fn += '\n) {\n' + + self.circuit_output.write(main_fn) + self.circuit_output.write(declarations) + self.circuit_output.write(expression_statements) + self.circuit_output.write('\n') + self.circuit_output.write(constrains) + self.circuit_output.write('}\n') + + def generate_expressions(self): + num_layers = len(self.clf.coefs_) + expressions = [] + + for l in range(num_layers): + layer = self.clf.coefs_[l] + ints = self.clf.intercepts_[l] + + for j in range(layer.shape[1]): + coefs = layer[:, j] + new_expressions = [] + + # Extract weight multiplications + for i in range(coefs.shape[0]): + new_expressions.append( + ['MUL', f'A_{l+1}_{j}_{i}', f'X_{l}_{i}', self.quant(coefs[i])]) + + # Extract sumations + partials = [] + for poly_i in range(1, len(new_expressions)): + if poly_i == 1: + sum_arg = ['SUM', f'B_{l}_{j}_{poly_i-1}', + new_expressions[0][1], new_expressions[1][1]] + else: + sum_arg = ['SUM', f'B_{l}_{j}_{poly_i-1}', + f'B_{l}_{j}_{poly_i-2}', new_expressions[poly_i][1]] + partials.append(sum_arg) + + partials.append( + ['SUM', f'WXb_{l}_{j}', f'B_{l}_{j}_{len(new_expressions)-2}', self.quant(ints[j])]) + + letter = 'X' + if l == num_layers - 1: + letter = 'O' + + # Add ReLU activation function statement + activations = [ + ['RELU', f'{letter}_{l+1}_{j}', f'WXb_{l}_{j}'] + ] + + expressions += new_expressions + expressions += partials + expressions += activations + + inputs = [f'X_0_{i}' for i in range(self.clf.coefs_[0].shape[0])] + outputs = [f'O_{len(self.clf.coefs_)}_{i}' for i in range( + self.clf.coefs_[-1].shape[1])] + + return (expressions, inputs, outputs) + + def simulate_ann(self, expressions, inputs, outputs, data): + y_pred = [] + for row in data: + node_values = {} + + # Init inputs + for i in range(len(inputs)): + input_node = inputs[i] + node_values[input_node] = self.quant(row[i]) + + for poly in expressions: + second_val = None + + if len(poly) == 4: + if isinstance(poly[3], FloatNum): + second_val = poly[3] + else: + second_val = node_values[poly[3]] + + if poly[0] == 'SUM': + node_values[poly[1]] = node_values[poly[2]] + second_val + if self.verbose: + self.circuit_output.write( + f"// {poly[1]} = {node_values[poly[2]]} + {second_val} = {node_values[poly[2]] + second_val}\n") + elif poly[0] == 'MUL': + # Verbose logs in Noir language comments + if self.verbose: + self.circuit_output.write( + f"// {poly[1]} = {node_values[poly[2]]} * {second_val} = {node_values[poly[2]] * second_val}\n") + node_values[poly[1]] = node_values[poly[2]] * second_val + elif poly[0] == 'RELU': + if node_values[poly[2]].mantissa < 0: + node_values[poly[1]] = FloatNum( + 0, 0, self.precision, self.exp_pad) + else: + node_values[poly[1]] = node_values[poly[2]] + + for key, value in node_values.items(): + if self.verbose: + self.circuit_output.write(f"// {key} => {value}\n") + + label = np.argmax([node_values[i] for i in outputs]) + y_pred.append(label) + + return y_pred, [node_values[p] for p in inputs], [node_values[o] for o in outputs] + + def import_lib(self): + lib_data = self.float_num_lib.read() + lib_data = lib_data.replace('maxValue : Field = 100000;', f'maxValue : Field = {10 ** self.precision};') + lib_data = lib_data.replace('maxLogValue : Field = 5;', f'maxLogValue : Field = {self.precision};') + self.circuit_output.write(lib_data) + self.circuit_output.write('\n') + + def generate_circuit(self, X): + if self.verbose: + print('Generating circuit...') + + # Write library into circuit file + self.import_lib() + + # Generate expressions from MLPClassifier model + expressions, inputs, outputs = self.generate_expressions() + + # Generate Noir language statements + self.generate_statements(expressions, inputs, outputs) + + # Generate prover file + prover_input_file = open('Prover.toml', 'w') + for i in range(X.shape[1]): + value = X[0, i].ravel()[0] + float_value = self.quant(value).get_prover_input() + prover_input_file.write(f'x_{i+1} = {float_value}\n') + + _, _, output_values = self.simulate_ann( + expressions, + inputs, + outputs, + X + ) + + for i in range(len(output_values)): + o = output_values[i].get_prover_input() + prover_input_file.write(f'y_{i+1} = {o}\n') + + self.circuit_output.close() + prover_input_file.close() + + if self.verbose: + print(f'Circuit generated successfuly in {self.circuit_output_path}') + + def prove(self, X): + self.generate_circuit(X) + path = self.circuit_output_path.split('/')[:-2] + + if self.verbose: + print('Generating proof, this may take a while...') + if len(path) != 0: + os.system(f'cd {"/".join(path)}') + os.system("nargo prove mlp-proof") + print(f'Done!') \ No newline at end of file diff --git a/skproof-package/build/lib/skproof/mlp/__init__.py b/skproof-package/build/lib/skproof/mlp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skproof-package/dist/skproof-0.1.0.tar.gz b/skproof-package/dist/skproof-0.1.0.tar.gz index 3a3cf2c..cad8b8f 100644 Binary files a/skproof-package/dist/skproof-0.1.0.tar.gz and b/skproof-package/dist/skproof-0.1.0.tar.gz differ diff --git a/skproof-package/dist/skproof-0.1.1.tar.gz b/skproof-package/dist/skproof-0.1.1.tar.gz new file mode 100644 index 0000000..53a6d47 Binary files /dev/null and b/skproof-package/dist/skproof-0.1.1.tar.gz differ diff --git a/skproof-package/setup.py b/skproof-package/setup.py index 7b652af..4bd15e0 100644 --- a/skproof-package/setup.py +++ b/skproof-package/setup.py @@ -2,7 +2,7 @@ setup( name='skproof', - version='0.1.0', + version='0.1.1', description='SciKit learn compatible library for generating ZK proofs of execution', url='https://github.com/0x3327/skproof.git', author='Aleksandar Veljković, 3327.io', diff --git a/skproof-package/skproof.egg-info/PKG-INFO b/skproof-package/skproof.egg-info/PKG-INFO index b3bf670..4657e78 100644 --- a/skproof-package/skproof.egg-info/PKG-INFO +++ b/skproof-package/skproof.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: skproof -Version: 0.1.0 +Version: 0.1.1 Summary: SciKit learn compatible library for generating ZK proofs of execution Home-page: https://github.com/0x3327/skproof.git Author: Aleksandar Veljković, 3327.io diff --git a/skproof-package/skproof/mlp/MLPClassifierProver.py b/skproof-package/skproof/mlp/MLPClassifierProver.py index d678431..1907f2d 100644 --- a/skproof-package/skproof/mlp/MLPClassifierProver.py +++ b/skproof-package/skproof/mlp/MLPClassifierProver.py @@ -1,5 +1,5 @@ import numpy as np -from skproof.float_num import FloatNum +from skproof.float_num.FloatNum import FloatNum import math import os diff --git a/src/main.nr b/src/main.nr new file mode 100644 index 0000000..c915747 --- /dev/null +++ b/src/main.nr @@ -0,0 +1,350 @@ +// Struct representing float numbers using sign, mantissa and exponent. +// When Noir language gets the update to support signed integers, the sign field will be removed +struct Float { + sign: Field, + mantissa: Field, + exponent: Field, +} + +// ReLU activation function used for neural network ML models +fn relu(x : Float) -> Float { + let mut res = x; + if x.sign as u64 == 1 { + res = Float { sign: 0, mantissa: 0, exponent: 100 }; + } + + res +} + +// Truncate Float to "precision" number of digits, 5 in the example +fn truncate(num: Float) -> Float { + let lookup : [Field; 25] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + ]; + + let maxValue : Field = 10000000; + let maxLogValue : Field = 7; + let mut decValue : Field = 1; + let mut logValue : Field = 0; + + for i in 0..25 { + if num.mantissa as u64 >= lookup[i] as u64 { + decValue = lookup[i]; + logValue = i; + } + } + + decValue *= 10; + logValue += 1; + + let mut res : Float = Float { sign: num.sign, mantissa: num.mantissa, exponent: num.exponent }; + + if logValue as u64 > maxLogValue as u64 { + let diff = (decValue / maxValue) as u64; + res = Float { sign: num.sign, mantissa: (num.mantissa as u64 / diff) as Field, exponent: num.exponent + (logValue - maxLogValue)}; // + } + + if res.mantissa == 0 { + res = Float { sign: res.sign, mantissa: 0, exponent: 100 }; + } + + res +} +// Multiplication of Float numbers +fn mulFloats(x : Float, y : Float) -> Float { + let mant = x.mantissa * y.mantissa; + let exp = x.exponent + y.exponent - 100; + let mut sign : Field = 0; + + if x.sign != y.sign { + sign = 1; + } + + truncate(Float { sign: sign, mantissa: mant, exponent: exp }) +} + +// Sumation of Float numbers +fn addFloats(x : Float, y : Float) -> Float { + let mut mant_1 = x.mantissa; + let mut mant_2 = y.mantissa; + + let mut exp_1 = x.exponent; + let mut exp_2 = y.exponent; + + let mut diff : Field = 0; + + if exp_1 as u64 > exp_2 as u64 { + diff = exp_1 - exp_2; + } else { + diff = exp_2 - exp_1; + } + + let lookup : [Field; 25] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + ]; + + let mut pow10 : Field = 1; + + for i in 0..25 { + if i == diff { + pow10 = lookup[i]; + } + } + + if x.exponent as u64 < y.exponent as u64 { + mant_2 *= pow10; + exp_1 = x.exponent; + } else { + mant_1 *= pow10; + exp_1 = y.exponent; + } + + let mut sum_mant = mant_1 + mant_2; + let mut sign = x.sign; + + if x.sign != y.sign { + if mant_1 as u64 > mant_2 as u64 { + sum_mant = mant_1 - mant_2; + } else { + sum_mant = mant_2 - mant_1; + sign = y.sign; + } + } + + truncate(Float { sign: sign, mantissa: sum_mant, exponent: exp_1 }) +} + +// Subtraction of float numbers +fn subFloats(x : Float, y : Float) -> Float { + addFloats(x, Float { sign: 1 - y.sign, mantissa: y.mantissa, exponent: y.exponent }) +} +fn main( + x_1 : pub [Field; 3], + x_2 : pub [Field; 3], + x_3 : pub [Field; 3], + x_4 : pub [Field; 3], + y_1 : pub [Field; 3], + y_2 : pub [Field; 3], + y_3 : pub [Field; 3] +) { + let mut X_0_0 = Float{ sign: x_1[0], mantissa: x_1[1], exponent: x_1[2] }; + let mut X_0_1 = Float{ sign: x_2[0], mantissa: x_2[1], exponent: x_2[2] }; + let mut X_0_2 = Float{ sign: x_3[0], mantissa: x_3[1], exponent: x_3[2] }; + let mut X_0_3 = Float{ sign: x_4[0], mantissa: x_4[1], exponent: x_4[2] }; + let A_1_0_0 = mulFloats(X_0_0, Float {sign: 0, mantissa: 1336663, exponent: 93 }); + let A_1_0_1 = mulFloats(X_0_1, Float {sign: 0, mantissa: 7586549, exponent: 93 }); + let A_1_0_2 = mulFloats(X_0_2, Float {sign: 0, mantissa: 1043080, exponent: 94 }); + let A_1_0_3 = mulFloats(X_0_3, Float {sign: 0, mantissa: 1170603, exponent: 94 }); + let B_0_0_0 = addFloats(A_1_0_0, A_1_0_1); + let B_0_0_1 = addFloats(B_0_0_0, A_1_0_2); + let B_0_0_2 = addFloats(B_0_0_1, A_1_0_3); + let WXb_0_0 = addFloats(B_0_0_2, Float {sign: 0, mantissa: 4078493, exponent: 93 }); + let X_1_0 = relu(WXb_0_0); + let A_1_1_0 = mulFloats(X_0_0, Float {sign: 0, mantissa: 7829273, exponent: 93 }); + let A_1_1_1 = mulFloats(X_0_1, Float {sign: 0, mantissa: 1365790, exponent: 94 }); + let A_1_1_2 = mulFloats(X_0_2, Float {sign: 1, mantissa: 1386341, exponent: 94 }); + let A_1_1_3 = mulFloats(X_0_3, Float {sign: 1, mantissa: 1053194, exponent: 94 }); + let B_0_1_0 = addFloats(A_1_1_0, A_1_1_1); + let B_0_1_1 = addFloats(B_0_1_0, A_1_1_2); + let B_0_1_2 = addFloats(B_0_1_1, A_1_1_3); + let WXb_0_1 = addFloats(B_0_1_2, Float {sign: 0, mantissa: 2466248, exponent: 94 }); + let X_1_1 = relu(WXb_0_1); + let A_2_0_0 = mulFloats(X_1_0, Float {sign: 0, mantissa: 1152279, exponent: 94 }); + let A_2_0_1 = mulFloats(X_1_1, Float {sign: 1, mantissa: 9092448, exponent: 93 }); + let B_1_0_0 = addFloats(A_2_0_0, A_2_0_1); + let WXb_1_0 = addFloats(B_1_0_0, Float {sign: 0, mantissa: 1596867, exponent: 93 }); + let X_2_0 = relu(WXb_1_0); + let A_2_1_0 = mulFloats(X_1_0, Float {sign: 0, mantissa: 1101590, exponent: 92 }); + let A_2_1_1 = mulFloats(X_1_1, Float {sign: 0, mantissa: 1627253, exponent: 94 }); + let B_1_1_0 = addFloats(A_2_1_0, A_2_1_1); + let WXb_1_1 = addFloats(B_1_1_0, Float {sign: 0, mantissa: 3013073, exponent: 93 }); + let X_2_1 = relu(WXb_1_1); + let A_2_2_0 = mulFloats(X_1_0, Float {sign: 1, mantissa: 9135291, exponent: 88 }); + let A_2_2_1 = mulFloats(X_1_1, Float {sign: 1, mantissa: 2661389, exponent: 91 }); + let B_1_2_0 = addFloats(A_2_2_0, A_2_2_1); + let WXb_1_2 = addFloats(B_1_2_0, Float {sign: 1, mantissa: 3054311, exponent: 93 }); + let X_2_2 = relu(WXb_1_2); + let A_3_0_0 = mulFloats(X_2_0, Float {sign: 1, mantissa: 6605394, exponent: 93 }); + let A_3_0_1 = mulFloats(X_2_1, Float {sign: 0, mantissa: 3677529, exponent: 93 }); + let A_3_0_2 = mulFloats(X_2_2, Float {sign: 1, mantissa: 14324, exponent: 86 }); + let B_2_0_0 = addFloats(A_3_0_0, A_3_0_1); + let B_2_0_1 = addFloats(B_2_0_0, A_3_0_2); + let WXb_2_0 = addFloats(B_2_0_1, Float {sign: 0, mantissa: 1129286, exponent: 93 }); + let O_3_0 = relu(WXb_2_0); + let A_3_1_0 = mulFloats(X_2_0, Float {sign: 0, mantissa: 4068012, exponent: 93 }); + let A_3_1_1 = mulFloats(X_2_1, Float {sign: 0, mantissa: 3445262, exponent: 92 }); + let A_3_1_2 = mulFloats(X_2_2, Float {sign: 0, mantissa: 3444056, exponent: 90 }); + let B_2_1_0 = addFloats(A_3_1_0, A_3_1_1); + let B_2_1_1 = addFloats(B_2_1_0, A_3_1_2); + let WXb_2_1 = addFloats(B_2_1_1, Float {sign: 1, mantissa: 4435015, exponent: 92 }); + let O_3_1 = relu(WXb_2_1); + let A_3_2_0 = mulFloats(X_2_0, Float {sign: 0, mantissa: 1092163, exponent: 94 }); + let A_3_2_1 = mulFloats(X_2_1, Float {sign: 1, mantissa: 1226450, exponent: 94 }); + let A_3_2_2 = mulFloats(X_2_2, Float {sign: 0, mantissa: 0, exponent: 100 }); + let B_2_2_0 = addFloats(A_3_2_0, A_3_2_1); + let B_2_2_1 = addFloats(B_2_2_0, A_3_2_2); + let WXb_2_2 = addFloats(B_2_2_1, Float {sign: 1, mantissa: 5861191, exponent: 93 }); + let O_3_2 = relu(WXb_2_2); + + constrain O_3_0.sign == y_1[0]; + constrain O_3_0.mantissa == y_1[1]; + constrain O_3_0.exponent == y_1[2]; + constrain O_3_1.sign == y_2[0]; + constrain O_3_1.mantissa == y_2[1]; + constrain O_3_1.exponent == y_2[2]; + constrain O_3_2.sign == y_3[0]; + constrain O_3_2.mantissa == y_3[1]; + constrain O_3_2.exponent == y_3[2]; +} +// A_1_0_0 = 5100000e94 * 1336663e93 = 6816981e93 +// A_1_0_1 = 3500000e94 * 7586549e93 = 2655292e94 +// A_1_0_2 = 1400000e94 * 1043080e94 = 1460312e94 +// A_1_0_3 = 2000000e93 * 1170603e94 = 2341206e93 +// B_0_0_0 = 6816981e93 + 2655292e94 = 3336990e94 +// B_0_0_1 = 3336990e94 + 1460312e94 = 4797302e94 +// B_0_0_2 = 4797302e94 + 2341206e93 = 5031422e94 +// WXb_0_0 = 5031422e94 + 4078493e93 = 5439271e94 +// A_1_1_0 = 5100000e94 * 7829273e93 = 3992929e94 +// A_1_1_1 = 3500000e94 * 1365790e94 = 4780265e94 +// A_1_1_2 = 1400000e94 * -1386341e94 = -1940877e94 +// A_1_1_3 = 2000000e93 * -1053194e94 = -2106388e93 +// B_0_1_0 = 3992929e94 + 4780265e94 = 8773194e94 +// B_0_1_1 = 8773194e94 + -1940877e94 = 6832317e94 +// B_0_1_2 = 6832317e94 + -2106388e93 = 6621678e94 +// WXb_0_1 = 6621678e94 + 2466248e94 = 9087926e94 +// A_2_0_0 = 5439271e94 * 1152279e94 = 6267557e94 +// A_2_0_1 = 9087926e94 * -9092448e93 = -8263149e94 +// B_1_0_0 = 6267557e94 + -8263149e94 = -1995592e94 +// WXb_1_0 = -1995592e94 + 1596867e93 = -1835905e94 +// A_2_1_0 = 5439271e94 * 1101590e92 = 5991846e92 +// A_2_1_1 = 9087926e94 * 1627253e94 = 1478835e95 +// B_1_1_0 = 5991846e92 + 1478835e95 = 1484826e95 +// WXb_1_1 = 1484826e95 + 3013073e93 = 1514956e95 +// A_2_2_0 = 5439271e94 * -9135291e88 = -4968932e89 +// A_2_2_1 = 9087926e94 * -2661389e91 = -2418650e92 +// B_1_2_0 = -4968932e89 + -2418650e92 = -2423618e92 +// WXb_1_2 = -2423618e92 + -3054311e93 = -3296672e93 +// A_3_0_0 = 0e100 * -6605394e93 = 0e100 +// A_3_0_1 = 1514956e95 * 3677529e93 = 5571294e94 +// A_3_0_2 = 0e100 * -14324e86 = 0e100 +// B_2_0_0 = 0e100 + 5571294e94 = 5571294e94 +// B_2_0_1 = 5571294e94 + 0e100 = 5571294e94 +// WXb_2_0 = 5571294e94 + 1129286e93 = 5684222e94 +// A_3_1_0 = 0e100 * 4068012e93 = 0e100 +// A_3_1_1 = 1514956e95 * 3445262e92 = 5219420e93 +// A_3_1_2 = 0e100 * 3444056e90 = 0e100 +// B_2_1_0 = 0e100 + 5219420e93 = 5219420e93 +// B_2_1_1 = 5219420e93 + 0e100 = 5219420e93 +// WXb_2_1 = 5219420e93 + -4435015e92 = 4775918e93 +// A_3_2_0 = 0e100 * 1092163e94 = 0e100 +// A_3_2_1 = 1514956e95 * -1226450e94 = -1858017e95 +// A_3_2_2 = 0e100 * 0e100 = 0e100 +// B_2_2_0 = 0e100 + -1858017e95 = -1858017e95 +// B_2_2_1 = -1858017e95 + 0e100 = -1858017e95 +// WXb_2_2 = -1858017e95 + -5861191e93 = -1916628e95 +// X_0_0 => 5100000e94 +// X_0_1 => 3500000e94 +// X_0_2 => 1400000e94 +// X_0_3 => 2000000e93 +// A_1_0_0 => 6816981e93 +// A_1_0_1 => 2655292e94 +// A_1_0_2 => 1460312e94 +// A_1_0_3 => 2341206e93 +// B_0_0_0 => 3336990e94 +// B_0_0_1 => 4797302e94 +// B_0_0_2 => 5031422e94 +// WXb_0_0 => 5439271e94 +// X_1_0 => 5439271e94 +// A_1_1_0 => 3992929e94 +// A_1_1_1 => 4780265e94 +// A_1_1_2 => -1940877e94 +// A_1_1_3 => -2106388e93 +// B_0_1_0 => 8773194e94 +// B_0_1_1 => 6832317e94 +// B_0_1_2 => 6621678e94 +// WXb_0_1 => 9087926e94 +// X_1_1 => 9087926e94 +// A_2_0_0 => 6267557e94 +// A_2_0_1 => -8263149e94 +// B_1_0_0 => -1995592e94 +// WXb_1_0 => -1835905e94 +// X_2_0 => 0e100 +// A_2_1_0 => 5991846e92 +// A_2_1_1 => 1478835e95 +// B_1_1_0 => 1484826e95 +// WXb_1_1 => 1514956e95 +// X_2_1 => 1514956e95 +// A_2_2_0 => -4968932e89 +// A_2_2_1 => -2418650e92 +// B_1_2_0 => -2423618e92 +// WXb_1_2 => -3296672e93 +// X_2_2 => 0e100 +// A_3_0_0 => 0e100 +// A_3_0_1 => 5571294e94 +// A_3_0_2 => 0e100 +// B_2_0_0 => 5571294e94 +// B_2_0_1 => 5571294e94 +// WXb_2_0 => 5684222e94 +// O_3_0 => 5684222e94 +// A_3_1_0 => 0e100 +// A_3_1_1 => 5219420e93 +// A_3_1_2 => 0e100 +// B_2_1_0 => 5219420e93 +// B_2_1_1 => 5219420e93 +// WXb_2_1 => 4775918e93 +// O_3_1 => 4775918e93 +// A_3_2_0 => 0e100 +// A_3_2_1 => -1858017e95 +// A_3_2_2 => 0e100 +// B_2_2_0 => -1858017e95 +// B_2_2_1 => -1858017e95 +// WXb_2_2 => -1916628e95 +// O_3_2 => 0e100