diff --git a/api/auth.go b/api/auth.go index 83229a9..d3ae561 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,16 +1,22 @@ package api import ( + "context" "encoding/json" "net/http" "os" "github.com/dannyvankooten/ana/datastore" - "github.com/dannyvankooten/ana/models" "github.com/gorilla/sessions" "golang.org/x/crypto/bcrypt" ) +type key int + +const ( + userKey key = 0 +) + type login struct { Email string `json:"email"` Password string `json:"password"` @@ -24,19 +30,16 @@ var LoginHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) // check login creds var l login json.NewDecoder(r.Body).Decode(&l) - var hashedPassword string - var u models.User - stmt, _ := datastore.DB.Prepare("SELECT id, email, password FROM users WHERE email = ? LIMIT 1") - err := stmt.QueryRow(l.Email).Scan(&u.ID, &u.Email, &hashedPassword) + + u, err := datastore.GetUserByEmail(l.Email) // compare pwd - if err != nil || bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(l.Password)) != nil { + if err != nil || bcrypt.CompareHashAndPassword([]byte(u.HashedPassword), []byte(l.Password)) != nil { w.WriteHeader(http.StatusUnauthorized) respond(w, envelope{Error: "invalid_credentials"}) return } - // TODO: Replace session filesystem store with DB store. session, _ := store.Get(r, "auth") session.Values["user_id"] = u.ID err = session.Save(r, w) @@ -68,14 +71,13 @@ func Authorize(next http.Handler) http.Handler { } // find user - var u models.User - stmt, _ := datastore.DB.Prepare("SELECT id, email FROM users WHERE id = ? LIMIT 1") - err := stmt.QueryRow(userID).Scan(&u.ID, &u.Email) + u, err := datastore.GetUser(userID.(int64)) if err != nil { w.WriteHeader(http.StatusUnauthorized) return } - next.ServeHTTP(w, r) + ctx := context.WithValue(r.Context(), userKey, u) + next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/datastore/users.go b/datastore/users.go new file mode 100644 index 0000000..90d7104 --- /dev/null +++ b/datastore/users.go @@ -0,0 +1,22 @@ +package datastore + +import ( + "database/sql" + "github.com/dannyvankooten/ana/models" +) + +var err error +var stmt *sql.Stmt +var u models.User + +func GetUser(id int64) (*models.User, error) { + stmt, err = DB.Prepare("SELECT id, email FROM users WHERE id = ? LIMIT 1") + err = stmt.QueryRow(id).Scan(&u.ID, &u.Email) + return &u, err +} + +func GetUserByEmail(email string) (*models.User, error) { + stmt, err = DB.Prepare("SELECT id, email, password FROM users WHERE email = ? LIMIT 1") + err := stmt.QueryRow(email).Scan(&u.ID, &u.Email, &u.HashedPassword) + return &u, err +} diff --git a/models/user.go b/models/user.go index 1ed74b0..7677717 100644 --- a/models/user.go +++ b/models/user.go @@ -5,9 +5,10 @@ import ( ) type User struct { - ID int64 - Email string - Password string `json:"-"` + ID int64 + Email string + Password string `json:"-"` + HashedPassword string `json:"-"` } func (u *User) Save(conn *sql.DB) error {