Addressed feedback from @ilmotta
This commit is contained in:
parent
43c2bc24d7
commit
b16631bbc3
|
@ -65,7 +65,7 @@ func NewChallengeGiver(e *PayloadEncryptor, logger *zap.Logger) (*ChallengeGiver
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) getIP(r *http.Request) (net.IP, *ChallengeError) {
|
||||
func (cg *ChallengeGiver) getIP(r *http.Request) (net.IP, error) {
|
||||
h, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
cg.logger.Error("getIP: h, _, err := net.SplitHostPort(r.RemoteAddr)", zap.Error(err), zap.String("r.RemoteAddr", r.RemoteAddr))
|
||||
|
@ -74,16 +74,16 @@ func (cg *ChallengeGiver) getIP(r *http.Request) (net.IP, *ChallengeError) {
|
|||
return net.ParseIP(h), nil
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) registerClientIP(r *http.Request) *ChallengeError {
|
||||
hIP, err := cg.getIP(r)
|
||||
func (cg *ChallengeGiver) registerClientIP(r *http.Request) error {
|
||||
IP, err := cg.getIP(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cg.authedIP = hIP
|
||||
cg.authedIP = IP
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) validateClientIP(r *http.Request) *ChallengeError {
|
||||
func (cg *ChallengeGiver) validateClientIP(r *http.Request) error {
|
||||
// If we haven't registered yet register the IP
|
||||
if cg.authedIP == nil || len(cg.authedIP) == 0 {
|
||||
err := cg.registerClientIP(r)
|
||||
|
@ -93,23 +93,23 @@ func (cg *ChallengeGiver) validateClientIP(r *http.Request) *ChallengeError {
|
|||
}
|
||||
|
||||
// Then compare the current req RemoteIP with the authed IP
|
||||
hIP, err := cg.getIP(r)
|
||||
IP, err := cg.getIP(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !cg.authedIP.Equal(hIP) {
|
||||
if !cg.authedIP.Equal(IP) {
|
||||
cg.logger.Error(
|
||||
"request RemoteAddr does not match authedIP: expected '%s', received '%s'",
|
||||
zap.String("expected", cg.authedIP.String()),
|
||||
zap.String("received", hIP.String()),
|
||||
zap.String("received", IP.String()),
|
||||
)
|
||||
return &ChallengeError{"forbidden", http.StatusForbidden}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) getSession(r *http.Request) (*sessions.Session, *ChallengeError) {
|
||||
func (cg *ChallengeGiver) getSession(r *http.Request) (*sessions.Session, error) {
|
||||
s, err := cg.cookieStore.Get(r, sessionChallenge)
|
||||
if err != nil {
|
||||
cg.logger.Error("checkChallengeResponse: cg.cookieStore.Get(r, sessionChallenge)", zap.Error(err), zap.String("sessionChallenge", sessionChallenge))
|
||||
|
@ -118,7 +118,7 @@ func (cg *ChallengeGiver) getSession(r *http.Request) (*sessions.Session, *Chall
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) generateNewChallenge(s *sessions.Session, w http.ResponseWriter, r *http.Request) ([]byte, *ChallengeError) {
|
||||
func (cg *ChallengeGiver) generateNewChallenge(s *sessions.Session, w http.ResponseWriter, r *http.Request) ([]byte, error) {
|
||||
challenge := make([]byte, 64)
|
||||
_, err := rand.Read(challenge)
|
||||
if err != nil {
|
||||
|
@ -136,7 +136,7 @@ func (cg *ChallengeGiver) generateNewChallenge(s *sessions.Session, w http.Respo
|
|||
return challenge, nil
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) block(s *sessions.Session, w http.ResponseWriter, r *http.Request) *ChallengeError {
|
||||
func (cg *ChallengeGiver) block(s *sessions.Session, w http.ResponseWriter, r *http.Request) error {
|
||||
s.Values[sessionBlocked] = true
|
||||
err := s.Save(r, w)
|
||||
if err != nil {
|
||||
|
@ -147,15 +147,15 @@ func (cg *ChallengeGiver) block(s *sessions.Session, w http.ResponseWriter, r *h
|
|||
return &ChallengeError{"forbidden", http.StatusForbidden}
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.Request) *ChallengeError {
|
||||
ce := cg.validateClientIP(r)
|
||||
if ce != nil {
|
||||
return ce
|
||||
func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.Request) error {
|
||||
err := cg.validateClientIP(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s, ce := cg.getSession(r)
|
||||
if ce != nil {
|
||||
return ce
|
||||
s, err := cg.getSession(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
blocked, ok := s.Values[sessionBlocked].(bool)
|
||||
|
@ -164,14 +164,14 @@ func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.
|
|||
}
|
||||
|
||||
// If the request header doesn't include a challenge don't punish the client, just throw a 403
|
||||
pc := r.Header.Get(sessionChallenge)
|
||||
if pc == "" {
|
||||
clientChallengeResp := r.Header.Get(sessionChallenge)
|
||||
if clientChallengeResp == "" {
|
||||
return &ChallengeError{"forbidden", http.StatusForbidden}
|
||||
}
|
||||
|
||||
c, err := cg.encryptor.decryptPlain(base58.Decode(pc))
|
||||
dcr, err := cg.encryptor.decryptPlain(base58.Decode(clientChallengeResp))
|
||||
if err != nil {
|
||||
cg.logger.Error("checkChallengeResponse: cg.encryptor.decryptPlain(base58.Decode(pc))", zap.Error(err), zap.String("pc", pc))
|
||||
cg.logger.Error("checkChallengeResponse: cg.encryptor.decryptPlain(base58.Decode(clientChallengeResp))", zap.Error(err), zap.String("clientChallengeResp", clientChallengeResp))
|
||||
return &ChallengeError{"error", http.StatusInternalServerError}
|
||||
}
|
||||
|
||||
|
@ -183,16 +183,16 @@ func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.
|
|||
|
||||
// Only if we have both a challenge in the session store and in the request header
|
||||
// do we entertain blocking the client. Because then we know someone is trying to be sneaky.
|
||||
if !bytes.Equal(c, challenge) {
|
||||
if !bytes.Equal(dcr, challenge) {
|
||||
return cg.block(s, w, r)
|
||||
}
|
||||
|
||||
// If every is ok, generate a new challenge for the next req
|
||||
_, ce = cg.generateNewChallenge(s, w, r)
|
||||
return ce
|
||||
_, err = cg.generateNewChallenge(s, w, r)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cg *ChallengeGiver) getChallenge(w http.ResponseWriter, r *http.Request) ([]byte, *ChallengeError) {
|
||||
func (cg *ChallengeGiver) getChallenge(w http.ResponseWriter, r *http.Request) ([]byte, error) {
|
||||
err := cg.validateClientIP(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -183,9 +183,14 @@ func handleSendInstallation(hs HandlerServer, pmr PayloadMounterReceiver) http.H
|
|||
|
||||
func middlewareChallenge(cg *ChallengeGiver, next http.Handler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ce := cg.checkChallengeResponse(w, r)
|
||||
if ce != nil {
|
||||
http.Error(w, ce.Text, ce.HTTPCode)
|
||||
err := cg.checkChallengeResponse(w, r)
|
||||
if err != nil {
|
||||
if cErr, ok := err.(*ChallengeError); ok {
|
||||
http.Error(w, cErr.Text, cErr.HTTPCode)
|
||||
return
|
||||
}
|
||||
cg.logger.Error("failed to checkChallengeResponse in middlewareChallenge", zap.Error(err))
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -195,16 +200,22 @@ func middlewareChallenge(cg *ChallengeGiver, next http.Handler) http.HandlerFunc
|
|||
|
||||
func handlePairingChallenge(cg *ChallengeGiver) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
challenge, ce := cg.getChallenge(w, r)
|
||||
if ce != nil {
|
||||
http.Error(w, ce.Text, ce.HTTPCode)
|
||||
challenge, err := cg.getChallenge(w, r)
|
||||
if err != nil {
|
||||
if cErr, ok := err.(*ChallengeError); ok {
|
||||
http.Error(w, cErr.Text, cErr.HTTPCode)
|
||||
return
|
||||
}
|
||||
cg.logger.Error("failed to getChallenge in handlePairingChallenge", zap.Error(err))
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, err := w.Write(challenge)
|
||||
_, err = w.Write(challenge)
|
||||
if err != nil {
|
||||
cg.logger.Error("handlePairingChallenge: _, err = w.Write(challenge)", zap.Error(err))
|
||||
cg.logger.Error("failed to Write(challenge) in handlePairingChallenge", zap.Error(err))
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue