diff --git a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/FindNodeMessage.java b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/FindNodeMessage.java index 08cc5819..14185cbf 100644 --- a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/FindNodeMessage.java +++ b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/FindNodeMessage.java @@ -1,11 +1,12 @@ package org.ethereum.net.rlpx; +import org.ethereum.crypto.ECKey; import org.ethereum.util.ByteUtil; import org.ethereum.util.RLP; public class FindNodeMessage extends Message { - public static Message create(byte[] target) { + public static Message create(byte[] target, ECKey privKey) { long expiration = System.currentTimeMillis(); @@ -17,7 +18,7 @@ public class FindNodeMessage extends Message { byte[] data = RLP.encodeList(rlpToken, rlpExp); FindNodeMessage findNode = new FindNodeMessage(); - findNode.encode(type, data); + findNode.encode(type, data, privKey); return findNode; } diff --git a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/Message.java b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/Message.java index 2c159d80..7927d870 100644 --- a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/Message.java +++ b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/Message.java @@ -5,11 +5,15 @@ import org.ethereum.util.FastByteComparisons; import org.spongycastle.util.BigIntegers; import org.spongycastle.util.encoders.Hex; +import java.security.SignatureException; + import static org.ethereum.crypto.HashUtil.sha3; import static org.ethereum.util.ByteUtil.merge; public class Message { + byte[] wire; + byte[] mdc; byte[] signature; byte[] type; @@ -48,12 +52,13 @@ public class Message { msg.signature = signature; msg.type = type; msg.data = data; + msg.wire = wire; return msg; } - public Message encode(byte[] type, byte[] data) { + public Message encode(byte[] type, byte[] data, ECKey privKey) { /* [1] Calc sha3 - prepare for sig */ byte[] payload = new byte[type.length + data.length]; @@ -62,12 +67,11 @@ public class Message { byte[] forSig = sha3(payload); /* [2] Crate signature*/ - ECKey privKey = ECKey.fromPrivate(Hex.decode("3ecb44df2159c26e0f995712d4f39b6f6e499b40749b1cf1246c37f9516cb6a4")); ECKey.ECDSASignature signature = privKey.sign(forSig); byte[] sigBytes = - merge(new byte[]{signature.v}, BigIntegers.asUnsignedByteArray(signature.r), - BigIntegers.asUnsignedByteArray(signature.s)); + merge(BigIntegers.asUnsignedByteArray(signature.r), + BigIntegers.asUnsignedByteArray(signature.s), new byte[]{signature.v}); // [3] calculate MDC byte[] forSha = merge(sigBytes, type, data); @@ -79,12 +83,35 @@ public class Message { this.type = type; this.data = data; + this.wire = merge(this.mdc, this.signature, this.type, this.data); + return this; } + public ECKey getKey() { + + byte[] r = new byte[32]; + byte[] s = new byte[32]; + byte v = signature[64]; + + System.arraycopy(signature, 0, r, 0, 32); + System.arraycopy(signature, 32, s, 0, 32); + + ECKey.ECDSASignature signature = ECKey.ECDSASignature.fromComponents(r, s, v); + byte[] msgHash = sha3(wire, 97, wire.length - 97); + + ECKey outKey = null; + try { + outKey = ECKey.signatureToKey(msgHash, signature.toBase64()); + } catch (SignatureException e) { + e.printStackTrace(); + } + + return outKey; + } + public byte[] getPacket() { - byte[] packet = merge(mdc, signature, type, data); - return packet; + return wire; } public byte[] getMdc() { diff --git a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/NeighborsMessage.java b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/NeighborsMessage.java index 27b9679c..dcbc122a 100644 --- a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/NeighborsMessage.java +++ b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/NeighborsMessage.java @@ -1,5 +1,6 @@ package org.ethereum.net.rlpx; +import org.ethereum.crypto.ECKey; import org.ethereum.util.ByteUtil; import org.ethereum.util.RLP; @@ -7,7 +8,7 @@ import java.util.List; public class NeighborsMessage extends Message { - public static Message create(List nodes) { + public static Message create(List nodes, ECKey privKey) { long expiration = System.currentTimeMillis(); @@ -27,7 +28,7 @@ public class NeighborsMessage extends Message { byte[] data = RLP.encodeList(rlpListNodes, rlpExp); NeighborsMessage neighborsMessage = new NeighborsMessage(); - neighborsMessage.encode(type, data); + neighborsMessage.encode(type, data, privKey); return neighborsMessage; } diff --git a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PingMessage.java b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PingMessage.java index 8a5df15e..e7a06741 100644 --- a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PingMessage.java +++ b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PingMessage.java @@ -1,11 +1,12 @@ package org.ethereum.net.rlpx; +import org.ethereum.crypto.ECKey; import org.ethereum.util.ByteUtil; import org.ethereum.util.RLP; public class PingMessage extends Message { - public static Message create(String ip, int port){ + public static Message create(String ip, int port, ECKey privKey){ long expiration = System.currentTimeMillis(); @@ -18,7 +19,7 @@ public class PingMessage extends Message { byte[] data = RLP.encodeList(rlpIp, rlpPort, rlpExp); PingMessage ping = new PingMessage(); - ping.encode(type, data); + ping.encode(type, data, privKey); return ping; } diff --git a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PongMessage.java b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PongMessage.java index 27ad21c7..d731e468 100644 --- a/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PongMessage.java +++ b/ethereumj-core/src/main/java/org/ethereum/net/rlpx/PongMessage.java @@ -1,11 +1,12 @@ package org.ethereum.net.rlpx; +import org.ethereum.crypto.ECKey; import org.ethereum.util.ByteUtil; import org.ethereum.util.RLP; public class PongMessage extends Message { - public static Message create(byte[] token) { + public static Message create(byte[] token, ECKey privKey) { long expiration = System.currentTimeMillis(); @@ -17,7 +18,7 @@ public class PongMessage extends Message { byte[] data = RLP.encodeList(rlpToken, rlpExp); PongMessage pong = new PongMessage(); - pong.encode(type, data); + pong.encode(type, data, privKey); return pong; } diff --git a/ethereumj-core/src/test/java/test/ethereum/net/RLPXTest.java b/ethereumj-core/src/test/java/test/ethereum/net/RLPXTest.java index bc582bc3..030c013b 100644 --- a/ethereumj-core/src/test/java/test/ethereum/net/RLPXTest.java +++ b/ethereumj-core/src/test/java/test/ethereum/net/RLPXTest.java @@ -1,10 +1,12 @@ package test.ethereum.net; +import org.ethereum.crypto.ECKey; import org.ethereum.net.rlpx.*; import org.junit.Assert; import org.junit.Test; import org.slf4j.LoggerFactory; +import java.math.BigInteger; import java.nio.charset.Charset; import java.util.Arrays; import java.util.List; @@ -22,9 +24,9 @@ public class RLPXTest { String ip = "85.65.19.231"; int port = 30303; - long expiration = System.currentTimeMillis(); + ECKey key = ECKey.fromPrivate(BigInteger.TEN); - Message ping = PingMessage.create(ip, port); + Message ping = PingMessage.create(ip, port, key); logger.info("{}", ping); byte[] wire = ping.getPacket(); @@ -32,14 +34,18 @@ public class RLPXTest { logger.info("{}", ping2); assertEquals(ping.toString(), ping2.toString()); + + String key2 = ping2.getKey().toString(); + assertEquals(key.toString(), key2.toString()); } @Test // pong test public void test2(){ byte[] token = sha3("+++".getBytes(Charset.forName("UTF-8"))); + ECKey key = ECKey.fromPrivate(BigInteger.TEN); - Message pong = PongMessage.create(token); + Message pong = PongMessage.create(token, key); logger.info("{}", pong); byte[] wire = pong.getPacket(); @@ -47,6 +53,9 @@ public class RLPXTest { logger.info("{}", pong); assertEquals(pong.toString(), pong2.toString()); + + String key2 = pong2.getKey().toString(); + assertEquals(key.toString(), key2.toString()); } @Test // neighbors message @@ -62,7 +71,9 @@ public class RLPXTest { Node node = new Node(id, ip, port); List nodes = Arrays.asList(node); - Message neighbors = NeighborsMessage.create(nodes); + ECKey key = ECKey.fromPrivate(BigInteger.TEN); + + Message neighbors = NeighborsMessage.create(nodes, key); logger.info("{}", neighbors); byte[] wire = neighbors.getPacket(); @@ -70,14 +81,18 @@ public class RLPXTest { logger.info("{}", neighbors2); assertEquals(neighbors.toString(), neighbors2.toString()); + + String key2 = neighbors2.getKey().toString(); + assertEquals(key.toString(), key2.toString()); } @Test // find node message public void test4(){ byte[] id = sha3("+++".getBytes(Charset.forName("UTF-8"))); + ECKey key = ECKey.fromPrivate(BigInteger.TEN); - Message findNode = FindNodeMessage.create(id); + Message findNode = FindNodeMessage.create(id, key); logger.info("{}", findNode); byte[] wire = findNode.getPacket(); @@ -85,6 +100,9 @@ public class RLPXTest { logger.info("{}", findNode2); assertEquals(findNode.toString(), findNode2.toString()); + + String key2 = findNode2.getKey().toString(); + assertEquals(key.toString(), key2.toString()); } @@ -92,8 +110,9 @@ public class RLPXTest { public void test5(){ byte[] id = sha3("+++".getBytes(Charset.forName("UTF-8"))); + ECKey key = ECKey.fromPrivate(BigInteger.TEN); - Message findNode = FindNodeMessage.create(id); + Message findNode = FindNodeMessage.create(id, key); logger.info("{}", findNode); byte[] wire = findNode.getPacket();