From 7138f7e94dc4c6b3bb410ebc06d47c37f2ae3bc9 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 11 Sep 2019 19:03:39 +0300 Subject: [PATCH] Add more primitives for SecIO. Fix SIGSEGV inside of rsa.nim and ecnist.nim. --- libp2p/crypto/crypto.nim | 109 ++++++++++++++++---- libp2p/crypto/ecnist.nim | 218 ++++++++++++++++++++++++++++++--------- libp2p/crypto/rsa.nim | 102 ++++++++++++------ tests/testcrypto.nim | 8 +- 4 files changed, 328 insertions(+), 109 deletions(-) diff --git a/libp2p/crypto/crypto.nim b/libp2p/crypto/crypto.nim index 83401f4a0..49416d504 100644 --- a/libp2p/crypto/crypto.nim +++ b/libp2p/crypto/crypto.nim @@ -8,14 +8,15 @@ ## those terms. ## This module implements Public Key and Private Key interface for libp2p. -import strutils import rsa, ecnist, ed25519/ed25519, secp -import ../protobuf/minprotobuf, ../vbuffer +import ../protobuf/minprotobuf, ../vbuffer, ../multihash, ../multicodec import nimcrypto/[rijndael, blowfish, sha, sha2, hash, hmac, utils] # This is workaround for Nim's `import` bug export rijndael, blowfish, sha, sha2, hash, hmac, utils +from strutils import split + type PKScheme* = enum RSA = 0, @@ -463,7 +464,7 @@ proc verify*(sig: Signature, message: openarray[byte], if signature.init(sig.data): result = signature.verify(message, key.skkey) -template makeSecret(buffer, hmactype, secret, seed) = +template makeSecret(buffer, hmactype, secret, seed: untyped) {.dirty.}= var ctx: hmactype var j = 0 # We need to strip leading zeros, because Go bigint serialization do it. @@ -489,16 +490,16 @@ template makeSecret(buffer, hmactype, secret, seed) = ctx.update(a.data) a = ctx.finish() -proc stretchKeys*(cipherScheme: CipherScheme, hashScheme: DigestSheme, - secret: openarray[byte]): Secret = +proc stretchKeys*(cipherType: string, hashType: string, + sharedSecret: seq[byte]): Secret = ## Expand shared secret to cryptographic keys. - if cipherScheme == Aes128: + if cipherType == "AES-128": result.ivsize = aes128.sizeBlock result.keysize = aes128.sizeKey - elif cipherScheme == Aes256: + elif cipherType == "AES-256": result.ivsize = aes256.sizeBlock result.keysize = aes256.sizeKey - elif cipherScheme == Blowfish: + elif cipherType == "BLOWFISH": result.ivsize = 8 result.keysize = 32 @@ -507,12 +508,12 @@ proc stretchKeys*(cipherScheme: CipherScheme, hashScheme: DigestSheme, let length = result.ivsize + result.keysize + result.macsize result.data = newSeq[byte](2 * length) - if hashScheme == Sha256: - makeSecret(result.data, HMAC[sha256], secret, seed) - elif hashScheme == Sha512: - makeSecret(result.data, HMAC[sha512], secret, seed) - elif hashScheme == Sha1: - makeSecret(result.data, HMAC[sha1], secret, seed) + if hashType == "SHA256": + makeSecret(result.data, HMAC[sha256], sharedSecret, seed) + elif hashType == "SHA512": + makeSecret(result.data, HMAC[sha512], sharedSecret, seed) + elif hashType == "SHA1": + makeSecret(result.data, HMAC[sha1], sharedSecret, seed) template goffset*(secret, id, o: untyped): untyped = id * (len(secret.data) shr 1) + o @@ -561,11 +562,27 @@ proc ephemeral*(scheme: ECDHEScheme): KeyPair = result.seckey.eckey = keypair.seckey result.pubkey.eckey = keypair.pubkey +proc ephemeral*(scheme: string): KeyPair {.inline.} = + ## Generate ephemeral keys used to perform ECDHE using string encoding. + ## + ## Currently supported encoding strings are P-256, P-384, P-521, if encoding + ## string is not supported P-521 key will be generated. + if scheme == "P-256": + result = ephemeral(Secp256r1) + elif scheme == "P-384": + result = ephemeral(Secp384r1) + elif scheme == "P-521": + result = ephemeral(Secp521r1) + else: + result = ephemeral(Secp521r1) + proc makeSecret*(remoteEPublic: PublicKey, localEPrivate: PrivateKey, data: var openarray[byte]): int = ## Calculate shared secret using remote ephemeral public key ## ``remoteEPublic`` and local ephemeral private key ``localEPrivate`` and - ## store shared secret to ``data`` + ## store shared secret to ``data``. + ## + ## Note this procedure supports only ECDSA keys. ## ## Returns number of bytes (octets) used to store shared secret data, or ## ``0`` on error. @@ -573,6 +590,19 @@ proc makeSecret*(remoteEPublic: PublicKey, localEPrivate: PrivateKey, if localEPrivate.scheme == remoteEPublic.scheme: result = toSecret(remoteEPublic.eckey, localEPrivate.eckey, data) +proc getSecret*(remoteEPublic: PublicKey, + localEPrivate: PrivateKey): seq[byte] = + ## Calculate shared secret using remote ephemeral public key + ## ``remoteEPublic`` and local ephemeral private key ``localEPrivate`` and + ## store shared secret to ``data``. + ## + ## Note this procedure supports only ECDSA keys. + ## + ## Returns shared secret on success. + if remoteEPublic.scheme == ECDSA: + if localEPrivate.scheme == remoteEPublic.scheme: + result = getSecret(remoteEPublic.eckey, localEPrivate.eckey) + proc getOrder*(remotePubkey, localNonce: openarray[byte], localPubkey, remoteNonce: openarray[byte]): int = ## Compare values and calculate `order` parameter. @@ -585,10 +615,16 @@ proc getOrder*(remotePubkey, localNonce: openarray[byte], ctx.update(localPubkey) ctx.update(remoteNonce) var digest2 = ctx.finish() - var diff = 0 - for i in 0 ..< len(digest1.data): - diff = int(digest1.data[i]) - int(digest2.data[i]) - result = (result and -not(diff)) or diff + var mh1 = MultiHash.init(multiCodec("sha2-256"), digest1) + var mh2 = MultiHash.init(multiCodec("sha2-256"), digest2) + for i in 0 ..< len(mh1.data.buffer): + result = int(mh1.data.buffer[i]) - int(mh2.data.buffer[i]) + if result != 0: + if result > 0: + result = -1 + elif result > 0: + result = 1 + break proc selectBest*(order: int, p1, p2: string): string = ## Determines which algorithm to use from list `p1` and `p2`. @@ -610,10 +646,14 @@ proc selectBest*(order: int, p1, p2: string): string = for selement in s: if felement == selement: result = felement - break + return proc createProposal*(nonce, pubkey: openarray[byte], exchanges, ciphers, hashes: string): seq[byte] = + ## Create SecIO proposal message using random ``nonce``, local public key + ## ``pubkey``, comma-delimieted list of supported exchange schemes + ## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and + ## comma-delimeted list of supported hashes ``hashes``. var msg = initProtoBuffer({WithUint32BeLength}) msg.write(initProtoField(1, nonce)) msg.write(initProtoField(2, pubkey)) @@ -623,13 +663,42 @@ proc createProposal*(nonce, pubkey: openarray[byte], msg.finish() shallowCopy(result, msg.buffer) +proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte], + exchanges, ciphers, hashes: var string): bool = + ## Parse incoming proposal message and decode remote random nonce ``nonce``, + ## remote public key ``pubkey``, comma-delimieted list of supported exchange + ## schemes ``exchanges``, comma-delimeted list of supported ciphers + ## ``ciphers`` and comma-delimeted list of supported hashes ``hashes``. + ## + ## Procedure returns ``true`` on success and ``false`` on error. + var pb = initProtoBuffer(message) + if pb.getLengthValue(1, nonce) != -1 and + pb.getLengthValue(2, pubkey) != -1 and + pb.getLengthValue(3, exchanges) != -1 and + pb.getLengthValue(4, ciphers) != -1 and + pb.getLengthValue(5, hashes) != -1: + result = true + proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] = + ## Create SecIO exchange message using ephemeral public key ``epubkey`` and + ## signature of proposal blocks ``signature``. var msg = initProtoBuffer({WithUint32BeLength}) msg.write(initProtoField(1, epubkey)) msg.write(initProtoField(2, signature)) msg.finish() shallowCopy(result, msg.buffer) +proc decodeExchange*(message: seq[byte], + pubkey, signature: var seq[byte]): bool = + ## Parse incoming exchange message and decode remote ephemeral public key + ## ``pubkey`` and signature ``signature``. + ## + ## Procedure returns ``true`` on success and ``false`` on error. + var pb = initProtoBuffer(message) + if pb.getLengthValue(1, pubkey) != -1 and + pb.getLengthValue(2, signature) != -1: + result = true + ## Serialization/Deserialization helpers proc write*(vb: var VBuffer, pubkey: PublicKey) {.inline.} = diff --git a/libp2p/crypto/ecnist.nim b/libp2p/crypto/ecnist.nim index 81ab34a5c..7d37c3a8f 100644 --- a/libp2p/crypto/ecnist.nim +++ b/libp2p/crypto/ecnist.nim @@ -141,36 +141,48 @@ template getPublicKeyLength*(curve: EcCurveKind): int = of Secp521r1: PubKey521Length +template getPrivateKeyLength*(curve: EcCurveKind): int = + case curve + of Secp256r1: + SecKey256Length + of Secp384r1: + SecKey384Length + of Secp521r1: + SecKey521Length + proc copy*[T: EcPKI](dst: var T, src: T): bool = ## Copy EC `private key`, `public key` or `signature` ``src`` to ``dst``. ## ## Returns ``true`` on success, ``false`` otherwise. - dst = new T - when T is EcPrivateKey: - let length = src.key.xlen - if length > 0 and len(src.buffer) > 0: - let offset = getOffset(src) - if offset >= 0: - dst.buffer = src.buffer - dst.key.curve = src.key.curve - dst.key.xlen = length - dst.key.x = cast[ptr cuchar](addr dst.buffer[offset]) - result = true - elif T is EcPublicKey: - let length = src.key.qlen - if length > 0 and len(src.buffer) > 0: - let offset = getOffset(src) - if offset >= 0: - dst.buffer = src.buffer - dst.key.curve = src.key.curve - dst.key.qlen = length - dst.key.q = cast[ptr cuchar](addr dst.buffer[offset]) - result = true + if isNil(src): + result = false else: - let length = len(src.buffer) - if length > 0: - dst.buffer = src.buffer - result = true + dst = new T + when T is EcPrivateKey: + let length = src.key.xlen + if length > 0 and len(src.buffer) > 0: + let offset = getOffset(src) + if offset >= 0: + dst.buffer = src.buffer + dst.key.curve = src.key.curve + dst.key.xlen = length + dst.key.x = cast[ptr cuchar](addr dst.buffer[offset]) + result = true + elif T is EcPublicKey: + let length = src.key.qlen + if length > 0 and len(src.buffer) > 0: + let offset = getOffset(src) + if offset >= 0: + dst.buffer = src.buffer + dst.key.curve = src.key.curve + dst.key.qlen = length + dst.key.q = cast[ptr cuchar](addr dst.buffer[offset]) + result = true + else: + let length = len(src.buffer) + if length > 0: + dst.buffer = src.buffer + result = true proc copy*[T: EcPKI](src: T): T {.inline.} = ## Returns copy of EC `private key`, `public key` or `signature` @@ -180,6 +192,7 @@ proc copy*[T: EcPKI](src: T): T {.inline.} = proc clear*[T: EcPKI|EcKeyPair](pki: var T) = ## Wipe and clear EC `private key`, `public key` or `signature` object. + doAssert(not isNil(pki)) when T is EcPrivateKey: burnMem(pki.buffer) pki.buffer.setLen(0) @@ -228,6 +241,7 @@ proc random*(t: typedesc[EcPrivateKey], kind: EcCurveKind): EcPrivateKey = proc getKey*(seckey: EcPrivateKey): EcPublicKey = ## Calculate and return EC public key from private key ``seckey``. + doAssert(not isNil(seckey)) var ecimp = brEcGetDefault() if seckey.key.curve in EcSupportedCurvesCint: var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve)) @@ -250,8 +264,9 @@ proc random*(t: typedesc[EcKeyPair], kind: EcCurveKind): EcKeyPair {.inline.} = proc `$`*(seckey: EcPrivateKey): string = ## Return string representation of EC private key. - if seckey.key.curve == 0 or seckey.key.xlen == 0 or len(seckey.buffer) == 0: - result = "Empty key" + if isNil(seckey) or seckey.key.curve == 0 or seckey.key.xlen == 0 or + len(seckey.buffer) == 0: + result = "Empty or uninitialized ECNIST key" else: if seckey.key.curve notin EcSupportedCurvesCint: result = "Unknown key" @@ -265,8 +280,9 @@ proc `$`*(seckey: EcPrivateKey): string = proc `$`*(pubkey: EcPublicKey): string = ## Return string representation of EC public key. - if pubkey.key.curve == 0 or pubkey.key.qlen == 0 or len(pubkey.buffer) == 0: - result = "Empty key" + if isNil(pubkey) or pubkey.key.curve == 0 or pubkey.key.qlen == 0 or + len(pubkey.buffer) == 0: + result = "Empty or uninitialized ECNIST key" else: if pubkey.key.curve notin EcSupportedCurvesCint: result = "Unknown key" @@ -280,7 +296,45 @@ proc `$`*(pubkey: EcPublicKey): string = proc `$`*(sig: EcSignature): string = ## Return hexadecimal string representation of EC signature. - result = toHex(sig.buffer) + if isNil(sig) or len(sig.buffer) == 0: + result = "Empty or uninitialized ECNIST signature" + else: + result = toHex(sig.buffer) + +proc toRawBytes*(seckey: EcPrivateKey, data: var openarray[byte]): int = + ## Serialize EC private key ``seckey`` to raw binary form and store it + ## to ``data``. + ## + ## Returns number of bytes (octets) needed to store EC private key, or `0` + ## if private key is not in supported curve. + doAssert(not isNil(seckey)) + if seckey.key.curve in EcSupportedCurvesCint: + result = getPrivateKeyLength(cast[EcCurveKind](seckey.key.curve)) + if len(data) >= result: + copyMem(addr data[0], unsafeAddr seckey.buffer[0], result) + +proc toRawBytes*(pubkey: EcPublicKey, data: var openarray[byte]): int = + ## Serialize EC public key ``pubkey`` to uncompressed form specified in + ## section 4.3.6 of ANSI X9.62. + ## + ## Returns number of bytes (octets) needed to store EC public key, or `0` + ## if public key is not in supported curve. + doAssert(not isNil(pubkey)) + if pubkey.key.curve in EcSupportedCurvesCint: + result = getPublicKeyLength(cast[EcCurveKind](pubkey.key.curve)) + if len(data) >= result: + copyMem(addr data[0], unsafeAddr pubkey.buffer[0], result) + +proc toRawBytes*(sig: EcSignature, data: var openarray[byte]): int = + ## Serialize EC signature ``sig`` to raw binary form and store it to ``data``. + ## + ## Returns number of bytes (octets) needed to store EC signature, or `0` + ## if signature is not in supported curve. + doAssert(not isNil(sig)) + result = len(sig.buffer) + if len(data) >= len(sig.buffer): + if len(sig.buffer) > 0: + copyMem(addr data[0], unsafeAddr sig.buffer[0], len(sig.buffer)) proc toBytes*(seckey: EcPrivateKey, data: var openarray[byte]): int = ## Serialize EC private key ``seckey`` to ASN.1 DER binary form and store it @@ -288,6 +342,7 @@ proc toBytes*(seckey: EcPrivateKey, data: var openarray[byte]): int = ## ## Procedure returns number of bytes (octets) needed to store EC private key, ## or `0` if private key is not in supported curve. + doAssert(not isNil(seckey)) if seckey.key.curve in EcSupportedCurvesCint: var offset, length: int var pubkey = seckey.getKey() @@ -327,6 +382,7 @@ proc toBytes*(pubkey: EcPublicKey, data: var openarray[byte]): int = ## ## Procedure returns number of bytes (octets) needed to store EC public key, ## or `0` if public key is not in supported curve. + doAssert(not isNil(pubkey)) if pubkey.key.curve in EcSupportedCurvesCint: var b = Asn1Buffer.init() var p = Asn1Composite.init(Asn1Tag.Sequence) @@ -357,12 +413,14 @@ proc toBytes*(sig: EcSignature, data: var openarray[byte]): int = ## ## Procedure returns number of bytes (octets) needed to store EC signature, ## or `0` if signature is not in supported curve. + doAssert(not isNil(sig)) result = len(sig.buffer) if len(data) >= result: copyMem(addr data[0], unsafeAddr sig.buffer[0], result) proc getBytes*(seckey: EcPrivateKey): seq[byte] = ## Serialize EC private key ``seckey`` to ASN.1 DER binary form and return it. + doAssert(not isNil(seckey)) if seckey.key.curve in EcSupportedCurvesCint: result = newSeq[byte]() let length = seckey.toBytes(result) @@ -373,6 +431,7 @@ proc getBytes*(seckey: EcPrivateKey): seq[byte] = proc getBytes*(pubkey: EcPublicKey): seq[byte] = ## Serialize EC public key ``pubkey`` to ASN.1 DER binary form and return it. + doAssert(not isNil(pubkey)) if pubkey.key.curve in EcSupportedCurvesCint: result = newSeq[byte]() let length = pubkey.toBytes(result) @@ -383,6 +442,37 @@ proc getBytes*(pubkey: EcPublicKey): seq[byte] = proc getBytes*(sig: EcSignature): seq[byte] = ## Serialize EC signature ``sig`` to ASN.1 DER binary form and return it. + doAssert(not isNil(sig)) + result = newSeq[byte]() + let length = sig.toBytes(result) + result.setLen(length) + discard sig.toBytes(result) + +proc getRawBytes*(seckey: EcPrivateKey): seq[byte] = + ## Serialize EC private key ``seckey`` to raw binary form and return it. + doAssert(not isNil(seckey)) + if seckey.key.curve in EcSupportedCurvesCint: + result = newSeq[byte]() + let length = seckey.toRawBytes(result) + result.setLen(length) + discard seckey.toRawBytes(result) + else: + raise newException(EcKeyIncorrectError, "Incorrect private key") + +proc getRawBytes*(pubkey: EcPublicKey): seq[byte] = + ## Serialize EC public key ``pubkey`` to raw binary form and return it. + doAssert(not isNil(pubkey)) + if pubkey.key.curve in EcSupportedCurvesCint: + result = newSeq[byte]() + let length = pubkey.toRawBytes(result) + result.setLen(length) + discard pubkey.toRawBytes(result) + else: + raise newException(EcKeyIncorrectError, "Incorrect public key") + +proc getRawBytes*(sig: EcSignature): seq[byte] = + ## Serialize EC signature ``sig`` to raw binary form and return it. + doAssert(not isNil(sig)) result = newSeq[byte]() let length = sig.toBytes(result) result.setLen(length) @@ -390,33 +480,54 @@ proc getBytes*(sig: EcSignature): seq[byte] = proc `==`*(pubkey1, pubkey2: EcPublicKey): bool = ## Returns ``true`` if both keys ``pubkey1`` and ``pubkey2`` are equal. - if pubkey1.key.curve != pubkey2.key.curve: - return false - if pubkey1.key.qlen != pubkey2.key.qlen: - return false - let op1 = pubkey1.getOffset() - let op2 = pubkey2.getOffset() - if op1 == -1 or op2 == -1: - return false - result = equalMem(unsafeAddr pubkey1.buffer[op1], - unsafeAddr pubkey2.buffer[op2], pubkey1.key.qlen) + if isNil(pubkey1) and isNil(pubkey2): + result = true + elif isNil(pubkey1) and (not isNil(pubkey2)): + result = false + elif isNil(pubkey2) and (not isNil(pubkey1)): + result = false + else: + if pubkey1.key.curve != pubkey2.key.curve: + return false + if pubkey1.key.qlen != pubkey2.key.qlen: + return false + let op1 = pubkey1.getOffset() + let op2 = pubkey2.getOffset() + if op1 == -1 or op2 == -1: + return false + result = equalMem(unsafeAddr pubkey1.buffer[op1], + unsafeAddr pubkey2.buffer[op2], pubkey1.key.qlen) proc `==`*(seckey1, seckey2: EcPrivateKey): bool = ## Returns ``true`` if both keys ``seckey1`` and ``seckey2`` are equal. - if seckey1.key.curve != seckey2.key.curve: - return false - if seckey1.key.xlen != seckey2.key.xlen: - return false - let op1 = seckey1.getOffset() - let op2 = seckey2.getOffset() - if op1 == -1 or op2 == -1: - return false - result = equalMem(unsafeAddr seckey1.buffer[op1], - unsafeAddr seckey2.buffer[op2], seckey1.key.xlen) + if isNil(seckey1) and isNil(seckey2): + result = true + elif isNil(seckey1) and (not isNil(seckey2)): + result = false + elif isNil(seckey2) and (not isNil(seckey1)): + result = false + else: + if seckey1.key.curve != seckey2.key.curve: + return false + if seckey1.key.xlen != seckey2.key.xlen: + return false + let op1 = seckey1.getOffset() + let op2 = seckey2.getOffset() + if op1 == -1 or op2 == -1: + return false + result = equalMem(unsafeAddr seckey1.buffer[op1], + unsafeAddr seckey2.buffer[op2], seckey1.key.xlen) proc `==`*(sig1, sig2: EcSignature): bool = ## Return ``true`` if both signatures ``sig1`` and ``sig2`` are equal. - result = (sig1.buffer == sig2.buffer) + if isNil(sig1) and isNil(sig2): + result = true + elif isNil(sig1) and (not isNil(sig2)): + result = false + elif isNil(sig2) and (not isNil(sig1)): + result = false + else: + result = (sig1.buffer == sig2.buffer) proc init*(key: var EcPrivateKey, data: openarray[byte]): Asn1Status = ## Initialize EC `private key` or `signature` ``key`` from ASN.1 DER binary @@ -698,6 +809,7 @@ proc scalarMul*(pub: EcPublicKey, sec: EcPrivateKey): EcPublicKey = ## Return scalar multiplication of ``pub`` and ``sec``. ## ## Returns point in curve as ``pub * sec`` or ``nil`` otherwise. + doAssert((not isNil(pub)) and (not isNil(sec))) var impl = brEcGetDefault() if sec.key.curve in EcSupportedCurvesCint: if pub.key.curve == sec.key.curve: @@ -726,6 +838,7 @@ proc toSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey, ## ## ``data`` array length must be at least 32 bytes for `secp256r1`, 48 bytes ## for `secp384r1` and 66 bytes for `secp521r1`. + doAssert((not isNil(pubkey)) and (not isNil(seckey))) var mult = scalarMul(pubkey, seckey) var length = 0 if not isNil(mult): @@ -745,6 +858,7 @@ proc getSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey): seq[byte] = ## shared secret. ## ## If error happens length of result array will be ``0``. + doAssert((not isNil(pubkey)) and (not isNil(seckey))) var data: array[Secret521Length, byte] let res = toSecret(pubkey, seckey, data) if res > 0: @@ -754,6 +868,7 @@ proc getSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey): seq[byte] = proc sign*[T: byte|char](seckey: EcPrivateKey, message: openarray[T]): EcSignature = ## Get ECDSA signature of data ``message`` using private key ``seckey``. + doAssert(not isNil(seckey)) var hc: BrHashCompatContext var hash: array[32, byte] var impl = brEcGetDefault() @@ -785,6 +900,7 @@ proc verify*[T: byte|char](sig: EcSignature, message: openarray[T], ## ## Return ``true`` if message verification succeeded, ``false`` if ## verification failed. + doAssert((not isNil(sig)) and (not isNil(pubkey))) var hc: BrHashCompatContext var hash: array[32, byte] var impl = brEcGetDefault() diff --git a/libp2p/crypto/rsa.nim b/libp2p/crypto/rsa.nim index 0b0ba9543..f2205fd34 100644 --- a/libp2p/crypto/rsa.nim +++ b/libp2p/crypto/rsa.nim @@ -12,7 +12,6 @@ ## This module uses unmodified parts of code from ## BearSSL library ## Copyright(C) 2018 Thomas Pornin . - import nimcrypto/utils import common, minasn1 export Asn1Status @@ -160,6 +159,7 @@ proc random*[T: RsaKP](t: typedesc[T], bits = DefaultKeySize, proc copy*[T: RsaPKI](key: T): T = ## Create copy of RSA private key, public key or signature. + doAssert(not isNil(key)) when T is RsaPrivateKey: if len(key.buffer) > 0: let length = key.seck.plen + key.seck.qlen + key.seck.dplen + @@ -220,6 +220,7 @@ proc copy*[T: RsaPKI](key: T): T = proc getKey*(key: RsaPrivateKey): RsaPublicKey = ## Get RSA public key from RSA private key. + doAssert(not isNil(key)) let length = key.pubk.nlen + key.pubk.elen result = new RsaPublicKey result.buffer = newSeq[byte](length) @@ -241,6 +242,7 @@ proc pubkey*(pair: RsaKeyPair): RsaPublicKey {.inline.} = proc clear*[T: RsaPKI|RsaKeyPair](pki: var T) = ## Wipe and clear EC private key, public key or scalar object. + doAssert(not isNil(pki)) when T is RsaPrivateKey: burnMem(pki.buffer) pki.buffer.setLen(0) @@ -276,6 +278,7 @@ proc toBytes*(key: RsaPrivateKey, data: var openarray[byte]): int = ## ## Procedure returns number of bytes (octets) needed to store RSA private key, ## or `0` if private key is is incorrect. + doAssert(not isNil(key)) if len(key.buffer) > 0: var b = Asn1Buffer.init() var p = Asn1Composite.init(Asn1Tag.Sequence) @@ -308,6 +311,7 @@ proc toBytes*(key: RsaPublicKey, data: var openarray[byte]): int = ## ## Procedure returns number of bytes (octets) needed to store RSA public key, ## or `0` if public key is incorrect. + doAssert(not isNil(key)) if len(key.buffer) > 0: var b = Asn1Buffer.init() var p = Asn1Composite.init(Asn1Tag.Sequence) @@ -337,6 +341,7 @@ proc toBytes*(sig: RsaSignature, data: var openarray[byte]): int = ## ## Procedure returns number of bytes (octets) needed to store RSA public key, ## or `0` if public key is incorrect. + doAssert(not isNil(sig)) result = len(sig.buffer) if len(data) >= result: copyMem(addr data[0], addr sig.buffer[0], result) @@ -344,6 +349,7 @@ proc toBytes*(sig: RsaSignature, data: var openarray[byte]): int = proc getBytes*(key: RsaPrivateKey): seq[byte] = ## Serialize RSA private key ``key`` to ASN.1 DER binary form and ## return it. + doAssert(not isNil(key)) result = newSeq[byte](4096) let length = key.toBytes(result) if length > 0: @@ -354,6 +360,7 @@ proc getBytes*(key: RsaPrivateKey): seq[byte] = proc getBytes*(key: RsaPublicKey): seq[byte] = ## Serialize RSA public key ``key`` to ASN.1 DER binary form and ## return it. + doAssert(not isNil(key)) result = newSeq[byte](4096) let length = key.toBytes(result) if length > 0: @@ -363,6 +370,7 @@ proc getBytes*(key: RsaPublicKey): seq[byte] = proc getBytes*(sig: RsaSignature): seq[byte] = ## Serialize RSA signature ``sig`` to raw binary form and return it. + doAssert(not isNil(sig)) result = newSeq[byte](4096) let length = sig.toBytes(result) if length > 0: @@ -592,8 +600,8 @@ proc init*[T: RsaPKI](t: typedesc[T], data: string): T {.inline.} = proc `$`*(key: RsaPrivateKey): string = ## Return string representation of RSA private key. - if len(key.buffer) == 0: - result = "Empty RSA key" + if isNil(key) or len(key.buffer) == 0: + result = "Empty or uninitialized RSA key" else: result = "RSA key (" result.add($key.seck.nBitlen) @@ -618,8 +626,8 @@ proc `$`*(key: RsaPrivateKey): string = proc `$`*(key: RsaPublicKey): string = ## Return string representation of RSA public key. - if len(key.buffer) == 0: - result = "Empty RSA key" + if isNil(key) or len(key.buffer) == 0: + result = "Empty or uninitialized RSA key" else: let nbitlen = key.key.nlen shl 3 result = "RSA key (" @@ -632,8 +640,8 @@ proc `$`*(key: RsaPublicKey): string = proc `$`*(sig: RsaSignature): string = ## Return string representation of RSA signature. - if len(sig.buffer) == 0: - result = "Empty RSA signature" + if isNil(sig) or len(sig.buffer) == 0: + result = "Empty or uninitialized RSA signature" else: result = "RSA signature (" result.add(toHex(sig.buffer)) @@ -656,44 +664,69 @@ proc cmp(a: openarray[byte], b: openarray[byte]): bool = proc `==`*(a, b: RsaPrivateKey): bool = ## Compare two RSA private keys for equality. - if a.seck.nBitlen == b.seck.nBitlen: - if cast[int](a.seck.nBitlen) > 0: - let r1 = cmp(getArray(a.buffer, a.seck.p, a.seck.plen), - getArray(b.buffer, b.seck.p, b.seck.plen)) - let r2 = cmp(getArray(a.buffer, a.seck.q, a.seck.qlen), - getArray(b.buffer, b.seck.q, b.seck.qlen)) - let r3 = cmp(getArray(a.buffer, a.seck.dp, a.seck.dplen), - getArray(b.buffer, b.seck.dp, b.seck.dplen)) - let r4 = cmp(getArray(a.buffer, a.seck.dq, a.seck.dqlen), - getArray(b.buffer, b.seck.dq, b.seck.dqlen)) - let r5 = cmp(getArray(a.buffer, a.seck.iq, a.seck.iqlen), - getArray(b.buffer, b.seck.iq, b.seck.iqlen)) - let r6 = cmp(getArray(a.buffer, a.pexp, a.pexplen), - getArray(b.buffer, b.pexp, b.pexplen)) - let r7 = cmp(getArray(a.buffer, a.pubk.n, a.pubk.nlen), - getArray(b.buffer, b.pubk.n, b.pubk.nlen)) - let r8 = cmp(getArray(a.buffer, a.pubk.e, a.pubk.elen), - getArray(b.buffer, b.pubk.e, b.pubk.elen)) - result = r1 and r2 and r3 and r4 and r5 and r6 and r7 and r8 - else: - result = true + ## + ## Result is true if ``a`` and ``b`` are both ``nil`` or ``a`` and ``b`` are + ## equal by value. + if isNil(a) and isNil(b): + result = true + elif isNil(a) and (not isNil(b)): + result = false + elif isNil(b) and (not isNil(a)): + result = false + else: + if a.seck.nBitlen == b.seck.nBitlen: + if cast[int](a.seck.nBitlen) > 0: + let r1 = cmp(getArray(a.buffer, a.seck.p, a.seck.plen), + getArray(b.buffer, b.seck.p, b.seck.plen)) + let r2 = cmp(getArray(a.buffer, a.seck.q, a.seck.qlen), + getArray(b.buffer, b.seck.q, b.seck.qlen)) + let r3 = cmp(getArray(a.buffer, a.seck.dp, a.seck.dplen), + getArray(b.buffer, b.seck.dp, b.seck.dplen)) + let r4 = cmp(getArray(a.buffer, a.seck.dq, a.seck.dqlen), + getArray(b.buffer, b.seck.dq, b.seck.dqlen)) + let r5 = cmp(getArray(a.buffer, a.seck.iq, a.seck.iqlen), + getArray(b.buffer, b.seck.iq, b.seck.iqlen)) + let r6 = cmp(getArray(a.buffer, a.pexp, a.pexplen), + getArray(b.buffer, b.pexp, b.pexplen)) + let r7 = cmp(getArray(a.buffer, a.pubk.n, a.pubk.nlen), + getArray(b.buffer, b.pubk.n, b.pubk.nlen)) + let r8 = cmp(getArray(a.buffer, a.pubk.e, a.pubk.elen), + getArray(b.buffer, b.pubk.e, b.pubk.elen)) + result = r1 and r2 and r3 and r4 and r5 and r6 and r7 and r8 + else: + result = true proc `==`*(a, b: RsaSignature): bool = ## Compare two RSA signatures for equality. - result = (a.buffer == b.buffer) + if isNil(a) and isNil(b): + result = true + elif isNil(a) and (not isNil(b)): + result = false + elif isNil(b) and (not isNil(a)): + result = false + else: + result = (a.buffer == b.buffer) proc `==`*(a, b: RsaPublicKey): bool = ## Compare two RSA public keys for equality. - let r1 = cmp(getArray(a.buffer, a.key.n, a.key.nlen), - getArray(b.buffer, b.key.n, b.key.nlen)) - let r2 = cmp(getArray(a.buffer, a.key.e, a.key.elen), - getArray(b.buffer, b.key.e, b.key.elen)) - result = r1 and r2 + if isNil(a) and isNil(b): + result = true + elif isNil(a) and (not isNil(b)): + result = false + elif isNil(b) and (not isNil(a)): + result = false + else: + let r1 = cmp(getArray(a.buffer, a.key.n, a.key.nlen), + getArray(b.buffer, b.key.n, b.key.nlen)) + let r2 = cmp(getArray(a.buffer, a.key.e, a.key.elen), + getArray(b.buffer, b.key.e, b.key.elen)) + result = r1 and r2 proc sign*[T: byte|char](key: RsaPrivateKey, message: openarray[T]): RsaSignature = ## Get RSA PKCS1.5 signature of data ``message`` using SHA256 and private ## key ``key``. + doAssert(not isNil(key)) var hc: BrHashCompatContext var hash: array[32, byte] var impl = BrRsaPkcs1SignGetDefault() @@ -720,6 +753,7 @@ proc verify*[T: byte|char](sig: RsaSignature, message: openarray[T], ## ## Return ``true`` if message verification succeeded, ``false`` if ## verification failed. + doAssert((not isNil(sig)) and (not isNil(pubkey))) if len(sig.buffer) > 0: var hc: BrHashCompatContext var hash: array[32, byte] diff --git a/tests/testcrypto.nim b/tests/testcrypto.nim index 5f13a9007..be2d0476f 100644 --- a/tests/testcrypto.nim +++ b/tests/testcrypto.nim @@ -345,7 +345,7 @@ const proc cmp(a, b: openarray[byte]): bool = result = (@a == @b) -proc testStretcher(s, e: int, cs: CipherScheme, ds: DigestSheme): bool = +proc testStretcher(s, e: int, cs: string, ds: string): bool = for i in s..