Source code
pub const rsa = struct {
const max_modulus_bits = 4096;
const Uint = std.crypto.ff.Uint(max_modulus_bits);
const Modulus = std.crypto.ff.Modulus(max_modulus_bits);
const Fe = Modulus.Fe;
pub const PSSSignature = struct {
pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 {
var result: [modulus_len]u8 = undefined;
@memcpy(result[0..msg.len], msg);
@memset(result[msg.len..], 0);
return result;
}
pub const VerifyError = EncryptError || error{InvalidSignature};
pub fn verify(
comptime modulus_len: usize,
sig: [modulus_len]u8,
msg: []const u8,
public_key: PublicKey,
comptime Hash: type,
) VerifyError!void {
try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash);
}
pub fn concatVerify(
comptime modulus_len: usize,
sig: [modulus_len]u8,
msg: []const []const u8,
public_key: PublicKey,
comptime Hash: type,
) VerifyError!void {
const mod_bits = public_key.n.bits();
const em_dec = try encrypt(modulus_len, sig, public_key);
try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash);
}
fn EMSA_PSS_VERIFY(msg: []const []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) VerifyError!void {if (emBit >= 1 << 61) return error.InvalidSignature;const emLen = ((emBit - 1) / 8) + 1;
std.debug.assert(emLen == em.len);var mHash: [Hash.digest_length]u8 = undefined;
{
var hasher: Hash = .init(.{});
for (msg) |part| hasher.update(part);
hasher.final(&mHash);
}if (emLen < Hash.digest_length + sLen + 2) {
return error.InvalidSignature;
}if (em[em.len - 1] != 0xbc) {
return error.InvalidSignature;
}const maskedDB = em[0..(emLen - Hash.digest_length - 1)];
const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)][0..Hash.digest_length];const zero_bits = emLen * 8 - emBit;
var mask: u8 = maskedDB[0];
var i: usize = 0;
while (i < 8 - zero_bits) : (i += 1) {
mask = mask >> 1;
}
if (mask != 0) {
return error.InvalidSignature;
}const mgf_len = emLen - Hash.digest_length - 1;
var mgf_out_buf: [512]u8 = undefined;
if (mgf_len > mgf_out_buf.len) {return error.InvalidSignature;
}
const mgf_out = mgf_out_buf[0 .. ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length];
var dbMask = try MGF1(Hash, mgf_out, h, mgf_len);i = 0;
while (i < dbMask.len) : (i += 1) {
dbMask[i] = maskedDB[i] ^ dbMask[i];
}i = 0;
mask = 0;
while (i < 8 - zero_bits) : (i += 1) {
mask = mask << 1;
mask += 1;
}
dbMask[0] = dbMask[0] & mask;if (dbMask[mgf_len - sLen - 2] != 0x00) {
return error.InvalidSignature;
}
if (dbMask[mgf_len - sLen - 1] != 0x01) {
return error.InvalidSignature;
}const salt = dbMask[(mgf_len - sLen)..];if (sLen > Hash.digest_length) {return error.InvalidSignature;
}
var m_p_buf: [8 + Hash.digest_length + Hash.digest_length]u8 = undefined;
var m_p = m_p_buf[0 .. 8 + Hash.digest_length + sLen];
std.mem.copyForwards(u8, m_p, &([_]u8{0} ** 8));
std.mem.copyForwards(u8, m_p[8..], &mHash);
std.mem.copyForwards(u8, m_p[(8 + Hash.digest_length)..], salt);var h_p: [Hash.digest_length]u8 = undefined;
Hash.hash(m_p, &h_p, .{});if (!std.mem.eql(u8, h, &h_p)) {
return error.InvalidSignature;
}
}
fn MGF1(comptime Hash: type, out: []u8, seed: *const [Hash.digest_length]u8, len: usize) ![]u8 {
var counter: u32 = 0;
var idx: usize = 0;
var hash = seed.* ++ @as([4]u8, undefined);
while (idx < len) {
std.mem.writeInt(u32, hash[seed.len..][0..4], counter, .big);
Hash.hash(&hash, out[idx..][0..Hash.digest_length], .{});
idx += Hash.digest_length;
counter += 1;
}
return out[0..len];
}
};
pub const PKCS1v1_5Signature = struct {
pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 {
var result: [modulus_len]u8 = undefined;
@memcpy(result[0..msg.len], msg);
@memset(result[msg.len..], 0);
return result;
}
pub const VerifyError = EncryptError || error{InvalidSignature};
pub fn verify(
comptime modulus_len: usize,
sig: [modulus_len]u8,
msg: []const u8,
public_key: PublicKey,
comptime Hash: type,
) VerifyError!void {
try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash);
}
pub fn concatVerify(
comptime modulus_len: usize,
sig: [modulus_len]u8,
msg: []const []const u8,
public_key: PublicKey,
comptime Hash: type,
) VerifyError!void {
const em_dec = try encrypt(modulus_len, sig, public_key);
const em = try EMSA_PKCS1_V1_5_ENCODE(msg, modulus_len, Hash);
if (!std.mem.eql(u8, &em_dec, &em)) return error.InvalidSignature;
}
fn EMSA_PKCS1_V1_5_ENCODE(msg: []const []const u8, comptime emLen: usize, comptime Hash: type) VerifyError![emLen]u8 {
comptime var em_index = emLen;
var em: [emLen]u8 = undefined;var hasher: Hash = .init(.{});
for (msg) |part| hasher.update(part);
em_index -= Hash.digest_length;
hasher.final(em[em_index..]);const hash_der: []const u8 = &switch (Hash) {
crypto.hash.Sha1 => .{
0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e,
0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14,
},
crypto.hash.sha2.Sha224 => .{
0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05,
0x00, 0x04, 0x1c,
},
crypto.hash.sha2.Sha256 => .{
0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
0x00, 0x04, 0x20,
},
crypto.hash.sha2.Sha384 => .{
0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05,
0x00, 0x04, 0x30,
},
crypto.hash.sha2.Sha512 => .{
0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05,
0x00, 0x04, 0x40,
},
else => @compileError("unreachable"),
};
em_index -= hash_der.len;
@memcpy(em[em_index..][0..hash_der.len], hash_der);em_index -= 1;
@memset(em[2..em_index], 0xff);em[em_index] = 0x00;
em[1] = 0x01;
em[0] = 0x00;return em;
}
};
pub const PublicKey = struct {
n: Modulus,
e: Fe,
pub const FromBytesError = error{CertificatePublicKeyInvalid};
pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) FromBytesError!PublicKey {const _n = Modulus.fromBytes(modulus_bytes, .big) catch return error.CertificatePublicKeyInvalid;
if (_n.bits() < 512) return error.CertificatePublicKeyInvalid;if (pub_bytes.len > 4) return error.CertificatePublicKeyInvalid;
const _e = Fe.fromBytes(_n, pub_bytes, .big) catch return error.CertificatePublicKeyInvalid;
if (!_e.isOdd()) return error.CertificatePublicKeyInvalid;
const e_v = _e.toPrimitive(u32) catch return error.CertificatePublicKeyInvalid;
if (e_v < 2) return error.CertificatePublicKeyInvalid;
return .{
.n = _n,
.e = _e,
};
}
pub const ParseDerError = der.Element.ParseError || error{CertificateFieldHasWrongDataType};
pub fn parseDer(pub_key: []const u8) ParseDerError!struct { modulus: []const u8, exponent: []const u8 } {
const pub_key_seq = try der.Element.parse(pub_key, 0);
if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
const modulus_offset = for (modulus_raw, 0..) |byte, i| {
if (byte != 0) break i;
} else modulus_raw.len;
return .{
.modulus = modulus_raw[modulus_offset..],
.exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end],
};
}
};
const EncryptError = error{MessageTooLong};
fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) EncryptError![modulus_len]u8 {
const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong;
const e = public_key.n.powPublic(m, public_key.e) catch unreachable;
var res: [modulus_len]u8 = undefined;
e.toBytes(&res, .big) catch unreachable;
return res;
}
}