Clean up implementation to use vectors. We'll use arrays when Rust doesn't suck at them.

This commit is contained in:
Sean Bowe 2015-12-27 19:33:36 -07:00
parent 5fa9d9f438
commit ca289581a8
2 changed files with 122 additions and 248 deletions

View File

@ -24,8 +24,9 @@ const keccakf_piln: [usize; 24] =
15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1
]; ];
fn keccakf(st: &mut [Chunk; 25], rounds: usize) fn keccakf(st: &mut [Chunk], rounds: usize)
{ {
assert_eq!(st.len(), 25);
for round in 0..rounds { for round in 0..rounds {
/* /*
// Theta // Theta
@ -33,22 +34,12 @@ fn keccakf(st: &mut [Chunk; 25], rounds: usize)
bc[i] = st[i] ^ st[i + 5] ^ st[i + 10] ^ st[i + 15] ^ st[i + 20]; bc[i] = st[i] ^ st[i + 5] ^ st[i + 10] ^ st[i + 15] ^ st[i + 20];
*/ */
// TODO: Rust arrays don't implement FromIterator. let mut bc: Vec<Chunk> = (0..5).map(|i| st[i]
let mut bc: [Option<Chunk>; 5] = [None, None, None, None, None];
for i in 0..5 {
bc[i] = Some(st[i]
.xor(&st[i+5]) .xor(&st[i+5])
.xor(&st[i+10]) .xor(&st[i+10])
.xor(&st[i+15]) .xor(&st[i+15])
.xor(&st[i+20])); .xor(&st[i+20])
} ).collect();
let mut bc: [Chunk; 5] = [bc[0].take().unwrap(),
bc[1].take().unwrap(),
bc[2].take().unwrap(),
bc[3].take().unwrap(),
bc[4].take().unwrap()];
/* /*
for (i = 0; i < 5; i++) { for (i = 0; i < 5; i++) {
@ -119,17 +110,10 @@ fn keccakf(st: &mut [Chunk; 25], rounds: usize)
} }
} }
// TODO: don't return a vec. currently don't have any fn keccak256(mut input: &[Byte]) -> Vec<Bit> {
// more patience for rust's awful arrays
fn keccak256(mut input: &[[Bit; 8]]) -> Vec<Bit> {
assert_eq!(input.len(), 144); assert_eq!(input.len(), 144);
let mut st: [Chunk; 25] = [Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0), let mut st: Vec<Chunk> = Some(Chunk::from(0)).into_iter().cycle().take(25).collect();
Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0),
Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0),
Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0),
Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0), Chunk::from(0)
];
let mdlen = 32; // 256 bit let mdlen = 32; // 256 bit
let rsiz = 200 - 2 * mdlen; let rsiz = 200 - 2 * mdlen;
@ -145,7 +129,11 @@ fn keccak256(mut input: &[[Bit; 8]]) -> Vec<Bit> {
let mut v = vec![]; let mut v = vec![];
for i in 0..4 { for i in 0..4 {
// due to endianness... // due to endianness...
let tmp: Vec<_> = st[i].bits.chunks(8).rev().flat_map(|x| x.iter()).map(|x| x.clone()).collect(); let tmp: Vec<_> = st[i].bits.chunks(8)
.rev()
.flat_map(|x| x.iter())
.map(|x| x.clone())
.collect();
v.extend_from_slice(&tmp); v.extend_from_slice(&tmp);
} }
@ -154,54 +142,39 @@ fn keccak256(mut input: &[[Bit; 8]]) -> Vec<Bit> {
v v
} }
#[derive(Clone)]
struct Chunk { struct Chunk {
bits: [Bit; 64] bits: Vec<Bit>
}
impl Clone for Chunk {
fn clone(&self) -> Chunk {
let mut new_chunk = Chunk::from(0);
for i in 0..64 {
new_chunk.bits[i] = self.bits[i].clone();
}
new_chunk
}
} }
impl Chunk { impl Chunk {
fn xor(&self, other: &Chunk) -> Chunk { fn xor(&self, other: &Chunk) -> Chunk {
let mut new_chunk = Chunk::from(0); Chunk {
bits: self.bits.iter()
for i in 0..64 { .zip(other.bits.iter())
new_chunk.bits[i] = self.bits[i].xor(&other.bits[i]); .map(|(a, b)| a.xor(b))
.collect()
} }
new_chunk
} }
fn notand(&self, other: &Chunk) -> Chunk { fn notand(&self, other: &Chunk) -> Chunk {
let mut new_chunk = Chunk::from(0); Chunk {
bits: self.bits.iter()
for i in 0..64 { .zip(other.bits.iter())
new_chunk.bits[i] = self.bits[i].notand(&other.bits[i]); .map(|(a, b)| a.notand(b))
.collect()
}
} }
new_chunk fn rotl(&self, mut by: usize) -> Chunk {
by = by % 64;
Chunk {
bits: self.bits[by..].iter()
.chain(self.bits[0..by].iter())
.cloned()
.collect()
} }
fn rotl(&self, by: usize) -> Chunk {
assert!(by < 64);
let mut new_bits = vec![];
new_bits.extend_from_slice(&self.bits[by..]);
new_bits.extend_from_slice(&self.bits[0..by]);
let mut clone = self.clone();
clone.bits.clone_from_slice(&new_bits);
clone
} }
} }
@ -215,113 +188,37 @@ impl PartialEq for Chunk {
} }
} }
impl<'a> From<&'a [[Bit; 8]]> for Chunk { impl<'a> From<&'a [Byte]> for Chunk {
fn from(bytes: &'a [[Bit; 8]]) -> Chunk { fn from(bytes: &'a [Byte]) -> Chunk {
assert!(bytes.len() == 8); // must be 64 bit assert!(bytes.len() == 8); // must be 64 bit
let mut new_chunk = Chunk::from(0); Chunk {
bits: bytes.iter().rev() // endianness
for (i, byte) in bytes.iter().rev().enumerate() { .flat_map(|x| x.bits.iter())
for (j, bit) in byte.iter().enumerate() { .cloned()
new_chunk.bits[i*8 + j] = bit.clone(); .collect()
} }
} }
new_chunk
}
} }
impl<'a> From<&'a [Bit]> for Chunk { impl<'a> From<&'a [Bit]> for Chunk {
fn from(bits: &'a [Bit]) -> Chunk { fn from(bits: &'a [Bit]) -> Chunk {
assert!(bits.len() == 64); // must be 64 bit assert!(bits.len() == 64); // must be 64 bit
let mut new_chunk = Chunk::from(0); Chunk {
bits: bits.iter().cloned().collect()
for (i, bit) in bits.iter().enumerate() {
new_chunk.bits[i] = bit.clone();
} }
new_chunk
} }
} }
impl From<u64> for Chunk { impl From<u64> for Chunk {
fn from(num: u64) -> Chunk { fn from(num: u64) -> Chunk {
use std::mem;
fn bit_at(num: u64, i: usize) -> u8 { fn bit_at(num: u64, i: usize) -> u8 {
((num << i) >> 63) as u8 ((num << i) >> 63) as u8
} }
// TODO: initialize this with unsafe { }
// sadly... GET INTEGER GENERICS WORKING RUST
Chunk { Chunk {
bits: [ bits: (0..64).map(|i| Bit::constant(bit_at(num, i))).collect()
Bit::constant(bit_at(num, 0)),
Bit::constant(bit_at(num, 1)),
Bit::constant(bit_at(num, 2)),
Bit::constant(bit_at(num, 3)),
Bit::constant(bit_at(num, 4)),
Bit::constant(bit_at(num, 5)),
Bit::constant(bit_at(num, 6)),
Bit::constant(bit_at(num, 7)),
Bit::constant(bit_at(num, 8)),
Bit::constant(bit_at(num, 9)),
Bit::constant(bit_at(num, 10)),
Bit::constant(bit_at(num, 11)),
Bit::constant(bit_at(num, 12)),
Bit::constant(bit_at(num, 13)),
Bit::constant(bit_at(num, 14)),
Bit::constant(bit_at(num, 15)),
Bit::constant(bit_at(num, 16)),
Bit::constant(bit_at(num, 17)),
Bit::constant(bit_at(num, 18)),
Bit::constant(bit_at(num, 19)),
Bit::constant(bit_at(num, 20)),
Bit::constant(bit_at(num, 21)),
Bit::constant(bit_at(num, 22)),
Bit::constant(bit_at(num, 23)),
Bit::constant(bit_at(num, 24)),
Bit::constant(bit_at(num, 25)),
Bit::constant(bit_at(num, 26)),
Bit::constant(bit_at(num, 27)),
Bit::constant(bit_at(num, 28)),
Bit::constant(bit_at(num, 29)),
Bit::constant(bit_at(num, 30)),
Bit::constant(bit_at(num, 31)),
Bit::constant(bit_at(num, 32)),
Bit::constant(bit_at(num, 33)),
Bit::constant(bit_at(num, 34)),
Bit::constant(bit_at(num, 35)),
Bit::constant(bit_at(num, 36)),
Bit::constant(bit_at(num, 37)),
Bit::constant(bit_at(num, 38)),
Bit::constant(bit_at(num, 39)),
Bit::constant(bit_at(num, 40)),
Bit::constant(bit_at(num, 41)),
Bit::constant(bit_at(num, 42)),
Bit::constant(bit_at(num, 43)),
Bit::constant(bit_at(num, 44)),
Bit::constant(bit_at(num, 45)),
Bit::constant(bit_at(num, 46)),
Bit::constant(bit_at(num, 47)),
Bit::constant(bit_at(num, 48)),
Bit::constant(bit_at(num, 49)),
Bit::constant(bit_at(num, 50)),
Bit::constant(bit_at(num, 51)),
Bit::constant(bit_at(num, 52)),
Bit::constant(bit_at(num, 53)),
Bit::constant(bit_at(num, 54)),
Bit::constant(bit_at(num, 55)),
Bit::constant(bit_at(num, 56)),
Bit::constant(bit_at(num, 57)),
Bit::constant(bit_at(num, 58)),
Bit::constant(bit_at(num, 59)),
Bit::constant(bit_at(num, 60)),
Bit::constant(bit_at(num, 61)),
Bit::constant(bit_at(num, 62)),
Bit::constant(bit_at(num, 63))
]
} }
} }
} }
@ -331,18 +228,18 @@ enum Bit {
Constant(u8) Constant(u8)
} }
struct Byte {
bits: Vec<Bit>
}
impl Bit { impl Bit {
fn byte(byte: u8) -> [Bit; 8] { fn byte(byte: u8) -> Byte {
[ Byte {
Bit::constant({if byte & 0b10000000 != 0 { 1 } else { 0 }}), bits: (0..8).map(|i| byte & (0b00000001 << i) != 0)
Bit::constant({if byte & 0b01000000 != 0 { 1 } else { 0 }}), .map(|b| Bit::constant(if b { 1 } else { 0 }))
Bit::constant({if byte & 0b00100000 != 0 { 1 } else { 0 }}), .rev()
Bit::constant({if byte & 0b00010000 != 0 { 1 } else { 0 }}), .collect()
Bit::constant({if byte & 0b00001000 != 0 { 1 } else { 0 }}), }
Bit::constant({if byte & 0b00000100 != 0 { 1 } else { 0 }}),
Bit::constant({if byte & 0b00000010 != 0 { 1 } else { 0 }}),
Bit::constant({if byte & 0b00000001 != 0 { 1 } else { 0 }}),
]
} }
fn constant(num: u8) -> Bit { fn constant(num: u8) -> Bit {
@ -385,7 +282,7 @@ fn testsha3() {
} }
}; };
let msg: Vec<_> = (0..144).map(bb).collect(); let msg: Vec<Byte> = (0..144).map(bb).collect();
let result = keccak256(&msg); let result = keccak256(&msg);
@ -407,66 +304,40 @@ fn testsha3() {
#[test] #[test]
fn testff() { fn testff() {
let mut a: [Chunk; 25] = let base = Chunk::from(0xABCDEF0123456789);
[
Chunk::from(0xABCDEF0123456789), let mut a: Vec<Chunk> = (0..25).map(|i| base.rotl(i*4)).collect();
Chunk::from(0x9ABCDEF012345678),
Chunk::from(0x89ABCDEF01234567),
Chunk::from(0x789ABCDEF0123456),
Chunk::from(0x6789ABCDEF012345),
Chunk::from(0x56789ABCDEF01234),
Chunk::from(0x456789ABCDEF0123),
Chunk::from(0x3456789ABCDEF012),
Chunk::from(0x23456789ABCDEF01),
Chunk::from(0x123456789ABCDEF0),
Chunk::from(0x0123456789ABCDEF),
Chunk::from(0xF0123456789ABCDE),
Chunk::from(0xEF0123456789ABCD),
Chunk::from(0xDEF0123456789ABC),
Chunk::from(0xCDEF0123456789AB),
Chunk::from(0xBCDEF0123456789A),
Chunk::from(0xABCDEF0123456789),
Chunk::from(0x9ABCDEF012345678),
Chunk::from(0x89ABCDEF01234567),
Chunk::from(0x789ABCDEF0123456),
Chunk::from(0x6789ABCDEF012345),
Chunk::from(0x56789ABCDEF01234),
Chunk::from(0x456789ABCDEF0123),
Chunk::from(0x3456789ABCDEF012),
Chunk::from(0x23456789ABCDEF01)
];
keccakf(&mut a, 24); keccakf(&mut a, 24);
/* const TEST_VECTOR: [u64; 25] = [
ebf3844f878a7d3b 0x4c8948fcb6616044,
4c9a23df85c470ef 0x75642a21f8bd1299,
4c2e69353217ca2b 0xb2e949825ace668e,
a3ffa213a668ba9d 0x9b73a04c53826c35,
34082fa7dc4c944b 0x914989b8d38ea4d1,
b8bd0a4331665932 0xdc73480ade4e2664,
bfcee841052def2d 0x931394137c6fbd69,
09e2f6993a65ac0b 0x234fa173896019f5,
ec78b15ef42a11e6 0x906da29a7796b157,
5088c480e6a77eb8 0x7666ebe222445610,
9c1ff840c7758823 0x41d77796738c884e,
df8f367ad977a6b1 0x8861db16234437fa,
517b9c3505b4195a 0xf07cb925b71f27f2,
04624d3094c46c2c 0xfec25b4810a2202c,
e71674d1b70748e2 0xa8ba9bbfa9076b54,
6739a678e25ae9f4 0x18d9b9e748d655b9,
2e64f74a9528d091 0xa2172c0059955be6,
9c17a1105709cbfe 0xea602c863b7947b8,
54678a20a3ac5925 0xc77f9f23851bc2bd,
0297df877fa4a559 0x0e8ab0a29b3fef79,
f55ec61b328a5cc5 0xfd73c2cd3b443de4,
56637274c0f2c301 0x447892bf2c03c2ef,
33943408ffd9b9c5 0xd5b3dae382c238b1,
f4b87c711ed56d77 0x2103d8a64e9f4cb6,
3300e5d2414b6a93 0xfe1f57d88e2de92f
*/ ];
assert!(a[0] == Chunk::from(0xebf3844f878a7d3b));
assert!(a[1] == Chunk::from(0x4c9a23df85c470ef)); for i in 0..25 {
assert!(a[2] == Chunk::from(0x4c2e69353217ca2b)); assert!(a[i] == Chunk::from(TEST_VECTOR[i]));
assert!(a[3] == Chunk::from(0xa3ffa213a668ba9d)); }
assert!(a[24] == Chunk::from(0x3300e5d2414b6a93));
} }

View File

@ -4,6 +4,7 @@
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include "keccak.h" #include "keccak.h"
#include <inttypes.h>
// test each length // test each length
@ -79,7 +80,7 @@ int main(int argc, char **argv)
return 0; return 0;
*/ */
/*
uint8_t in[144] = { uint8_t in[144] = {
0xBB, 0x3B, 0x1B, 0x0B, 0xFF, 0xBB, 0x3B, 0x1B, 0x0B, 0xFF,
0xBB, 0x3B, 0x1B, 0x0B, 0xFF, 0xBB, 0x3B, 0x1B, 0x0B, 0xFF,
@ -123,40 +124,42 @@ int main(int argc, char **argv)
for (int i = 0; i < 32; i++) { for (int i = 0; i < 32; i++) {
printf("%02x ", md[i]); printf("%02x ", md[i]);
} }
*/
/*
uint64_t st[25] = { uint64_t st[25] = {
0xabcdef0123456789, 0xabcdef0123456789,
0x9abcdef012345678,
0x89abcdef01234567,
0x789abcdef0123456,
0x6789abcdef012345,
0x56789abcdef01234,
0x456789abcdef0123,
0x3456789abcdef012,
0x23456789abcdef01,
0x123456789abcdef0,
0x0123456789abcdef,
0xf0123456789abcde,
0xef0123456789abcd,
0xdef0123456789abc,
0xcdef0123456789ab,
0xbcdef0123456789a, 0xbcdef0123456789a,
0xabcdef0123456789, 0xcdef0123456789ab,
0x9abcdef012345678, 0xdef0123456789abc,
0x89abcdef01234567, 0xef0123456789abcd,
0x789abcdef0123456,
0x6789abcdef012345, 0xf0123456789abcde,
0x56789abcdef01234, 0x0123456789abcdef,
0x456789abcdef0123, 0x123456789abcdef0,
0x23456789abcdef01,
0x3456789abcdef012, 0x3456789abcdef012,
0x23456789abcdef01
0x456789abcdef0123,
0x56789abcdef01234,
0x6789abcdef012345,
0x789abcdef0123456,
0x89abcdef01234567,
0x9abcdef012345678,
0xabcdef0123456789,
0xbcdef0123456789a,
0xcdef0123456789ab,
0xdef0123456789abc,
0xef0123456789abcd,
0xf0123456789abcde,
0x0123456789abcdef,
0x123456789abcdef0,
0x23456789abcdef01,
}; };
keccakf(st, 24); keccakf(st, 24);
if(st[24] == 0x3300e5d2414b6a93) { for (int i = 0; i < 25; i++) {
printf("wow good job\n"); printf("%" PRIx64 "\n", st[i]);
} }
*/
} }