diff --git a/keycard_context.go b/keycard_context.go index 780be91..d5d9d25 100644 --- a/keycard_context.go +++ b/keycard_context.go @@ -3,6 +3,7 @@ package statuskeycardgo import ( "crypto/sha512" "errors" + "runtime" "time" "github.com/ebfe/scard" @@ -19,86 +20,133 @@ import ( const bip39Salt = "mnemonic" +type commandType int + +const ( + Close commandType = iota + Transmit + Ack +) + type keycardContext struct { cardCtx *scard.Context card *scard.Card readers []string c types.Channel cmdSet *keycard.CommandSet - connected chan (struct{}) + connected chan (bool) + command chan (commandType) + apdu []byte + rpdu []byte runErr error } +func (kc *keycardContext) Transmit(apdu []byte) ([]byte, error) { + kc.apdu = apdu + kc.command <- Transmit + <-kc.command + kc.apdu = nil + rpdu, err := kc.rpdu, kc.runErr + kc.rpdu = nil + kc.runErr = nil + return rpdu, err +} + func startKeycardContext() (*keycardContext, error) { kctx := &keycardContext{ - connected: make(chan (struct{})), - } - err := kctx.start() - if err != nil { - return nil, err + connected: make(chan (bool)), + command: make(chan (commandType)), } go kctx.run() + <-kctx.connected + + if kctx.runErr != nil { + return nil, kctx.runErr + } + return kctx, nil } +func (kc *keycardContext) run() { + runtime.LockOSThread() + + var err error + + defer func() { + if err != nil { + l(err.Error()) + } + + kc.runErr = err + + if kc.cardCtx != nil { + _ = kc.cardCtx.Release() + } + + close(kc.connected) + runtime.UnlockOSThread() + }() + + err = kc.start() + + if err != nil { + return + } + + kc.connected <- true + + err = kc.connect() + + if err != nil { + return + } + + kc.connected <- true + + for cmd := range kc.command { + switch cmd { + case Transmit: + kc.rpdu, kc.runErr = kc.card.Transmit(kc.apdu) + kc.command <- Ack + case Close: + return + } + } +} + func (kc *keycardContext) start() error { cardCtx, err := scard.EstablishContext() if err != nil { - err = errors.New(ErrorPCSC) - l(err.Error()) - close(kc.connected) - return err + return errors.New(ErrorPCSC) } l("listing readers") readers, err := cardCtx.ListReaders() if err != nil { - err = errors.New(ErrorReaderList) - l(err.Error()) - close(kc.connected) - _ = cardCtx.Release() - return err + return errors.New(ErrorReaderList) } kc.readers = readers if len(readers) == 0 { - err = errors.New(ErrorNoReader) - l(err.Error()) - close(kc.connected) - _ = cardCtx.Release() - return err + return errors.New(ErrorNoReader) } kc.cardCtx = cardCtx return nil } -func (kc *keycardContext) stop() error { - if kc.runErr != nil { - return kc.runErr - } - - if err := kc.cardCtx.Release(); err != nil { - err = errors.New(ErrorConnection) - l(err.Error()) - return err - } - - return nil +func (kc *keycardContext) stop() { + close(kc.command) } -func (kc *keycardContext) run() { +func (kc *keycardContext) connect() error { l("waiting for card") index, err := kc.waitForCard(kc.cardCtx, kc.readers) if err != nil { - l(err.Error()) - kc.runErr = err - close(kc.connected) - _ = kc.cardCtx.Release() - return + return err } l("card found at index %d", index) @@ -109,22 +157,14 @@ func (kc *keycardContext) run() { card, err := kc.cardCtx.Connect(reader, scard.ShareShared, scard.ProtocolAny) if err != nil { // error connecting to card - l(err.Error()) - kc.runErr = err time.Sleep(500 * time.Millisecond) - close(kc.connected) - _ = kc.cardCtx.Release() - return + return err } status, err := card.Status() if err != nil { - l(err.Error()) - kc.runErr = err time.Sleep(500 * time.Millisecond) - close(kc.connected) - _ = kc.cardCtx.Release() - return + return err } switch status.ActiveProtocol { @@ -139,7 +179,8 @@ func (kc *keycardContext) run() { kc.card = card kc.c = io.NewNormalChannel(card) kc.cmdSet = keycard.NewCommandSet(kc.c) - close(kc.connected) + + return nil } func (kc *keycardContext) waitForCard(ctx *scard.Context, readers []string) (int, error) { @@ -167,11 +208,6 @@ func (kc *keycardContext) waitForCard(ctx *scard.Context, readers []string) (int } func (kc *keycardContext) selectApplet() (*types.ApplicationInfo, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - err := kc.cmdSet.Select() if err != nil { if e, ok := err.(*apdu.ErrBadResponse); ok && e.Sw == globalplatform.SwFileNotFound { @@ -186,11 +222,6 @@ func (kc *keycardContext) selectApplet() (*types.ApplicationInfo, error) { } func (kc *keycardContext) pair(pairingPassword string) (*types.PairingInfo, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - err := kc.cmdSet.Pair(pairingPassword) if err != nil { l("pair failed %+v", err) @@ -201,11 +232,6 @@ func (kc *keycardContext) pair(pairingPassword string) (*types.PairingInfo, erro } func (kc *keycardContext) openSecureChannel(index int, key []byte) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - kc.cmdSet.SetPairingInfo(key, index) err := kc.cmdSet.OpenSecureChannel() if err != nil { @@ -217,11 +243,6 @@ func (kc *keycardContext) openSecureChannel(index int, key []byte) error { } func (kc *keycardContext) verifyPin(pin string) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.VerifyPIN(pin) if err != nil { l("verifyPin failed %+v", err) @@ -232,11 +253,6 @@ func (kc *keycardContext) verifyPin(pin string) error { } func (kc *keycardContext) unblockPIN(puk string, newPIN string) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.UnblockPIN(puk, newPIN) if err != nil { l("unblockPIN failed %+v", err) @@ -248,11 +264,6 @@ func (kc *keycardContext) unblockPIN(puk string, newPIN string) error { //lint:ignore U1000 will be used func (kc *keycardContext) generateKey() ([]byte, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - appStatus, err := kc.cmdSet.GetStatusApplication() if err != nil { l("getStatus failed %+v", err) @@ -274,11 +285,6 @@ func (kc *keycardContext) generateKey() ([]byte, error) { } func (kc *keycardContext) generateMnemonic(checksumSize int) ([]int, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - indexes, err := kc.cmdSet.GenerateMnemonic(checksumSize) if err != nil { l("generateMnemonic failed %+v", err) @@ -289,11 +295,6 @@ func (kc *keycardContext) generateMnemonic(checksumSize int) ([]int, error) { } func (kc *keycardContext) removeKey() error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.RemoveKey() if err != nil { l("removeKey failed %+v", err) @@ -305,11 +306,6 @@ func (kc *keycardContext) removeKey() error { //lint:ignore U1000 will be used func (kc *keycardContext) deriveKey(path string) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.DeriveKey(path) if err != nil { l("deriveKey failed %+v", err) @@ -320,11 +316,6 @@ func (kc *keycardContext) deriveKey(path string) error { } func (kc *keycardContext) signWithPath(data []byte, path string) (*types.Signature, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - sig, err := kc.cmdSet.SignWithPath(data, path) if err != nil { l("signWithPath failed %+v", err) @@ -335,11 +326,6 @@ func (kc *keycardContext) signWithPath(data []byte, path string) (*types.Signatu } func (kc *keycardContext) exportKey(derive bool, makeCurrent bool, onlyPublic bool, path string) (*KeyPair, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - address := "" privKey, pubKey, err := kc.cmdSet.ExportKey(derive, makeCurrent, onlyPublic, path) if err != nil { @@ -360,11 +346,6 @@ func (kc *keycardContext) exportKey(derive bool, makeCurrent bool, onlyPublic bo } func (kc *keycardContext) loadSeed(seed []byte) ([]byte, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - pubKey, err := kc.cmdSet.LoadSeed(seed) if err != nil { l("loadSeed failed %+v", err) @@ -380,11 +361,6 @@ func (kc *keycardContext) loadMnemonic(mnemonic string, password string) ([]byte } func (kc *keycardContext) init(pin, puk, pairingPassword string) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - secrets := keycard.NewSecrets(pin, puk, pairingPassword) err := kc.cmdSet.Init(secrets) if err != nil { @@ -396,11 +372,6 @@ func (kc *keycardContext) init(pin, puk, pairingPassword string) error { } func (kc *keycardContext) unpair(index uint8) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.Unpair(index) if err != nil { l("unpair failed %+v", err) @@ -415,11 +386,6 @@ func (kc *keycardContext) unpairCurrent() error { } func (kc *keycardContext) getStatusApplication() (*types.ApplicationStatus, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - status, err := kc.cmdSet.GetStatusApplication() if err != nil { l("getStatusApplication failed %+v", err) @@ -430,11 +396,6 @@ func (kc *keycardContext) getStatusApplication() (*types.ApplicationStatus, erro } func (kc *keycardContext) changePin(pin string) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.ChangePIN(pin) if err != nil { l("chaingePin failed %+v", err) @@ -445,11 +406,6 @@ func (kc *keycardContext) changePin(pin string) error { } func (kc *keycardContext) changePuk(puk string) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.ChangePUK(puk) if err != nil { l("chaingePuk failed %+v", err) @@ -460,11 +416,6 @@ func (kc *keycardContext) changePuk(puk string) error { } func (kc *keycardContext) changePairingPassword(pairingPassword string) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.ChangePairingSecret(pairingPassword) if err != nil { l("chaingePairingPassword failed %+v", err) @@ -475,11 +426,6 @@ func (kc *keycardContext) changePairingPassword(pairingPassword string) error { } func (kc *keycardContext) factoryReset(retry bool) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - cmdSet := globalplatform.NewCommandSet(kc.c) if err := cmdSet.Select(); err != nil { @@ -517,11 +463,6 @@ func (kc *keycardContext) factoryReset(retry bool) error { } func (kc *keycardContext) storeMetadata(metadata *types.Metadata) error { - <-kc.connected - if kc.runErr != nil { - return kc.runErr - } - err := kc.cmdSet.StoreData(keycard.P1StoreDataPublic, metadata.Serialize()) if err != nil { @@ -533,11 +474,6 @@ func (kc *keycardContext) storeMetadata(metadata *types.Metadata) error { } func (kc *keycardContext) getMetadata() (*types.Metadata, error) { - <-kc.connected - if kc.runErr != nil { - return nil, kc.runErr - } - data, err := kc.cmdSet.GetData(keycard.P1StoreDataPublic) if err != nil {