pss.ha (5235B)
1 // SPDX-License-Identifier: MPL-2.0 2 // (c) Hare authors <https://harelang.org> 3 4 use bytes; 5 use crypto::math; 6 use endian; 7 use errors; 8 use hash; 9 use io; 10 use types; 11 12 export type default = void; 13 14 // Required minimum buffer size for [[pss_verify]] 15 export def PSS_VERIFYBUFSZ = PUBEXP_BUFSZ + ((BITSZ + 7) / 8); 16 17 // Required minimum buffer size for [[pss_sign]] 18 export def PSS_SIGNBUFSZ = PRIVEXP_BUFSZ; 19 20 // Signs the hash 'msghash' using the private key 'privkey' by applying the PSS 21 // signature scheme as defined in RFC 8017. 'sig' must be in the the size of the 22 // modulus n (see [[privkey_nsize]]) 23 // 24 // It is recommended that 'hf' is the same hash function that was used to 25 // generate 'msgmhash'. 'buf' needs to be at least the size of 26 // [[PSS_SIGNBUFSZ]]. 'rand' must be an [[io::reader]] that returns a 27 // cryptographiclly random data on read like [[crypto::random::stream]]. The 28 // expected size of the salt is provided with 'saltsz'. Default is the maximum 29 // possible salt size. 30 // 31 // Returns [[errors::invalid]], if one of the parameters are invalid. 32 // [[errors::overflow]] is returned, if 'buf' is to small. Errors that occur by 33 // reading from 'rand' are returned as [[io::error]]. 34 export fn pss_sign( 35 privkey: []u8, 36 msghash: []u8, 37 sig: []u8, 38 hf: *hash::hash, 39 rand: io::handle, 40 buf: []u8, 41 saltsz: (size | default) = default, 42 ) (void | error | io::error) = { 43 let priv = privkey_params(privkey); 44 45 // Use var names that match the rfc. 46 let embits = priv.nbitlen - 1; 47 if (len(sig) != (embits + 7) / 8) { 48 return errors::invalid: error; 49 }; 50 let em = sig; 51 let hlen = len(msghash); 52 let emlen = len(em); 53 const slen = dsaltsz(len(sig), hlen, saltsz); 54 55 if (emlen < hlen + slen + 2) { 56 return errors::invalid: error; 57 }; 58 59 let db = em[..emlen - hlen - 1]; 60 db[..] = [0...]; 61 db[len(db) - slen - 1] = 0x01; 62 let salt = db[len(db) - slen..]; 63 io::readall(rand, salt)?; 64 65 let h = em[emlen - hlen - 1..emlen - 1]; 66 const padding: [8]u8 = [0...]; 67 hash::reset(hf); 68 hash::write(hf, padding); 69 hash::write(hf, msghash); 70 hash::write(hf, salt); 71 hash::sum(hf, h); 72 73 mgfxor(db, hf, h, buf); 74 75 em[0] &= 0xff >> (8*emlen - embits): u8; 76 em[len(db)..emlen - 1] = h[..]; 77 em[emlen - 1] = 0xbc; 78 79 privexp(&priv, em, buf)?; 80 }; 81 82 fn dsaltsz(nsz: size, hsz: size, s: (size | default)) size = { 83 match (s) { 84 case let s: size => 85 return s; 86 case default => 87 return nsz - hsz - 2; 88 }; 89 }; 90 91 // Verifies a PSS signature 'sig' of the mesage hash 'msghash' using the public 92 // key 'pupkey' as defined in RFC 8017. 93 // 94 // 'hf' must be the hash that was used to create the signature. 'buf' needs to 95 // be at least the size of [[PSS_VERIFYBUFSZ]]. The expected size of the salt is 96 // provided with 'saltsz'. Default is the maximum possible salt size. The 97 // function will fail, if the signature's salt size does not match the expected. 98 // 99 // Returns [[badsig]], if the signature verification fails. [[errors::overflow]] 100 // is returned, if 'buf' is to small. 101 export fn pss_verify( 102 pubkey: []u8, 103 msghash: []u8, 104 sig: []u8, 105 hf: *hash::hash, 106 buf: []u8, 107 saltsz: (size | default) = default, 108 ) (void | error) = { 109 let pub = pubkey_params(pubkey); 110 if (len(sig) != len(pub.n)) { 111 return badsig; 112 }; 113 114 // rename some variables to match the ones in the RFC 115 let mhash = msghash; 116 const hlen = hash::sz(hf); 117 const slen = dsaltsz(len(pub.n), hlen, saltsz); 118 let em = buf[..len(sig)]; 119 const emlen = len(em); 120 em[..] = sig[..]; 121 122 if (emlen < hlen + slen + 2) { 123 return badsig; 124 }; 125 126 let pubbuf = buf[len(sig)..]; 127 match (pubexp(&pub, em, pubbuf)) { 128 case void => void; 129 case errors::invalid => 130 return badsig; 131 case let e: error => 132 return e; 133 }; 134 135 if (em[emlen - 1] != 0xbc) { 136 return badsig; 137 }; 138 139 const maskdbsz = emlen - hlen - 1; 140 let maskeddb = em[..maskdbsz]; 141 let h = em[maskdbsz..maskdbsz + hlen]; 142 143 const embitlen = pubkey_nbitlen(pubkey) - 1; 144 const zerobitsh = 8 - (8*len(em) - embitlen); 145 if (maskeddb[0] >> zerobitsh != 0) { 146 return badsig; 147 }; 148 149 let db = maskeddb; 150 mgfxor(db, hf, h, pubbuf); 151 db[0] &= 0xff >> (8*len(em) - embitlen): u8; 152 153 const seppos = len(em) - hlen - slen - 2; 154 for (let i = 0z; i < seppos; i += 1) { 155 if (db[i] != 0x00) { 156 return badsig; 157 }; 158 }; 159 if (db[seppos] != 0x01) { 160 return badsig; 161 }; 162 163 const salt = db[len(db) - slen..]; 164 const padding: [8]u8 = [0...]; 165 hash::reset(hf); 166 hash::write(hf, padding); 167 hash::write(hf, mhash); 168 hash::write(hf, salt); 169 170 let genh = pubbuf[..hlen]; 171 hash::sum(hf, genh); 172 173 if (math::eqslice(genh, h) != 1) { 174 return badsig; 175 }; 176 }; 177 178 // dest = dest XOR mgf(h, seed, len(dest)). 'buf' must be hash::sz(h) bytes 179 // long. 180 fn mgfxor(dest: []u8, h: *hash::hash, seed: []u8, buf: []u8) void = { 181 assert(len(buf) >= hash::sz(h)); 182 183 let ctrbuf: [4]u8 = [0...]; 184 let sum = buf[..hash::sz(h)]; 185 const iterations = (len(dest) + len(sum) - 1) / len(sum); 186 187 for (let ctr: u32 = 0; ctr < iterations; ctr += 1) { 188 endian::beputu32(ctrbuf, ctr); 189 hash::reset(h); 190 hash::write(h, seed); 191 hash::write(h, ctrbuf); 192 hash::sum(h, sum); 193 194 const start = ctr * len(sum); 195 const remain = len(dest) - start; 196 const chunksz = if (remain < len(sum)) remain else len(sum); 197 198 let chunk = dest[start..start + chunksz]; 199 math::xor(chunk, chunk, sum[..chunksz]); 200 }; 201 202 bytes::zero(sum); 203 };