diff --git a/src/main/java/im/status/keycard/SECP256k1.java b/src/main/java/im/status/keycard/SECP256k1.java index caba2f7..6d18cb0 100644 --- a/src/main/java/im/status/keycard/SECP256k1.java +++ b/src/main/java/im/status/keycard/SECP256k1.java @@ -60,7 +60,7 @@ public class SECP256k1 { static final short SCHNORR_MULT_KEY_SIZE = KeyBuilder.LENGTH_RSA_736; static final short SCHNORR_COMPONENT_SIZE = (short) (SCHNORR_MULT_KEY_SIZE / 8); - static final short SCHNORR_S_OUT_SIZE = (short) 64; + static final short MULT_OUT_SIZE = (short) 64; static final short SCHNORR_K_OUT_OFF = (short) 0; static final short SCHNORR_E_OUT_OFF = (short) (SECP256K1_BYTE_SIZE + SCHNORR_K_OUT_OFF); @@ -70,11 +70,15 @@ public class SECP256k1 { static final short SCHNORR_E_32_OFF = (short) (SCHNORR_COMPONENT_SIZE - SECP256K1_BYTE_SIZE + SCHNORR_E_OUT_OFF); static final short SCHNORR_D_32_OFF = (short) (SCHNORR_COMPONENT_SIZE - SECP256K1_BYTE_SIZE + SCHNORR_D_OUT_OFF); static final short SCHNORR_RES_32_OFF = (short) (SCHNORR_COMPONENT_SIZE - SECP256K1_BYTE_SIZE + SCHNORR_RES_OUT_OFF); - static final short SCHNORR_RES_64_OFF = (short) (SCHNORR_COMPONENT_SIZE - SCHNORR_S_OUT_SIZE + SCHNORR_RES_OUT_OFF); + static final short SCHNORR_RES_64_OFF = (short) (SCHNORR_COMPONENT_SIZE - MULT_OUT_SIZE + SCHNORR_RES_OUT_OFF); static final short TMP_LEN = (short) (SECP256K1_BYTE_SIZE + (SCHNORR_COMPONENT_SIZE * 3)); private static final byte ALG_EC_SVDP_DH_PLAIN_XY = 6; // constant from JavaCard 3.0.5 + private static final short MOD_DIGIT_LEN = 8; + private static final short MOD_DDIGIT_LEN = 16; + private static final short MOD_DIGIT_MASK = 0xff; + private static final short MOD_DDIGIT_MASK = 0x7fff; private KeyAgreement ecPointMultiplier; private Crypto crypto; @@ -171,7 +175,7 @@ public class SECP256k1 { short signSchnorr(ECPrivateKey privKey, byte[] pubKey, short pubOff, byte[] data, short dataOff, short dataLen, byte[] output, short outOff) { output[outOff++] = TLV_SCHNORR_SIGNATURE; output[outOff++] = (byte) 0x81; - output[outOff++] = (byte) (Crypto.KEY_PUB_SIZE + SCHNORR_S_OUT_SIZE); + output[outOff++] = (byte) (Crypto.KEY_PUB_SIZE + MULT_OUT_SIZE); crypto.random.generateData(tmp, SCHNORR_K_OUT_OFF, SECP256K1_BYTE_SIZE); Util.arrayFillNonAtomic(tmp, SCHNORR_E_OUT_OFF, (short)(TMP_LEN - SCHNORR_E_OUT_OFF), (byte) 0x00); @@ -191,17 +195,151 @@ public class SECP256k1 { divideResBy2(); - crypto.addBig(tmp, SCHNORR_RES_64_OFF, SCHNORR_S_OUT_SIZE, tmp, SCHNORR_K_OUT_OFF, SECP256K1_BYTE_SIZE, output, (short) (outOff + Crypto.KEY_PUB_SIZE)); - return (short) (3 + Crypto.KEY_PUB_SIZE + SCHNORR_S_OUT_SIZE); + crypto.addBig(tmp, SCHNORR_RES_64_OFF, MULT_OUT_SIZE, tmp, SCHNORR_K_OUT_OFF, SECP256K1_BYTE_SIZE, output, (short) (outOff + Crypto.KEY_PUB_SIZE)); + secp256k1Mod(output, (short) (outOff + Crypto.KEY_PUB_SIZE)); + return (short) (3 + Crypto.KEY_PUB_SIZE + MULT_OUT_SIZE); } private void divideResBy2() { short res, res2; - for (short i = (short) (SCHNORR_COMPONENT_SIZE - 1); i >= (short) (SCHNORR_COMPONENT_SIZE - SCHNORR_S_OUT_SIZE - 1); i--) { + for (short i = (short) (SCHNORR_COMPONENT_SIZE - 1); i >= (short) (SCHNORR_COMPONENT_SIZE - MULT_OUT_SIZE - 1); i--) { res = (short) ((short) (tmp[(short)(SCHNORR_RES_OUT_OFF + i)] & 0xff) >> 1); res2 = (short) ((short) (tmp[(short)(SCHNORR_RES_OUT_OFF + i - 1)] & 0xff) << 7); tmp[(short)(SCHNORR_RES_OUT_OFF + i)] = (byte) ((short) (res | res2)); } } + + private void secp256k1Mod(byte[] value, short offset) { + short divisorShift = (short) (MULT_OUT_SIZE - SECP256K1_R.length); + short divisionRound = 0; + + short firstDivisorDigit = (short) (SECP256K1_R[(short) 0] & MOD_DIGIT_MASK); + short divisorBitShift = (short) (highestBit((short) (firstDivisorDigit + 1)) - 1); + byte secondDivisorDigit = SECP256K1_R[(short) 1]; + byte thirdDivisorDigit = SECP256K1_R[(short) 2]; + + short dividendDigits, divisorDigit; + short dividendBitShift, bitShift; + short multiple; + + while (divisorShift >= 0) { + while (!shiftLesser(value, offset, divisorShift, (short) (divisionRound > 0 ? divisionRound - 1 : 0))) { + dividendDigits = divisionRound == 0 ? 0 : (short) ((short) (value[(short) (offset + divisionRound - 1)]) << MOD_DIGIT_LEN); + dividendDigits |= (short) (value[(short)(offset + divisionRound)] & MOD_DIGIT_MASK); + + if (dividendDigits < 0) { + dividendDigits = (short) ((dividendDigits >>> 1) & MOD_DDIGIT_MASK); + divisorDigit = (short) ((firstDivisorDigit >>> 1) & MOD_DDIGIT_MASK); + } else { + dividendBitShift = (short) (highestBit(dividendDigits) - 1); + bitShift = dividendBitShift <= divisorBitShift ? dividendBitShift : divisorBitShift; + + dividendDigits = shiftBits(dividendDigits, divisionRound < (short) (MULT_OUT_SIZE - 1) ? value[(short) (offset + divisionRound + 1)] : 0, divisionRound < (short) (SECP256K1_R.length - 2) ? value[(short) (offset + divisionRound + 2)] : 0, bitShift); + divisorDigit = shiftBits(firstDivisorDigit, secondDivisorDigit, thirdDivisorDigit, bitShift); + } + + multiple = (short) (dividendDigits / (short) (divisorDigit + 1)); + + if (multiple < 1) { + multiple = 1; + } + + timesMinus(value, offset, divisorShift, multiple); + } + + divisionRound++; + divisorShift--; + } + } + + private void timesMinus(byte[] value, short offset, short shift, short mult) { + short accu = 0; + short subtractionResult; + short i = (short) (MULT_OUT_SIZE - 1 - shift); + short j = (short) (SECP256K1_R.length - 1); + + for (; i >= 0 && j >= 0; i--, j--) { + accu = (short) (accu + (short) (mult * (SECP256K1_R[j] & MOD_DIGIT_MASK))); + subtractionResult = (short) ((value[(short)(offset + i)] & MOD_DIGIT_MASK) - (accu & MOD_DIGIT_MASK)); + + value[(short)(offset + i)] = (byte) (subtractionResult & MOD_DIGIT_MASK); + accu = (short) ((accu >> MOD_DIGIT_LEN) & MOD_DIGIT_MASK); + if (subtractionResult < 0) { + accu++; + } + } + + while (i >= 0 && accu != 0) { + subtractionResult = (short) ((value[(short)(offset + i)] & MOD_DIGIT_MASK) - (accu & MOD_DIGIT_MASK)); + value[(short)(offset + i)] = (byte) (subtractionResult & MOD_DIGIT_MASK); + accu = (short) ((accu >> MOD_DIGIT_LEN) & MOD_DIGIT_MASK); + if (subtractionResult < 0) { + accu++; + } + i--; + } + } + + private static short highestBit(short x) { + for (short i = 0; i < MOD_DDIGIT_LEN; i++) { + if (x < 0) { + return i; + } + + x <<= 1; + } + + return MOD_DDIGIT_LEN; + } + + private static short shiftBits(short high, byte middle, byte low, short shift) { + high <<= shift; + + byte mask = (byte) (MOD_DIGIT_MASK << (shift >= MOD_DIGIT_LEN ? 0 : MOD_DIGIT_LEN - shift)); + short bits = (short) ((short) (middle & mask) & MOD_DIGIT_MASK); + + if (shift > MOD_DIGIT_LEN) { + bits <<= shift - MOD_DIGIT_LEN; + } else { + bits >>>= MOD_DIGIT_LEN - shift; + } + + high |= bits; + + if (shift <= MOD_DIGIT_LEN) { + return high; + } + + mask = (byte) (MOD_DIGIT_MASK << MOD_DDIGIT_LEN - shift); + bits = (short) ((((short) (low & mask) & MOD_DIGIT_MASK) >> MOD_DDIGIT_LEN - shift)); + high |= bits; + + return high; + } + + private static boolean shiftLesser(byte[] value, short offset, short shift, short start) { + short j; + + j = (short) (SECP256K1_R.length + shift - MULT_OUT_SIZE + start); + short valShort, divisorShort; + + for (short i = start; i < MULT_OUT_SIZE; i++, j++) { + valShort = (short) (value[(short)(i + offset)] & MOD_DIGIT_MASK); + + if (j >= 0 && j < SECP256K1_R.length) { + divisorShort = (short) (SECP256K1_R[j] & MOD_DIGIT_MASK); + } + else { + divisorShort = 0; + } + if (valShort < divisorShort) { + return true; // CTO + } + if (valShort > divisorShort) { + return false; + } + } + return false; + } } diff --git a/src/test/java/im/status/keycard/KeycardTest.java b/src/test/java/im/status/keycard/KeycardTest.java index 3c3b08f..cf4126c 100644 --- a/src/test/java/im/status/keycard/KeycardTest.java +++ b/src/test/java/im/status/keycard/KeycardTest.java @@ -1085,7 +1085,6 @@ public class KeycardTest { ECPoint G = ecSpec.getG(); BigInteger s = new BigInteger(1, Arrays.copyOfRange(rawSig, 65, rawSig.length)); - s = s.mod(ecSpec.getCurve().getOrder()); ECPoint R = G.multiply(s).subtract(P.multiply(e)); System.out.println("R = " + Hex.toHexString(R.getEncoded(false)));