Files
YMhut-box-C-/server/unified-management/internal/db/dialect.go
T
QWQLwToo 962a2f2143
build-winui / winui (push) Waiting to run
更新 update 门户站点界面和后台功能
2026-06-27 18:09:11 +08:00

215 lines
4.8 KiB
Go

package db
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
_ "modernc.org/sqlite"
"ymhut-box/server/unified-management/internal/config"
)
type dialect struct {
name string
driver string
}
func dialectFor(provider string) dialect {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "mysql":
return dialect{name: "mysql", driver: "mysql"}
default:
return dialect{name: "sqlite", driver: "sqlite"}
}
}
func (d dialect) rebind(query string) string {
return query
}
func (d dialect) idType() string {
if d.name == "mysql" {
return "BIGINT PRIMARY KEY AUTO_INCREMENT"
}
return "INTEGER PRIMARY KEY AUTOINCREMENT"
}
func (d dialect) keyTextType() string {
if d.name == "mysql" {
return "VARCHAR(191)"
}
return "TEXT"
}
func (d dialect) shortTextType() string {
if d.name == "mysql" {
return "VARCHAR(255)"
}
return "TEXT"
}
func (d dialect) mediumTextType() string {
if d.name == "mysql" {
return "VARCHAR(1024)"
}
return "TEXT"
}
func (d dialect) longTextType() string {
if d.name == "mysql" {
return "LONGTEXT"
}
return "TEXT"
}
func (d dialect) quoteIdent(identifier string) string {
return "`" + strings.ReplaceAll(identifier, "`", "``") + "`"
}
func (d dialect) columnList(columns []string) string {
quoted := make([]string, len(columns))
for index, column := range columns {
quoted[index] = d.quoteIdent(column)
}
return strings.Join(quoted, ", ")
}
func (d dialect) boolExpr(value bool) int {
if value {
return 1
}
return 0
}
func (d dialect) upsert(table string, columns, conflict []string) string {
placeholders := make([]string, len(columns))
for i := range placeholders {
placeholders[i] = "?"
}
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, d.columnList(columns), strings.Join(placeholders, ", "))
conflictSet := map[string]bool{}
for _, column := range conflict {
conflictSet[column] = true
}
updates := []string{}
for _, column := range columns {
if conflictSet[column] {
continue
}
quoted := d.quoteIdent(column)
if d.name == "mysql" {
updates = append(updates, fmt.Sprintf("%s = VALUES(%s)", quoted, quoted))
} else {
updates = append(updates, fmt.Sprintf("%s = excluded.%s", quoted, quoted))
}
}
if len(updates) == 0 {
if d.name == "mysql" {
return strings.Replace(base, "INSERT INTO", "INSERT IGNORE INTO", 1)
}
return strings.Replace(base, "INSERT INTO", "INSERT OR IGNORE INTO", 1)
}
if d.name == "mysql" {
return base + " ON DUPLICATE KEY UPDATE " + strings.Join(updates, ", ")
}
return base + " ON CONFLICT (" + d.columnList(conflict) + ") DO UPDATE SET " + strings.Join(updates, ", ")
}
func (d dialect) limitOffset(limit, offset int) string {
return fmt.Sprintf(" LIMIT %d OFFSET %d", limit, offset)
}
func openSQLDatabase(cfg config.DatabaseConfig) (*sql.DB, dialect, error) {
d := dialectFor(cfg.Provider)
dsn := strings.TrimSpace(cfg.SQLitePath)
if d.name == "mysql" {
dsn = strings.TrimSpace(cfg.MySQLDSN)
if dsn == "" {
return nil, d, errors.New("mysql_dsn is required")
}
} else {
if dsn == "" {
return nil, d, errors.New("sqlite path is required")
}
if !strings.HasPrefix(strings.ToLower(dsn), "file:") {
if err := os.MkdirAll(filepath.Dir(dsn), 0o750); err != nil {
return nil, d, err
}
}
}
conn, err := sql.Open(d.driver, dsn)
if err != nil {
return nil, d, err
}
conn.SetMaxOpenConns(cfg.MaxOpenConns)
conn.SetMaxIdleConns(cfg.MaxIdleConns)
conn.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetimeSeconds) * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := conn.PingContext(ctx); err != nil {
_ = conn.Close()
return nil, d, err
}
return conn, d, nil
}
func TestDatabase(cfg config.DatabaseConfig) error {
if cfg.Provider == "" {
cfg.Provider = "sqlite"
}
if cfg.MaxOpenConns <= 0 {
cfg.MaxOpenConns = 1
}
if cfg.MaxIdleConns <= 0 {
cfg.MaxIdleConns = 1
}
if cfg.ConnMaxLifetimeSeconds <= 0 {
cfg.ConnMaxLifetimeSeconds = 60
}
conn, d, err := openSQLDatabase(cfg)
if err != nil {
return err
}
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := conn.PingContext(ctx); err != nil {
return err
}
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
create := "CREATE TEMPORARY TABLE ymhut_unified_connection_test (id INTEGER)"
if _, err := tx.ExecContext(ctx, d.rebind(create)); err != nil {
_ = tx.Rollback()
return err
}
_ = tx.Rollback()
return nil
}
func placeholders(n int) string {
items := make([]string, n)
for i := range items {
items[i] = "?"
}
return strings.Join(items, ", ")
}
func atoiDefault(value string, fallback int) int {
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}