diff --git a/README.md b/README.md index 059ce19..6c6a552 100644 --- a/README.md +++ b/README.md @@ -40,10 +40,12 @@ if err != nil { // Find a YubiKey and open the reader. var yk *piv.YubiKey for _, card := range cards { - if strings.Contains(strings.ToLower(card), "yubikey") { - if yk, err = piv.Open(card); err != nil { - // ... + if yk, err := piv.Open(card); err == nil { + status := yk.Status() + if !strings.Contains(strings.ToLower(string(status.Atr())), "ubike") { + continue } + // .. break } } diff --git a/piv/pcsc.go b/piv/pcsc.go index 00e16b3..861a338 100644 --- a/piv/pcsc.go +++ b/piv/pcsc.go @@ -127,6 +127,29 @@ func (a *apduErr) Unwrap() error { return nil } +type scStatus struct { + reader string + state uint32 + protocol uint32 + atr []byte +} + +func (s *scStatus) Reader() string { + return s.reader +} + +func (s *scStatus) State() uint32 { + return s.state +} + +func (s *scStatus) Protocol() uint32 { + return s.protocol +} + +func (s *scStatus) Atr() []byte { + return s.atr +} + type apdu struct { instruction byte param1 byte diff --git a/piv/pcsc_unix.go b/piv/pcsc_unix.go index a43d259..af029a3 100644 --- a/piv/pcsc_unix.go +++ b/piv/pcsc_unix.go @@ -89,7 +89,8 @@ func (c *scContext) ListReaders() ([]string, error) { } type scHandle struct { - h C.SCARDHANDLE + h C.SCARDHANDLE + status scStatus } func (c *scContext) Connect(reader string) (*scHandle, error) { @@ -103,7 +104,44 @@ func (c *scContext) Connect(reader string) (*scHandle, error) { if err := scCheck(rc); err != nil { return nil, err } - return &scHandle{handle}, nil + + var readerNameLen C.DWORD + var atrLen C.DWORD + + C.SCardStatus( + handle, + nil, + &readerNameLen, + nil, + nil, + nil, + &atrLen, + ) + + var state uint32 + var protocol uint32 + + readerName := make([]byte, readerNameLen) + atr := make([]byte, atrLen) + + C.SCardStatus( + handle, + (*C.char)(unsafe.Pointer(&readerName[0])), + (*C.DWORD)(unsafe.Pointer(&readerNameLen)), + (*C.DWORD)(unsafe.Pointer(&state)), + (*C.DWORD)(unsafe.Pointer(&protocol)), + (*C.uchar)(unsafe.Pointer(&atr[0])), + (*C.DWORD)(unsafe.Pointer(&atrLen)), + ) + + status := scStatus{ + reader: string(readerName), + state: state, + protocol: protocol, + atr: atr, + } + + return &scHandle{handle, status}, nil } func (h *scHandle) Close() error { diff --git a/piv/pcsc_windows.go b/piv/pcsc_windows.go index 845194f..48a8dd1 100644 --- a/piv/pcsc_windows.go +++ b/piv/pcsc_windows.go @@ -24,6 +24,7 @@ var ( winscard = syscall.NewLazyDLL("Winscard.dll") procSCardEstablishContext = winscard.NewProc("SCardEstablishContext") procSCardListReadersW = winscard.NewProc("SCardListReadersW") + procSCardStatusW = winscard.NewProc("SCardStatusW") procSCardReleaseContext = winscard.NewProc("SCardReleaseContext") procSCardConnectW = winscard.NewProc("SCardConnectW") procSCardDisconnect = winscard.NewProc("SCardDisconnect") @@ -122,6 +123,11 @@ func (c *scContext) ListReaders() ([]string, error) { return readers, nil } +type scHandle struct { + handle syscall.Handle + status scStatus +} + func (c *scContext) Connect(reader string) (*scHandle, error) { var ( handle syscall.Handle @@ -142,11 +148,49 @@ func (c *scContext) Connect(reader string) (*scHandle, error) { if err := scCheck(r0); err != nil { return nil, err } - return &scHandle{handle}, nil -} -type scHandle struct { - handle syscall.Handle + var readerNameLen uint32 + var atrLen uint32 + r0, _, _ = procSCardStatusW.Call( + uintptr(handle), + uintptr(unsafe.Pointer(nil)), + uintptr(unsafe.Pointer(&readerNameLen)), + uintptr(unsafe.Pointer(nil)), + uintptr(unsafe.Pointer(nil)), + uintptr(unsafe.Pointer(nil)), + uintptr(unsafe.Pointer(&atrLen)), + ) + if err := scCheck(r0); err != nil { + return nil, err + } + + var state uint32 + var protocol uint32 + + readerName := make([]uint16, readerNameLen) + atr := make([]byte, atrLen) + + r0, _, _ = procSCardStatusW.Call( + uintptr(handle), + uintptr(unsafe.Pointer(&readerName[0])), + uintptr(unsafe.Pointer(&readerNameLen)), + uintptr(unsafe.Pointer(&state)), + uintptr(unsafe.Pointer(&protocol)), + uintptr(unsafe.Pointer(&atr[0])), + uintptr(unsafe.Pointer(&atrLen)), + ) + if err := scCheck(r0); err != nil { + return nil, err + } + + status := scStatus{ + reader: syscall.UTF16ToString(readerName), + state: state, + protocol: protocol, + atr: atr, + } + + return &scHandle{handle, status}, nil } func (h *scHandle) Close() error { diff --git a/piv/piv.go b/piv/piv.go index 4e7171a..570777c 100644 --- a/piv/piv.go +++ b/piv/piv.go @@ -196,6 +196,10 @@ func (yk *YubiKey) Version() Version { } } +func (yk *YubiKey) Status() scStatus { + return yk.h.status +} + // Serial returns the YubiKey's serial number. func (yk *YubiKey) Serial() (uint32, error) { return ykSerial(yk.tx, yk.version) @@ -216,7 +220,7 @@ func encodePIN(pin string) ([]byte, error) { return data, nil } -// authPIN attempts to authenticate against the card with the provided PIN. +// AuthPIN attempts to authenticate against the card with the provided PIN. // The PIN is required to use and modify certain slots. // // After a specific number of authentication attemps with an invalid PIN, @@ -224,7 +228,7 @@ func encodePIN(pin string) ([]byte, error) { // point the PUK must be used to unblock the PIN. // // Use DefaultPIN if the PIN hasn't been set. -func (yk *YubiKey) authPIN(pin string) error { +func (yk *YubiKey) AuthPIN(pin string) error { return ykLogin(yk.tx, pin) } diff --git a/piv/piv_test.go b/piv/piv_test.go index 534b259..a978fa0 100644 --- a/piv/piv_test.go +++ b/piv/piv_test.go @@ -190,7 +190,7 @@ func TestYubiKeyReset(t *testing.T) { if err := yk.Reset(); err != nil { t.Fatalf("resetting yubikey: %v", err) } - if err := yk.authPIN(DefaultPIN); err != nil { + if err := yk.AuthPIN(DefaultPIN); err != nil { t.Fatalf("login: %v", err) } } @@ -199,7 +199,7 @@ func TestYubiKeyLogin(t *testing.T) { yk, close := newTestYubiKey(t) defer close() - if err := yk.authPIN(DefaultPIN); err != nil { + if err := yk.AuthPIN(DefaultPIN); err != nil { t.Fatalf("login: %v", err) } }