From 352cf930fc47cc2cab7c34fd853ebdc9cf0f4b42 Mon Sep 17 00:00:00 2001 From: Pierre Cauchois Date: Sat, 19 Sep 2020 01:59:04 +0000 Subject: [PATCH] ServerError type check before EOF string comparison --- lib/eof.go | 13 +++++++++++-- lib/eof_test.go | 10 ++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/lib/eof.go b/lib/eof.go index 75408c8c0d..3023a6ce02 100644 --- a/lib/eof.go +++ b/lib/eof.go @@ -2,7 +2,9 @@ package lib import ( "errors" + "fmt" "io" + "net/rpc" "strings" "github.com/hashicorp/yamux" @@ -20,10 +22,17 @@ func IsErrEOF(err error) bool { errStr := err.Error() if strings.Contains(errStr, yamuxStreamClosed) || - strings.Contains(errStr, yamuxSessionShutdown) || - strings.HasSuffix(errStr, io.EOF.Error()) { + strings.Contains(errStr, yamuxSessionShutdown) { return true } + if srvErr, ok := err.(rpc.ServerError); ok { + return strings.HasSuffix(srvErr.Error(), fmt.Sprintf(": %s", io.EOF.Error())) + } + + if srvErr, ok := errors.Unwrap(err).(rpc.ServerError); ok { + return strings.HasSuffix(srvErr.Error(), fmt.Sprintf(": %s", io.EOF.Error())) + } + return false } diff --git a/lib/eof_test.go b/lib/eof_test.go index 964f9e70e9..38106ae997 100644 --- a/lib/eof_test.go +++ b/lib/eof_test.go @@ -3,6 +3,7 @@ package lib import ( "fmt" "io" + "net/rpc" "testing" "github.com/hashicorp/yamux" @@ -15,19 +16,16 @@ func TestErrIsEOF(t *testing.T) { err error }{ {name: "EOF", err: io.EOF}, + {name: "Wrapped EOF", err: fmt.Errorf("test: %w", io.EOF)}, {name: "yamuxStreamClosed", err: yamux.ErrStreamClosed}, {name: "yamuxSessionShutdown", err: yamux.ErrSessionShutdown}, + {name: "ServerError(___: EOF)", err: rpc.ServerError(fmt.Sprintf("rpc error: %s", io.EOF.Error()))}, + {name: "Wrapped ServerError(___: EOF)", err: fmt.Errorf("rpc error: %w", rpc.ServerError(fmt.Sprintf("rpc error: %s", io.EOF.Error())))}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.True(t, IsErrEOF(tt.err)) }) - t.Run(fmt.Sprintf("Wrapped %s", tt.name), func(t *testing.T) { - require.True(t, IsErrEOF(fmt.Errorf("test: %w", tt.err))) - }) - t.Run(fmt.Sprintf("String suffix is %s", tt.name), func(t *testing.T) { - require.True(t, IsErrEOF(fmt.Errorf("rpc error: %s", tt.err.Error()))) - }) } }