package db import ( "context" "database/sql" "errors" "fmt" "net/url" "os" "path/filepath" "strconv" "strings" "time" mysql "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v5/stdlib" _ "modernc.org/sqlite" "ymhut-box/server/feedback-mailer/internal/config" ) type dialect struct { name string driverName string } func dialectFor(provider string) dialect { switch strings.ToLower(strings.TrimSpace(provider)) { case "mysql": return dialect{name: "mysql", driverName: "mysql"} case "postgres", "pgsql": return dialect{name: "postgres", driverName: "pgx"} default: return dialect{name: "sqlite", driverName: "sqlite"} } } func (d dialect) rebind(query string) string { if d.name != "postgres" { return query } var builder strings.Builder index := 1 inSingle := false for i := 0; i < len(query); i++ { ch := query[i] if ch == '\'' { inSingle = !inSingle builder.WriteByte(ch) continue } if ch == '?' && !inSingle { builder.WriteByte('$') builder.WriteString(strconv.Itoa(index)) index++ continue } builder.WriteByte(ch) } return builder.String() } func (d dialect) boolValue(value bool) int { if value { return 1 } return 0 } func (d dialect) insertIgnore(table string, columns, conflict []string) string { placeholders := make([]string, len(columns)) for i := range placeholders { placeholders[i] = "?" } base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) switch d.name { case "mysql": return strings.Replace(base, "INSERT INTO", "INSERT IGNORE INTO", 1) case "postgres": return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO NOTHING" default: return strings.Replace(base, "INSERT INTO", "INSERT OR IGNORE INTO", 1) } } func (d dialect) upsert(table string, columns, conflict []string) string { placeholders := make([]string, len(columns)) for i := range placeholders { placeholders[i] = "?" } base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) updateColumns := make([]string, 0, len(columns)) conflictSet := map[string]bool{} for _, column := range conflict { conflictSet[column] = true } for _, column := range columns { if conflictSet[column] { continue } switch d.name { case "mysql": updateColumns = append(updateColumns, fmt.Sprintf("%s = VALUES(%s)", column, column)) case "postgres": updateColumns = append(updateColumns, fmt.Sprintf("%s = EXCLUDED.%s", column, column)) default: updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", column, column)) } } if len(updateColumns) == 0 { return d.insertIgnore(table, columns, conflict) } switch d.name { case "mysql": return base + " ON DUPLICATE KEY UPDATE " + strings.Join(updateColumns, ", ") case "postgres": return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updateColumns, ", ") default: return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updateColumns, ", ") } } func (d dialect) textType() string { return "TEXT" } func (d dialect) idType() string { switch d.name { case "mysql": return "BIGINT PRIMARY KEY AUTO_INCREMENT" case "postgres": return "BIGSERIAL PRIMARY KEY" default: return "INTEGER PRIMARY KEY AUTOINCREMENT" } } func (d dialect) columnDefault(def string) string { def = strings.ReplaceAll(def, `DEFAULT ""`, `DEFAULT ''`) if d.name == "mysql" { def = strings.ReplaceAll(def, `TEXT NOT NULL DEFAULT ''`, `VARCHAR(3000) NOT NULL DEFAULT ''`) } return def } func openSQLDatabase(cfg config.DatabaseConfig, baseDir string) (*sql.DB, dialect, error) { d := dialectFor(cfg.Provider) dsn, err := databaseDSN(cfg, baseDir) if err != nil { return nil, d, err } if d.name == "sqlite" && !strings.HasPrefix(strings.ToLower(dsn), "file:") { if err := os.MkdirAll(filepath.Dir(dsn), 0o750); err != nil { return nil, d, err } } conn, err := sql.Open(d.driverName, dsn) if err != nil { return nil, d, err } conn.SetMaxOpenConns(cfg.MaxOpenConns) conn.SetMaxIdleConns(cfg.MaxIdleConns) conn.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetimeSeconds) * time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := conn.PingContext(ctx); err != nil { _ = conn.Close() return nil, d, err } return conn, d, nil } func databaseDSN(cfg config.DatabaseConfig, baseDir string) (string, error) { provider := strings.ToLower(strings.TrimSpace(cfg.Provider)) switch provider { case "", "sqlite": path := strings.TrimSpace(cfg.SQLitePath) if path == "" { path = filepath.Join(baseDir, "storage", "feedback.sqlite") } if strings.HasPrefix(strings.ToLower(path), "file:") { return path, nil } if !filepath.IsAbs(path) && !strings.HasPrefix(strings.ToLower(path), "file:") { path = filepath.Join(baseDir, path) } return filepath.Clean(path), nil case "mysql": if strings.TrimSpace(cfg.DSN) != "" { return cfg.DSN, nil } if cfg.Host == "" || cfg.Name == "" || cfg.User == "" { return "", errors.New("mysql host, name and user are required") } host := cfg.Host if cfg.Port > 0 { host = host + ":" + strconv.Itoa(cfg.Port) } mysqlCfg := mysql.NewConfig() mysqlCfg.User = cfg.User mysqlCfg.Passwd = cfg.Password mysqlCfg.Net = "tcp" mysqlCfg.Addr = host mysqlCfg.DBName = cfg.Name mysqlCfg.ParseTime = true mysqlCfg.Loc = time.Local mysqlCfg.Params = map[string]string{"charset": "utf8mb4"} if cfg.SSLMode != "" && cfg.SSLMode != "disable" { mysqlCfg.TLSConfig = cfg.SSLMode } return mysqlCfg.FormatDSN(), nil case "postgres", "pgsql": if strings.TrimSpace(cfg.DSN) != "" { return cfg.DSN, nil } if cfg.Host == "" || cfg.Name == "" || cfg.User == "" { return "", errors.New("postgres host, name and user are required") } host := cfg.Host if cfg.Port > 0 { host = host + ":" + strconv.Itoa(cfg.Port) } u := url.URL{ Scheme: "postgres", User: url.UserPassword(cfg.User, cfg.Password), Host: host, Path: "/" + cfg.Name, } params := url.Values{} params.Set("sslmode", defaultString(cfg.SSLMode, "disable")) u.RawQuery = params.Encode() return u.String(), nil default: return "", fmt.Errorf("unsupported database provider %q", cfg.Provider) } } func TestDatabase(cfg config.DatabaseConfig, baseDir string) error { conn, d, err := openSQLDatabase(cfg, baseDir) if err != nil { return err } defer conn.Close() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := conn.PingContext(ctx); err != nil { return err } tx, err := conn.BeginTx(ctx, nil) if err != nil { return err } create := `CREATE TEMPORARY TABLE ymhut_connection_test (id INTEGER)` if d.name == "postgres" { create = `CREATE TEMP TABLE ymhut_connection_test (id INTEGER)` } if _, err := tx.ExecContext(ctx, d.rebind(create)); err != nil { _ = tx.Rollback() return err } _ = tx.Rollback() return nil }