hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

salsa20.ha (5007B)


      1 // License: MPL-2.0
      2 // (c) 2021-2022 Armin Preiml <apreiml@strohwolke.at>
      3 use bytes;
      4 use crypto::cipher;
      5 use crypto::math::{rotl32, xor};
      6 use endian;
      7 use io;
      8 
      9 // Size of a Salsa key, in bytes.
     10 export def KEYSIZE: size = 32;
     11 
     12 // Size of the XSalsa20 nonce, in bytes.
     13 export def XNONCESIZE: size = 24;
     14 
     15 // Size of the Salsa20 nonce, in bytes.
     16 export def NONCESIZE: size = 8;
     17 
     18 def ROUNDS: size = 20;
     19 
     20 // The block size of the Salsa cipher.
     21 export def BLOCKSIZE: size = 64;
     22 
     23 const magic: [4]u32 = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574];
     24 
     25 export type stream = struct {
     26 	cipher::xorstream,
     27 	state: [16]u32,
     28 	xorbuf: [BLOCKSIZE]u8,
     29 	xorused: size,
     30 	rounds: size,
     31 };
     32 
     33 // Create a Salsa20 or XSalsa20 stream. Must be initialized with either
     34 // [[salsa20_init]] or [[xsalsa20_init]], and must be closed with [[io::close]]
     35 // after use to wipe sensitive data from memory.
     36 export fn salsa20() stream = {
     37 	return stream {
     38 		stream = &cipher::xorstream_vtable,
     39 		h = 0,
     40 		keybuf = &keybuf,
     41 		advance = &advance,
     42 		finish = &finish,
     43 		xorused = BLOCKSIZE,
     44 		rounds = ROUNDS,
     45 		...
     46 	};
     47 };
     48 
     49 fn init(
     50 	state: *[16]u32,
     51 	key: []u8,
     52 	nonce: []u8,
     53 	ctr: []u8
     54 ) void = {
     55 	state[0] = magic[0];
     56 	state[1] = endian::legetu32(key[0..4]);
     57 	state[2] = endian::legetu32(key[4..8]);
     58 	state[3] = endian::legetu32(key[8..12]);
     59 	state[4] = endian::legetu32(key[12..16]);
     60 	state[5] = magic[1];
     61 	state[6] = endian::legetu32(nonce[0..4]);
     62 	state[7] = endian::legetu32(nonce[4..8]);
     63 	state[8] = endian::legetu32(ctr[0..4]);
     64 	state[9] = endian::legetu32(ctr[4..8]);
     65 	state[10] = magic[2];
     66 	state[11] = endian::legetu32(key[16..20]);
     67 	state[12] = endian::legetu32(key[20..24]);
     68 	state[13] = endian::legetu32(key[24..28]);
     69 	state[14] = endian::legetu32(key[28..32]);
     70 	state[15] = magic[3];
     71 };
     72 
     73 // Initialize a Salsa20 stream.
     74 export fn salsa20_init(
     75 	s: *stream,
     76 	h: io::handle,
     77 	key: []u8,
     78 	nonce: []u8,
     79 ) void = {
     80 	assert(len(key) == KEYSIZE);
     81 	assert(len(nonce) == NONCESIZE);
     82 
     83 	let counter: [8]u8 = [0...];
     84 	init(&s.state, key, nonce, &counter);
     85 	s.h = h;
     86 };
     87 
     88 // Initialize an XSalsa20 stream. XSalsa20 differs from Salsa20 via the use of a
     89 // larger nonce parameter.
     90 export fn xsalsa20_init(
     91 	s: *stream,
     92 	h: io::handle,
     93 	key: []u8,
     94 	nonce: []u8
     95 ) void = {
     96 	assert(len(key) == KEYSIZE);
     97 	assert(len(nonce) == XNONCESIZE);
     98 
     99 	let dkey: [32]u8 = [0...];
    100 	defer bytes::zero(dkey);
    101 	hsalsa20(&dkey, key, nonce[..16]);
    102 	salsa20_init(s, h, &dkey, nonce[16..]: *[NONCESIZE]u8);
    103 };
    104 
    105 // Derives a new key from 'key' and 'nonce' as used during XSalsa20
    106 // initialization. This function may only be used for specific purposes
    107 // such as X25519 key derivation. Do not use if in doubt.
    108 export fn hsalsa20(out: []u8, key: []u8, nonce: []u8) void = {
    109 	assert(len(out) == KEYSIZE);
    110 	assert(len(key) == KEYSIZE);
    111 	assert(len(nonce) == 16);
    112 
    113 	let state: [16]u32 = [0...];
    114 	defer bytes::zero((state: []u8: *[*]u8)[..BLOCKSIZE]);
    115 
    116 	init(&state, key, nonce[0..8]: *[8]u8, nonce[8..16]: *[8]u8);
    117 	hblock(state[..], &state, 20);
    118 
    119 	endian::leputu32(out[0..4], state[0]);
    120 	endian::leputu32(out[4..8], state[5]);
    121 	endian::leputu32(out[8..12], state[10]);
    122 	endian::leputu32(out[12..16], state[15]);
    123 	endian::leputu32(out[16..20], state[6]);
    124 	endian::leputu32(out[20..24], state[7]);
    125 	endian::leputu32(out[24..28], state[8]);
    126 	endian::leputu32(out[28..32], state[9]);
    127 };
    128 
    129 // Advances the key stream to "seek" to a future state by 'counter' times
    130 // [[BLOCKSIZE]].
    131 export fn setctr(s: *stream, counter: u64) void = {
    132 	s.state[8] = (counter & 0xFFFFFFFF): u32;
    133 	s.state[9] = (counter >> 32): u32;
    134 	s.xorused = BLOCKSIZE;
    135 };
    136 
    137 fn keybuf(s: *cipher::xorstream) []u8 = {
    138 	let s = s: *stream;
    139 	if (s.xorused >= BLOCKSIZE) {
    140 		block((s.xorbuf[..]: *[*]u32)[..16], &s.state, s.rounds);
    141 		s.state[8] += 1;
    142 		if (s.state[8] == 0) {
    143 			s.state[9] += 1;
    144 		};
    145 		s.xorused = 0;
    146 	};
    147 
    148 	return s.xorbuf[s.xorused..];
    149 };
    150 
    151 fn advance(s: *cipher::xorstream, n: size) void = {
    152 	let s = s: *stream;
    153 	assert(n <= len(s.xorbuf));
    154 	s.xorused += n;
    155 };
    156 
    157 fn block(dest: []u32, state: *[16]u32, rounds: size) void = {
    158 	hblock(dest, state, rounds);
    159 
    160 	for (let i = 0z; i < 16; i += 1) {
    161 		dest[i] += state[i];
    162 	};
    163 };
    164 
    165 fn hblock(dest: []u32, state: *[16]u32, rounds: size) void = {
    166 	for (let i = 0z; i < 16; i += 1) {
    167 		dest[i] = state[i];
    168 	};
    169 
    170 	for (let i = 0z; i < rounds; i += 2) {
    171 		qr(&dest[0], &dest[4], &dest[8], &dest[12]);
    172 		qr(&dest[5], &dest[9], &dest[13], &dest[1]);
    173 		qr(&dest[10], &dest[14], &dest[2], &dest[6]);
    174 		qr(&dest[15], &dest[3], &dest[7], &dest[11]);
    175 
    176 		qr(&dest[0], &dest[1], &dest[2], &dest[3]);
    177 		qr(&dest[5], &dest[6], &dest[7], &dest[4]);
    178 		qr(&dest[10], &dest[11], &dest[8], &dest[9]);
    179 		qr(&dest[15], &dest[12], &dest[13], &dest[14]);
    180 	};
    181 };
    182 
    183 fn qr(a: *u32, b: *u32, c: *u32, d: *u32) void = {
    184 	*b ^= rotl32(*a + *d, 7);
    185 	*c ^= rotl32(*b + *a, 9);
    186 	*d ^= rotl32(*c + *b, 13);
    187 	*a ^= rotl32(*d + *c, 18);
    188 };
    189 
    190 fn finish(s: *cipher::xorstream) void = {
    191 	let s = s: *stream;
    192 	bytes::zero((s.state[..]: *[*]u8)[..len(s.state) * size(u32)]);
    193 	bytes::zero(s.xorbuf);
    194 };