src/internal/security/security.go
5,218 bytes · 211 lines · capsule://quake0day/[email protected]
raw on github
package security
import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"net/http"
"net/url"
"strings"
"time"
"scenemint/internal/config"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"github.com/sunls24/gox/server"
)
const (
csrfCookieName = "_scenemint_csrf"
csrfTokenBytes = 32
maxBodyBytes = 16 << 20
)
type SessionResponse struct {
CSRFToken string `json:"csrfToken"`
}
type Middleware struct {
secureCookies bool
}
func New(cfg config.Security) *Middleware {
return &Middleware{
secureCookies: cfg.SecureCookies,
}
}
func Headers() echo.MiddlewareFunc {
return middleware.SecureWithConfig(middleware.SecureConfig{
XFrameOptions: "DENY",
ContentSecurityPolicy: "frame-ancestors 'none'",
ContentTypeNosniff: "nosniff",
ReferrerPolicy: "same-origin",
})
}
func BodyLimit() echo.MiddlewareFunc {
return middleware.BodyLimit(maxBodyBytes)
}
func (m *Middleware) SourceGuard() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if !m.allowedSource(c) {
return reject(c, http.StatusForbidden, "请求来源不被允许")
}
return next(c)
}
}
}
func (m *Middleware) CSRF() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
cookie, err := c.Cookie(csrfCookieName)
if err != nil || !validToken(cookie.Value) {
return reject(c, http.StatusForbidden, "会话校验失败,请刷新页面后重试")
}
token := strings.TrimSpace(c.Request().Header.Get(echo.HeaderXCSRFToken))
if !sameToken(cookie.Value, token) {
return reject(c, http.StatusForbidden, "会话校验失败,请刷新页面后重试")
}
return next(c)
}
}
}
func (m *Middleware) Session(c *echo.Context) error {
token, err := m.ensureToken(c)
if err != nil {
return reject(c, http.StatusInternalServerError, "会话创建失败")
}
return c.JSON(http.StatusOK, server.Envelope{
Message: "ok",
Data: SessionResponse{
CSRFToken: token,
},
})
}
func (m *Middleware) ensureToken(c *echo.Context) (string, error) {
if cookie, err := c.Cookie(csrfCookieName); err == nil && validToken(cookie.Value) {
m.setTokenCookie(c, cookie.Value)
return cookie.Value, nil
}
token, err := newToken()
if err != nil {
return "", err
}
m.setTokenCookie(c, token)
return token, nil
}
func (m *Middleware) setTokenCookie(c *echo.Context, token string) {
c.SetCookie(&http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
MaxAge: 86400,
Expires: time.Now().Add(24 * time.Hour),
Secure: m.secureCookies,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
}
func (m *Middleware) allowedSource(c *echo.Context) bool {
req := c.Request()
secFetchSite := strings.ToLower(strings.TrimSpace(req.Header.Get(echo.HeaderSecFetchSite)))
if secFetchSite == "cross-site" {
return false
}
origin := strings.TrimSpace(req.Header.Get(echo.HeaderOrigin))
if origin != "" {
return sameOrigin(c, origin)
}
referer := strings.TrimSpace(req.Header.Get("Referer"))
if referer != "" {
refererOrigin, ok := refererOrigin(referer)
return ok && sameOrigin(c, refererOrigin)
}
return isSafeMethod(req.Method)
}
func sameOrigin(c *echo.Context, origin string) bool {
normalized, ok := normalizeOrigin(origin)
return ok && normalized == requestOrigin(c)
}
func requestOrigin(c *echo.Context) string {
return strings.ToLower(c.Scheme()) + "://" + strings.ToLower(c.Request().Host)
}
func normalizeOrigin(raw string) (string, bool) {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", false
}
if parsed.User != nil || parsed.RawQuery != "" || parsed.Fragment != "" {
return "", false
}
if parsed.Path != "" && parsed.Path != "/" {
return "", false
}
scheme := strings.ToLower(parsed.Scheme)
if scheme != "http" && scheme != "https" {
return "", false
}
return scheme + "://" + strings.ToLower(parsed.Host), true
}
func refererOrigin(raw string) (string, bool) {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", false
}
return normalizeOrigin(parsed.Scheme + "://" + parsed.Host)
}
func isSafeMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func newToken() (string, error) {
var data [csrfTokenBytes]byte
if _, err := rand.Read(data[:]); err != nil {
return "", err
}
return hex.EncodeToString(data[:]), nil
}
func validToken(token string) bool {
if len(token) != csrfTokenBytes*2 {
return false
}
_, err := hex.DecodeString(token)
return err == nil
}
func sameToken(expected string, actual string) bool {
if !validToken(actual) || len(expected) != len(actual) {
return false
}
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
}
func reject(c *echo.Context, status int, message string) error {
return c.JSON(status, server.ErrMsg(message).Envelope())
}