265 lines
8.1 KiB
Go
265 lines
8.1 KiB
Go
package web
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"ymhut-box/server/unified-management/internal/config"
|
|
"ymhut-box/server/unified-management/internal/db"
|
|
)
|
|
|
|
type setupRouter struct {
|
|
cfg *config.Config
|
|
}
|
|
|
|
type setupRequest struct {
|
|
Provider string `json:"provider"`
|
|
BaseURL string `json:"baseUrl"`
|
|
SQLitePath string `json:"sqlitePath"`
|
|
MySQLDSN string `json:"mysqlDsn"`
|
|
MySQL setupMySQLConfig `json:"mysql"`
|
|
}
|
|
|
|
type setupMySQLConfig struct {
|
|
Host string `json:"host"`
|
|
Port int `json:"port"`
|
|
Database string `json:"database"`
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
Charset string `json:"charset"`
|
|
ParseTime bool `json:"parseTime"`
|
|
TLS string `json:"tls"`
|
|
}
|
|
|
|
func NewSetupRouter(cfg *config.Config) http.Handler {
|
|
return withSecurity(&setupRouter{cfg: cfg})
|
|
}
|
|
|
|
func (r *setupRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
path := cleanPath(req.URL.Path)
|
|
switch {
|
|
case path == "/" || path == "/setup":
|
|
r.serveSetup(w, req)
|
|
case strings.HasPrefix(path, "/setup/assets/"):
|
|
serveStaticAsset(w, req, r.cfg.SetupWebDir, "setup/dist", strings.TrimPrefix(path, "/setup/"))
|
|
case path == "/api/setup/status":
|
|
writeJSON(w, http.StatusOK, r.status())
|
|
case path == "/api/setup/database/test":
|
|
r.handleDatabaseTest(w, req)
|
|
case path == "/api/setup/complete":
|
|
r.handleComplete(w, req)
|
|
default:
|
|
if strings.HasPrefix(path, "/api/") {
|
|
writeError(w, http.StatusServiceUnavailable, "SETUP_REQUIRED", errors.New("system setup is required"))
|
|
return
|
|
}
|
|
http.Redirect(w, req, "/setup", http.StatusFound)
|
|
}
|
|
}
|
|
|
|
func (r *setupRouter) status() map[string]any {
|
|
return map[string]any{
|
|
"ok": true,
|
|
"initialized": r.cfg.Initialized,
|
|
"baseDir": r.cfg.BaseDir,
|
|
"configPath": r.cfg.ConfigPath,
|
|
"defaults": map[string]any{
|
|
"provider": firstNonEmpty(r.cfg.Database.Provider, "sqlite"),
|
|
"sqlitePath": relativeToBase(r.cfg.BaseDir, r.cfg.Database.SQLitePath),
|
|
"mysqlDsn": maskDSN(r.cfg.Database.MySQLDSN),
|
|
"baseUrl": r.cfg.BaseURL,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (r *setupRouter) handleDatabaseTest(w http.ResponseWriter, req *http.Request) {
|
|
if req.Method != http.MethodPost {
|
|
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", errors.New("POST required"))
|
|
return
|
|
}
|
|
next, body, err := r.decodeSetupDatabase(req)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
|
return
|
|
}
|
|
started := time.Now()
|
|
if err := db.TestDatabase(next); err != nil {
|
|
writeError(w, http.StatusBadGateway, "DATABASE_TEST_FAILED", err)
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{
|
|
"ok": true,
|
|
"provider": next.Provider,
|
|
"latencyMs": time.Since(started).Milliseconds(),
|
|
"maskedDsn": maskedDatabaseTarget(r.cfg.BaseDir, next),
|
|
"normalized": map[string]any{
|
|
"provider": next.Provider,
|
|
"baseUrl": firstNonEmpty(body.BaseURL, r.cfg.BaseURL),
|
|
"sqlitePath": relativeToBase(r.cfg.BaseDir, next.SQLitePath),
|
|
"mysqlDsn": maskDSN(next.MySQLDSN),
|
|
},
|
|
})
|
|
}
|
|
|
|
func (r *setupRouter) handleComplete(w http.ResponseWriter, req *http.Request) {
|
|
if req.Method != http.MethodPost {
|
|
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", errors.New("POST required"))
|
|
return
|
|
}
|
|
nextDB, body, err := r.decodeSetupDatabase(req)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
|
return
|
|
}
|
|
if err := db.TestDatabase(nextDB); err != nil {
|
|
writeError(w, http.StatusBadGateway, "DATABASE_TEST_FAILED", err)
|
|
return
|
|
}
|
|
next := *r.cfg
|
|
next.Initialized = true
|
|
next.BaseURL = firstNonEmpty(strings.TrimSpace(body.BaseURL), next.BaseURL)
|
|
next.Database = nextDB
|
|
if strings.EqualFold(next.Database.Provider, "mysql") {
|
|
next.Database.FailoverEnabled = true
|
|
next.Database.HotSyncEnabled = true
|
|
}
|
|
store, err := db.Open(&next)
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "DATABASE_OPEN_FAILED", err)
|
|
return
|
|
}
|
|
if err := store.EnsureDefaultAdmin(req.Context()); err != nil {
|
|
_ = store.Close()
|
|
writeError(w, http.StatusInternalServerError, "ADMIN_INIT_FAILED", err)
|
|
return
|
|
}
|
|
_ = store.Close()
|
|
if err := config.Save(&next); err != nil {
|
|
writeError(w, http.StatusInternalServerError, "CONFIG_SAVE_FAILED", err)
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "initialized": true, "message": "Setup completed. Restart the service, then open /admin/login."})
|
|
}
|
|
|
|
func (r *setupRouter) decodeSetupDatabase(req *http.Request) (config.DatabaseConfig, setupRequest, error) {
|
|
var body setupRequest
|
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
|
return config.DatabaseConfig{}, body, err
|
|
}
|
|
next := r.cfg.Database
|
|
next.Provider = strings.ToLower(strings.TrimSpace(firstNonEmpty(body.Provider, next.Provider, "sqlite")))
|
|
if body.SQLitePath != "" {
|
|
next.SQLitePath = body.SQLitePath
|
|
}
|
|
if next.SQLitePath != "" && !filepath.IsAbs(next.SQLitePath) && !strings.HasPrefix(strings.ToLower(next.SQLitePath), "file:") {
|
|
next.SQLitePath = filepath.Join(r.cfg.BaseDir, next.SQLitePath)
|
|
}
|
|
if next.Provider == "sqlite" {
|
|
next.MySQLDSN = ""
|
|
} else if body.MySQLDSN != "" {
|
|
next.MySQLDSN = body.MySQLDSN
|
|
} else if body.MySQL.Host != "" || body.MySQL.Database != "" || body.MySQL.Username != "" {
|
|
dsn, err := buildMySQLDSN(body.MySQL)
|
|
if err != nil {
|
|
return config.DatabaseConfig{}, body, err
|
|
}
|
|
next.MySQLDSN = dsn
|
|
}
|
|
if next.Provider != "sqlite" && next.Provider != "mysql" {
|
|
return config.DatabaseConfig{}, body, errors.New("provider must be sqlite or mysql")
|
|
}
|
|
if next.Provider == "mysql" && strings.TrimSpace(next.MySQLDSN) == "" {
|
|
return config.DatabaseConfig{}, body, errors.New("mysql connection is required")
|
|
}
|
|
if next.MaxOpenConns <= 0 {
|
|
next.MaxOpenConns = 10
|
|
}
|
|
if next.MaxIdleConns <= 0 {
|
|
next.MaxIdleConns = 4
|
|
}
|
|
if next.ConnMaxLifetimeSeconds <= 0 {
|
|
next.ConnMaxLifetimeSeconds = 300
|
|
}
|
|
if next.HealthIntervalSec <= 0 {
|
|
next.HealthIntervalSec = 30
|
|
}
|
|
return next, body, nil
|
|
}
|
|
|
|
func (r *setupRouter) serveSetup(w http.ResponseWriter, req *http.Request) {
|
|
index := filepath.Join(r.cfg.SetupWebDir, "index.html")
|
|
if tryServeDiskFile(w, req, r.cfg.SetupWebDir, "index.html") {
|
|
return
|
|
}
|
|
if serveEmbeddedFile(w, req, "setup/dist/index.html") {
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
_, _ = w.Write([]byte(`<!doctype html><html lang="zh-CN"><head><meta charset="utf-8"><title>YMhut Setup</title></head><body><main><h1>YMhut Setup</h1><p>Setup frontend is not built. Run npm install && npm run build in web/setup.</p><p>` + index + `</p></main></body></html>`))
|
|
}
|
|
|
|
func buildMySQLDSN(input setupMySQLConfig) (string, error) {
|
|
host := strings.TrimSpace(input.Host)
|
|
if host == "" {
|
|
host = "127.0.0.1"
|
|
}
|
|
port := input.Port
|
|
if port <= 0 {
|
|
port = 3306
|
|
}
|
|
database := strings.TrimSpace(input.Database)
|
|
username := strings.TrimSpace(input.Username)
|
|
if database == "" {
|
|
return "", errors.New("mysql database is required")
|
|
}
|
|
if username == "" {
|
|
return "", errors.New("mysql username is required")
|
|
}
|
|
params := url.Values{}
|
|
params.Set("charset", firstNonEmpty(strings.TrimSpace(input.Charset), "utf8mb4"))
|
|
params.Set("parseTime", strconv.FormatBool(input.ParseTime))
|
|
if tls := strings.TrimSpace(input.TLS); tls != "" {
|
|
params.Set("tls", tls)
|
|
}
|
|
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", username, input.Password, host, port, database, params.Encode()), nil
|
|
}
|
|
|
|
func maskDSN(value string) string {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return ""
|
|
}
|
|
at := strings.Index(value, "@")
|
|
colon := strings.Index(value, ":")
|
|
if at > -1 && colon > -1 && colon < at {
|
|
return value[:colon+1] + "******" + value[at:]
|
|
}
|
|
return value
|
|
}
|
|
|
|
func maskedDatabaseTarget(base string, cfg config.DatabaseConfig) string {
|
|
if strings.EqualFold(cfg.Provider, "mysql") {
|
|
return maskDSN(cfg.MySQLDSN)
|
|
}
|
|
return relativeToBase(base, cfg.SQLitePath)
|
|
}
|
|
|
|
func relativeToBase(base, value string) string {
|
|
if strings.TrimSpace(value) == "" {
|
|
return ""
|
|
}
|
|
if base != "" {
|
|
if rel, err := filepath.Rel(base, value); err == nil && !strings.HasPrefix(rel, "..") && rel != "." {
|
|
return filepath.ToSlash(rel)
|
|
}
|
|
}
|
|
return filepath.ToSlash(value)
|
|
}
|