diff --git a/cmd/waku/server/rest/filter.go b/cmd/waku/server/rest/filter.go index 30666c6f..d6719e03 100644 --- a/cmd/waku/server/rest/filter.go +++ b/cmd/waku/server/rest/filter.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "net/http" - "net/url" "strings" "github.com/go-chi/chi/v5" @@ -327,24 +326,24 @@ func (s FilterService) getRandomFilterPeer(ctx context.Context, requestId []byte } func (s *FilterService) getMessagesByContentTopic(w http.ResponseWriter, req *http.Request) { - contentTopic := s.topicFromPath(w, req, "contentTopic") + contentTopic := topicFromPath(w, req, "contentTopic", s.log) if contentTopic == "" { return } pubsubTopic, err := protocol.GetPubSubTopicFromContentTopic(contentTopic) if err != nil { - s.writeGetMessageErr(w, fmt.Errorf("bad content topic"), http.StatusBadRequest) + writeGetMessageErr(w, fmt.Errorf("bad content topic"), http.StatusBadRequest, s.log) return } s.getMessages(w, req, pubsubTopic, contentTopic) } func (s *FilterService) getMessagesByPubsubTopic(w http.ResponseWriter, req *http.Request) { - contentTopic := s.topicFromPath(w, req, "contentTopic") + contentTopic := topicFromPath(w, req, "contentTopic", s.log) if contentTopic == "" { return } - pubsubTopic := s.topicFromPath(w, req, "pubsubTopic") + pubsubTopic := topicFromPath(w, req, "pubsubTopic", s.log) if pubsubTopic == "" { return } @@ -358,33 +357,8 @@ func (s *FilterService) getMessagesByPubsubTopic(w http.ResponseWriter, req *htt func (s *FilterService) getMessages(w http.ResponseWriter, req *http.Request, pubsubTopic, contentTopic string) { msgs, err := s.cache.getMessages(pubsubTopic, contentTopic) if err != nil { - s.writeGetMessageErr(w, err, http.StatusNotFound) + writeGetMessageErr(w, err, http.StatusNotFound, s.log) return } writeResponse(w, msgs, http.StatusOK) } - -func (s *FilterService) topicFromPath(w http.ResponseWriter, req *http.Request, field string) string { - cTopic := chi.URLParam(req, field) - if cTopic == "" { - errMissing := fmt.Errorf("missing %s", field) - s.writeGetMessageErr(w, errMissing, http.StatusBadRequest) - return "" - } - cTopic, err := url.QueryUnescape(cTopic) - if err != nil { - errInvalid := fmt.Errorf("invalid %s format", field) - s.writeGetMessageErr(w, errInvalid, http.StatusBadRequest) - return "" - } - return cTopic -} - -func (s *FilterService) writeGetMessageErr(w http.ResponseWriter, err error, code int) { - // write status before the body - w.WriteHeader(code) - s.log.Error("get message", zap.Error(err)) - if _, err := w.Write([]byte(err.Error())); err != nil { - s.log.Error("writing response", zap.Error(err)) - } -} diff --git a/cmd/waku/server/rest/filter_cache.go b/cmd/waku/server/rest/filter_cache.go index 49fb57bf..6e684ab8 100644 --- a/cmd/waku/server/rest/filter_cache.go +++ b/cmd/waku/server/rest/filter_cache.go @@ -68,7 +68,7 @@ func (c *filterCache) getMessages(pubsubTopic string, contentTopic string) ([]*p defer c.mu.RUnlock() if c.data[pubsubTopic] == nil || c.data[pubsubTopic][contentTopic] == nil { - return nil, fmt.Errorf("Not subscribed to pubsubTopic:%s contentTopic: %s", pubsubTopic, contentTopic) + return nil, fmt.Errorf("not subscribed to pubsubTopic:%s contentTopic: %s", pubsubTopic, contentTopic) } msgs := c.data[pubsubTopic][contentTopic] c.data[pubsubTopic][contentTopic] = []*pb.WakuMessage{} diff --git a/cmd/waku/server/rest/relay.go b/cmd/waku/server/rest/relay.go index 1261f6ed..968ed338 100644 --- a/cmd/waku/server/rest/relay.go +++ b/cmd/waku/server/rest/relay.go @@ -166,20 +166,17 @@ func (r *RelayService) postV1Subscriptions(w http.ResponseWriter, req *http.Requ } func (r *RelayService) getV1Messages(w http.ResponseWriter, req *http.Request) { - topic := chi.URLParam(req, "topic") + topic := topicFromPath(w, req, "topic", r.log) if topic == "" { - w.WriteHeader(http.StatusBadRequest) return } - var err error - r.messagesMutex.Lock() defer r.messagesMutex.Unlock() if _, ok := r.messages[topic]; !ok { w.WriteHeader(http.StatusNotFound) - _, err = w.Write([]byte("not subscribed to topic")) + _, err := w.Write([]byte("not subscribed to topic")) r.log.Error("writing response", zap.Error(err)) return } @@ -191,9 +188,8 @@ func (r *RelayService) getV1Messages(w http.ResponseWriter, req *http.Request) { } func (r *RelayService) postV1Message(w http.ResponseWriter, req *http.Request) { - topic := chi.URLParam(req, "topic") + topic := topicFromPath(w, req, "topic", r.log) if topic == "" { - w.WriteHeader(http.StatusBadRequest) return } @@ -205,7 +201,6 @@ func (r *RelayService) postV1Message(w http.ResponseWriter, req *http.Request) { } defer req.Body.Close() - var err error if topic == "" { topic = relay.DefaultWakuTopic } @@ -215,12 +210,12 @@ func (r *RelayService) postV1Message(w http.ResponseWriter, req *http.Request) { return } - if err = server.AppendRLNProof(r.node, message); err != nil { + if err := server.AppendRLNProof(r.node, message); err != nil { writeErrOrResponse(w, err, nil) return } - _, err = r.node.Relay().Publish(req.Context(), message, relay.WithPubSubTopic(strings.Replace(topic, "\n", "", -1))) + _, err := r.node.Relay().Publish(req.Context(), message, relay.WithPubSubTopic(strings.Replace(topic, "\n", "", -1))) if err != nil { r.log.Error("publishing message", zap.Error(err)) } diff --git a/cmd/waku/server/rest/utils.go b/cmd/waku/server/rest/utils.go index 0178bc0c..caa81c59 100644 --- a/cmd/waku/server/rest/utils.go +++ b/cmd/waku/server/rest/utils.go @@ -2,7 +2,12 @@ package rest import ( "encoding/json" + "fmt" "net/http" + "net/url" + + "github.com/go-chi/chi/v5" + "go.uber.org/zap" ) func writeErrOrResponse(w http.ResponseWriter, err error, value interface{}) { @@ -40,3 +45,28 @@ func writeResponse(w http.ResponseWriter, value interface{}, code int) { w.WriteHeader(code) _, _ = w.Write(jsonResponse) } + +func topicFromPath(w http.ResponseWriter, req *http.Request, field string, logger *zap.Logger) string { + topic := chi.URLParam(req, field) + if topic == "" { + errMissing := fmt.Errorf("missing %s", field) + writeGetMessageErr(w, errMissing, http.StatusBadRequest, logger) + return "" + } + topic, err := url.QueryUnescape(topic) + if err != nil { + errInvalid := fmt.Errorf("invalid %s format", field) + writeGetMessageErr(w, errInvalid, http.StatusBadRequest, logger) + return "" + } + return topic +} + +func writeGetMessageErr(w http.ResponseWriter, err error, code int, logger *zap.Logger) { + // write status before the body + w.WriteHeader(code) + logger.Error("get message", zap.Error(err)) + if _, err := w.Write([]byte(err.Error())); err != nil { + logger.Error("writing response", zap.Error(err)) + } +}