hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

salsa20.ha (4979B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 use bytes;
      5 use crypto::cipher;
      6 use crypto::math::{rotl32, xor};
      7 use endian;
      8 use io;
      9 
     10 // Size of a Salsa key, in bytes.
     11 export def KEYSZ: size = 32;
     12 
     13 // Size of the XSalsa20 nonce, in bytes.
     14 export def XNONCESZ: size = 24;
     15 
     16 // Size of the Salsa20 nonce, in bytes.
     17 export def NONCESZ: size = 8;
     18 
     19 def ROUNDS: size = 20;
     20 
     21 // The block size of the Salsa cipher.
     22 export def BLOCKSZ: size = 64;
     23 
     24 const magic: [4]u32 = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574];
     25 
     26 export type stream = struct {
     27 	cipher::xorstream,
     28 	state: [16]u32,
     29 	xorbuf: [BLOCKSZ]u8,
     30 	xorused: size,
     31 	rounds: size,
     32 };
     33 
     34 // Create a Salsa20 or XSalsa20 stream. Must be initialized with either
     35 // [[salsa20_init]] or [[xsalsa20_init]], and must be closed with [[io::close]]
     36 // after use to wipe sensitive data from memory.
     37 export fn salsa20() stream = {
     38 	return stream {
     39 		stream = &cipher::xorstream_vtable,
     40 		h = 0,
     41 		keybuf = &keybuf,
     42 		advance = &advance,
     43 		finish = &finish,
     44 		xorused = BLOCKSZ,
     45 		rounds = ROUNDS,
     46 		...
     47 	};
     48 };
     49 
     50 fn init(
     51 	state: *[16]u32,
     52 	key: []u8,
     53 	nonce: []u8,
     54 	ctr: []u8
     55 ) void = {
     56 	state[0] = magic[0];
     57 	state[1] = endian::legetu32(key[0..4]);
     58 	state[2] = endian::legetu32(key[4..8]);
     59 	state[3] = endian::legetu32(key[8..12]);
     60 	state[4] = endian::legetu32(key[12..16]);
     61 	state[5] = magic[1];
     62 	state[6] = endian::legetu32(nonce[0..4]);
     63 	state[7] = endian::legetu32(nonce[4..8]);
     64 	state[8] = endian::legetu32(ctr[0..4]);
     65 	state[9] = endian::legetu32(ctr[4..8]);
     66 	state[10] = magic[2];
     67 	state[11] = endian::legetu32(key[16..20]);
     68 	state[12] = endian::legetu32(key[20..24]);
     69 	state[13] = endian::legetu32(key[24..28]);
     70 	state[14] = endian::legetu32(key[28..32]);
     71 	state[15] = magic[3];
     72 };
     73 
     74 // Initialize a Salsa20 stream.
     75 export fn salsa20_init(
     76 	s: *stream,
     77 	h: io::handle,
     78 	key: []u8,
     79 	nonce: []u8,
     80 ) void = {
     81 	assert(len(key) == KEYSZ);
     82 	assert(len(nonce) == NONCESZ);
     83 
     84 	let counter: [8]u8 = [0...];
     85 	init(&s.state, key, nonce, &counter);
     86 	s.h = h;
     87 };
     88 
     89 // Initialize an XSalsa20 stream. XSalsa20 differs from Salsa20 via the use of a
     90 // larger nonce parameter.
     91 export fn xsalsa20_init(
     92 	s: *stream,
     93 	h: io::handle,
     94 	key: []u8,
     95 	nonce: []u8
     96 ) void = {
     97 	assert(len(key) == KEYSZ);
     98 	assert(len(nonce) == XNONCESZ);
     99 
    100 	let dkey: [32]u8 = [0...];
    101 	defer bytes::zero(dkey);
    102 	hsalsa20(&dkey, key, nonce[..16]);
    103 	salsa20_init(s, h, &dkey, nonce[16..]: *[NONCESZ]u8);
    104 };
    105 
    106 // Derives a new key from 'key' and 'nonce' as used during XSalsa20
    107 // initialization. This function may only be used for specific purposes
    108 // such as X25519 key derivation. Do not use if in doubt.
    109 export fn hsalsa20(out: []u8, key: []u8, nonce: []u8) void = {
    110 	assert(len(out) == KEYSZ);
    111 	assert(len(key) == KEYSZ);
    112 	assert(len(nonce) == 16);
    113 
    114 	let state: [16]u32 = [0...];
    115 	defer bytes::zero((state: []u8: *[*]u8)[..BLOCKSZ]);
    116 
    117 	init(&state, key, nonce[0..8]: *[8]u8, nonce[8..16]: *[8]u8);
    118 	hblock(state[..], &state, 20);
    119 
    120 	endian::leputu32(out[0..4], state[0]);
    121 	endian::leputu32(out[4..8], state[5]);
    122 	endian::leputu32(out[8..12], state[10]);
    123 	endian::leputu32(out[12..16], state[15]);
    124 	endian::leputu32(out[16..20], state[6]);
    125 	endian::leputu32(out[20..24], state[7]);
    126 	endian::leputu32(out[24..28], state[8]);
    127 	endian::leputu32(out[28..32], state[9]);
    128 };
    129 
    130 // Advances the key stream to "seek" to a future state by 'counter' times
    131 // [[BLOCKSZ]].
    132 export fn setctr(s: *stream, counter: u64) void = {
    133 	s.state[8] = (counter & 0xFFFFFFFF): u32;
    134 	s.state[9] = (counter >> 32): u32;
    135 	s.xorused = BLOCKSZ;
    136 };
    137 
    138 fn keybuf(s: *cipher::xorstream) []u8 = {
    139 	let s = s: *stream;
    140 	if (s.xorused >= BLOCKSZ) {
    141 		block((s.xorbuf[..]: *[*]u32)[..16], &s.state, s.rounds);
    142 		s.state[8] += 1;
    143 		if (s.state[8] == 0) {
    144 			s.state[9] += 1;
    145 		};
    146 		s.xorused = 0;
    147 	};
    148 
    149 	return s.xorbuf[s.xorused..];
    150 };
    151 
    152 fn advance(s: *cipher::xorstream, n: size) void = {
    153 	let s = s: *stream;
    154 	assert(n <= len(s.xorbuf));
    155 	s.xorused += n;
    156 };
    157 
    158 fn block(dest: []u32, state: *[16]u32, rounds: size) void = {
    159 	hblock(dest, state, rounds);
    160 
    161 	for (let i = 0z; i < 16; i += 1) {
    162 		dest[i] += state[i];
    163 	};
    164 };
    165 
    166 fn hblock(dest: []u32, state: *[16]u32, rounds: size) void = {
    167 	for (let i = 0z; i < 16; i += 1) {
    168 		dest[i] = state[i];
    169 	};
    170 
    171 	for (let i = 0z; i < rounds; i += 2) {
    172 		qr(&dest[0], &dest[4], &dest[8], &dest[12]);
    173 		qr(&dest[5], &dest[9], &dest[13], &dest[1]);
    174 		qr(&dest[10], &dest[14], &dest[2], &dest[6]);
    175 		qr(&dest[15], &dest[3], &dest[7], &dest[11]);
    176 
    177 		qr(&dest[0], &dest[1], &dest[2], &dest[3]);
    178 		qr(&dest[5], &dest[6], &dest[7], &dest[4]);
    179 		qr(&dest[10], &dest[11], &dest[8], &dest[9]);
    180 		qr(&dest[15], &dest[12], &dest[13], &dest[14]);
    181 	};
    182 };
    183 
    184 fn qr(a: *u32, b: *u32, c: *u32, d: *u32) void = {
    185 	*b ^= rotl32(*a + *d, 7);
    186 	*c ^= rotl32(*b + *a, 9);
    187 	*d ^= rotl32(*c + *b, 13);
    188 	*a ^= rotl32(*d + *c, 18);
    189 };
    190 
    191 fn finish(s: *cipher::xorstream) void = {
    192 	let s = s: *stream;
    193 	bytes::zero((s.state[..]: *[*]u8)[..len(s.state) * size(u32)]);
    194 	bytes::zero(s.xorbuf);
    195 };