@@ -3,11 +3,8 @@ package web
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -24,18 +21,7 @@ type setupRequest struct {
|
||||
BaseURL string `json:"baseUrl"`
|
||||
SQLitePath string `json:"sqlitePath"`
|
||||
MySQLDSN string `json:"mysqlDsn"`
|
||||
MySQL setupMySQLConfig `json:"mysql"`
|
||||
}
|
||||
|
||||
type setupMySQLConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Database string `json:"database"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Charset string `json:"charset"`
|
||||
ParseTime bool `json:"parseTime"`
|
||||
TLS string `json:"tls"`
|
||||
MySQL config.MySQLInput `json:"mysql"`
|
||||
}
|
||||
|
||||
func NewSetupRouter(cfg *config.Config) http.Handler {
|
||||
@@ -73,7 +59,7 @@ func (r *setupRouter) status() map[string]any {
|
||||
"defaults": map[string]any{
|
||||
"provider": firstNonEmpty(r.cfg.Database.Provider, "sqlite"),
|
||||
"sqlitePath": relativeToBase(r.cfg.BaseDir, r.cfg.Database.SQLitePath),
|
||||
"mysqlDsn": maskDSN(r.cfg.Database.MySQLDSN),
|
||||
"mysqlDsn": config.MaskDSN(r.cfg.Database.MySQLDSN),
|
||||
"baseUrl": r.cfg.BaseURL,
|
||||
},
|
||||
}
|
||||
@@ -103,7 +89,7 @@ func (r *setupRouter) handleDatabaseTest(w http.ResponseWriter, req *http.Reques
|
||||
"provider": next.Provider,
|
||||
"baseUrl": firstNonEmpty(body.BaseURL, r.cfg.BaseURL),
|
||||
"sqlitePath": relativeToBase(r.cfg.BaseDir, next.SQLitePath),
|
||||
"mysqlDsn": maskDSN(next.MySQLDSN),
|
||||
"mysqlDsn": config.MaskDSN(next.MySQLDSN),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -153,44 +139,18 @@ func (r *setupRouter) decodeSetupDatabase(req *http.Request) (config.DatabaseCon
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
return config.DatabaseConfig{}, body, err
|
||||
}
|
||||
next := r.cfg.Database
|
||||
next.Provider = strings.ToLower(strings.TrimSpace(firstNonEmpty(body.Provider, next.Provider, "sqlite")))
|
||||
if body.SQLitePath != "" {
|
||||
next.SQLitePath = body.SQLitePath
|
||||
incoming := config.DatabaseConfig{
|
||||
Provider: body.Provider,
|
||||
SQLitePath: body.SQLitePath,
|
||||
MySQLDSN: body.MySQLDSN,
|
||||
MySQLHost: body.MySQL.Host,
|
||||
MySQLPort: body.MySQL.Port,
|
||||
MySQLDatabase: body.MySQL.Database,
|
||||
MySQLUser: body.MySQL.Username,
|
||||
MySQLPassword: body.MySQL.Password,
|
||||
}
|
||||
if next.SQLitePath != "" && !filepath.IsAbs(next.SQLitePath) && !strings.HasPrefix(strings.ToLower(next.SQLitePath), "file:") {
|
||||
next.SQLitePath = filepath.Join(r.cfg.BaseDir, next.SQLitePath)
|
||||
}
|
||||
if next.Provider == "sqlite" {
|
||||
next.MySQLDSN = ""
|
||||
} else if body.MySQLDSN != "" {
|
||||
next.MySQLDSN = body.MySQLDSN
|
||||
} else if body.MySQL.Host != "" || body.MySQL.Database != "" || body.MySQL.Username != "" {
|
||||
dsn, err := buildMySQLDSN(body.MySQL)
|
||||
if err != nil {
|
||||
return config.DatabaseConfig{}, body, err
|
||||
}
|
||||
next.MySQLDSN = dsn
|
||||
}
|
||||
if next.Provider != "sqlite" && next.Provider != "mysql" {
|
||||
return config.DatabaseConfig{}, body, errors.New("provider must be sqlite or mysql")
|
||||
}
|
||||
if next.Provider == "mysql" && strings.TrimSpace(next.MySQLDSN) == "" {
|
||||
return config.DatabaseConfig{}, body, errors.New("mysql connection is required")
|
||||
}
|
||||
if next.MaxOpenConns <= 0 {
|
||||
next.MaxOpenConns = 10
|
||||
}
|
||||
if next.MaxIdleConns <= 0 {
|
||||
next.MaxIdleConns = 4
|
||||
}
|
||||
if next.ConnMaxLifetimeSeconds <= 0 {
|
||||
next.ConnMaxLifetimeSeconds = 300
|
||||
}
|
||||
if next.HealthIntervalSec <= 0 {
|
||||
next.HealthIntervalSec = 30
|
||||
}
|
||||
return next, body, nil
|
||||
next, err := config.NormalizeDatabase(r.cfg.BaseDir, r.cfg.Database, incoming, false)
|
||||
return next, body, err
|
||||
}
|
||||
|
||||
func (r *setupRouter) serveSetup(w http.ResponseWriter, req *http.Request) {
|
||||
@@ -205,48 +165,9 @@ func (r *setupRouter) serveSetup(w http.ResponseWriter, req *http.Request) {
|
||||
_, _ = w.Write([]byte(`<!doctype html><html lang="zh-CN"><head><meta charset="utf-8"><title>YMhut Setup</title></head><body><main><h1>YMhut Setup</h1><p>Setup frontend is not built. Run npm install && npm run build in web/setup.</p><p>` + index + `</p></main></body></html>`))
|
||||
}
|
||||
|
||||
func buildMySQLDSN(input setupMySQLConfig) (string, error) {
|
||||
host := strings.TrimSpace(input.Host)
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
port := input.Port
|
||||
if port <= 0 {
|
||||
port = 3306
|
||||
}
|
||||
database := strings.TrimSpace(input.Database)
|
||||
username := strings.TrimSpace(input.Username)
|
||||
if database == "" {
|
||||
return "", errors.New("mysql database is required")
|
||||
}
|
||||
if username == "" {
|
||||
return "", errors.New("mysql username is required")
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("charset", firstNonEmpty(strings.TrimSpace(input.Charset), "utf8mb4"))
|
||||
params.Set("parseTime", strconv.FormatBool(input.ParseTime))
|
||||
if tls := strings.TrimSpace(input.TLS); tls != "" {
|
||||
params.Set("tls", tls)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", username, input.Password, host, port, database, params.Encode()), nil
|
||||
}
|
||||
|
||||
func maskDSN(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
at := strings.Index(value, "@")
|
||||
colon := strings.Index(value, ":")
|
||||
if at > -1 && colon > -1 && colon < at {
|
||||
return value[:colon+1] + "******" + value[at:]
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func maskedDatabaseTarget(base string, cfg config.DatabaseConfig) string {
|
||||
if strings.EqualFold(cfg.Provider, "mysql") {
|
||||
return maskDSN(cfg.MySQLDSN)
|
||||
return config.MaskDSN(cfg.MySQLDSN)
|
||||
}
|
||||
return relativeToBase(base, cfg.SQLitePath)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user