231 lines
5.3 KiB
Go
231 lines
5.3 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/sha256"
|
||
|
|
"encoding/hex"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
"github.com/gothinkster/golang-gin-realworld-example-app/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
// AuthResponse represents the response from auth center
|
||
|
|
type AuthResponse struct {
|
||
|
|
Success bool `json:"success"`
|
||
|
|
Message string `json:"message"`
|
||
|
|
Data AuthData `json:"data"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type AuthData struct {
|
||
|
|
Email string `json:"email"`
|
||
|
|
Roles []string `json:"roles"`
|
||
|
|
Status string `json:"status"`
|
||
|
|
TokenInfo TokenInfo `json:"tokenInfo"`
|
||
|
|
UserID int64 `json:"userID"`
|
||
|
|
Username string `json:"username"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type TokenInfo struct {
|
||
|
|
Exp string `json:"exp"`
|
||
|
|
Iat string `json:"iat"`
|
||
|
|
Jti string `json:"jti"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// UserContext represents the user data stored in gin context
|
||
|
|
type UserContext struct {
|
||
|
|
UserID int64 `json:"user_id"`
|
||
|
|
Username string `json:"username"`
|
||
|
|
Email string `json:"email"`
|
||
|
|
Roles []string `json:"roles"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// Simple in-memory cache (in production, use Redis)
|
||
|
|
var tokenCache = make(map[string]*CacheEntry)
|
||
|
|
|
||
|
|
type CacheEntry struct {
|
||
|
|
UserData *UserContext
|
||
|
|
ExpiresAt time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
// hashToken creates SHA256 hash of token
|
||
|
|
func hashToken(token string) string {
|
||
|
|
hash := sha256.Sum256([]byte(token))
|
||
|
|
return hex.EncodeToString(hash[:])
|
||
|
|
}
|
||
|
|
|
||
|
|
// validateTokenWithAuthCenter validates token with auth center
|
||
|
|
func validateTokenWithAuthCenter(token string) (*AuthData, error) {
|
||
|
|
cfg := config.AppConfig
|
||
|
|
url := fmt.Sprintf("%s/api/auth/validate-token", cfg.AuthCenterURL)
|
||
|
|
|
||
|
|
req, err := http.NewRequest("POST", url, nil)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
|
||
|
|
client := &http.Client{Timeout: 5 * time.Second}
|
||
|
|
resp, err := client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("auth service unavailable: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
body, err := io.ReadAll(resp.Body)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to read response: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
return nil, fmt.Errorf("auth failed: %s", string(body))
|
||
|
|
}
|
||
|
|
|
||
|
|
var authResp AuthResponse
|
||
|
|
if err := json.Unmarshal(body, &authResp); err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to parse response: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if !authResp.Success {
|
||
|
|
return nil, fmt.Errorf("token validation failed: %s", authResp.Message)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &authResp.Data, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// AuthMiddleware validates JWT token from auth center
|
||
|
|
func AuthMiddleware(required bool) gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
// Extract token from Authorization header
|
||
|
|
authHeader := c.GetHeader("Authorization")
|
||
|
|
if authHeader == "" {
|
||
|
|
if required {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||
|
|
"success": false,
|
||
|
|
"message": "Authorization header required",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parse Bearer token
|
||
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
||
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||
|
|
if required {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||
|
|
"success": false,
|
||
|
|
"message": "Invalid authorization header format",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
token := parts[1]
|
||
|
|
tokenHash := hashToken(token)
|
||
|
|
|
||
|
|
// Check cache first
|
||
|
|
if cached, exists := tokenCache[tokenHash]; exists {
|
||
|
|
if time.Now().Before(cached.ExpiresAt) {
|
||
|
|
// Cache hit and not expired
|
||
|
|
c.Set("user", cached.UserData)
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
// Cache expired, remove it
|
||
|
|
delete(tokenCache, tokenHash)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Validate with auth center
|
||
|
|
authData, err := validateTokenWithAuthCenter(token)
|
||
|
|
if err != nil {
|
||
|
|
if required {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||
|
|
"success": false,
|
||
|
|
"message": "Token validation failed",
|
||
|
|
"error": err.Error(),
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create user context
|
||
|
|
userCtx := &UserContext{
|
||
|
|
UserID: authData.UserID,
|
||
|
|
Username: authData.Username,
|
||
|
|
Email: authData.Email,
|
||
|
|
Roles: authData.Roles,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Cache the result (5 minutes)
|
||
|
|
tokenCache[tokenHash] = &CacheEntry{
|
||
|
|
UserData: userCtx,
|
||
|
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||
|
|
}
|
||
|
|
|
||
|
|
// Set user in context
|
||
|
|
c.Set("user", userCtx)
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetCurrentUser retrieves user from context
|
||
|
|
func GetCurrentUser(c *gin.Context) (*UserContext, bool) {
|
||
|
|
userVal, exists := c.Get("user")
|
||
|
|
if !exists {
|
||
|
|
return nil, false
|
||
|
|
}
|
||
|
|
|
||
|
|
user, ok := userVal.(*UserContext)
|
||
|
|
return user, ok
|
||
|
|
}
|
||
|
|
|
||
|
|
// RequireAdmin middleware to check admin role
|
||
|
|
func RequireAdmin() gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
user, exists := GetCurrentUser(c)
|
||
|
|
if !exists {
|
||
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||
|
|
"success": false,
|
||
|
|
"message": "Authentication required",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check if user has admin role
|
||
|
|
isAdmin := false
|
||
|
|
for _, role := range user.Roles {
|
||
|
|
if role == "admin" {
|
||
|
|
isAdmin = true
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if !isAdmin {
|
||
|
|
c.JSON(http.StatusForbidden, gin.H{
|
||
|
|
"success": false,
|
||
|
|
"message": "Admin access required",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|