@@ -0,0 +1,204 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *Store) EnsureDefaultAdmin(ctx context.Context) error {
|
||||
if err := s.ensureDefaultAdminOn(s.localDB, s.localDialect); err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.RLock()
|
||||
remote, remoteDialect := s.remoteDB, s.remoteDialect
|
||||
s.mu.RUnlock()
|
||||
if remote != nil && remote != s.localDB {
|
||||
if err := s.ensureDefaultAdminOn(remote, remoteDialect); err != nil {
|
||||
s.markFailover(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ensureDefaultAdminOn(conn *sql.DB, d dialect) error {
|
||||
if conn == nil {
|
||||
return errors.New("database is not available")
|
||||
}
|
||||
var count int
|
||||
if err := conn.QueryRow(d.rebind(`SELECT COUNT(*) FROM admin_users WHERE username = ?`), "admin").Scan(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
now := Now()
|
||||
_, err := conn.Exec(d.rebind(`INSERT INTO admin_users (username, password_hash, password_changed, created_at, updated_at) VALUES (?, ?, 0, ?, ?)`),
|
||||
"admin", passwordHash("admin"), now, now)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) VerifyAdminPassword(ctx context.Context, username, password string) (AdminUser, bool, error) {
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" {
|
||||
username = "admin"
|
||||
}
|
||||
user, ok, err := s.verifyAdminPasswordOn(s.localDB, s.localDialect, username, password)
|
||||
if err == nil && (ok || user.Username != "") {
|
||||
return user, ok, nil
|
||||
}
|
||||
if err != nil {
|
||||
return user, ok, err
|
||||
}
|
||||
s.mu.RLock()
|
||||
remote, remoteDialect := s.remoteDB, s.remoteDialect
|
||||
s.mu.RUnlock()
|
||||
if remote != nil && remote != s.localDB {
|
||||
user, ok, err := s.verifyAdminPasswordOn(remote, remoteDialect, username, password)
|
||||
if err != nil {
|
||||
s.markFailover(err)
|
||||
}
|
||||
return user, ok, err
|
||||
}
|
||||
return user, ok, nil
|
||||
}
|
||||
|
||||
func (s *Store) verifyAdminPasswordOn(conn *sql.DB, d dialect, username, password string) (AdminUser, bool, error) {
|
||||
if conn == nil {
|
||||
return AdminUser{}, false, errors.New("database is not available")
|
||||
}
|
||||
var row adminRow
|
||||
var changed int
|
||||
err := conn.QueryRow(d.rebind(`SELECT id, username, password_hash, password_changed, created_at, updated_at FROM admin_users WHERE username = ?`), username).
|
||||
Scan(&row.ID, &row.Username, &row.PasswordHash, &changed, &row.CreatedAt, &row.UpdatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return AdminUser{}, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return AdminUser{}, false, err
|
||||
}
|
||||
row.PasswordChanged = changed == 1
|
||||
user := AdminUser{ID: row.ID, Username: row.Username, PasswordChanged: row.PasswordChanged, CreatedAt: row.CreatedAt, UpdatedAt: row.UpdatedAt}
|
||||
return user, subtleConstantCompare(row.PasswordHash, password), nil
|
||||
}
|
||||
|
||||
func (s *Store) IsDefaultAdminPassword(ctx context.Context) (bool, error) {
|
||||
user, ok, err := s.VerifyAdminPassword(ctx, "admin", "admin")
|
||||
if err != nil || !ok {
|
||||
return false, err
|
||||
}
|
||||
return !user.PasswordChanged, nil
|
||||
}
|
||||
|
||||
func (s *Store) ChangeAdminPassword(ctx context.Context, username, current, next string) error {
|
||||
_, err := s.ChangeAdminPasswordWithWarning(ctx, username, current, next)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ChangeAdminPasswordWithWarning(ctx context.Context, username, current, next string) (string, error) {
|
||||
next = strings.TrimSpace(next)
|
||||
if err := validateAdminPasswordChange(current, next); err != nil {
|
||||
return "", err
|
||||
}
|
||||
username = firstNonEmpty(strings.TrimSpace(username), "admin")
|
||||
_, ok, err := s.verifyAdminPasswordOn(s.localDB, s.localDialect, username, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
remoteOK, remoteErr := s.verifyRemoteAdminPassword(username, current)
|
||||
if remoteErr != nil {
|
||||
s.markFailover(remoteErr)
|
||||
}
|
||||
if !remoteOK {
|
||||
return "", errors.New("当前密码不正确")
|
||||
}
|
||||
}
|
||||
hash := passwordHash(next)
|
||||
now := Now()
|
||||
if err := s.changeAdminPasswordOn(s.localDB, s.localDialect, username, hash, now, true); err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.mu.RLock()
|
||||
remote, remoteDialect := s.remoteDB, s.remoteDialect
|
||||
s.mu.RUnlock()
|
||||
if remote != nil && remote != s.localDB {
|
||||
if err := s.changeAdminPasswordOn(remote, remoteDialect, username, hash, now, false); err != nil {
|
||||
s.markFailover(err)
|
||||
return "远端 MySQL 同步失败,密码已持久化到本地 SQLite", nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func validateAdminPasswordChange(current, next string) error {
|
||||
if next == "" {
|
||||
return errors.New("new password is required")
|
||||
}
|
||||
if len([]rune(next)) < 8 {
|
||||
return errors.New("new password must be at least 8 characters")
|
||||
}
|
||||
if strings.EqualFold(next, "admin") {
|
||||
return errors.New("new password cannot be admin")
|
||||
}
|
||||
if strings.TrimSpace(current) != "" && subtleConstantCompare(passwordHash(current), next) {
|
||||
return errors.New("new password must be different from current password")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) verifyRemoteAdminPassword(username, password string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
remote, remoteDialect := s.remoteDB, s.remoteDialect
|
||||
s.mu.RUnlock()
|
||||
if remote == nil || remote == s.localDB {
|
||||
return false, nil
|
||||
}
|
||||
_, ok, err := s.verifyAdminPasswordOn(remote, remoteDialect, username, password)
|
||||
return ok, err
|
||||
}
|
||||
|
||||
func (s *Store) changeAdminPasswordOn(conn *sql.DB, d dialect, username, hash, updatedAt string, insertIfMissing bool) error {
|
||||
if conn == nil {
|
||||
return errors.New("database is not available")
|
||||
}
|
||||
result, err := conn.Exec(d.rebind(`UPDATE admin_users SET password_hash = ?, password_changed = 1, updated_at = ? WHERE username = ?`), hash, updatedAt, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows > 0 {
|
||||
return nil
|
||||
}
|
||||
if !insertIfMissing {
|
||||
return errors.New("admin user not found")
|
||||
}
|
||||
_, err = conn.Exec(d.rebind(`INSERT INTO admin_users (username, password_hash, password_changed, created_at, updated_at) VALUES (?, ?, 1, ?, ?)`), username, hash, updatedAt, updatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func passwordHash(password string) string {
|
||||
sum := sha256.Sum256([]byte("ymhut-unified|" + strings.TrimSpace(password)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func subtleConstantCompare(hash, password string) bool {
|
||||
expected := passwordHash(password)
|
||||
return subtleConstantTimeCompare([]byte(hash), []byte(expected)) == 1
|
||||
}
|
||||
|
||||
func subtleConstantTimeCompare(a, b []byte) int {
|
||||
if len(a) != len(b) {
|
||||
return 0
|
||||
}
|
||||
var v byte
|
||||
for i := range a {
|
||||
v |= a[i] ^ b[i]
|
||||
}
|
||||
if v == 0 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
Reference in New Issue
Block a user