2 Commits

Author SHA1 Message Date
  Your Name 44c9eb4785 What if we didn't make a new circuit for every iteration in the test 3 weeks ago
  Your Name 588bfe15fe 0.04 3 weeks ago
3 changed files with 31 additions and 70 deletions
  1. 25
    65
      src/circuit.rs
  2. 2
    2
      src/main.rs
  3. 4
    3
      tests/test.rs

+ 25
- 65
src/circuit.rs View File

2
 
2
 
3
 #[derive(Debug, Clone)]
3
 #[derive(Debug, Clone)]
4
 pub struct Circuit {
4
 pub struct Circuit {
5
-    pub input_bits: Vec<bool>, // Input wires
6
     pub gates: Vec<Gate>,      // All gates
5
     pub gates: Vec<Gate>,      // All gates
7
 }
6
 }
8
 
7
 
9
 impl Circuit {
8
 impl Circuit {
10
-    pub fn eval(self) -> bool {
9
+    pub fn eval(self, input_bits: Vec<bool>,)  -> bool {
11
         let mut evaluated_gates = vec![];
10
         let mut evaluated_gates = vec![];
12
 
11
 
13
         for gate in self.gates {
12
         for gate in self.gates {
14
-            let result = gate.eval(&self.input_bits, &evaluated_gates);
13
+            let result = gate.eval(&input_bits, &evaluated_gates);
15
             evaluated_gates.push(result);
14
             evaluated_gates.push(result);
16
         }
15
         }
17
 
16
 
25
         This method should create a circuit that outputs 1 if the first number A (encoded in the first n
24
         This method should create a circuit that outputs 1 if the first number A (encoded in the first n
26
         bits) is greater than the second number B (encoded in the next n bits) .
25
         bits) is greater than the second number B (encoded in the next n bits) .
27
     */
26
     */
28
-    pub fn compare_n_bit_numbers(input_bits: Vec<bool>, n: usize) -> Self {
29
-        if input_bits.len() < 2 * n {
30
-            panic!(
31
-                "Expected input_bits to be of at least length {}, but it was {}",
32
-                2 * n,
33
-                input_bits.len()
34
-            )
35
-        }
27
+    pub fn compare_n_bit_numbers( n: usize) -> Self {
36
 
28
 
37
         /*
29
         /*
38
             base case n=1: 1-bit:
30
             base case n=1: 1-bit:
77
         */
69
         */
78
 
70
 
79
         let gates = create_n_bit_comparator_gates(n);
71
         let gates = create_n_bit_comparator_gates(n);
80
-        return Circuit { input_bits, gates };
72
+        return Circuit { gates };
81
     }
73
     }
82
 }
74
 }
83
 
75
 
84
 fn create_n_bit_comparator_gates(n: usize) -> Vec<Gate> {
76
 fn create_n_bit_comparator_gates(n: usize) -> Vec<Gate> {
85
-    let mut indices: Vec<usize> = vec![0; n];
86
-    let mut all_gates: Vec<Gate> = Vec::with_capacity(1+3*n);
87
-    let mut and_gate_indices: Vec<usize> = vec![0; n];
88
-
89
-    for curr in 0..n {
90
-        // Gate(A_current > B_current)
91
-        let gt_gate = Gate::new(GateType::Bigger, vec![], vec![curr, curr + n]);
92
-        let gt_gate_index = all_gates.len();
93
-        all_gates.push(gt_gate);
94
-
95
-        // print!("( ({}) ", format!("{} > {}", format!("A{}", n-1-curr), format!("B{}", n-1-curr)));
96
-
97
-        let mut current_bit_gate_indices: Vec<usize> = Vec::with_capacity(curr+1);
98
-        current_bit_gate_indices.push(gt_gate_index);
99
-
100
-        // Bit to the left of curr. The one at array-index 0 doesn't have one.
101
-        if curr != 0 {
102
-            // Gate(A_i = B_i)
103
-            let eq_gate = Gate::new(GateType::Equal, vec![], vec![curr - 1, curr - 1 + n]);
104
-            // remember which key-index (see below) this bit belongs to.
105
-            indices[n - curr] = all_gates.len();
106
-            all_gates.push(eq_gate);
107
-        }
77
+    let mut all_gates: Vec<Gate> = Vec::with_capacity(2+3*(n-1));
78
+    let mut eq_gate_indices: Vec<usize> = Vec::with_capacity(n-1);
79
+    let mut or_gate_input_indices: Vec<usize> = Vec::with_capacity(n);
108
 
80
 
109
-        for i in 0..curr {
110
-            // Translate i to the key'th bit position because the array index does not match the bit index
111
-            // i.e.
112
-            //
113
-            //  let input_bits = vec![
114
-            //         A2     A1    A0    <- A_key
115
-            //          0      1     2    <- i iterates all array indices less than curr
116
-            //        false, false, true,
117
-            //
118
-            //         B2     B1    B0
119
-            //          3      4     5
120
-            //        false, true, false
121
-            //  ];
122
-            let key = n - 1 - i;
123
-
124
-            // print!("&& ({}) ", format!("{} = {}", format!("A{}", key), format!("B{}", key)));
125
-
126
-            // Index of Gate(A_i = B_i). 
127
-            // You could maybe find a solution without the indices vector but it's insanely cheap anyway and maybe even better
128
-            let eq_gate_index = indices[key];
129
-            current_bit_gate_indices.push(eq_gate_index);
130
-        }
81
+    const EMPTY_VEC: Vec<usize> = Vec::new();
131
 
82
 
132
-        // The AND spanning all gates for this bit
133
-        let and_gate_index = all_gates.len();
134
-        let and_gate = Gate::new(GateType::And, current_bit_gate_indices, vec![]);
83
+    all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![0, n]));
84
+    or_gate_input_indices.push(0);
135
 
85
 
136
-        // println!(")");
137
-        all_gates.push(and_gate);
138
-        and_gate_indices.push(and_gate_index);
86
+    for curr in 1..n {
87
+        // Gate(A_curr > B_curr)
88
+        let mut and_gate_input_indices = vec!(all_gates.len());
89
+        all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![curr, curr + n]));
90
+
91
+        // Gate(A_curr-1 = B_curr-1)
92
+        eq_gate_indices.push(all_gates.len());
93
+        all_gates.push(Gate::new(GateType::Equal, EMPTY_VEC, vec![curr - 1, curr - 1 + n]));
94
+
95
+        and_gate_input_indices.extend(eq_gate_indices.iter());
96
+
97
+        // The AND spanning all gates for this bit
98
+        or_gate_input_indices.push(all_gates.len());
99
+        all_gates.push(Gate::new(GateType::And, and_gate_input_indices, EMPTY_VEC));
139
     }
100
     }
140
 
101
 
141
     // the OR spanning all ANDs
102
     // the OR spanning all ANDs
142
-    let or_gate = Gate::new(GateType::Or, and_gate_indices, vec![]);
143
-    all_gates.push(or_gate);
103
+    all_gates.push(Gate::new(GateType::Or, or_gate_input_indices, EMPTY_VEC));
144
 
104
 
145
     return all_gates;
105
     return all_gates;
146
 }
106
 }

+ 2
- 2
src/main.rs View File

33
     input_bits.extend_from_slice(&b_bits);
33
     input_bits.extend_from_slice(&b_bits);
34
 
34
 
35
     // Build and evaluate the comparison circuit
35
     // Build and evaluate the comparison circuit
36
-    let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
37
-    let circuit_result = circuit.eval();
36
+    let circuit = Circuit::compare_n_bit_numbers(256);
37
+    let circuit_result = circuit.eval(input_bits);
38
 
38
 
39
     let a_int = BigUint::from_bytes_le(&a.to_bytes());
39
     let a_int = BigUint::from_bytes_le(&a.to_bytes());
40
     let b_int = BigUint::from_bytes_le(&b.to_bytes());
40
     let b_int = BigUint::from_bytes_le(&b.to_bytes());

+ 4
- 3
tests/test.rs View File

6
     use rand::rngs::OsRng;
6
     use rand::rngs::OsRng;
7
     #[test]
7
     #[test]
8
     fn test_scalar_comparison_via_circuit() {
8
     fn test_scalar_comparison_via_circuit() {
9
+        let circuit = Circuit::compare_n_bit_numbers(256);
10
+
9
         for _ in 0..100 {
11
         for _ in 0..100 {
10
             let a = Scalar::random(&mut OsRng);
12
             let a = Scalar::random(&mut OsRng);
11
             let b = Scalar::random(&mut OsRng);
13
             let b = Scalar::random(&mut OsRng);
19
             input_bits.extend_from_slice(&a_bits);
21
             input_bits.extend_from_slice(&a_bits);
20
             input_bits.extend_from_slice(&b_bits);
22
             input_bits.extend_from_slice(&b_bits);
21
 
23
 
22
-            // Build and evaluate the comparison circuit
23
-            let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
24
-            let circuit_result = circuit.eval();
24
+            // Evaluate the comparison circuit
25
+            let circuit_result = circuit.clone().eval(input_bits);
25
 
26
 
26
             // Compare expected result using BigUint
27
             // Compare expected result using BigUint
27
             let a_int = BigUint::from_bytes_le(&a.to_bytes());
28
             let a_int = BigUint::from_bytes_le(&a.to_bytes());

Loading…
Cancel
Save