package auth import ( "bytes" "context" "crypto/rand" "crypto/subtle" "encoding/base64" "encoding/hex" "encoding/json" "image" "image/color" "image/draw" "image/png" "net/http" "strings" "sync" "time" "ymhut-box/server/unified-management/internal/db" ) const ( SessionCookie = "ymhut_unified_session" captchaTTL = 5 * time.Minute sessionTTL = 12 * time.Hour ) type Service struct { store *db.Store mu sync.Mutex captchas map[string]captchaEntry sessions map[string]sessionEntry } type captchaEntry struct { answer string expiresAt time.Time } type sessionEntry struct { username string csrf string expiresAt time.Time } type Captcha struct { ID string `json:"captchaId"` Image string `json:"image"` } func NewService(store *db.Store) *Service { return &Service{ store: store, captchas: map[string]captchaEntry{}, sessions: map[string]sessionEntry{}, } } func (s *Service) Bootstrap(ctx context.Context) (map[string]any, error) { isDefault, err := s.store.IsDefaultAdminPassword(ctx) if err != nil { return nil, err } payload := map[string]any{ "ok": true, "defaultUsername": "admin", "defaultPassword": "", "isDefaultPassword": isDefault, } if isDefault { payload["defaultPassword"] = "admin" } return payload, nil } func (s *Service) NewCaptcha() (Captcha, error) { answer := randomDigits(5) id := randomToken(16) imageBytes, err := renderCaptcha(answer) if err != nil { return Captcha{}, err } s.mu.Lock() s.cleanupLocked() s.captchas[id] = captchaEntry{answer: answer, expiresAt: time.Now().Add(captchaTTL)} s.mu.Unlock() return Captcha{ ID: id, Image: "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageBytes), }, nil } func (s *Service) Login(ctx context.Context, username, password, captchaID, captcha string) (string, string, bool, error) { if !s.consumeCaptcha(captchaID, captcha) { return "", "", false, nil } user, ok, err := s.store.VerifyAdminPassword(ctx, username, password) if err != nil || !ok { return "", "", false, err } sessionID := randomToken(32) csrf := randomToken(32) s.mu.Lock() s.cleanupLocked() s.sessions[sessionID] = sessionEntry{username: user.Username, csrf: csrf, expiresAt: time.Now().Add(sessionTTL)} s.mu.Unlock() return sessionID, csrf, true, nil } func (s *Service) Logout(w http.ResponseWriter, r *http.Request) { if cookie, err := r.Cookie(SessionCookie); err == nil { s.mu.Lock() delete(s.sessions, cookie.Value) s.mu.Unlock() } clearCookie(w) } func (s *Service) UserForRequest(r *http.Request) (string, string, bool) { cookie, err := r.Cookie(SessionCookie) if err != nil || cookie.Value == "" { return "", "", false } s.mu.Lock() defer s.mu.Unlock() s.cleanupLocked() session, ok := s.sessions[cookie.Value] if !ok { return "", "", false } return session.username, session.csrf, true } func (s *Service) Require(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, csrf, ok := s.UserForRequest(r) if !ok { writeJSON(w, http.StatusUnauthorized, map[string]any{"ok": false, "error": "UNAUTHORIZED", "message": "Login required"}) return } if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions { actual := r.Header.Get("X-CSRF-Token") if actual == "" || subtle.ConstantTimeCompare([]byte(csrf), []byte(actual)) != 1 { writeJSON(w, http.StatusForbidden, map[string]any{"ok": false, "error": "CSRF_INVALID", "message": "Invalid CSRF token"}) return } } next.ServeHTTP(w, r) }) } func SetSessionCookie(w http.ResponseWriter, sessionID string) { http.SetCookie(w, &http.Cookie{ Name: SessionCookie, Value: sessionID, Path: "/", MaxAge: int(sessionTTL.Seconds()), HttpOnly: true, SameSite: http.SameSiteLaxMode, }) } func clearCookie(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{Name: SessionCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: true, SameSite: http.SameSiteLaxMode}) } func (s *Service) consumeCaptcha(id, answer string) bool { id = strings.TrimSpace(id) answer = strings.TrimSpace(answer) s.mu.Lock() defer s.mu.Unlock() s.cleanupLocked() entry, ok := s.captchas[id] if ok { delete(s.captchas, id) } if !ok || time.Now().After(entry.expiresAt) { return false } return subtle.ConstantTimeCompare([]byte(strings.ToLower(entry.answer)), []byte(strings.ToLower(answer))) == 1 } func (s *Service) cleanupLocked() { now := time.Now() for id, entry := range s.captchas { if now.After(entry.expiresAt) { delete(s.captchas, id) } } for id, entry := range s.sessions { if now.After(entry.expiresAt) { delete(s.sessions, id) } } } func randomDigits(count int) string { data := make([]byte, count) if _, err := rand.Read(data); err != nil { return "12345" } var builder strings.Builder for _, value := range data { builder.WriteByte('0' + value%10) } return builder.String() } func randomToken(bytesLen int) string { data := make([]byte, bytesLen) if _, err := rand.Read(data); err != nil { return hex.EncodeToString([]byte(time.Now().Format(time.RFC3339Nano))) } return hex.EncodeToString(data) } func renderCaptcha(answer string) ([]byte, error) { img := image.NewRGBA(image.Rect(0, 0, 180, 64)) draw.Draw(img, img.Bounds(), &image.Uniform{color.RGBA{246, 248, 244, 255}}, image.Point{}, draw.Src) for i := 0; i < 26; i++ { x := (i*37 + 13) % 180 y := (i*19 + 7) % 64 img.Set(x, y, color.RGBA{111, 119, 130, 255}) } for index, digit := range answer { drawDigit(img, int(digit-'0'), 18+index*32, 13, color.RGBA{28, 61, 89, 255}) } var buffer bytes.Buffer if err := png.Encode(&buffer, img); err != nil { return nil, err } return buffer.Bytes(), nil } var segments = [10][7]bool{ {true, true, true, true, true, true, false}, {false, true, true, false, false, false, false}, {true, true, false, true, true, false, true}, {true, true, true, true, false, false, true}, {false, true, true, false, false, true, true}, {true, false, true, true, false, true, true}, {true, false, true, true, true, true, true}, {true, true, true, false, false, false, false}, {true, true, true, true, true, true, true}, {true, true, true, true, false, true, true}, } func drawDigit(img *image.RGBA, digit, x, y int, col color.Color) { if digit < 0 || digit > 9 { return } thick := 4 width := 22 height := 36 drawSegment := func(rect image.Rectangle) { draw.Draw(img, rect, &image.Uniform{col}, image.Point{}, draw.Src) } if segments[digit][0] { drawSegment(image.Rect(x+thick, y, x+width-thick, y+thick)) } if segments[digit][1] { drawSegment(image.Rect(x+width-thick, y+thick, x+width, y+height/2)) } if segments[digit][2] { drawSegment(image.Rect(x+width-thick, y+height/2, x+width, y+height-thick)) } if segments[digit][3] { drawSegment(image.Rect(x+thick, y+height-thick, x+width-thick, y+height)) } if segments[digit][4] { drawSegment(image.Rect(x, y+height/2, x+thick, y+height-thick)) } if segments[digit][5] { drawSegment(image.Rect(x, y+thick, x+thick, y+height/2)) } if segments[digit][6] { drawSegment(image.Rect(x+thick, y+height/2-thick/2, x+width-thick, y+height/2+thick/2)) } } func writeJSON(w http.ResponseWriter, status int, payload map[string]any) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) _ = jsonNewEncoder(w).Encode(payload) } type jsonEncoder interface { Encode(v any) error } func jsonNewEncoder(w http.ResponseWriter) jsonEncoder { return json.NewEncoder(w) }