diff --git a/src/keccak.rs b/src/keccak.rs index e90339c..cc2327e 100644 --- a/src/keccak.rs +++ b/src/keccak.rs @@ -108,6 +108,83 @@ fn keccakf(st: &mut [Chunk], rounds: usize) } } +fn temporary_shim(state: &mut [Byte]) { + assert_eq!(state.len(), 200); + + let mut chunks = Vec::with_capacity(25); + for i in 0..25 { + let i = i * 8; + + chunks.push(Chunk::from(&state[i..i+8])); + } + + keccakf(&mut chunks, 24); + + for (i, bit) in chunks.iter().flat_map(|c| c.bits.iter()).enumerate() { + state[i / 8].bits[i % 8] = bit.clone(); + } +} + +fn sha3_256(message: &[Byte]) -> Vec { + keccak(1088, 512, message, 0x06, 32) +} + +fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8, mut mdlen: usize) + -> Vec +{ + use std::cmp::min; + + let mut st: Vec = Some(Bit::byte(0)).into_iter().cycle().take(200).collect(); + + let rateInBytes = rate / 8; + let mut inputByteLen = input.len(); + let mut blockSize = 0; + + if ((rate + capacity) != 1600) || ((rate % 8) != 0) { + panic!("invalid parameters"); + } + + while inputByteLen > 0 { + blockSize = min(inputByteLen, rateInBytes); + + for i in 0..blockSize { + st[i] = st[i].xor(&input[i]); + } + + input = &input[blockSize..]; + inputByteLen -= blockSize; + + if blockSize == rateInBytes { + temporary_shim(&mut st); + blockSize = 0; + } + } + + st[blockSize] = st[blockSize].xor(&Bit::byte(delimited_suffix)); + + if ((delimited_suffix & 0x80) != 0) && (blockSize == (rateInBytes-1)) { + temporary_shim(&mut st); + } + + st[rateInBytes-1] = st[rateInBytes-1].xor(&Bit::byte(0x80)); + + temporary_shim(&mut st); + + let mut output = Vec::with_capacity(mdlen); + + while mdlen > 0 { + blockSize = min(mdlen, rateInBytes); + output.extend_from_slice(&st[0..blockSize]); + mdlen -= blockSize; + + if mdlen > 0 { + temporary_shim(&mut st); + } + } + + output +} + fn keccak256(input: &[Byte]) -> Vec { assert_eq!(input.len(), 144); @@ -226,10 +303,22 @@ enum Bit { Constant(u8) } +#[derive(Clone, Debug, PartialEq)] struct Byte { bits: Vec } +impl Byte { + fn xor(&self, other: &Byte) -> Byte { + Byte { + bits: self.bits.iter() + .zip(other.bits.iter()) + .map(|(a, b)| a.xor(b)) + .collect() + } + } +} + impl Bit { fn byte(byte: u8) -> Byte { Byte { @@ -267,6 +356,44 @@ impl Bit { } } +#[test] +fn test_shim() { + let mut chunks: Vec<_> = (0..25).map(|_| Chunk::from(0xABCDEF0123456789)).collect(); + keccakf(&mut chunks, 24); + + let mut bytes: Vec = (0..200).map(|i| { + match i % 8 { + 0 => Bit::byte(0xAB), + 1 => Bit::byte(0xCD), + 2 => Bit::byte(0xEF), + 3 => Bit::byte(0x01), + 4 => Bit::byte(0x23), + 5 => Bit::byte(0x45), + 6 => Bit::byte(0x67), + 7 => Bit::byte(0x89), + _ => unreachable!() + } + }).collect(); + + temporary_shim(&mut bytes); + + for (i, bit) in bytes.iter().flat_map(|c| c.bits.iter()).enumerate() { + //println!("i = {}", i); + if &chunks[i / 64].bits[i % 64] != bit { + panic!("fuck."); + } + } +} + +#[test] +fn woohoo() { + let message = [Bit::byte(0x30)]; + let test = sha3_256(&message); + + println!("{:?}", test); + assert!(test[0] == Bit::byte(0xf9)); +} + #[test] fn testsha3() { let bb = |x: usize| {