Преглед изворни кода

What if we didn't make a new circuit for every iteration in the test

master
Your Name пре 2 недеља
родитељ
комит
44c9eb4785
3 измењених фајлова са 21 додато и 32 уклоњено
  1. 15
    27
      src/circuit.rs
  2. 2
    2
      src/main.rs
  3. 4
    3
      tests/test.rs

+ 15
- 27
src/circuit.rs Прегледај датотеку

@@ -2,16 +2,15 @@ use crate::{gate::Gate, gate_type::GateType};
2 2
 
3 3
 #[derive(Debug, Clone)]
4 4
 pub struct Circuit {
5
-    pub input_bits: Vec<bool>, // Input wires
6 5
     pub gates: Vec<Gate>,      // All gates
7 6
 }
8 7
 
9 8
 impl Circuit {
10
-    pub fn eval(self) -> bool {
9
+    pub fn eval(self, input_bits: Vec<bool>,)  -> bool {
11 10
         let mut evaluated_gates = vec![];
12 11
 
13 12
         for gate in self.gates {
14
-            let result = gate.eval(&self.input_bits, &evaluated_gates);
13
+            let result = gate.eval(&input_bits, &evaluated_gates);
15 14
             evaluated_gates.push(result);
16 15
         }
17 16
 
@@ -25,14 +24,7 @@ impl Circuit {
25 24
         This method should create a circuit that outputs 1 if the first number A (encoded in the first n
26 25
         bits) is greater than the second number B (encoded in the next n bits) .
27 26
     */
28
-    pub fn compare_n_bit_numbers(input_bits: Vec<bool>, n: usize) -> Self {
29
-        if input_bits.len() < 2 * n {
30
-            panic!(
31
-                "Expected input_bits to be of at least length {}, but it was {}",
32
-                2 * n,
33
-                input_bits.len()
34
-            )
35
-        }
27
+    pub fn compare_n_bit_numbers( n: usize) -> Self {
36 28
 
37 29
         /*
38 30
             base case n=1: 1-bit:
@@ -77,42 +69,38 @@ impl Circuit {
77 69
         */
78 70
 
79 71
         let gates = create_n_bit_comparator_gates(n);
80
-        return Circuit { input_bits, gates };
72
+        return Circuit { gates };
81 73
     }
82 74
 }
83 75
 
84 76
 fn create_n_bit_comparator_gates(n: usize) -> Vec<Gate> {
85
-    let mut all_gates: Vec<Gate> = Vec::with_capacity(3*n);
86
-    let mut and_gate_indices: Vec<usize> = vec![0; n];
77
+    let mut all_gates: Vec<Gate> = Vec::with_capacity(2+3*(n-1));
87 78
     let mut eq_gate_indices: Vec<usize> = Vec::with_capacity(n-1);
79
+    let mut or_gate_input_indices: Vec<usize> = Vec::with_capacity(n);
88 80
 
89 81
     const EMPTY_VEC: Vec<usize> = Vec::new();
90 82
 
91
-    for curr in 0..n {
83
+    all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![0, n]));
84
+    or_gate_input_indices.push(0);
85
+
86
+    for curr in 1..n {
92 87
         // Gate(A_curr > B_curr)
93 88
         let mut and_gate_input_indices = vec!(all_gates.len());
94
-        println!("GT = {}", all_gates.len());
95 89
         all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![curr, curr + n]));
96 90
 
97
-        // Bit to the left of curr. The one at array-index 0 doesn't have one.
98
-        if curr != 0 {
99
-            // Gate(A_curr-1 = B_curr-1)
100
-            eq_gate_indices.push(all_gates.len());
101
-            println!("EQ = {}", all_gates.len());
102
-            all_gates.push(Gate::new(GateType::Equal, EMPTY_VEC, vec![curr - 1, curr - 1 + n]));
103
-        }
91
+        // Gate(A_curr-1 = B_curr-1)
92
+        eq_gate_indices.push(all_gates.len());
93
+        all_gates.push(Gate::new(GateType::Equal, EMPTY_VEC, vec![curr - 1, curr - 1 + n]));
104 94
 
105 95
         and_gate_input_indices.extend(eq_gate_indices.iter());
106 96
 
107 97
         // The AND spanning all gates for this bit
108
-        and_gate_indices.push(all_gates.len());
109
-        println!("&& = {}", all_gates.len());
98
+        or_gate_input_indices.push(all_gates.len());
110 99
         all_gates.push(Gate::new(GateType::And, and_gate_input_indices, EMPTY_VEC));
111 100
     }
112 101
 
113 102
     // the OR spanning all ANDs
114
-    let or_gate = Gate::new(GateType::Or, and_gate_indices, EMPTY_VEC);
115
-    all_gates.push(or_gate);
103
+    all_gates.push(Gate::new(GateType::Or, or_gate_input_indices, EMPTY_VEC));
116 104
 
117 105
     return all_gates;
118 106
 }

+ 2
- 2
src/main.rs Прегледај датотеку

@@ -33,8 +33,8 @@ fn main() {
33 33
     input_bits.extend_from_slice(&b_bits);
34 34
 
35 35
     // Build and evaluate the comparison circuit
36
-    let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
37
-    let circuit_result = circuit.eval();
36
+    let circuit = Circuit::compare_n_bit_numbers(256);
37
+    let circuit_result = circuit.eval(input_bits);
38 38
 
39 39
     let a_int = BigUint::from_bytes_le(&a.to_bytes());
40 40
     let b_int = BigUint::from_bytes_le(&b.to_bytes());

+ 4
- 3
tests/test.rs Прегледај датотеку

@@ -6,6 +6,8 @@ mod tests {
6 6
     use rand::rngs::OsRng;
7 7
     #[test]
8 8
     fn test_scalar_comparison_via_circuit() {
9
+        let circuit = Circuit::compare_n_bit_numbers(256);
10
+
9 11
         for _ in 0..100 {
10 12
             let a = Scalar::random(&mut OsRng);
11 13
             let b = Scalar::random(&mut OsRng);
@@ -19,9 +21,8 @@ mod tests {
19 21
             input_bits.extend_from_slice(&a_bits);
20 22
             input_bits.extend_from_slice(&b_bits);
21 23
 
22
-            // Build and evaluate the comparison circuit
23
-            let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
24
-            let circuit_result = circuit.eval();
24
+            // Evaluate the comparison circuit
25
+            let circuit_result = circuit.clone().eval(input_bits);
25 26
 
26 27
             // Compare expected result using BigUint
27 28
             let a_int = BigUint::from_bytes_le(&a.to_bytes());

Loading…
Откажи
Сачувај