diff --git a/assets/src/js/components/LogoutButton.js b/assets/src/js/components/LogoutButton.js index 8a4fa7a..c3e333a 100644 --- a/assets/src/js/components/LogoutButton.js +++ b/assets/src/js/components/LogoutButton.js @@ -16,6 +16,10 @@ class LogoutButton extends Component { } render() { + if(document.cookie.indexOf('auth') < 0) { + return '' + } + return ( Sign out ) diff --git a/assets/src/js/pages/dashboard.js b/assets/src/js/pages/dashboard.js index 6e95a2b..5947304 100644 --- a/assets/src/js/pages/dashboard.js +++ b/assets/src/js/pages/dashboard.js @@ -17,6 +17,7 @@ class Dashboard extends Component { period: (window.location.hash.substring(2) || 'last-7-days'), before: 0, after: 0, + isPublic: document.cookie.indexOf('auth') < 0, } } @@ -27,6 +28,11 @@ class Dashboard extends Component { } render(props, state) { + // only show logout link if this dashboard is not public + let logoutMenuItem = state.isPublic ? '' : ( +
  • ·
  • + ); + return (
    @@ -34,9 +40,8 @@ class Dashboard extends Component { diff --git a/assets/src/js/script.js b/assets/src/js/script.js index f4f5e5c..9c67338 100644 --- a/assets/src/js/script.js +++ b/assets/src/js/script.js @@ -4,6 +4,7 @@ import { h, render, Component } from 'preact' import Login from './pages/login.js' import Dashboard from './pages/dashboard.js' import { bind } from 'decko'; +import Client from './lib/client.js'; class App extends Component { constructor(props) { @@ -12,6 +13,16 @@ class App extends Component { this.state = { authenticated: document.cookie.indexOf('auth') > -1 } + + this.fetchAuthStatus() + } + + @bind + fetchAuthStatus() { + Client.request(`session`) + .then((d) => { + this.setState({ authenticated: d }) + }) } @bind diff --git a/pkg/api/auth.go b/pkg/api/auth.go index 04526c1..e33d961 100644 --- a/pkg/api/auth.go +++ b/pkg/api/auth.go @@ -1,7 +1,6 @@ package api import ( - "context" "encoding/json" "net/http" "strings" @@ -25,8 +24,30 @@ func (l *login) Sanitize() { l.Email = strings.ToLower(strings.TrimSpace(l.Email)) } +// GET /api/session +func (api *API) GetSession(w http.ResponseWriter, r *http.Request) error { + userCount, err := api.database.CountUsers() + if err != nil { + return err + } + + // if 0 users in database, dashboard is public + if userCount == 0 { + return respond(w, envelope{Data: true}) + } + + // if existing session, assume logged-in + session, _ := api.sessions.Get(r, "auth") + if !session.IsNew { + respond(w, envelope{Data: true}) + } + + // otherwise: not logged-in yet + return respond(w, envelope{Data: false}) +} + // URL: POST /api/session -func (api *API) LoginHandler(w http.ResponseWriter, r *http.Request) error { +func (api *API) CreateSession(w http.ResponseWriter, r *http.Request) error { // check login creds var l login err := json.NewDecoder(r.Body).Decode(&l) @@ -59,7 +80,7 @@ func (api *API) LoginHandler(w http.ResponseWriter, r *http.Request) error { } // URL: DELETE /api/session -func (api *API) LogoutHandler(w http.ResponseWriter, r *http.Request) error { +func (api *API) DeleteSession(w http.ResponseWriter, r *http.Request) error { session, _ := api.sessions.Get(r, "auth") if !session.IsNew { session.Options.MaxAge = -1 @@ -79,27 +100,35 @@ func (api *API) Authorize(next http.Handler) http.Handler { // see http://www.gorillatoolkit.org/pkg/sessions#overview defer gcontext.Clear(r) - session, err := api.sessions.Get(r, "auth") - // an err is returned if cookie has been tampered with, so check that + // first count users in datastore + // if 0, assume dashboard is public + userCount, err := api.database.CountUsers() if err != nil { - w.WriteHeader(http.StatusUnauthorized) + w.WriteHeader(http.StatusInternalServerError) return } - userID, ok := session.Values["user_id"] - if session.IsNew || !ok { - w.WriteHeader(http.StatusUnauthorized) - return + if userCount > 0 { + session, err := api.sessions.Get(r, "auth") + // an err is returned if cookie has been tampered with, so check that + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + userID, ok := session.Values["user_id"] + if session.IsNew || !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // validate user ID in session + if _, err := api.database.GetUser(userID.(int64)); err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } } - // find user - u, err := api.database.GetUser(userID.(int64)) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - return - } - - ctx := context.WithValue(r.Context(), userKey, u) - next.ServeHTTP(w, r.WithContext(ctx)) + next.ServeHTTP(w, r) }) } diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 7553037..e64982f 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -11,8 +11,10 @@ func (api *API) Routes() *mux.Router { // register routes r := mux.NewRouter() r.Handle("/collect", NewCollector(api.database)).Methods(http.MethodGet) - r.Handle("/api/session", HandlerFunc(api.LoginHandler)).Methods(http.MethodPost) - r.Handle("/api/session", HandlerFunc(api.LogoutHandler)).Methods(http.MethodDelete) + + r.Handle("/api/session", HandlerFunc(api.GetSession)).Methods(http.MethodGet) + r.Handle("/api/session", HandlerFunc(api.CreateSession)).Methods(http.MethodPost) + r.Handle("/api/session", HandlerFunc(api.DeleteSession)).Methods(http.MethodDelete) r.Handle("/api/stats/site", api.Authorize(HandlerFunc(api.GetSiteStatsHandler))).Methods(http.MethodGet) r.Handle("/api/stats/site/groupby/day", api.Authorize(HandlerFunc(api.GetSiteStatsPerDayHandler))).Methods(http.MethodGet)