package db import ( "context" "database/sql" "errors" "fmt" "os" "path/filepath" "strconv" "strings" "time" _ "github.com/go-sql-driver/mysql" _ "modernc.org/sqlite" "ymhut-box/server/unified-management/internal/config" ) type dialect struct { name string driver string } func dialectFor(provider string) dialect { switch strings.ToLower(strings.TrimSpace(provider)) { case "mysql": return dialect{name: "mysql", driver: "mysql"} default: return dialect{name: "sqlite", driver: "sqlite"} } } func (d dialect) rebind(query string) string { return query } func (d dialect) idType() string { if d.name == "mysql" { return "BIGINT PRIMARY KEY AUTO_INCREMENT" } return "INTEGER PRIMARY KEY AUTOINCREMENT" } func (d dialect) keyTextType() string { if d.name == "mysql" { return "VARCHAR(191)" } return "TEXT" } func (d dialect) shortTextType() string { if d.name == "mysql" { return "VARCHAR(255)" } return "TEXT" } func (d dialect) mediumTextType() string { if d.name == "mysql" { return "VARCHAR(1024)" } return "TEXT" } func (d dialect) longTextType() string { if d.name == "mysql" { return "LONGTEXT" } return "TEXT" } func (d dialect) quoteIdent(identifier string) string { return "`" + strings.ReplaceAll(identifier, "`", "``") + "`" } func (d dialect) columnList(columns []string) string { quoted := make([]string, len(columns)) for index, column := range columns { quoted[index] = d.quoteIdent(column) } return strings.Join(quoted, ", ") } func (d dialect) boolExpr(value bool) int { if value { return 1 } return 0 } func (d dialect) upsert(table string, columns, conflict []string) string { placeholders := make([]string, len(columns)) for i := range placeholders { placeholders[i] = "?" } base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, d.columnList(columns), strings.Join(placeholders, ", ")) conflictSet := map[string]bool{} for _, column := range conflict { conflictSet[column] = true } updates := []string{} for _, column := range columns { if conflictSet[column] { continue } quoted := d.quoteIdent(column) if d.name == "mysql" { updates = append(updates, fmt.Sprintf("%s = VALUES(%s)", quoted, quoted)) } else { updates = append(updates, fmt.Sprintf("%s = excluded.%s", quoted, quoted)) } } if len(updates) == 0 { if d.name == "mysql" { return strings.Replace(base, "INSERT INTO", "INSERT IGNORE INTO", 1) } return strings.Replace(base, "INSERT INTO", "INSERT OR IGNORE INTO", 1) } if d.name == "mysql" { return base + " ON DUPLICATE KEY UPDATE " + strings.Join(updates, ", ") } return base + " ON CONFLICT (" + d.columnList(conflict) + ") DO UPDATE SET " + strings.Join(updates, ", ") } func (d dialect) limitOffset(limit, offset int) string { return fmt.Sprintf(" LIMIT %d OFFSET %d", limit, offset) } func openSQLDatabase(cfg config.DatabaseConfig) (*sql.DB, dialect, error) { d := dialectFor(cfg.Provider) dsn := strings.TrimSpace(cfg.SQLitePath) if d.name == "mysql" { dsn = strings.TrimSpace(cfg.MySQLDSN) if dsn == "" { return nil, d, errors.New("mysql_dsn is required") } } else { if dsn == "" { return nil, d, errors.New("sqlite path is required") } if !strings.HasPrefix(strings.ToLower(dsn), "file:") { if err := os.MkdirAll(filepath.Dir(dsn), 0o750); err != nil { return nil, d, err } } } conn, err := sql.Open(d.driver, dsn) if err != nil { return nil, d, err } conn.SetMaxOpenConns(cfg.MaxOpenConns) conn.SetMaxIdleConns(cfg.MaxIdleConns) conn.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetimeSeconds) * time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := conn.PingContext(ctx); err != nil { _ = conn.Close() return nil, d, err } return conn, d, nil } func TestDatabase(cfg config.DatabaseConfig) error { if cfg.Provider == "" { cfg.Provider = "sqlite" } if cfg.MaxOpenConns <= 0 { cfg.MaxOpenConns = 1 } if cfg.MaxIdleConns <= 0 { cfg.MaxIdleConns = 1 } if cfg.ConnMaxLifetimeSeconds <= 0 { cfg.ConnMaxLifetimeSeconds = 60 } conn, d, err := openSQLDatabase(cfg) if err != nil { return err } defer conn.Close() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := conn.PingContext(ctx); err != nil { return err } tx, err := conn.BeginTx(ctx, nil) if err != nil { return err } create := "CREATE TEMPORARY TABLE ymhut_unified_connection_test (id INTEGER)" if _, err := tx.ExecContext(ctx, d.rebind(create)); err != nil { _ = tx.Rollback() return err } _ = tx.Rollback() return nil } func placeholders(n int) string { items := make([]string, n) for i := range items { items[i] = "?" } return strings.Join(items, ", ") } func atoiDefault(value string, fallback int) int { parsed, err := strconv.Atoi(value) if err != nil { return fallback } return parsed }