call all pcsc funcs from a single os thread

This commit is contained in:
Michele Balistreri 2023-04-26 15:42:11 +02:00
parent ad9531fa4b
commit 65313d6767
No known key found for this signature in database
GPG Key ID: E9567DA33A4F791A

View File

@ -3,6 +3,7 @@ package statuskeycardgo
import ( import (
"crypto/sha512" "crypto/sha512"
"errors" "errors"
"runtime"
"time" "time"
"github.com/ebfe/scard" "github.com/ebfe/scard"
@ -19,86 +20,133 @@ import (
const bip39Salt = "mnemonic" const bip39Salt = "mnemonic"
type commandType int
const (
Close commandType = iota
Transmit
Ack
)
type keycardContext struct { type keycardContext struct {
cardCtx *scard.Context cardCtx *scard.Context
card *scard.Card card *scard.Card
readers []string readers []string
c types.Channel c types.Channel
cmdSet *keycard.CommandSet cmdSet *keycard.CommandSet
connected chan (struct{}) connected chan (bool)
command chan (commandType)
apdu []byte
rpdu []byte
runErr error runErr error
} }
func (kc *keycardContext) Transmit(apdu []byte) ([]byte, error) {
kc.apdu = apdu
kc.command <- Transmit
<-kc.command
kc.apdu = nil
rpdu, err := kc.rpdu, kc.runErr
kc.rpdu = nil
kc.runErr = nil
return rpdu, err
}
func startKeycardContext() (*keycardContext, error) { func startKeycardContext() (*keycardContext, error) {
kctx := &keycardContext{ kctx := &keycardContext{
connected: make(chan (struct{})), connected: make(chan (bool)),
} command: make(chan (commandType)),
err := kctx.start()
if err != nil {
return nil, err
} }
go kctx.run() go kctx.run()
<-kctx.connected
if kctx.runErr != nil {
return nil, kctx.runErr
}
return kctx, nil return kctx, nil
} }
func (kc *keycardContext) run() {
runtime.LockOSThread()
var err error
defer func() {
if err != nil {
l(err.Error())
}
kc.runErr = err
if kc.cardCtx != nil {
_ = kc.cardCtx.Release()
}
close(kc.connected)
runtime.UnlockOSThread()
}()
err = kc.start()
if err != nil {
return
}
kc.connected <- true
err = kc.connect()
if err != nil {
return
}
kc.connected <- true
for cmd := range kc.command {
switch cmd {
case Transmit:
kc.rpdu, kc.runErr = kc.card.Transmit(kc.apdu)
kc.command <- Ack
case Close:
return
}
}
}
func (kc *keycardContext) start() error { func (kc *keycardContext) start() error {
cardCtx, err := scard.EstablishContext() cardCtx, err := scard.EstablishContext()
if err != nil { if err != nil {
err = errors.New(ErrorPCSC) return errors.New(ErrorPCSC)
l(err.Error())
close(kc.connected)
return err
} }
l("listing readers") l("listing readers")
readers, err := cardCtx.ListReaders() readers, err := cardCtx.ListReaders()
if err != nil { if err != nil {
err = errors.New(ErrorReaderList) return errors.New(ErrorReaderList)
l(err.Error())
close(kc.connected)
_ = cardCtx.Release()
return err
} }
kc.readers = readers kc.readers = readers
if len(readers) == 0 { if len(readers) == 0 {
err = errors.New(ErrorNoReader) return errors.New(ErrorNoReader)
l(err.Error())
close(kc.connected)
_ = cardCtx.Release()
return err
} }
kc.cardCtx = cardCtx kc.cardCtx = cardCtx
return nil return nil
} }
func (kc *keycardContext) stop() error { func (kc *keycardContext) stop() {
if kc.runErr != nil { close(kc.command)
return kc.runErr
}
if err := kc.cardCtx.Release(); err != nil {
err = errors.New(ErrorConnection)
l(err.Error())
return err
}
return nil
} }
func (kc *keycardContext) run() { func (kc *keycardContext) connect() error {
l("waiting for card") l("waiting for card")
index, err := kc.waitForCard(kc.cardCtx, kc.readers) index, err := kc.waitForCard(kc.cardCtx, kc.readers)
if err != nil { if err != nil {
l(err.Error()) return err
kc.runErr = err
close(kc.connected)
_ = kc.cardCtx.Release()
return
} }
l("card found at index %d", index) l("card found at index %d", index)
@ -109,22 +157,14 @@ func (kc *keycardContext) run() {
card, err := kc.cardCtx.Connect(reader, scard.ShareShared, scard.ProtocolAny) card, err := kc.cardCtx.Connect(reader, scard.ShareShared, scard.ProtocolAny)
if err != nil { if err != nil {
// error connecting to card // error connecting to card
l(err.Error())
kc.runErr = err
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
close(kc.connected) return err
_ = kc.cardCtx.Release()
return
} }
status, err := card.Status() status, err := card.Status()
if err != nil { if err != nil {
l(err.Error())
kc.runErr = err
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
close(kc.connected) return err
_ = kc.cardCtx.Release()
return
} }
switch status.ActiveProtocol { switch status.ActiveProtocol {
@ -139,7 +179,8 @@ func (kc *keycardContext) run() {
kc.card = card kc.card = card
kc.c = io.NewNormalChannel(card) kc.c = io.NewNormalChannel(card)
kc.cmdSet = keycard.NewCommandSet(kc.c) kc.cmdSet = keycard.NewCommandSet(kc.c)
close(kc.connected)
return nil
} }
func (kc *keycardContext) waitForCard(ctx *scard.Context, readers []string) (int, error) { func (kc *keycardContext) waitForCard(ctx *scard.Context, readers []string) (int, error) {
@ -167,11 +208,6 @@ func (kc *keycardContext) waitForCard(ctx *scard.Context, readers []string) (int
} }
func (kc *keycardContext) selectApplet() (*types.ApplicationInfo, error) { func (kc *keycardContext) selectApplet() (*types.ApplicationInfo, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
err := kc.cmdSet.Select() err := kc.cmdSet.Select()
if err != nil { if err != nil {
if e, ok := err.(*apdu.ErrBadResponse); ok && e.Sw == globalplatform.SwFileNotFound { if e, ok := err.(*apdu.ErrBadResponse); ok && e.Sw == globalplatform.SwFileNotFound {
@ -186,11 +222,6 @@ func (kc *keycardContext) selectApplet() (*types.ApplicationInfo, error) {
} }
func (kc *keycardContext) pair(pairingPassword string) (*types.PairingInfo, error) { func (kc *keycardContext) pair(pairingPassword string) (*types.PairingInfo, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
err := kc.cmdSet.Pair(pairingPassword) err := kc.cmdSet.Pair(pairingPassword)
if err != nil { if err != nil {
l("pair failed %+v", err) l("pair failed %+v", err)
@ -201,11 +232,6 @@ func (kc *keycardContext) pair(pairingPassword string) (*types.PairingInfo, erro
} }
func (kc *keycardContext) openSecureChannel(index int, key []byte) error { func (kc *keycardContext) openSecureChannel(index int, key []byte) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
kc.cmdSet.SetPairingInfo(key, index) kc.cmdSet.SetPairingInfo(key, index)
err := kc.cmdSet.OpenSecureChannel() err := kc.cmdSet.OpenSecureChannel()
if err != nil { if err != nil {
@ -217,11 +243,6 @@ func (kc *keycardContext) openSecureChannel(index int, key []byte) error {
} }
func (kc *keycardContext) verifyPin(pin string) error { func (kc *keycardContext) verifyPin(pin string) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.VerifyPIN(pin) err := kc.cmdSet.VerifyPIN(pin)
if err != nil { if err != nil {
l("verifyPin failed %+v", err) l("verifyPin failed %+v", err)
@ -232,11 +253,6 @@ func (kc *keycardContext) verifyPin(pin string) error {
} }
func (kc *keycardContext) unblockPIN(puk string, newPIN string) error { func (kc *keycardContext) unblockPIN(puk string, newPIN string) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.UnblockPIN(puk, newPIN) err := kc.cmdSet.UnblockPIN(puk, newPIN)
if err != nil { if err != nil {
l("unblockPIN failed %+v", err) l("unblockPIN failed %+v", err)
@ -248,11 +264,6 @@ func (kc *keycardContext) unblockPIN(puk string, newPIN string) error {
//lint:ignore U1000 will be used //lint:ignore U1000 will be used
func (kc *keycardContext) generateKey() ([]byte, error) { func (kc *keycardContext) generateKey() ([]byte, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
appStatus, err := kc.cmdSet.GetStatusApplication() appStatus, err := kc.cmdSet.GetStatusApplication()
if err != nil { if err != nil {
l("getStatus failed %+v", err) l("getStatus failed %+v", err)
@ -274,11 +285,6 @@ func (kc *keycardContext) generateKey() ([]byte, error) {
} }
func (kc *keycardContext) generateMnemonic(checksumSize int) ([]int, error) { func (kc *keycardContext) generateMnemonic(checksumSize int) ([]int, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
indexes, err := kc.cmdSet.GenerateMnemonic(checksumSize) indexes, err := kc.cmdSet.GenerateMnemonic(checksumSize)
if err != nil { if err != nil {
l("generateMnemonic failed %+v", err) l("generateMnemonic failed %+v", err)
@ -289,11 +295,6 @@ func (kc *keycardContext) generateMnemonic(checksumSize int) ([]int, error) {
} }
func (kc *keycardContext) removeKey() error { func (kc *keycardContext) removeKey() error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.RemoveKey() err := kc.cmdSet.RemoveKey()
if err != nil { if err != nil {
l("removeKey failed %+v", err) l("removeKey failed %+v", err)
@ -305,11 +306,6 @@ func (kc *keycardContext) removeKey() error {
//lint:ignore U1000 will be used //lint:ignore U1000 will be used
func (kc *keycardContext) deriveKey(path string) error { func (kc *keycardContext) deriveKey(path string) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.DeriveKey(path) err := kc.cmdSet.DeriveKey(path)
if err != nil { if err != nil {
l("deriveKey failed %+v", err) l("deriveKey failed %+v", err)
@ -320,11 +316,6 @@ func (kc *keycardContext) deriveKey(path string) error {
} }
func (kc *keycardContext) signWithPath(data []byte, path string) (*types.Signature, error) { func (kc *keycardContext) signWithPath(data []byte, path string) (*types.Signature, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
sig, err := kc.cmdSet.SignWithPath(data, path) sig, err := kc.cmdSet.SignWithPath(data, path)
if err != nil { if err != nil {
l("signWithPath failed %+v", err) l("signWithPath failed %+v", err)
@ -335,11 +326,6 @@ func (kc *keycardContext) signWithPath(data []byte, path string) (*types.Signatu
} }
func (kc *keycardContext) exportKey(derive bool, makeCurrent bool, onlyPublic bool, path string) (*KeyPair, error) { func (kc *keycardContext) exportKey(derive bool, makeCurrent bool, onlyPublic bool, path string) (*KeyPair, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
address := "" address := ""
privKey, pubKey, err := kc.cmdSet.ExportKey(derive, makeCurrent, onlyPublic, path) privKey, pubKey, err := kc.cmdSet.ExportKey(derive, makeCurrent, onlyPublic, path)
if err != nil { if err != nil {
@ -360,11 +346,6 @@ func (kc *keycardContext) exportKey(derive bool, makeCurrent bool, onlyPublic bo
} }
func (kc *keycardContext) loadSeed(seed []byte) ([]byte, error) { func (kc *keycardContext) loadSeed(seed []byte) ([]byte, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
pubKey, err := kc.cmdSet.LoadSeed(seed) pubKey, err := kc.cmdSet.LoadSeed(seed)
if err != nil { if err != nil {
l("loadSeed failed %+v", err) l("loadSeed failed %+v", err)
@ -380,11 +361,6 @@ func (kc *keycardContext) loadMnemonic(mnemonic string, password string) ([]byte
} }
func (kc *keycardContext) init(pin, puk, pairingPassword string) error { func (kc *keycardContext) init(pin, puk, pairingPassword string) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
secrets := keycard.NewSecrets(pin, puk, pairingPassword) secrets := keycard.NewSecrets(pin, puk, pairingPassword)
err := kc.cmdSet.Init(secrets) err := kc.cmdSet.Init(secrets)
if err != nil { if err != nil {
@ -396,11 +372,6 @@ func (kc *keycardContext) init(pin, puk, pairingPassword string) error {
} }
func (kc *keycardContext) unpair(index uint8) error { func (kc *keycardContext) unpair(index uint8) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.Unpair(index) err := kc.cmdSet.Unpair(index)
if err != nil { if err != nil {
l("unpair failed %+v", err) l("unpair failed %+v", err)
@ -415,11 +386,6 @@ func (kc *keycardContext) unpairCurrent() error {
} }
func (kc *keycardContext) getStatusApplication() (*types.ApplicationStatus, error) { func (kc *keycardContext) getStatusApplication() (*types.ApplicationStatus, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
status, err := kc.cmdSet.GetStatusApplication() status, err := kc.cmdSet.GetStatusApplication()
if err != nil { if err != nil {
l("getStatusApplication failed %+v", err) l("getStatusApplication failed %+v", err)
@ -430,11 +396,6 @@ func (kc *keycardContext) getStatusApplication() (*types.ApplicationStatus, erro
} }
func (kc *keycardContext) changePin(pin string) error { func (kc *keycardContext) changePin(pin string) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.ChangePIN(pin) err := kc.cmdSet.ChangePIN(pin)
if err != nil { if err != nil {
l("chaingePin failed %+v", err) l("chaingePin failed %+v", err)
@ -445,11 +406,6 @@ func (kc *keycardContext) changePin(pin string) error {
} }
func (kc *keycardContext) changePuk(puk string) error { func (kc *keycardContext) changePuk(puk string) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.ChangePUK(puk) err := kc.cmdSet.ChangePUK(puk)
if err != nil { if err != nil {
l("chaingePuk failed %+v", err) l("chaingePuk failed %+v", err)
@ -460,11 +416,6 @@ func (kc *keycardContext) changePuk(puk string) error {
} }
func (kc *keycardContext) changePairingPassword(pairingPassword string) error { func (kc *keycardContext) changePairingPassword(pairingPassword string) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.ChangePairingSecret(pairingPassword) err := kc.cmdSet.ChangePairingSecret(pairingPassword)
if err != nil { if err != nil {
l("chaingePairingPassword failed %+v", err) l("chaingePairingPassword failed %+v", err)
@ -475,11 +426,6 @@ func (kc *keycardContext) changePairingPassword(pairingPassword string) error {
} }
func (kc *keycardContext) factoryReset(retry bool) error { func (kc *keycardContext) factoryReset(retry bool) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
cmdSet := globalplatform.NewCommandSet(kc.c) cmdSet := globalplatform.NewCommandSet(kc.c)
if err := cmdSet.Select(); err != nil { if err := cmdSet.Select(); err != nil {
@ -517,11 +463,6 @@ func (kc *keycardContext) factoryReset(retry bool) error {
} }
func (kc *keycardContext) storeMetadata(metadata *types.Metadata) error { func (kc *keycardContext) storeMetadata(metadata *types.Metadata) error {
<-kc.connected
if kc.runErr != nil {
return kc.runErr
}
err := kc.cmdSet.StoreData(keycard.P1StoreDataPublic, metadata.Serialize()) err := kc.cmdSet.StoreData(keycard.P1StoreDataPublic, metadata.Serialize())
if err != nil { if err != nil {
@ -533,11 +474,6 @@ func (kc *keycardContext) storeMetadata(metadata *types.Metadata) error {
} }
func (kc *keycardContext) getMetadata() (*types.Metadata, error) { func (kc *keycardContext) getMetadata() (*types.Metadata, error) {
<-kc.connected
if kc.runErr != nil {
return nil, kc.runErr
}
data, err := kc.cmdSet.GetData(keycard.P1StoreDataPublic) data, err := kc.cmdSet.GetData(keycard.P1StoreDataPublic)
if err != nil { if err != nil {