Addressed feedback from @ilmotta

This commit is contained in:
Samuel Hawksby-Robinson 2023-03-22 12:58:09 +00:00
parent 43c2bc24d7
commit b16631bbc3
2 changed files with 45 additions and 34 deletions

View File

@ -65,7 +65,7 @@ func NewChallengeGiver(e *PayloadEncryptor, logger *zap.Logger) (*ChallengeGiver
}, nil
}
func (cg *ChallengeGiver) getIP(r *http.Request) (net.IP, *ChallengeError) {
func (cg *ChallengeGiver) getIP(r *http.Request) (net.IP, error) {
h, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
cg.logger.Error("getIP: h, _, err := net.SplitHostPort(r.RemoteAddr)", zap.Error(err), zap.String("r.RemoteAddr", r.RemoteAddr))
@ -74,16 +74,16 @@ func (cg *ChallengeGiver) getIP(r *http.Request) (net.IP, *ChallengeError) {
return net.ParseIP(h), nil
}
func (cg *ChallengeGiver) registerClientIP(r *http.Request) *ChallengeError {
hIP, err := cg.getIP(r)
func (cg *ChallengeGiver) registerClientIP(r *http.Request) error {
IP, err := cg.getIP(r)
if err != nil {
return err
}
cg.authedIP = hIP
cg.authedIP = IP
return nil
}
func (cg *ChallengeGiver) validateClientIP(r *http.Request) *ChallengeError {
func (cg *ChallengeGiver) validateClientIP(r *http.Request) error {
// If we haven't registered yet register the IP
if cg.authedIP == nil || len(cg.authedIP) == 0 {
err := cg.registerClientIP(r)
@ -93,23 +93,23 @@ func (cg *ChallengeGiver) validateClientIP(r *http.Request) *ChallengeError {
}
// Then compare the current req RemoteIP with the authed IP
hIP, err := cg.getIP(r)
IP, err := cg.getIP(r)
if err != nil {
return err
}
if !cg.authedIP.Equal(hIP) {
if !cg.authedIP.Equal(IP) {
cg.logger.Error(
"request RemoteAddr does not match authedIP: expected '%s', received '%s'",
zap.String("expected", cg.authedIP.String()),
zap.String("received", hIP.String()),
zap.String("received", IP.String()),
)
return &ChallengeError{"forbidden", http.StatusForbidden}
}
return nil
}
func (cg *ChallengeGiver) getSession(r *http.Request) (*sessions.Session, *ChallengeError) {
func (cg *ChallengeGiver) getSession(r *http.Request) (*sessions.Session, error) {
s, err := cg.cookieStore.Get(r, sessionChallenge)
if err != nil {
cg.logger.Error("checkChallengeResponse: cg.cookieStore.Get(r, sessionChallenge)", zap.Error(err), zap.String("sessionChallenge", sessionChallenge))
@ -118,7 +118,7 @@ func (cg *ChallengeGiver) getSession(r *http.Request) (*sessions.Session, *Chall
return s, nil
}
func (cg *ChallengeGiver) generateNewChallenge(s *sessions.Session, w http.ResponseWriter, r *http.Request) ([]byte, *ChallengeError) {
func (cg *ChallengeGiver) generateNewChallenge(s *sessions.Session, w http.ResponseWriter, r *http.Request) ([]byte, error) {
challenge := make([]byte, 64)
_, err := rand.Read(challenge)
if err != nil {
@ -136,7 +136,7 @@ func (cg *ChallengeGiver) generateNewChallenge(s *sessions.Session, w http.Respo
return challenge, nil
}
func (cg *ChallengeGiver) block(s *sessions.Session, w http.ResponseWriter, r *http.Request) *ChallengeError {
func (cg *ChallengeGiver) block(s *sessions.Session, w http.ResponseWriter, r *http.Request) error {
s.Values[sessionBlocked] = true
err := s.Save(r, w)
if err != nil {
@ -147,15 +147,15 @@ func (cg *ChallengeGiver) block(s *sessions.Session, w http.ResponseWriter, r *h
return &ChallengeError{"forbidden", http.StatusForbidden}
}
func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.Request) *ChallengeError {
ce := cg.validateClientIP(r)
if ce != nil {
return ce
func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.Request) error {
err := cg.validateClientIP(r)
if err != nil {
return err
}
s, ce := cg.getSession(r)
if ce != nil {
return ce
s, err := cg.getSession(r)
if err != nil {
return err
}
blocked, ok := s.Values[sessionBlocked].(bool)
@ -164,14 +164,14 @@ func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.
}
// If the request header doesn't include a challenge don't punish the client, just throw a 403
pc := r.Header.Get(sessionChallenge)
if pc == "" {
clientChallengeResp := r.Header.Get(sessionChallenge)
if clientChallengeResp == "" {
return &ChallengeError{"forbidden", http.StatusForbidden}
}
c, err := cg.encryptor.decryptPlain(base58.Decode(pc))
dcr, err := cg.encryptor.decryptPlain(base58.Decode(clientChallengeResp))
if err != nil {
cg.logger.Error("checkChallengeResponse: cg.encryptor.decryptPlain(base58.Decode(pc))", zap.Error(err), zap.String("pc", pc))
cg.logger.Error("checkChallengeResponse: cg.encryptor.decryptPlain(base58.Decode(clientChallengeResp))", zap.Error(err), zap.String("clientChallengeResp", clientChallengeResp))
return &ChallengeError{"error", http.StatusInternalServerError}
}
@ -183,16 +183,16 @@ func (cg *ChallengeGiver) checkChallengeResponse(w http.ResponseWriter, r *http.
// Only if we have both a challenge in the session store and in the request header
// do we entertain blocking the client. Because then we know someone is trying to be sneaky.
if !bytes.Equal(c, challenge) {
if !bytes.Equal(dcr, challenge) {
return cg.block(s, w, r)
}
// If every is ok, generate a new challenge for the next req
_, ce = cg.generateNewChallenge(s, w, r)
return ce
_, err = cg.generateNewChallenge(s, w, r)
return err
}
func (cg *ChallengeGiver) getChallenge(w http.ResponseWriter, r *http.Request) ([]byte, *ChallengeError) {
func (cg *ChallengeGiver) getChallenge(w http.ResponseWriter, r *http.Request) ([]byte, error) {
err := cg.validateClientIP(r)
if err != nil {
return nil, err

View File

@ -183,9 +183,14 @@ func handleSendInstallation(hs HandlerServer, pmr PayloadMounterReceiver) http.H
func middlewareChallenge(cg *ChallengeGiver, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ce := cg.checkChallengeResponse(w, r)
if ce != nil {
http.Error(w, ce.Text, ce.HTTPCode)
err := cg.checkChallengeResponse(w, r)
if err != nil {
if cErr, ok := err.(*ChallengeError); ok {
http.Error(w, cErr.Text, cErr.HTTPCode)
return
}
cg.logger.Error("failed to checkChallengeResponse in middlewareChallenge", zap.Error(err))
http.Error(w, "error", http.StatusInternalServerError)
return
}
@ -195,16 +200,22 @@ func middlewareChallenge(cg *ChallengeGiver, next http.Handler) http.HandlerFunc
func handlePairingChallenge(cg *ChallengeGiver) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
challenge, ce := cg.getChallenge(w, r)
if ce != nil {
http.Error(w, ce.Text, ce.HTTPCode)
challenge, err := cg.getChallenge(w, r)
if err != nil {
if cErr, ok := err.(*ChallengeError); ok {
http.Error(w, cErr.Text, cErr.HTTPCode)
return
}
cg.logger.Error("failed to getChallenge in handlePairingChallenge", zap.Error(err))
http.Error(w, "error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
_, err := w.Write(challenge)
_, err = w.Write(challenge)
if err != nil {
cg.logger.Error("handlePairingChallenge: _, err = w.Write(challenge)", zap.Error(err))
cg.logger.Error("failed to Write(challenge) in handlePairingChallenge", zap.Error(err))
http.Error(w, "error", http.StatusInternalServerError)
return
}
}