hare

[hare] The Hare programming language
git clone https://git.torresjrjr.com/hare.git
Log | Files | Refs | README | LICENSE

edwards25519.ha (8210B)


      1 // SPDX-License-Identifier: MPL-2.0
      2 // (c) Hare authors <https://harelang.org>
      3 
      4 def FIELDSZ: size = 16;
      5 type elem = [FIELDSZ]i64;
      6 
      7 const feZero: elem = [0...];
      8 const feOne: elem = [1, 0...];
      9 const D: elem = [0x78a3, 0x1359, 0x4dca, 0x75eb, 0xd8ab, 0x4141, 0x0a4d, 0x0070, 0xe898, 0x7779, 0x4079, 0x8cc7, 0xfe73, 0x2b6f, 0x6cee, 0x5203];
     10 const D2: elem = [0xf159, 0x26b2, 0x9b94, 0xebd6, 0xb156, 0x8283, 0x149a, 0x00e0, 0xd130, 0xeef3, 0x80f2, 0x198e, 0xfce7, 0x56df, 0xd9dc, 0x2406];
     11 const X: elem = [0xd51a, 0x8f25, 0x2d60, 0xc956, 0xa7b2, 0x9525, 0xc760, 0x692c, 0xdc5c, 0xfdd6, 0xe231, 0xc0a4, 0x53fe, 0xcd6e, 0x36d3, 0x2169];
     12 const Y: elem = [0x6658, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666, 0x6666];
     13 const I: elem = [0xa0b0, 0x4a0e, 0x1b27, 0xc4ee, 0xe478, 0xad2f, 0x1806, 0x2f43, 0xd7a7, 0x3dfb, 0x0099, 0x2b4d, 0xdf0b, 0x4fc1, 0x2480, 0x2b83];
     14 
     15 fn fe_reduce(fe: *elem) void = {
     16 	let carry: i64 = 0;
     17 	for (let i = 0z; i < FIELDSZ; i += 1) {
     18 		carry = fe[i] >> 16;
     19 		fe[i] -= (carry << 16);
     20 		if (i+1 < FIELDSZ) {
     21 			fe[i + 1] += carry;
     22 		} else {
     23 			fe[0] += (38 * carry);
     24 		};
     25 	};
     26 };
     27 
     28 fn fe_add(out: *elem, a: const *elem, b: const *elem) *elem = {
     29 	for (let i = 0z; i < FIELDSZ; i += 1) {
     30 		out[i] = a[i] + b[i];
     31 	};
     32 	return out;
     33 };
     34 
     35 fn fe_sub(out: *elem, a: const *elem, b: const *elem) *elem = {
     36 	for (let i = 0z; i < FIELDSZ; i += 1) {
     37 		out[i] = a[i] - b[i];
     38 	};
     39 	return out;
     40 };
     41 
     42 fn fe_negate(out: *elem, a: const *elem) *elem = {
     43 	return fe_sub(out, &feZero, a);
     44 };
     45 
     46 fn fe_mul(out: *elem, a: const *elem, b: const *elem) *elem = {
     47 	let prod: [31]i64 = [0...];
     48 	for (let i = 0z; i < FIELDSZ; i += 1) {
     49 		for (let j = 0z; j < FIELDSZ; j += 1) {
     50 			prod[i + j] += a[i] * b[j];
     51 		};
     52 	};
     53 	for (let i = 0; i < 15; i += 1) {
     54 		prod[i] += (38 * prod[i + 16]);
     55 	};
     56 	out[0..FIELDSZ] = prod[0..FIELDSZ];
     57 	fe_reduce(out);
     58 	fe_reduce(out);
     59 	return out;
     60 };
     61 
     62 fn fe_square(out: *elem, a: const *elem) *elem = {
     63 	return fe_mul(out, a, a);
     64 };
     65 
     66 // out = i ** (2**252 - 3)
     67 fn fe_pow2523(out: *elem, a: *elem) *elem = {
     68 	let c: elem = [0...];
     69 	c[..] = a[..];
     70 	for (let i = 250i; i >= 0; i -= 1) {
     71 		fe_square(&c, &c);
     72 		if (i != 1) {
     73 			fe_mul(&c, &c, a);
     74 		};
     75 	};
     76 	out[..] = c[..];
     77 	return out;
     78 };
     79 
     80 fn fe_inv(out: *elem, a: const *elem) *elem = {
     81 	let c: elem = [0...];
     82 	c[..] = a[..];
     83 	for (let i = 253i; i >= 0; i -= 1) {
     84 		fe_square(&c, &c);
     85 		if (i != 2 && i != 4) {
     86 			fe_mul(&c, &c, a);
     87 		};
     88 	};
     89 	out[..] = c[..];
     90 	return out;
     91 };
     92 
     93 fn fe_parity(a: const *elem) u8 = {
     94 	let d: scalar = [0...];
     95 	fe_encode(&d, a);
     96 	return d[0]&1;
     97 };
     98 
     99 // a == b -> 0
    100 // a != b -> 1
    101 fn fe_cmp(a: const *elem, b: const *elem) u8 = {
    102 	let x: scalar = [0...];
    103 	fe_encode(&x, a);
    104 	let y: scalar = [0...];
    105 	fe_encode(&y, b);
    106 
    107 	// constant-time compare
    108 	let d: u32 = 0;
    109 	for (let i = 0z; i < SCALARSZ; i += 1) {
    110 		d |= x[i] ^ y[i];
    111 	};
    112 	return (1 & ((d - 1) >> 8): u8) - 1;
    113 };
    114 
    115 // swap p and q if bit is 1, otherwise noop
    116 fn fe_swap(p: *elem, q: *elem, bit: u8) void = {
    117 	let c = ~(bit: u64 - 1): i64;
    118 	for (let i = 0z; i < FIELDSZ; i += 1) {
    119 		let t = c & (p[i] ^ q[i]);
    120 		p[i] ^= t;
    121 		q[i] ^= t;
    122 	};
    123 };
    124 
    125 fn fe_encode(out: *scalar, a: const *elem) void = {
    126 	let m: elem = [0...];
    127 	let t: elem = *a;
    128 
    129 	fe_reduce(&t);
    130 	fe_reduce(&t);
    131 	fe_reduce(&t);
    132 
    133 	for (let _i = 0; _i < 2; _i += 1) {
    134 		m[0] = t[0] - 0xffed;
    135 		for (let i = 1z; i < 15; i += 1) {
    136 			m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1);
    137 			m[i - 1] &= 0xffff;
    138 		};
    139 		m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1);
    140 		let b = ((m[15] >> 16): u8) & 1;
    141 		m[14] &= 0xffff;
    142 		fe_swap(&t, &m, 1-b);
    143 	};
    144 
    145 	for (let i = 0z; i < FIELDSZ; i += 1) {
    146 		out[2*i+0] = (t[i] & 0xff) : u8;
    147 		out[2*i+1] = (t[i] >> 8) : u8;
    148 	};
    149 };
    150 
    151 // len(in) must be SCALARSZ
    152 fn fe_decode(fe: *elem, in: []u8) *elem = {
    153 	for (let i = 0z; i < FIELDSZ; i += 1) {
    154 		fe[i] = in[2 * i] : i64 + ((in[2 * i + 1] : i64) << 8);
    155 	};
    156 	fe[15] &= 0x7fff;
    157 	return fe;
    158 };
    159 
    160 
    161 def SCALARSZ: size = 32;
    162 type scalar = [SCALARSZ]u8;
    163 
    164 const L: scalar = [
    165 	0xed, 0xd3, 0xf5, 0x5c, 0x1a, 0x63, 0x12, 0x58, 0xd6, 0x9c, 0xf7, 0xa2,
    166 	0xde, 0xf9, 0xde, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
    167 	0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10,
    168 ];
    169 
    170 fn scalar_clamp(s: *scalar) void = {
    171 	s[0] &= 248;
    172 	s[31] &= 127;
    173 	s[31] |= 64;
    174 };
    175 
    176 // r = x % -1
    177 fn scalar_mod_L(r: *scalar, x: *[64]i64) void = {
    178 	for (let i: i64 = 63; i >= 32; i -= 1) {
    179 		let carry: i64 = 0;
    180 		let j = i - 32;
    181 		for (j < i - 12; j += 1) {
    182 			x[j] += carry - 16 * x[i] * (L[j - (i - 32)]: i64);
    183 			carry = (x[j] + 128) >> 8;
    184 			x[j] -= carry << 8;
    185 		};
    186 		x[j] += carry;
    187 		x[i] = 0;
    188 	};
    189 
    190 	let carry: i64 = 0;
    191 	for (let j = 0; j < 32; j += 1) {
    192 		x[j] += carry - (x[31] >> 4) * (L[j]: i64);
    193 		carry = x[j] >> 8;
    194 		x[j] &= 255;
    195 	};
    196 	for (let j = 0; j < 32; j += 1) {
    197 		x[j] -= carry * (L[j]: i64);
    198 	};
    199 	for (let i = 0; i < 32; i += 1) {
    200 		x[i+1] += x[i] >> 8;
    201 		r[i] = (x[i]&255): u8;
    202 	};
    203 };
    204 
    205 fn scalar_reduce(r: *scalar, h: *[64]u8) void = {
    206 	let x: [64]i64 = [0...];
    207 	for (let i = 0; i < 64; i += 1) {
    208 		x[i] = h[i]: i64;
    209 	};
    210 	scalar_mod_L(r, &x);
    211 };
    212 
    213 // s = a*b + c
    214 fn scalar_multiply_add(s: *scalar, a: *scalar, b: *scalar, c: *scalar) void = {
    215 	let x: [64]i64 = [0...];
    216 	for (let i = 0; i < 32; i += 1) {
    217 		for (let j = 0; j < 32; j += 1) {
    218 			x[i+j] += (a[i]: i64) * (b[j]: i64);
    219 		};
    220 	};
    221 	for (let i = 0; i < 32; i += 1) {
    222 		x[i] += (c[i]: i64);
    223 	};
    224 	scalar_mod_L(s, &x);
    225 };
    226 
    227 
    228 def POINTSZ: size = 32;
    229 
    230 type point = struct {
    231 	x: elem,
    232 	y: elem,
    233 	z: elem,
    234 	t: elem,
    235 };
    236 
    237 // out = p += q
    238 fn point_add(out: *point, p: *point, q: *point) *point = {
    239 	let a: elem = [0...];
    240 	let b: elem = [0...];
    241 	let c: elem = [0...];
    242 	let d: elem = [0...];
    243 	let t: elem = [0...];
    244 	let e: elem = [0...];
    245 	let f: elem = [0...];
    246 	let g: elem = [0...];
    247 	let h: elem = [0...];
    248 
    249 	fe_sub(&a, &p.y, &p.x);
    250 	fe_sub(&t, &q.y, &q.x);
    251 	fe_mul(&a, &a, &t);
    252 	fe_add(&b, &p.x, &p.y);
    253 	fe_add(&t, &q.x, &q.y);
    254 	fe_mul(&b, &b, &t);
    255 	fe_mul(&c, &p.t, &q.t);
    256 	fe_mul(&c, &c, &D2);
    257 	fe_mul(&d, &p.z, &q.z);
    258 	fe_add(&d, &d, &d);
    259 	fe_sub(&e, &b, &a);
    260 	fe_sub(&f, &d, &c);
    261 	fe_add(&g, &d, &c);
    262 	fe_add(&h, &b, &a);
    263 
    264 	fe_mul(&out.x, &e, &f);
    265 	fe_mul(&out.y, &h, &g);
    266 	fe_mul(&out.z, &g, &f);
    267 	fe_mul(&out.t, &e, &h);
    268 	return out;
    269 };
    270 
    271 // swap p and q if bit is 1, otherwise noop
    272 fn point_swap(p: *point, q: *point, bit: u8) void = {
    273 	fe_swap(&p.x, &q.x, bit);
    274 	fe_swap(&p.y, &q.y, bit);
    275 	fe_swap(&p.z, &q.z, bit);
    276 	fe_swap(&p.t, &q.t, bit);
    277 };
    278 
    279 // p = q * s
    280 fn scalarmult(p: *point, q: *point, s: const *scalar) *point = {
    281 	p.x[..] = feZero[..];
    282 	p.y[..] = feOne[..];
    283 	p.z[..] = feOne[..];
    284 	p.t[..] = feZero[..];
    285 	for (let i = 255; i >= 0; i -= 1) {
    286 		let b: u8 = (s[i/8]>>((i: u8)&7))&1;
    287 		point_swap(p, q, b);
    288 		point_add(q, q, p);
    289 		point_add(p, p, p);
    290 		point_swap(p, q, b);
    291 	};
    292 	return p;
    293 };
    294 
    295 // p = B * s
    296 fn scalarmult_base(p: *point, s: const *scalar) *point = {
    297 	let B = point {...};
    298 	B.x[..] = X[..];
    299 	B.y[..] = Y[..];
    300 	B.z[..] = feOne[..];
    301 	fe_mul(&B.t, &X, &Y);
    302 
    303 	return scalarmult(p, &B, s);
    304 };
    305 
    306 fn point_encode(out: *scalar, p: *point) void = {
    307 	let tx: elem = [0...];
    308 	let ty: elem = [0...];
    309 	let zi: elem = [0...];
    310 	fe_inv(&zi, &p.z);
    311 	fe_mul(&tx, &p.x, &zi);
    312 	fe_mul(&ty, &p.y, &zi);
    313 	fe_encode(out, &ty);
    314 	out[31] ^= fe_parity(&tx) << 7;
    315 };
    316 
    317 // len(in) must be POINTSZ
    318 fn point_decode(p: *point, in: []u8) bool = {
    319 	let t: elem = [0...];
    320 	let chk: elem = [0...];
    321 	let num: elem = [0...];
    322 	let den: elem = [0...];
    323 	let den2: elem = [0...];
    324 	let den4: elem = [0...];
    325 	let den6: elem = [0...];
    326 	p.z[..] = feOne[..];
    327 	fe_decode(&p.y, in);
    328 	fe_square(&num, &p.y);
    329 	fe_mul(&den, &num, &D);
    330 	fe_sub(&num, &num, &p.z);
    331 	fe_add(&den, &p.z, &den);
    332 
    333 	fe_square(&den2, &den);
    334 	fe_square(&den4, &den2);
    335 	fe_mul(&den6, &den4, &den2);
    336 	fe_mul(&t, &den6, &num);
    337 	fe_mul(&t, &t, &den);
    338 
    339 	fe_pow2523(&t, &t);
    340 	fe_mul(&t, &t, &num);
    341 	fe_mul(&t, &t, &den);
    342 	fe_mul(&t, &t, &den);
    343 	fe_mul(&p.x, &t, &den);
    344 
    345 	fe_square(&chk, &p.x);
    346 	fe_mul(&chk, &chk, &den);
    347 	if (fe_cmp(&chk, &num) != 0) {
    348 		fe_mul(&p.x, &p.x, &I);
    349 	};
    350 
    351 	fe_square(&chk, &p.x);
    352 	fe_mul(&chk, &chk, &den);
    353 	if (fe_cmp(&chk, &num) != 0) {
    354 		return false;
    355 	};
    356 
    357 	if (fe_parity(&p.x) == (in[31]>>7)) {
    358 		fe_negate(&p.x, &p.x);
    359 	};
    360 
    361 	fe_mul(&p.t, &p.x, &p.y);
    362 	return true;
    363 };