salsa20.ha (4521B)
1 // License: MPL-2.0 2 // (c) 2021-2022 Armin Preiml <apreiml@strohwolke.at> 3 use bytes; 4 use crypto::cipher; 5 use crypto::math::{rotl32, xor}; 6 use endian; 7 use io; 8 9 // Size of a Salsa key, in bytes. 10 export def KEYSIZE: size = 32; 11 12 // Size of the XSalsa20 nonce, in bytes. 13 export def XNONCESIZE: size = 24; 14 15 // Size of the Salsa20 nonce, in bytes. 16 export def NONCESIZE: size = 8; 17 18 def ROUNDS: size = 20; 19 20 // The block size of the Salsa cipher. 21 export def BLOCKSIZE: size = 64; 22 23 const magic: [4]u32 = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574]; 24 25 export type stream = struct { 26 cipher::xorstream, 27 state: [16]u32, 28 xorbuf: [BLOCKSIZE]u8, 29 xorused: size, 30 rounds: size, 31 }; 32 33 // Create a Salsa20 or XSalsa20 stream. Must be initialized with either 34 // [[salsa20_init]] or [[xsalsa20_init]], and must be closed with [[io::close]] 35 // after use to wipe sensitive data from memory. 36 export fn salsa20() stream = { 37 return stream { 38 keybuf = &keybuf, 39 advance = &advance, 40 finish = &finish, 41 xorused = BLOCKSIZE, 42 rounds = ROUNDS, 43 ... 44 }; 45 }; 46 47 fn init( 48 state: *[16]u32, 49 key: *[KEYSIZE]u8, 50 nonce: *[8]u8, 51 ctr: *[8]u8 52 ) void = { 53 state[0] = magic[0]; 54 state[1] = endian::legetu32(key[0..4]); 55 state[2] = endian::legetu32(key[4..8]); 56 state[3] = endian::legetu32(key[8..12]); 57 state[4] = endian::legetu32(key[12..16]); 58 state[5] = magic[1]; 59 state[6] = endian::legetu32(nonce[0..4]); 60 state[7] = endian::legetu32(nonce[4..8]); 61 state[8] = endian::legetu32(ctr[0..4]); 62 state[9] = endian::legetu32(ctr[4..8]); 63 state[10] = magic[2]; 64 state[11] = endian::legetu32(key[16..20]); 65 state[12] = endian::legetu32(key[20..24]); 66 state[13] = endian::legetu32(key[24..28]); 67 state[14] = endian::legetu32(key[28..32]); 68 state[15] = magic[3]; 69 }; 70 71 // Initialize a Salsa20 stream. 72 export fn salsa20_init( 73 s: *stream, 74 h: io::handle, 75 key: *[KEYSIZE]u8, 76 nonce: *[NONCESIZE]u8, 77 ) void = { 78 let counter: [8]u8 = [0...]; 79 init(&s.state, key, nonce, &counter); 80 s.xorused = BLOCKSIZE; 81 82 cipher::xorstream_init(s, h); 83 }; 84 85 // Initialize an XSalsa20 stream. XSalsa20 differs from Salsa20 via the use of a 86 // larger nonce parameter. 87 export fn xsalsa20_init( 88 s: *stream, 89 h: io::handle, 90 key: *[KEYSIZE]u8, 91 nonce: *[XNONCESIZE]u8 92 ) void = { 93 let state: [16]u32 = [0...]; 94 init(&state, key, nonce[0..8]: *[8]u8, nonce[8..16]: *[8]u8); 95 hblock(state[..], &state, 20); 96 97 let dkey: [32]u8 = [0...]; 98 endian::leputu32(dkey[0..4], state[0]); 99 endian::leputu32(dkey[4..8], state[5]); 100 endian::leputu32(dkey[8..12], state[10]); 101 endian::leputu32(dkey[12..16], state[15]); 102 endian::leputu32(dkey[16..20], state[6]); 103 endian::leputu32(dkey[20..24], state[7]); 104 endian::leputu32(dkey[24..28], state[8]); 105 endian::leputu32(dkey[28..], state[9]); 106 107 salsa20_init(s, h, &dkey, nonce[16..]: *[NONCESIZE]u8); 108 109 bytes::zero((state[..]: *[*]u8)[..64]); 110 bytes::zero(dkey); 111 }; 112 113 // Advances the key stream to "seek" to a future state by 'counter' times 114 // [[BLOCKSIZE]]. 115 export fn setctr(s: *stream, counter: u64) void = { 116 s.state[8] = (counter & 0xFFFFFFFF): u32; 117 s.state[9] = (counter >> 32): u32; 118 s.xorused = BLOCKSIZE; 119 }; 120 121 fn keybuf(s: *cipher::xorstream) []u8 = { 122 let s = s: *stream; 123 if (s.xorused >= BLOCKSIZE) { 124 block((s.xorbuf[..]: *[*]u32)[..16], &s.state, s.rounds); 125 s.state[8] += 1; 126 if (s.state[8] == 0) { 127 s.state[9] += 1; 128 }; 129 s.xorused = 0; 130 }; 131 132 return s.xorbuf[s.xorused..]; 133 }; 134 135 fn advance(s: *cipher::xorstream, n: size) void = { 136 let s = s: *stream; 137 assert(n <= len(s.xorbuf)); 138 s.xorused += n; 139 }; 140 141 fn block(dest: []u32, state: *[16]u32, rounds: size) void = { 142 hblock(dest, state, rounds); 143 144 for (let i = 0z; i < 16; i += 1) { 145 dest[i] += state[i]; 146 }; 147 }; 148 149 fn hblock(dest: []u32, state: *[16]u32, rounds: size) void = { 150 for (let i = 0z; i < 16; i += 1) { 151 dest[i] = state[i]; 152 }; 153 154 for (let i = 0z; i < rounds; i += 2) { 155 qr(&dest[0], &dest[4], &dest[8], &dest[12]); 156 qr(&dest[5], &dest[9], &dest[13], &dest[1]); 157 qr(&dest[10], &dest[14], &dest[2], &dest[6]); 158 qr(&dest[15], &dest[3], &dest[7], &dest[11]); 159 160 qr(&dest[0], &dest[1], &dest[2], &dest[3]); 161 qr(&dest[5], &dest[6], &dest[7], &dest[4]); 162 qr(&dest[10], &dest[11], &dest[8], &dest[9]); 163 qr(&dest[15], &dest[12], &dest[13], &dest[14]); 164 }; 165 }; 166 167 fn qr(a: *u32, b: *u32, c: *u32, d: *u32) void = { 168 *b ^= rotl32(*a + *d, 7); 169 *c ^= rotl32(*b + *a, 9); 170 *d ^= rotl32(*c + *b, 13); 171 *a ^= rotl32(*d + *c, 18); 172 }; 173 174 fn finish(s: *cipher::xorstream) void = { 175 let s = s: *stream; 176 bytes::zero((s.state[..]: *[*]u8)[..len(s.state) * size(u32)]); 177 bytes::zero(s.xorbuf); 178 };