Browse Source

What if we didn't make a new circuit for every iteration in the test

master
Your Name 2 weeks ago
parent
commit
44c9eb4785
3 changed files with 21 additions and 32 deletions
  1. 15
    27
      src/circuit.rs
  2. 2
    2
      src/main.rs
  3. 4
    3
      tests/test.rs

+ 15
- 27
src/circuit.rs View File

2
 
2
 
3
 #[derive(Debug, Clone)]
3
 #[derive(Debug, Clone)]
4
 pub struct Circuit {
4
 pub struct Circuit {
5
-    pub input_bits: Vec<bool>, // Input wires
6
     pub gates: Vec<Gate>,      // All gates
5
     pub gates: Vec<Gate>,      // All gates
7
 }
6
 }
8
 
7
 
9
 impl Circuit {
8
 impl Circuit {
10
-    pub fn eval(self) -> bool {
9
+    pub fn eval(self, input_bits: Vec<bool>,)  -> bool {
11
         let mut evaluated_gates = vec![];
10
         let mut evaluated_gates = vec![];
12
 
11
 
13
         for gate in self.gates {
12
         for gate in self.gates {
14
-            let result = gate.eval(&self.input_bits, &evaluated_gates);
13
+            let result = gate.eval(&input_bits, &evaluated_gates);
15
             evaluated_gates.push(result);
14
             evaluated_gates.push(result);
16
         }
15
         }
17
 
16
 
25
         This method should create a circuit that outputs 1 if the first number A (encoded in the first n
24
         This method should create a circuit that outputs 1 if the first number A (encoded in the first n
26
         bits) is greater than the second number B (encoded in the next n bits) .
25
         bits) is greater than the second number B (encoded in the next n bits) .
27
     */
26
     */
28
-    pub fn compare_n_bit_numbers(input_bits: Vec<bool>, n: usize) -> Self {
29
-        if input_bits.len() < 2 * n {
30
-            panic!(
31
-                "Expected input_bits to be of at least length {}, but it was {}",
32
-                2 * n,
33
-                input_bits.len()
34
-            )
35
-        }
27
+    pub fn compare_n_bit_numbers( n: usize) -> Self {
36
 
28
 
37
         /*
29
         /*
38
             base case n=1: 1-bit:
30
             base case n=1: 1-bit:
77
         */
69
         */
78
 
70
 
79
         let gates = create_n_bit_comparator_gates(n);
71
         let gates = create_n_bit_comparator_gates(n);
80
-        return Circuit { input_bits, gates };
72
+        return Circuit { gates };
81
     }
73
     }
82
 }
74
 }
83
 
75
 
84
 fn create_n_bit_comparator_gates(n: usize) -> Vec<Gate> {
76
 fn create_n_bit_comparator_gates(n: usize) -> Vec<Gate> {
85
-    let mut all_gates: Vec<Gate> = Vec::with_capacity(3*n);
86
-    let mut and_gate_indices: Vec<usize> = vec![0; n];
77
+    let mut all_gates: Vec<Gate> = Vec::with_capacity(2+3*(n-1));
87
     let mut eq_gate_indices: Vec<usize> = Vec::with_capacity(n-1);
78
     let mut eq_gate_indices: Vec<usize> = Vec::with_capacity(n-1);
79
+    let mut or_gate_input_indices: Vec<usize> = Vec::with_capacity(n);
88
 
80
 
89
     const EMPTY_VEC: Vec<usize> = Vec::new();
81
     const EMPTY_VEC: Vec<usize> = Vec::new();
90
 
82
 
91
-    for curr in 0..n {
83
+    all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![0, n]));
84
+    or_gate_input_indices.push(0);
85
+
86
+    for curr in 1..n {
92
         // Gate(A_curr > B_curr)
87
         // Gate(A_curr > B_curr)
93
         let mut and_gate_input_indices = vec!(all_gates.len());
88
         let mut and_gate_input_indices = vec!(all_gates.len());
94
-        println!("GT = {}", all_gates.len());
95
         all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![curr, curr + n]));
89
         all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![curr, curr + n]));
96
 
90
 
97
-        // Bit to the left of curr. The one at array-index 0 doesn't have one.
98
-        if curr != 0 {
99
-            // Gate(A_curr-1 = B_curr-1)
100
-            eq_gate_indices.push(all_gates.len());
101
-            println!("EQ = {}", all_gates.len());
102
-            all_gates.push(Gate::new(GateType::Equal, EMPTY_VEC, vec![curr - 1, curr - 1 + n]));
103
-        }
91
+        // Gate(A_curr-1 = B_curr-1)
92
+        eq_gate_indices.push(all_gates.len());
93
+        all_gates.push(Gate::new(GateType::Equal, EMPTY_VEC, vec![curr - 1, curr - 1 + n]));
104
 
94
 
105
         and_gate_input_indices.extend(eq_gate_indices.iter());
95
         and_gate_input_indices.extend(eq_gate_indices.iter());
106
 
96
 
107
         // The AND spanning all gates for this bit
97
         // The AND spanning all gates for this bit
108
-        and_gate_indices.push(all_gates.len());
109
-        println!("&& = {}", all_gates.len());
98
+        or_gate_input_indices.push(all_gates.len());
110
         all_gates.push(Gate::new(GateType::And, and_gate_input_indices, EMPTY_VEC));
99
         all_gates.push(Gate::new(GateType::And, and_gate_input_indices, EMPTY_VEC));
111
     }
100
     }
112
 
101
 
113
     // the OR spanning all ANDs
102
     // the OR spanning all ANDs
114
-    let or_gate = Gate::new(GateType::Or, and_gate_indices, EMPTY_VEC);
115
-    all_gates.push(or_gate);
103
+    all_gates.push(Gate::new(GateType::Or, or_gate_input_indices, EMPTY_VEC));
116
 
104
 
117
     return all_gates;
105
     return all_gates;
118
 }
106
 }

+ 2
- 2
src/main.rs View File

33
     input_bits.extend_from_slice(&b_bits);
33
     input_bits.extend_from_slice(&b_bits);
34
 
34
 
35
     // Build and evaluate the comparison circuit
35
     // Build and evaluate the comparison circuit
36
-    let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
37
-    let circuit_result = circuit.eval();
36
+    let circuit = Circuit::compare_n_bit_numbers(256);
37
+    let circuit_result = circuit.eval(input_bits);
38
 
38
 
39
     let a_int = BigUint::from_bytes_le(&a.to_bytes());
39
     let a_int = BigUint::from_bytes_le(&a.to_bytes());
40
     let b_int = BigUint::from_bytes_le(&b.to_bytes());
40
     let b_int = BigUint::from_bytes_le(&b.to_bytes());

+ 4
- 3
tests/test.rs View File

6
     use rand::rngs::OsRng;
6
     use rand::rngs::OsRng;
7
     #[test]
7
     #[test]
8
     fn test_scalar_comparison_via_circuit() {
8
     fn test_scalar_comparison_via_circuit() {
9
+        let circuit = Circuit::compare_n_bit_numbers(256);
10
+
9
         for _ in 0..100 {
11
         for _ in 0..100 {
10
             let a = Scalar::random(&mut OsRng);
12
             let a = Scalar::random(&mut OsRng);
11
             let b = Scalar::random(&mut OsRng);
13
             let b = Scalar::random(&mut OsRng);
19
             input_bits.extend_from_slice(&a_bits);
21
             input_bits.extend_from_slice(&a_bits);
20
             input_bits.extend_from_slice(&b_bits);
22
             input_bits.extend_from_slice(&b_bits);
21
 
23
 
22
-            // Build and evaluate the comparison circuit
23
-            let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
24
-            let circuit_result = circuit.eval();
24
+            // Evaluate the comparison circuit
25
+            let circuit_result = circuit.clone().eval(input_bits);
25
 
26
 
26
             // Compare expected result using BigUint
27
             // Compare expected result using BigUint
27
             let a_int = BigUint::from_bytes_le(&a.to_bytes());
28
             let a_int = BigUint::from_bytes_le(&a.to_bytes());

Loading…
Cancel
Save