@@ -24,13 +24,17 @@ const (
|
||||
SessionCookie = "ymhut_unified_session"
|
||||
captchaTTL = 5 * time.Minute
|
||||
sessionTTL = 12 * time.Hour
|
||||
loginWindow = 5 * time.Minute
|
||||
loginLockTTL = 5 * time.Minute
|
||||
loginMaxFails = 5
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
store *db.Store
|
||||
mu sync.Mutex
|
||||
captchas map[string]captchaEntry
|
||||
sessions map[string]sessionEntry
|
||||
store *db.Store
|
||||
mu sync.Mutex
|
||||
captchas map[string]captchaEntry
|
||||
sessions map[string]sessionEntry
|
||||
loginAttempts map[string]loginAttempt
|
||||
}
|
||||
|
||||
type captchaEntry struct {
|
||||
@@ -44,6 +48,12 @@ type sessionEntry struct {
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type loginAttempt struct {
|
||||
failures int
|
||||
lastFailure time.Time
|
||||
lockedUntil time.Time
|
||||
}
|
||||
|
||||
type Captcha struct {
|
||||
ID string `json:"captchaId"`
|
||||
Image string `json:"image"`
|
||||
@@ -51,9 +61,10 @@ type Captcha struct {
|
||||
|
||||
func NewService(store *db.Store) *Service {
|
||||
return &Service{
|
||||
store: store,
|
||||
captchas: map[string]captchaEntry{},
|
||||
sessions: map[string]sessionEntry{},
|
||||
store: store,
|
||||
captchas: map[string]captchaEntry{},
|
||||
sessions: map[string]sessionEntry{},
|
||||
loginAttempts: map[string]loginAttempt{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,12 +102,18 @@ func (s *Service) NewCaptcha() (Captcha, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, username, password, captchaID, captcha string) (string, string, bool, error) {
|
||||
func (s *Service) Login(ctx context.Context, username, password, captchaID, captcha string, clientKeys ...string) (string, string, bool, error) {
|
||||
attemptKey := loginAttemptKey(username, clientKeys...)
|
||||
if s.loginLocked(attemptKey) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
if !s.consumeCaptcha(captchaID, captcha) {
|
||||
s.recordLoginFailure(attemptKey)
|
||||
return "", "", false, nil
|
||||
}
|
||||
user, ok, err := s.store.VerifyAdminPassword(ctx, username, password)
|
||||
if err != nil || !ok {
|
||||
s.recordLoginFailure(attemptKey)
|
||||
return "", "", false, err
|
||||
}
|
||||
sessionID := randomToken(32)
|
||||
@@ -104,6 +121,7 @@ func (s *Service) Login(ctx context.Context, username, password, captchaID, capt
|
||||
s.mu.Lock()
|
||||
s.cleanupLocked()
|
||||
s.sessions[sessionID] = sessionEntry{username: user.Username, csrf: csrf, expiresAt: time.Now().Add(sessionTTL)}
|
||||
delete(s.loginAttempts, attemptKey)
|
||||
s.mu.Unlock()
|
||||
return sessionID, csrf, true, nil
|
||||
}
|
||||
@@ -136,13 +154,13 @@ func (s *Service) Require(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, csrf, ok := s.UserForRequest(r)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"ok": false, "error": "UNAUTHORIZED", "message": "Login required"})
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"ok": false, "error": "UNAUTHORIZED", "message": "需要登录后继续操作"})
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
|
||||
actual := r.Header.Get("X-CSRF-Token")
|
||||
if actual == "" || subtle.ConstantTimeCompare([]byte(csrf), []byte(actual)) != 1 {
|
||||
writeJSON(w, http.StatusForbidden, map[string]any{"ok": false, "error": "CSRF_INVALID", "message": "Invalid CSRF token"})
|
||||
writeJSON(w, http.StatusForbidden, map[string]any{"ok": false, "error": "CSRF_INVALID", "message": "页面安全令牌无效,请刷新后重试"})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -151,6 +169,14 @@ func (s *Service) Require(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
func SetSessionCookie(w http.ResponseWriter, sessionID string) {
|
||||
setSessionCookie(w, sessionID, false)
|
||||
}
|
||||
|
||||
func SetSessionCookieForRequest(w http.ResponseWriter, r *http.Request, sessionID string) {
|
||||
setSessionCookie(w, sessionID, requestIsHTTPS(r))
|
||||
}
|
||||
|
||||
func setSessionCookie(w http.ResponseWriter, sessionID string, secure bool) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookie,
|
||||
Value: sessionID,
|
||||
@@ -158,6 +184,7 @@ func SetSessionCookie(w http.ResponseWriter, sessionID string) {
|
||||
MaxAge: int(sessionTTL.Seconds()),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: secure,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -165,6 +192,16 @@ func clearCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{Name: SessionCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: true, SameSite: http.SameSiteLaxMode})
|
||||
}
|
||||
|
||||
func requestIsHTTPS(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")), "https")
|
||||
}
|
||||
|
||||
func (s *Service) consumeCaptcha(id, answer string) bool {
|
||||
id = strings.TrimSpace(id)
|
||||
answer = strings.TrimSpace(answer)
|
||||
@@ -193,6 +230,50 @@ func (s *Service) cleanupLocked() {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
for key, attempt := range s.loginAttempts {
|
||||
if attempt.lockedUntil.IsZero() && now.Sub(attempt.lastFailure) > loginWindow {
|
||||
delete(s.loginAttempts, key)
|
||||
continue
|
||||
}
|
||||
if !attempt.lockedUntil.IsZero() && now.After(attempt.lockedUntil) && now.Sub(attempt.lastFailure) > loginWindow {
|
||||
delete(s.loginAttempts, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) loginLocked(key string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cleanupLocked()
|
||||
attempt := s.loginAttempts[key]
|
||||
return !attempt.lockedUntil.IsZero() && time.Now().Before(attempt.lockedUntil)
|
||||
}
|
||||
|
||||
func (s *Service) recordLoginFailure(key string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
now := time.Now()
|
||||
attempt := s.loginAttempts[key]
|
||||
if now.Sub(attempt.lastFailure) > loginWindow {
|
||||
attempt.failures = 0
|
||||
}
|
||||
attempt.failures++
|
||||
attempt.lastFailure = now
|
||||
if attempt.failures >= loginMaxFails {
|
||||
attempt.lockedUntil = now.Add(loginLockTTL)
|
||||
}
|
||||
s.loginAttempts[key] = attempt
|
||||
}
|
||||
|
||||
func loginAttemptKey(username string, clientKeys ...string) string {
|
||||
parts := []string{strings.ToLower(strings.TrimSpace(username))}
|
||||
for _, value := range clientKeys {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
parts = append(parts, value)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "|")
|
||||
}
|
||||
|
||||
func randomDigits(count int) string {
|
||||
|
||||
@@ -2,6 +2,9 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
@@ -35,7 +38,7 @@ func TestBootstrapShowsDefaultPasswordOnlyBeforeChange(t *testing.T) {
|
||||
if payload["isDefaultPassword"] != true || payload["defaultPassword"] != "admin" {
|
||||
t.Fatalf("unexpected bootstrap payload: %#v", payload)
|
||||
}
|
||||
if err := store.ChangeAdminPassword(context.Background(), "admin", "admin", "changed"); err != nil {
|
||||
if err := store.ChangeAdminPassword(context.Background(), "admin", "admin", "changed-password"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payload, err = service.Bootstrap(context.Background())
|
||||
@@ -46,3 +49,91 @@ func TestBootstrapShowsDefaultPasswordOnlyBeforeChange(t *testing.T) {
|
||||
t.Fatalf("default password leaked after change: %#v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeAdminPasswordPersistsAfterReopen(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
dbPath := filepath.Join(root, "test.sqlite")
|
||||
cfg := &config.Config{
|
||||
StorageDir: root,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: dbPath,
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
},
|
||||
}
|
||||
store, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := store.EnsureDefaultAdmin(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := store.ChangeAdminPassword(context.Background(), "admin", "admin", "persisted-password"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = store.Close()
|
||||
|
||||
reopened, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer reopened.Close()
|
||||
if _, ok, err := reopened.VerifyAdminPassword(context.Background(), "admin", "persisted-password"); err != nil || !ok {
|
||||
t.Fatalf("new password did not persist, ok=%v err=%v", ok, err)
|
||||
}
|
||||
if _, ok, err := reopened.VerifyAdminPassword(context.Background(), "admin", "admin"); err != nil || ok {
|
||||
t.Fatalf("old password still works, ok=%v err=%v", ok, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginLocksAfterRepeatedFailures(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
StorageDir: root,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: filepath.Join(root, "test.sqlite"),
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
},
|
||||
}
|
||||
store, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
if err := store.EnsureDefaultAdmin(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
service := NewService(store)
|
||||
for i := 0; i < loginMaxFails; i++ {
|
||||
if _, _, ok, err := service.Login(context.Background(), "admin", "wrong", "bad-captcha", "00000", "127.0.0.1"); err != nil || ok {
|
||||
t.Fatalf("failed login %d returned ok=%v err=%v", i, ok, err)
|
||||
}
|
||||
}
|
||||
captcha, err := service.NewCaptcha()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
service.mu.Lock()
|
||||
answer := service.captchas[captcha.ID].answer
|
||||
service.mu.Unlock()
|
||||
if _, _, ok, err := service.Login(context.Background(), "admin", "admin", captcha.ID, answer, "127.0.0.1"); err != nil || ok {
|
||||
t.Fatalf("locked login should fail without error, ok=%v err=%v", ok, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCookieUsesSecureForForwardedHTTPS(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/admin/auth/login", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
res := httptest.NewRecorder()
|
||||
SetSessionCookieForRequest(res, req, "session-id")
|
||||
cookies := res.Result().Cookies()
|
||||
if len(cookies) != 1 {
|
||||
t.Fatalf("expected one cookie, got %d", len(cookies))
|
||||
}
|
||||
if !cookies[0].Secure {
|
||||
t.Fatalf("expected secure cookie for forwarded https: %#v", cookies[0])
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user