diff --git a/cmd/statusd/library.go b/cmd/statusd/library.go index c4c9afa8c..5d69de63e 100644 --- a/cmd/statusd/library.go +++ b/cmd/statusd/library.go @@ -78,8 +78,8 @@ func RecoverAccount(password, mnemonic *C.char) *C.char { } //export VerifyAccountPassword -func VerifyAccountPassword(keyPath, address, password *C.char) *C.char { - _, err := geth.VerifyAccountPassword(C.GoString(keyPath), C.GoString(address), C.GoString(password)) +func VerifyAccountPassword(keyStoreDir, address, password *C.char) *C.char { + _, err := geth.VerifyAccountPassword(C.GoString(keyStoreDir), C.GoString(address), C.GoString(password)) return makeJSONErrorResponse(err) } diff --git a/cmd/statusd/utils.go b/cmd/statusd/utils.go index 35282d1ef..3fb516911 100644 --- a/cmd/statusd/utils.go +++ b/cmd/statusd/utils.go @@ -120,11 +120,20 @@ func testVerifyAccountPassword(t *testing.T) bool { if err = geth.ImportTestAccount(tmpDir, "test-account1.pk"); err != nil { t.Fatal(err) } + if err = geth.ImportTestAccount(tmpDir, "test-account2.pk"); err != nil { + t.Fatal(err) + } + + // rename account file (to see that file's internals reviewed, when locating account key) + accountFilePathOriginal := filepath.Join(tmpDir, "test-account1.pk") + accountFilePath := filepath.Join(tmpDir, "foo"+testConfig.Account1.Address+"bar.pk") + if err := os.Rename(accountFilePathOriginal, accountFilePath); err != nil { + t.Fatal(err) + } - accountFilePath := filepath.Join(tmpDir, "test-account1.pk") response := geth.JSONError{} rawResponse := VerifyAccountPassword( - C.CString(accountFilePath), + C.CString(tmpDir), C.CString(testConfig.Account1.Address), C.CString(testConfig.Account1.Password)) diff --git a/geth/accounts.go b/geth/accounts.go index 36429ce47..b81b7b0ef 100644 --- a/geth/accounts.go +++ b/geth/accounts.go @@ -1,9 +1,12 @@ package geth import ( + "bytes" "errors" "fmt" "io/ioutil" + "os" + "path/filepath" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/keystore" @@ -129,10 +132,39 @@ func RecoverAccount(password, mnemonic string) (address, pubKey string, err erro // VerifyAccountPassword tries to decrypt a given account key file, with a provided password. // If no error is returned, then account is considered verified. -func VerifyAccountPassword(keyPath, address, password string) (*keystore.Key, error) { - keyJSON, err := ioutil.ReadFile(keyPath) +func VerifyAccountPassword(keyStoreDir, address, password string) (*keystore.Key, error) { + var err error + var keyJSON []byte + + addressObj := common.BytesToAddress(common.FromHex(address)) + checkAccountKey := func(path string, fileInfo os.FileInfo) error { + if len(keyJSON) > 0 || fileInfo.IsDir() { + return nil + } + + keyJSON, err = ioutil.ReadFile(path) + if err != nil { + return fmt.Errorf("invalid account key file: %v", err) + } + if !bytes.Contains(keyJSON, []byte(fmt.Sprintf(`"address":"%s"`, addressObj.Hex()[2:]))) { + keyJSON = []byte{} + } + + return nil + } + // locate key within key store directory (address should be within the file) + err = filepath.Walk(keyStoreDir, func(path string, fileInfo os.FileInfo, err error) error { + if err != nil { + return err + } + return checkAccountKey(path, fileInfo) + }) if err != nil { - return nil, fmt.Errorf("invalid account key file: %v", err) + return nil, fmt.Errorf("cannot traverse key store folder: %v", err) + } + + if len(keyJSON) == 0 { + return nil, fmt.Errorf("cannot locate account for address: %x", addressObj) } key, err := keystore.DecryptKey(keyJSON, password) @@ -141,9 +173,8 @@ func VerifyAccountPassword(keyPath, address, password string) (*keystore.Key, er } // avoid swap attack - addr := common.BytesToAddress(common.FromHex(address)) - if key.Address != addr { - return nil, fmt.Errorf("account mismatch: have %x, want %x", key.Address, addr) + if key.Address != addressObj { + return nil, fmt.Errorf("account mismatch: have %x, want %x", key.Address, addressObj) } return key, nil diff --git a/geth/accounts_test.go b/geth/accounts_test.go index 963a5993b..9b587d670 100644 --- a/geth/accounts_test.go +++ b/geth/accounts_test.go @@ -14,19 +14,27 @@ import ( ) func TestVerifyAccountPassword(t *testing.T) { - tmpDir, err := ioutil.TempDir(os.TempDir(), "accounts") + keyStoreDir, err := ioutil.TempDir(os.TempDir(), "accounts") if err != nil { t.Fatal(err) } - defer os.RemoveAll(tmpDir) // nolint: errcheck + defer os.RemoveAll(keyStoreDir) // nolint: errcheck - if err = geth.ImportTestAccount(tmpDir, "test-account1.pk"); err != nil { + emptyKeyStoreDir, err := ioutil.TempDir(os.TempDir(), "empty") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(emptyKeyStoreDir) // nolint: errcheck + + // import account keys + if err = geth.ImportTestAccount(keyStoreDir, "test-account1.pk"); err != nil { + t.Fatal(err) + } + if err = geth.ImportTestAccount(keyStoreDir, "test-account2.pk"); err != nil { t.Fatal(err) } - accountFilePath := filepath.Join(tmpDir, "test-account1.pk") account1Address := common.BytesToAddress(common.FromHex(testConfig.Account1.Address)) - account2Address := common.BytesToAddress(common.FromHex(testConfig.Account2.Address)) testCases := []struct { name string @@ -37,28 +45,35 @@ func TestVerifyAccountPassword(t *testing.T) { }{ { "correct address, correct password (decrypt should succeed)", - accountFilePath, + keyStoreDir, testConfig.Account1.Address, testConfig.Account1.Password, nil, }, { - "correct address, correct password, invalid key file", - filepath.Join(tmpDir, "non-existent-file.pk"), + "correct address, correct password, non-existent key store", + filepath.Join(keyStoreDir, "non-existent-folder"), testConfig.Account1.Address, testConfig.Account1.Password, - fmt.Errorf("invalid account key file: open %s/non-existent-file.pk: no such file or directory", tmpDir), + fmt.Errorf("cannot traverse key store folder: lstat %s/non-existent-folder: no such file or directory", keyStoreDir), + }, + { + "correct address, correct password, empty key store (pk is not there)", + emptyKeyStoreDir, + testConfig.Account1.Address, + testConfig.Account1.Password, + fmt.Errorf("cannot locate account for address: %x", account1Address), }, { "wrong address, correct password", - accountFilePath, - testConfig.Account2.Address, // wrong address (swap attack) + keyStoreDir, + "0x79791d3e8f2daa1f7fec29649d152c0ada3cc535", testConfig.Account1.Password, - fmt.Errorf("account mismatch: have %x, want %x", account1Address, account2Address), + fmt.Errorf("cannot locate account for address: %s", "79791d3e8f2daa1f7fec29649d152c0ada3cc535"), }, { "correct address, wrong password", - accountFilePath, + keyStoreDir, testConfig.Account1.Address, "wrong password", // wrong password errors.New("could not decrypt key with given passphrase"), @@ -68,7 +83,7 @@ func TestVerifyAccountPassword(t *testing.T) { t.Log(testCase.name) accountKey, err := geth.VerifyAccountPassword(testCase.keyPath, testCase.address, testCase.password) if !reflect.DeepEqual(err, testCase.expectedError) { - t.Errorf("unexpected error: expected \n'%v', got \n'%v'", testCase.expectedError, err) + t.Fatalf("unexpected error: expected \n'%v', got \n'%v'", testCase.expectedError, err) } if err == nil { if accountKey == nil { @@ -76,7 +91,7 @@ func TestVerifyAccountPassword(t *testing.T) { } accountAddress := common.BytesToAddress(common.FromHex(testCase.address)) if accountKey.Address != accountAddress { - t.Errorf("account mismatch: have %x, want %x", accountKey.Address, accountAddress) + t.Fatalf("account mismatch: have %x, want %x", accountKey.Address, accountAddress) } } }