2 Commity

Autor SHA1 Wiadomość Data
  Your Name 44c9eb4785 What if we didn't make a new circuit for every iteration in the test 3 tygodni temu
  Your Name 588bfe15fe 0.04 3 tygodni temu
3 zmienionych plików z 31 dodań i 70 usunięć
  1. 25
    65
      src/circuit.rs
  2. 2
    2
      src/main.rs
  3. 4
    3
      tests/test.rs

+ 25
- 65
src/circuit.rs Wyświetl plik

@@ -2,16 +2,15 @@ use crate::{gate::Gate, gate_type::GateType};
2 2
 
3 3
 #[derive(Debug, Clone)]
4 4
 pub struct Circuit {
5
-    pub input_bits: Vec<bool>, // Input wires
6 5
     pub gates: Vec<Gate>,      // All gates
7 6
 }
8 7
 
9 8
 impl Circuit {
10
-    pub fn eval(self) -> bool {
9
+    pub fn eval(self, input_bits: Vec<bool>,)  -> bool {
11 10
         let mut evaluated_gates = vec![];
12 11
 
13 12
         for gate in self.gates {
14
-            let result = gate.eval(&self.input_bits, &evaluated_gates);
13
+            let result = gate.eval(&input_bits, &evaluated_gates);
15 14
             evaluated_gates.push(result);
16 15
         }
17 16
 
@@ -25,14 +24,7 @@ impl Circuit {
25 24
         This method should create a circuit that outputs 1 if the first number A (encoded in the first n
26 25
         bits) is greater than the second number B (encoded in the next n bits) .
27 26
     */
28
-    pub fn compare_n_bit_numbers(input_bits: Vec<bool>, n: usize) -> Self {
29
-        if input_bits.len() < 2 * n {
30
-            panic!(
31
-                "Expected input_bits to be of at least length {}, but it was {}",
32
-                2 * n,
33
-                input_bits.len()
34
-            )
35
-        }
27
+    pub fn compare_n_bit_numbers( n: usize) -> Self {
36 28
 
37 29
         /*
38 30
             base case n=1: 1-bit:
@@ -77,70 +69,38 @@ impl Circuit {
77 69
         */
78 70
 
79 71
         let gates = create_n_bit_comparator_gates(n);
80
-        return Circuit { input_bits, gates };
72
+        return Circuit { gates };
81 73
     }
82 74
 }
83 75
 
84 76
 fn create_n_bit_comparator_gates(n: usize) -> Vec<Gate> {
85
-    let mut indices: Vec<usize> = vec![0; n];
86
-    let mut all_gates: Vec<Gate> = Vec::with_capacity(1+3*n);
87
-    let mut and_gate_indices: Vec<usize> = vec![0; n];
88
-
89
-    for curr in 0..n {
90
-        // Gate(A_current > B_current)
91
-        let gt_gate = Gate::new(GateType::Bigger, vec![], vec![curr, curr + n]);
92
-        let gt_gate_index = all_gates.len();
93
-        all_gates.push(gt_gate);
94
-
95
-        // print!("( ({}) ", format!("{} > {}", format!("A{}", n-1-curr), format!("B{}", n-1-curr)));
96
-
97
-        let mut current_bit_gate_indices: Vec<usize> = Vec::with_capacity(curr+1);
98
-        current_bit_gate_indices.push(gt_gate_index);
99
-
100
-        // Bit to the left of curr. The one at array-index 0 doesn't have one.
101
-        if curr != 0 {
102
-            // Gate(A_i = B_i)
103
-            let eq_gate = Gate::new(GateType::Equal, vec![], vec![curr - 1, curr - 1 + n]);
104
-            // remember which key-index (see below) this bit belongs to.
105
-            indices[n - curr] = all_gates.len();
106
-            all_gates.push(eq_gate);
107
-        }
77
+    let mut all_gates: Vec<Gate> = Vec::with_capacity(2+3*(n-1));
78
+    let mut eq_gate_indices: Vec<usize> = Vec::with_capacity(n-1);
79
+    let mut or_gate_input_indices: Vec<usize> = Vec::with_capacity(n);
108 80
 
109
-        for i in 0..curr {
110
-            // Translate i to the key'th bit position because the array index does not match the bit index
111
-            // i.e.
112
-            //
113
-            //  let input_bits = vec![
114
-            //         A2     A1    A0    <- A_key
115
-            //          0      1     2    <- i iterates all array indices less than curr
116
-            //        false, false, true,
117
-            //
118
-            //         B2     B1    B0
119
-            //          3      4     5
120
-            //        false, true, false
121
-            //  ];
122
-            let key = n - 1 - i;
123
-
124
-            // print!("&& ({}) ", format!("{} = {}", format!("A{}", key), format!("B{}", key)));
125
-
126
-            // Index of Gate(A_i = B_i). 
127
-            // You could maybe find a solution without the indices vector but it's insanely cheap anyway and maybe even better
128
-            let eq_gate_index = indices[key];
129
-            current_bit_gate_indices.push(eq_gate_index);
130
-        }
81
+    const EMPTY_VEC: Vec<usize> = Vec::new();
131 82
 
132
-        // The AND spanning all gates for this bit
133
-        let and_gate_index = all_gates.len();
134
-        let and_gate = Gate::new(GateType::And, current_bit_gate_indices, vec![]);
83
+    all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![0, n]));
84
+    or_gate_input_indices.push(0);
135 85
 
136
-        // println!(")");
137
-        all_gates.push(and_gate);
138
-        and_gate_indices.push(and_gate_index);
86
+    for curr in 1..n {
87
+        // Gate(A_curr > B_curr)
88
+        let mut and_gate_input_indices = vec!(all_gates.len());
89
+        all_gates.push(Gate::new(GateType::Bigger, EMPTY_VEC, vec![curr, curr + n]));
90
+
91
+        // Gate(A_curr-1 = B_curr-1)
92
+        eq_gate_indices.push(all_gates.len());
93
+        all_gates.push(Gate::new(GateType::Equal, EMPTY_VEC, vec![curr - 1, curr - 1 + n]));
94
+
95
+        and_gate_input_indices.extend(eq_gate_indices.iter());
96
+
97
+        // The AND spanning all gates for this bit
98
+        or_gate_input_indices.push(all_gates.len());
99
+        all_gates.push(Gate::new(GateType::And, and_gate_input_indices, EMPTY_VEC));
139 100
     }
140 101
 
141 102
     // the OR spanning all ANDs
142
-    let or_gate = Gate::new(GateType::Or, and_gate_indices, vec![]);
143
-    all_gates.push(or_gate);
103
+    all_gates.push(Gate::new(GateType::Or, or_gate_input_indices, EMPTY_VEC));
144 104
 
145 105
     return all_gates;
146 106
 }

+ 2
- 2
src/main.rs Wyświetl plik

@@ -33,8 +33,8 @@ fn main() {
33 33
     input_bits.extend_from_slice(&b_bits);
34 34
 
35 35
     // Build and evaluate the comparison circuit
36
-    let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
37
-    let circuit_result = circuit.eval();
36
+    let circuit = Circuit::compare_n_bit_numbers(256);
37
+    let circuit_result = circuit.eval(input_bits);
38 38
 
39 39
     let a_int = BigUint::from_bytes_le(&a.to_bytes());
40 40
     let b_int = BigUint::from_bytes_le(&b.to_bytes());

+ 4
- 3
tests/test.rs Wyświetl plik

@@ -6,6 +6,8 @@ mod tests {
6 6
     use rand::rngs::OsRng;
7 7
     #[test]
8 8
     fn test_scalar_comparison_via_circuit() {
9
+        let circuit = Circuit::compare_n_bit_numbers(256);
10
+
9 11
         for _ in 0..100 {
10 12
             let a = Scalar::random(&mut OsRng);
11 13
             let b = Scalar::random(&mut OsRng);
@@ -19,9 +21,8 @@ mod tests {
19 21
             input_bits.extend_from_slice(&a_bits);
20 22
             input_bits.extend_from_slice(&b_bits);
21 23
 
22
-            // Build and evaluate the comparison circuit
23
-            let circuit = Circuit::compare_n_bit_numbers(input_bits, 256);
24
-            let circuit_result = circuit.eval();
24
+            // Evaluate the comparison circuit
25
+            let circuit_result = circuit.clone().eval(input_bits);
25 26
 
26 27
             // Compare expected result using BigUint
27 28
             let a_int = BigUint::from_bytes_le(&a.to_bytes());

Ładowanie…
Anuluj
Zapisz