@@ -0,0 +1,263 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mysql "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"ymhut-box/server/feedback-mailer/internal/config"
|
||||
)
|
||||
|
||||
type dialect struct {
|
||||
name string
|
||||
driverName string
|
||||
}
|
||||
|
||||
func dialectFor(provider string) dialect {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "mysql":
|
||||
return dialect{name: "mysql", driverName: "mysql"}
|
||||
case "postgres", "pgsql":
|
||||
return dialect{name: "postgres", driverName: "pgx"}
|
||||
default:
|
||||
return dialect{name: "sqlite", driverName: "sqlite"}
|
||||
}
|
||||
}
|
||||
|
||||
func (d dialect) rebind(query string) string {
|
||||
if d.name != "postgres" {
|
||||
return query
|
||||
}
|
||||
var builder strings.Builder
|
||||
index := 1
|
||||
inSingle := false
|
||||
for i := 0; i < len(query); i++ {
|
||||
ch := query[i]
|
||||
if ch == '\'' {
|
||||
inSingle = !inSingle
|
||||
builder.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if ch == '?' && !inSingle {
|
||||
builder.WriteByte('$')
|
||||
builder.WriteString(strconv.Itoa(index))
|
||||
index++
|
||||
continue
|
||||
}
|
||||
builder.WriteByte(ch)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (d dialect) boolValue(value bool) int {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d dialect) insertIgnore(table string, columns, conflict []string) string {
|
||||
placeholders := make([]string, len(columns))
|
||||
for i := range placeholders {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||
switch d.name {
|
||||
case "mysql":
|
||||
return strings.Replace(base, "INSERT INTO", "INSERT IGNORE INTO", 1)
|
||||
case "postgres":
|
||||
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO NOTHING"
|
||||
default:
|
||||
return strings.Replace(base, "INSERT INTO", "INSERT OR IGNORE INTO", 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (d dialect) upsert(table string, columns, conflict []string) string {
|
||||
placeholders := make([]string, len(columns))
|
||||
for i := range placeholders {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||
updateColumns := make([]string, 0, len(columns))
|
||||
conflictSet := map[string]bool{}
|
||||
for _, column := range conflict {
|
||||
conflictSet[column] = true
|
||||
}
|
||||
for _, column := range columns {
|
||||
if conflictSet[column] {
|
||||
continue
|
||||
}
|
||||
switch d.name {
|
||||
case "mysql":
|
||||
updateColumns = append(updateColumns, fmt.Sprintf("%s = VALUES(%s)", column, column))
|
||||
case "postgres":
|
||||
updateColumns = append(updateColumns, fmt.Sprintf("%s = EXCLUDED.%s", column, column))
|
||||
default:
|
||||
updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", column, column))
|
||||
}
|
||||
}
|
||||
if len(updateColumns) == 0 {
|
||||
return d.insertIgnore(table, columns, conflict)
|
||||
}
|
||||
switch d.name {
|
||||
case "mysql":
|
||||
return base + " ON DUPLICATE KEY UPDATE " + strings.Join(updateColumns, ", ")
|
||||
case "postgres":
|
||||
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updateColumns, ", ")
|
||||
default:
|
||||
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updateColumns, ", ")
|
||||
}
|
||||
}
|
||||
|
||||
func (d dialect) textType() string {
|
||||
return "TEXT"
|
||||
}
|
||||
|
||||
func (d dialect) idType() string {
|
||||
switch d.name {
|
||||
case "mysql":
|
||||
return "BIGINT PRIMARY KEY AUTO_INCREMENT"
|
||||
case "postgres":
|
||||
return "BIGSERIAL PRIMARY KEY"
|
||||
default:
|
||||
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||
}
|
||||
}
|
||||
|
||||
func (d dialect) columnDefault(def string) string {
|
||||
def = strings.ReplaceAll(def, `DEFAULT ""`, `DEFAULT ''`)
|
||||
if d.name == "mysql" {
|
||||
def = strings.ReplaceAll(def, `TEXT NOT NULL DEFAULT ''`, `VARCHAR(3000) NOT NULL DEFAULT ''`)
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func openSQLDatabase(cfg config.DatabaseConfig, baseDir string) (*sql.DB, dialect, error) {
|
||||
d := dialectFor(cfg.Provider)
|
||||
dsn, err := databaseDSN(cfg, baseDir)
|
||||
if err != nil {
|
||||
return nil, d, err
|
||||
}
|
||||
if d.name == "sqlite" && !strings.HasPrefix(strings.ToLower(dsn), "file:") {
|
||||
if err := os.MkdirAll(filepath.Dir(dsn), 0o750); err != nil {
|
||||
return nil, d, err
|
||||
}
|
||||
}
|
||||
conn, err := sql.Open(d.driverName, dsn)
|
||||
if err != nil {
|
||||
return nil, d, err
|
||||
}
|
||||
conn.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
conn.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
conn.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetimeSeconds) * time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, d, err
|
||||
}
|
||||
return conn, d, nil
|
||||
}
|
||||
|
||||
func databaseDSN(cfg config.DatabaseConfig, baseDir string) (string, error) {
|
||||
provider := strings.ToLower(strings.TrimSpace(cfg.Provider))
|
||||
switch provider {
|
||||
case "", "sqlite":
|
||||
path := strings.TrimSpace(cfg.SQLitePath)
|
||||
if path == "" {
|
||||
path = filepath.Join(baseDir, "storage", "feedback.sqlite")
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(path), "file:") {
|
||||
return path, nil
|
||||
}
|
||||
if !filepath.IsAbs(path) && !strings.HasPrefix(strings.ToLower(path), "file:") {
|
||||
path = filepath.Join(baseDir, path)
|
||||
}
|
||||
return filepath.Clean(path), nil
|
||||
case "mysql":
|
||||
if strings.TrimSpace(cfg.DSN) != "" {
|
||||
return cfg.DSN, nil
|
||||
}
|
||||
if cfg.Host == "" || cfg.Name == "" || cfg.User == "" {
|
||||
return "", errors.New("mysql host, name and user are required")
|
||||
}
|
||||
host := cfg.Host
|
||||
if cfg.Port > 0 {
|
||||
host = host + ":" + strconv.Itoa(cfg.Port)
|
||||
}
|
||||
mysqlCfg := mysql.NewConfig()
|
||||
mysqlCfg.User = cfg.User
|
||||
mysqlCfg.Passwd = cfg.Password
|
||||
mysqlCfg.Net = "tcp"
|
||||
mysqlCfg.Addr = host
|
||||
mysqlCfg.DBName = cfg.Name
|
||||
mysqlCfg.ParseTime = true
|
||||
mysqlCfg.Loc = time.Local
|
||||
mysqlCfg.Params = map[string]string{"charset": "utf8mb4"}
|
||||
if cfg.SSLMode != "" && cfg.SSLMode != "disable" {
|
||||
mysqlCfg.TLSConfig = cfg.SSLMode
|
||||
}
|
||||
return mysqlCfg.FormatDSN(), nil
|
||||
case "postgres", "pgsql":
|
||||
if strings.TrimSpace(cfg.DSN) != "" {
|
||||
return cfg.DSN, nil
|
||||
}
|
||||
if cfg.Host == "" || cfg.Name == "" || cfg.User == "" {
|
||||
return "", errors.New("postgres host, name and user are required")
|
||||
}
|
||||
host := cfg.Host
|
||||
if cfg.Port > 0 {
|
||||
host = host + ":" + strconv.Itoa(cfg.Port)
|
||||
}
|
||||
u := url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(cfg.User, cfg.Password),
|
||||
Host: host,
|
||||
Path: "/" + cfg.Name,
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("sslmode", defaultString(cfg.SSLMode, "disable"))
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported database provider %q", cfg.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabase(cfg config.DatabaseConfig, baseDir string) error {
|
||||
conn, d, err := openSQLDatabase(cfg, baseDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
create := `CREATE TEMPORARY TABLE ymhut_connection_test (id INTEGER)`
|
||||
if d.name == "postgres" {
|
||||
create = `CREATE TEMP TABLE ymhut_connection_test (id INTEGER)`
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, d.rebind(create)); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
_ = tx.Rollback()
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user