diff --git a/assets/src/js/components/LogoutButton.js b/assets/src/js/components/LogoutButton.js
index 8a4fa7a..c3e333a 100644
--- a/assets/src/js/components/LogoutButton.js
+++ b/assets/src/js/components/LogoutButton.js
@@ -16,6 +16,10 @@ class LogoutButton extends Component {
}
render() {
+ if(document.cookie.indexOf('auth') < 0) {
+ return ''
+ }
+
return (
Sign out
)
diff --git a/assets/src/js/pages/dashboard.js b/assets/src/js/pages/dashboard.js
index 6e95a2b..5947304 100644
--- a/assets/src/js/pages/dashboard.js
+++ b/assets/src/js/pages/dashboard.js
@@ -17,6 +17,7 @@ class Dashboard extends Component {
period: (window.location.hash.substring(2) || 'last-7-days'),
before: 0,
after: 0,
+ isPublic: document.cookie.indexOf('auth') < 0,
}
}
@@ -27,6 +28,11 @@ class Dashboard extends Component {
}
render(props, state) {
+ // only show logout link if this dashboard is not public
+ let logoutMenuItem = state.isPublic ? '' : (
+
@@ -34,9 +40,8 @@ class Dashboard extends Component {
diff --git a/assets/src/js/script.js b/assets/src/js/script.js
index f4f5e5c..9c67338 100644
--- a/assets/src/js/script.js
+++ b/assets/src/js/script.js
@@ -4,6 +4,7 @@ import { h, render, Component } from 'preact'
import Login from './pages/login.js'
import Dashboard from './pages/dashboard.js'
import { bind } from 'decko';
+import Client from './lib/client.js';
class App extends Component {
constructor(props) {
@@ -12,6 +13,16 @@ class App extends Component {
this.state = {
authenticated: document.cookie.indexOf('auth') > -1
}
+
+ this.fetchAuthStatus()
+ }
+
+ @bind
+ fetchAuthStatus() {
+ Client.request(`session`)
+ .then((d) => {
+ this.setState({ authenticated: d })
+ })
}
@bind
diff --git a/pkg/api/auth.go b/pkg/api/auth.go
index 04526c1..e33d961 100644
--- a/pkg/api/auth.go
+++ b/pkg/api/auth.go
@@ -1,7 +1,6 @@
package api
import (
- "context"
"encoding/json"
"net/http"
"strings"
@@ -25,8 +24,30 @@ func (l *login) Sanitize() {
l.Email = strings.ToLower(strings.TrimSpace(l.Email))
}
+// GET /api/session
+func (api *API) GetSession(w http.ResponseWriter, r *http.Request) error {
+ userCount, err := api.database.CountUsers()
+ if err != nil {
+ return err
+ }
+
+ // if 0 users in database, dashboard is public
+ if userCount == 0 {
+ return respond(w, envelope{Data: true})
+ }
+
+ // if existing session, assume logged-in
+ session, _ := api.sessions.Get(r, "auth")
+ if !session.IsNew {
+ respond(w, envelope{Data: true})
+ }
+
+ // otherwise: not logged-in yet
+ return respond(w, envelope{Data: false})
+}
+
// URL: POST /api/session
-func (api *API) LoginHandler(w http.ResponseWriter, r *http.Request) error {
+func (api *API) CreateSession(w http.ResponseWriter, r *http.Request) error {
// check login creds
var l login
err := json.NewDecoder(r.Body).Decode(&l)
@@ -59,7 +80,7 @@ func (api *API) LoginHandler(w http.ResponseWriter, r *http.Request) error {
}
// URL: DELETE /api/session
-func (api *API) LogoutHandler(w http.ResponseWriter, r *http.Request) error {
+func (api *API) DeleteSession(w http.ResponseWriter, r *http.Request) error {
session, _ := api.sessions.Get(r, "auth")
if !session.IsNew {
session.Options.MaxAge = -1
@@ -79,27 +100,35 @@ func (api *API) Authorize(next http.Handler) http.Handler {
// see http://www.gorillatoolkit.org/pkg/sessions#overview
defer gcontext.Clear(r)
- session, err := api.sessions.Get(r, "auth")
- // an err is returned if cookie has been tampered with, so check that
+ // first count users in datastore
+ // if 0, assume dashboard is public
+ userCount, err := api.database.CountUsers()
if err != nil {
- w.WriteHeader(http.StatusUnauthorized)
+ w.WriteHeader(http.StatusInternalServerError)
return
}
- userID, ok := session.Values["user_id"]
- if session.IsNew || !ok {
- w.WriteHeader(http.StatusUnauthorized)
- return
+ if userCount > 0 {
+ session, err := api.sessions.Get(r, "auth")
+ // an err is returned if cookie has been tampered with, so check that
+ if err != nil {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ userID, ok := session.Values["user_id"]
+ if session.IsNew || !ok {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ // validate user ID in session
+ if _, err := api.database.GetUser(userID.(int64)); err != nil {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
}
- // find user
- u, err := api.database.GetUser(userID.(int64))
- if err != nil {
- w.WriteHeader(http.StatusUnauthorized)
- return
- }
-
- ctx := context.WithValue(r.Context(), userKey, u)
- next.ServeHTTP(w, r.WithContext(ctx))
+ next.ServeHTTP(w, r)
})
}
diff --git a/pkg/api/routes.go b/pkg/api/routes.go
index 7553037..e64982f 100644
--- a/pkg/api/routes.go
+++ b/pkg/api/routes.go
@@ -11,8 +11,10 @@ func (api *API) Routes() *mux.Router {
// register routes
r := mux.NewRouter()
r.Handle("/collect", NewCollector(api.database)).Methods(http.MethodGet)
- r.Handle("/api/session", HandlerFunc(api.LoginHandler)).Methods(http.MethodPost)
- r.Handle("/api/session", HandlerFunc(api.LogoutHandler)).Methods(http.MethodDelete)
+
+ r.Handle("/api/session", HandlerFunc(api.GetSession)).Methods(http.MethodGet)
+ r.Handle("/api/session", HandlerFunc(api.CreateSession)).Methods(http.MethodPost)
+ r.Handle("/api/session", HandlerFunc(api.DeleteSession)).Methods(http.MethodDelete)
r.Handle("/api/stats/site", api.Authorize(HandlerFunc(api.GetSiteStatsHandler))).Methods(http.MethodGet)
r.Handle("/api/stats/site/groupby/day", api.Authorize(HandlerFunc(api.GetSiteStatsPerDayHandler))).Methods(http.MethodGet)