@@ -0,0 +1,115 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"ymhut-box/server/feedback-mailer/internal/config"
|
||||
)
|
||||
|
||||
type rateLimitSet struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*visitorBucket
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type visitorBucket struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
func newRateLimitSet(cfg *config.Config) *rateLimitSet {
|
||||
return &rateLimitSet{cfg: cfg, buckets: map[string]*visitorBucket{}}
|
||||
}
|
||||
|
||||
func (s *rateLimitSet) allow(kind, ip string) bool {
|
||||
if ip == "" {
|
||||
ip = "unknown"
|
||||
}
|
||||
limit, burst := s.policy(kind)
|
||||
key := kind + ":" + ip
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.buckets) > 4096 {
|
||||
for key, bucket := range s.buckets {
|
||||
if now.Sub(bucket.lastSeen) > 10*time.Minute {
|
||||
delete(s.buckets, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
bucket, ok := s.buckets[key]
|
||||
if !ok {
|
||||
bucket = &visitorBucket{limiter: rate.NewLimiter(limit, burst)}
|
||||
s.buckets[key] = bucket
|
||||
}
|
||||
bucket.lastSeen = now
|
||||
return bucket.limiter.Allow()
|
||||
}
|
||||
|
||||
func (s *rateLimitSet) middleware(kind string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !s.allow(kind, c.ClientIP()) {
|
||||
tooManyRequests(c)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rateLimitSet) adminMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
kind := "admin_read"
|
||||
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead && c.Request.Method != http.MethodOptions {
|
||||
kind = "admin_write"
|
||||
}
|
||||
if !s.allow(kind, c.ClientIP()) {
|
||||
tooManyRequests(c)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rateLimitSet) policy(kind string) (rate.Limit, int) {
|
||||
perMinute := s.cfg.RateLimit.AdminReadPerMinute
|
||||
burst := s.cfg.RateLimit.AdminReadBurst
|
||||
switch kind {
|
||||
case "submission":
|
||||
perMinute = s.cfg.RateLimit.SubmissionPerMinute
|
||||
burst = s.cfg.RateLimit.SubmissionBurst
|
||||
case "status":
|
||||
perMinute = s.cfg.RateLimit.StatusPerMinute
|
||||
burst = s.cfg.RateLimit.StatusBurst
|
||||
case "captcha":
|
||||
perMinute = s.cfg.RateLimit.CaptchaPerMinute
|
||||
burst = s.cfg.RateLimit.CaptchaBurst
|
||||
case "login":
|
||||
perMinute = s.cfg.RateLimit.LoginPerMinute
|
||||
burst = s.cfg.RateLimit.LoginBurst
|
||||
case "admin_write":
|
||||
perMinute = s.cfg.RateLimit.AdminWritePerMinute
|
||||
burst = s.cfg.RateLimit.AdminWriteBurst
|
||||
}
|
||||
if perMinute <= 0 {
|
||||
perMinute = 60
|
||||
}
|
||||
if burst <= 0 {
|
||||
burst = 5
|
||||
}
|
||||
return rate.Limit(float64(perMinute) / 60.0), burst
|
||||
}
|
||||
|
||||
func tooManyRequests(c *gin.Context) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"ok": false,
|
||||
"error": "RATE_LIMITED",
|
||||
"message": "Too many requests, please retry later",
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
Reference in New Issue
Block a user