diff --git a/commands.go b/commands.go index 24f5c36..1949f1d 100644 --- a/commands.go +++ b/commands.go @@ -111,7 +111,7 @@ func NewCommandVerifyPIN(pin string) *apdu.Command { } func NewCommandDeriveKey(pathStr string) (*apdu.Command, error) { - startingPoint, path, err := derivationpath.Parse(pathStr) + startingPoint, path, err := derivationpath.Decode(pathStr) if err != nil { return nil, err } diff --git a/derivationpath/decoder.go b/derivationpath/decoder.go new file mode 100644 index 0000000..f269da2 --- /dev/null +++ b/derivationpath/decoder.go @@ -0,0 +1,214 @@ +package derivationpath + +import ( + "fmt" + "io" + "strconv" + "strings" +) + +type StartingPoint int + +const ( + tokenMaster = 0x6D // char m + tokenSeparator = 0x2F // char / + tokenHardened = 0x27 // char ' + tokenDot = 0x2E // char . + + hardenedStart = 0x80000000 // 2^31 +) + +const ( + StartingPointMaster StartingPoint = iota + 1 + StartingPointCurrent + StartingPointParent +) + +type parseFunc = func() error + +type decoder struct { + r *strings.Reader + f parseFunc + pos int + path []uint32 + start StartingPoint + currentToken string + currentTokenHardened bool +} + +func newDecoder(path string) *decoder { + d := &decoder{ + r: strings.NewReader(path), + } + + d.reset() + + return d +} + +func (d *decoder) reset() { + d.r.Seek(0, io.SeekStart) + d.pos = 0 + d.start = StartingPointCurrent + d.f = d.parseStart + d.path = make([]uint32, 0) + d.resetCurrentToken() +} + +func (d *decoder) resetCurrentToken() { + d.currentToken = "" + d.currentTokenHardened = false +} + +func (d *decoder) parse() (StartingPoint, []uint32, error) { + for { + err := d.f() + if err != nil { + if err == io.EOF { + err = nil + } else { + err = fmt.Errorf("at position %d, %s", d.pos, err.Error()) + } + + return d.start, d.path, err + } + } + + return d.start, d.path, nil +} + +func (d *decoder) readByte() (byte, error) { + b, err := d.r.ReadByte() + if err != nil { + return b, err + } + + d.pos++ + + return b, nil +} + +func (d *decoder) unreadByte() error { + err := d.r.UnreadByte() + if err != nil { + return err + } + + d.pos-- + + return nil +} + +func (d *decoder) parseStart() error { + b, err := d.readByte() + if err != nil { + return err + } + + if b == tokenMaster { + d.start = StartingPointMaster + d.f = d.parseSeparator + return nil + } + + if b == tokenDot { + b2, err := d.readByte() + if err != nil { + return err + } + + if b2 == tokenDot { + d.f = d.parseSeparator + d.start = StartingPointParent + return nil + } + + d.f = d.parseSeparator + d.start = StartingPointCurrent + return d.unreadByte() + } + + d.f = d.parseSegment + + return d.unreadByte() +} + +func (d *decoder) saveSegment() error { + if len(d.currentToken) > 0 { + i, err := strconv.ParseUint(d.currentToken, 10, 32) + if err != nil { + return err + } + + if i >= hardenedStart { + d.pos -= len(d.currentToken) - 1 + return fmt.Errorf("index must be lower than 2^31, got %d", i) + } + + if d.currentTokenHardened { + i += hardenedStart + } + + d.path = append(d.path, uint32(i)) + } + + d.f = d.parseSegment + d.resetCurrentToken() + + return nil +} + +func (d *decoder) parseSeparator() error { + b, err := d.readByte() + if err != nil { + return err + } + + if b == tokenSeparator { + return d.saveSegment() + } + + return fmt.Errorf("expected %s, got %s", string(tokenSeparator), string(b)) +} + +func (d *decoder) parseSegment() error { + b, err := d.readByte() + if err == io.EOF { + if len(d.currentToken) == 0 { + return fmt.Errorf("expected number, got EOF") + } + + if newErr := d.saveSegment(); newErr != nil { + return newErr + } + + return err + } + + if err != nil { + return err + } + + if len(d.currentToken) > 0 && b == tokenSeparator { + return d.saveSegment() + } + + if len(d.currentToken) > 0 && b == tokenHardened { + d.currentTokenHardened = true + d.f = d.parseSeparator + return nil + } + + if b < 0x30 || b > 0x39 { + return fmt.Errorf("expected number, got %s", string(b)) + } + + d.currentToken = fmt.Sprintf("%s%s", d.currentToken, string(b)) + + return nil +} + +func Decode(str string) (StartingPoint, []uint32, error) { + d := newDecoder(str) + return d.parse() +} diff --git a/derivationpath/derivationpath_test.go b/derivationpath/decoder_test.go similarity index 96% rename from derivationpath/derivationpath_test.go rename to derivationpath/decoder_test.go index ae3fe64..b276707 100644 --- a/derivationpath/derivationpath_test.go +++ b/derivationpath/decoder_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestParse(t *testing.T) { +func TestDecode(t *testing.T) { scenarios := []struct { path string expectedPath []uint32 @@ -73,7 +73,7 @@ func TestParse(t *testing.T) { for i, s := range scenarios { t.Run(fmt.Sprintf("scenario %d", i), func(t *testing.T) { - startingPoint, path, err := Parse(s.path) + startingPoint, path, err := Decode(s.path) if s.err == nil { assert.NoError(t, err) assert.Equal(t, s.expectedStartingPoint, startingPoint) diff --git a/derivationpath/derivationpath.go b/derivationpath/derivationpath.go deleted file mode 100644 index 4f49654..0000000 --- a/derivationpath/derivationpath.go +++ /dev/null @@ -1,214 +0,0 @@ -package derivationpath - -import ( - "fmt" - "io" - "strconv" - "strings" -) - -type StartingPoint int - -const ( - tokenMaster = 0x6D // char m - tokenSeparator = 0x2F // char / - tokenHardened = 0x27 // char ' - tokenDot = 0x2E // char . - - hardenedStart = 0x80000000 // 2^31 -) - -const ( - StartingPointMaster StartingPoint = iota + 1 - StartingPointCurrent - StartingPointParent -) - -type parseFunc = func() error - -type parser struct { - r *strings.Reader - f parseFunc - pos int - path []uint32 - start StartingPoint - currentToken string - currentTokenHardened bool -} - -func newParser(path string) *parser { - p := &parser{ - r: strings.NewReader(path), - } - - p.reset() - - return p -} - -func (p *parser) reset() { - p.r.Seek(0, io.SeekStart) - p.pos = 0 - p.start = StartingPointCurrent - p.f = p.parseStart - p.path = make([]uint32, 0) - p.resetCurrentToken() -} - -func (p *parser) resetCurrentToken() { - p.currentToken = "" - p.currentTokenHardened = false -} - -func (p *parser) parse() (StartingPoint, []uint32, error) { - for { - err := p.f() - if err != nil { - if err == io.EOF { - err = nil - } else { - err = fmt.Errorf("at position %d, %s", p.pos, err.Error()) - } - - return p.start, p.path, err - } - } - - return p.start, p.path, nil -} - -func (p *parser) readByte() (byte, error) { - b, err := p.r.ReadByte() - if err != nil { - return b, err - } - - p.pos++ - - return b, nil -} - -func (p *parser) unreadByte() error { - err := p.r.UnreadByte() - if err != nil { - return err - } - - p.pos-- - - return nil -} - -func (p *parser) parseStart() error { - b, err := p.readByte() - if err != nil { - return err - } - - if b == tokenMaster { - p.start = StartingPointMaster - p.f = p.parseSeparator - return nil - } - - if b == tokenDot { - b2, err := p.readByte() - if err != nil { - return err - } - - if b2 == tokenDot { - p.f = p.parseSeparator - p.start = StartingPointParent - return nil - } - - p.f = p.parseSeparator - p.start = StartingPointCurrent - return p.unreadByte() - } - - p.f = p.parseSegment - - return p.unreadByte() -} - -func (p *parser) saveSegment() error { - if len(p.currentToken) > 0 { - i, err := strconv.ParseUint(p.currentToken, 10, 32) - if err != nil { - return err - } - - if i >= hardenedStart { - p.pos -= len(p.currentToken) - 1 - return fmt.Errorf("index must be lower than 2^31, got %d", i) - } - - if p.currentTokenHardened { - i += hardenedStart - } - - p.path = append(p.path, uint32(i)) - } - - p.f = p.parseSegment - p.resetCurrentToken() - - return nil -} - -func (p *parser) parseSeparator() error { - b, err := p.readByte() - if err != nil { - return err - } - - if b == tokenSeparator { - return p.saveSegment() - } - - return fmt.Errorf("expected %s, got %s", string(tokenSeparator), string(b)) -} - -func (p *parser) parseSegment() error { - b, err := p.readByte() - if err == io.EOF { - if len(p.currentToken) == 0 { - return fmt.Errorf("expected number, got EOF") - } - - if newErr := p.saveSegment(); newErr != nil { - return newErr - } - - return err - } - - if err != nil { - return err - } - - if len(p.currentToken) > 0 && b == tokenSeparator { - return p.saveSegment() - } - - if len(p.currentToken) > 0 && b == tokenHardened { - p.currentTokenHardened = true - p.f = p.parseSeparator - return nil - } - - if b < 0x30 || b > 0x39 { - return fmt.Errorf("expected number, got %s", string(b)) - } - - p.currentToken = fmt.Sprintf("%s%s", p.currentToken, string(b)) - - return nil -} - -func Parse(str string) (StartingPoint, []uint32, error) { - p := newParser(str) - return p.parse() -}