package api import ( "context" "errors" "fmt" "net/http" "strconv" "strings" "git.ctrlz.es/mgdelacroix/craban/model" "github.com/dgrijalva/jwt-go" "github.com/rs/zerolog/log" ) const userContextKey = "user" func (a *API) Login(w http.ResponseWriter, r *http.Request) { body := ParseBody(r) username := body.String("username") password := body.String("password") token, err := a.App.Login(username, password) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if token == "" { http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } w.Write([]byte(token)) } func (a *API) getUserFromToken(tokenStr string) (*model.User, error) { token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { // Validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) } return []byte(*a.App.Config.Secret), nil }) if err != nil { return nil, err } if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { userIDClaim, ok := claims["userID"].(string) if !ok { return nil, errors.New("userID claim is not set") } userID, err := strconv.Atoi(userIDClaim) if err != nil { return nil, err } return a.App.GetUserByID(userID) } else { return nil, errors.New("Malformed claims") } } func (a *API) Secured(fn func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { tokenStr := r.Header.Get("Authorization") if tokenStr == "" { http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } if strings.HasPrefix(tokenStr, "Bearer ") { tokenStr = tokenStr[7:] } user, err := a.getUserFromToken(tokenStr) if err != nil { log.Debug().Err(err).Str("token", tokenStr).Msg("cannot get user from token") http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } // get user ctx := context.WithValue(r.Context(), userContextKey, user) fn(w, r.Clone(ctx)) } } func UserFromRequest(r *http.Request) (*model.User, bool) { user, ok := r.Context().Value(userContextKey).(*model.User) return user, ok }