264 lines
7.0 KiB
Go
264 lines
7.0 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
mysql "github.com/go-sql-driver/mysql"
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
_ "modernc.org/sqlite"
|
|
|
|
"ymhut-box/server/feedback-mailer/internal/config"
|
|
)
|
|
|
|
type dialect struct {
|
|
name string
|
|
driverName string
|
|
}
|
|
|
|
func dialectFor(provider string) dialect {
|
|
switch strings.ToLower(strings.TrimSpace(provider)) {
|
|
case "mysql":
|
|
return dialect{name: "mysql", driverName: "mysql"}
|
|
case "postgres", "pgsql":
|
|
return dialect{name: "postgres", driverName: "pgx"}
|
|
default:
|
|
return dialect{name: "sqlite", driverName: "sqlite"}
|
|
}
|
|
}
|
|
|
|
func (d dialect) rebind(query string) string {
|
|
if d.name != "postgres" {
|
|
return query
|
|
}
|
|
var builder strings.Builder
|
|
index := 1
|
|
inSingle := false
|
|
for i := 0; i < len(query); i++ {
|
|
ch := query[i]
|
|
if ch == '\'' {
|
|
inSingle = !inSingle
|
|
builder.WriteByte(ch)
|
|
continue
|
|
}
|
|
if ch == '?' && !inSingle {
|
|
builder.WriteByte('$')
|
|
builder.WriteString(strconv.Itoa(index))
|
|
index++
|
|
continue
|
|
}
|
|
builder.WriteByte(ch)
|
|
}
|
|
return builder.String()
|
|
}
|
|
|
|
func (d dialect) boolValue(value bool) int {
|
|
if value {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d dialect) insertIgnore(table string, columns, conflict []string) string {
|
|
placeholders := make([]string, len(columns))
|
|
for i := range placeholders {
|
|
placeholders[i] = "?"
|
|
}
|
|
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
|
switch d.name {
|
|
case "mysql":
|
|
return strings.Replace(base, "INSERT INTO", "INSERT IGNORE INTO", 1)
|
|
case "postgres":
|
|
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO NOTHING"
|
|
default:
|
|
return strings.Replace(base, "INSERT INTO", "INSERT OR IGNORE INTO", 1)
|
|
}
|
|
}
|
|
|
|
func (d dialect) upsert(table string, columns, conflict []string) string {
|
|
placeholders := make([]string, len(columns))
|
|
for i := range placeholders {
|
|
placeholders[i] = "?"
|
|
}
|
|
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
|
updateColumns := make([]string, 0, len(columns))
|
|
conflictSet := map[string]bool{}
|
|
for _, column := range conflict {
|
|
conflictSet[column] = true
|
|
}
|
|
for _, column := range columns {
|
|
if conflictSet[column] {
|
|
continue
|
|
}
|
|
switch d.name {
|
|
case "mysql":
|
|
updateColumns = append(updateColumns, fmt.Sprintf("%s = VALUES(%s)", column, column))
|
|
case "postgres":
|
|
updateColumns = append(updateColumns, fmt.Sprintf("%s = EXCLUDED.%s", column, column))
|
|
default:
|
|
updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", column, column))
|
|
}
|
|
}
|
|
if len(updateColumns) == 0 {
|
|
return d.insertIgnore(table, columns, conflict)
|
|
}
|
|
switch d.name {
|
|
case "mysql":
|
|
return base + " ON DUPLICATE KEY UPDATE " + strings.Join(updateColumns, ", ")
|
|
case "postgres":
|
|
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updateColumns, ", ")
|
|
default:
|
|
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updateColumns, ", ")
|
|
}
|
|
}
|
|
|
|
func (d dialect) textType() string {
|
|
return "TEXT"
|
|
}
|
|
|
|
func (d dialect) idType() string {
|
|
switch d.name {
|
|
case "mysql":
|
|
return "BIGINT PRIMARY KEY AUTO_INCREMENT"
|
|
case "postgres":
|
|
return "BIGSERIAL PRIMARY KEY"
|
|
default:
|
|
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
|
}
|
|
}
|
|
|
|
func (d dialect) columnDefault(def string) string {
|
|
def = strings.ReplaceAll(def, `DEFAULT ""`, `DEFAULT ''`)
|
|
if d.name == "mysql" {
|
|
def = strings.ReplaceAll(def, `TEXT NOT NULL DEFAULT ''`, `VARCHAR(3000) NOT NULL DEFAULT ''`)
|
|
}
|
|
return def
|
|
}
|
|
|
|
func openSQLDatabase(cfg config.DatabaseConfig, baseDir string) (*sql.DB, dialect, error) {
|
|
d := dialectFor(cfg.Provider)
|
|
dsn, err := databaseDSN(cfg, baseDir)
|
|
if err != nil {
|
|
return nil, d, err
|
|
}
|
|
if d.name == "sqlite" && !strings.HasPrefix(strings.ToLower(dsn), "file:") {
|
|
if err := os.MkdirAll(filepath.Dir(dsn), 0o750); err != nil {
|
|
return nil, d, err
|
|
}
|
|
}
|
|
conn, err := sql.Open(d.driverName, dsn)
|
|
if err != nil {
|
|
return nil, d, err
|
|
}
|
|
conn.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
conn.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
conn.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetimeSeconds) * time.Second)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := conn.PingContext(ctx); err != nil {
|
|
_ = conn.Close()
|
|
return nil, d, err
|
|
}
|
|
return conn, d, nil
|
|
}
|
|
|
|
func databaseDSN(cfg config.DatabaseConfig, baseDir string) (string, error) {
|
|
provider := strings.ToLower(strings.TrimSpace(cfg.Provider))
|
|
switch provider {
|
|
case "", "sqlite":
|
|
path := strings.TrimSpace(cfg.SQLitePath)
|
|
if path == "" {
|
|
path = filepath.Join(baseDir, "storage", "feedback.sqlite")
|
|
}
|
|
if strings.HasPrefix(strings.ToLower(path), "file:") {
|
|
return path, nil
|
|
}
|
|
if !filepath.IsAbs(path) && !strings.HasPrefix(strings.ToLower(path), "file:") {
|
|
path = filepath.Join(baseDir, path)
|
|
}
|
|
return filepath.Clean(path), nil
|
|
case "mysql":
|
|
if strings.TrimSpace(cfg.DSN) != "" {
|
|
return cfg.DSN, nil
|
|
}
|
|
if cfg.Host == "" || cfg.Name == "" || cfg.User == "" {
|
|
return "", errors.New("mysql host, name and user are required")
|
|
}
|
|
host := cfg.Host
|
|
if cfg.Port > 0 {
|
|
host = host + ":" + strconv.Itoa(cfg.Port)
|
|
}
|
|
mysqlCfg := mysql.NewConfig()
|
|
mysqlCfg.User = cfg.User
|
|
mysqlCfg.Passwd = cfg.Password
|
|
mysqlCfg.Net = "tcp"
|
|
mysqlCfg.Addr = host
|
|
mysqlCfg.DBName = cfg.Name
|
|
mysqlCfg.ParseTime = true
|
|
mysqlCfg.Loc = time.Local
|
|
mysqlCfg.Params = map[string]string{"charset": "utf8mb4"}
|
|
if cfg.SSLMode != "" && cfg.SSLMode != "disable" {
|
|
mysqlCfg.TLSConfig = cfg.SSLMode
|
|
}
|
|
return mysqlCfg.FormatDSN(), nil
|
|
case "postgres", "pgsql":
|
|
if strings.TrimSpace(cfg.DSN) != "" {
|
|
return cfg.DSN, nil
|
|
}
|
|
if cfg.Host == "" || cfg.Name == "" || cfg.User == "" {
|
|
return "", errors.New("postgres host, name and user are required")
|
|
}
|
|
host := cfg.Host
|
|
if cfg.Port > 0 {
|
|
host = host + ":" + strconv.Itoa(cfg.Port)
|
|
}
|
|
u := url.URL{
|
|
Scheme: "postgres",
|
|
User: url.UserPassword(cfg.User, cfg.Password),
|
|
Host: host,
|
|
Path: "/" + cfg.Name,
|
|
}
|
|
params := url.Values{}
|
|
params.Set("sslmode", defaultString(cfg.SSLMode, "disable"))
|
|
u.RawQuery = params.Encode()
|
|
return u.String(), nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported database provider %q", cfg.Provider)
|
|
}
|
|
}
|
|
|
|
func TestDatabase(cfg config.DatabaseConfig, baseDir string) error {
|
|
conn, d, err := openSQLDatabase(cfg, baseDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := conn.PingContext(ctx); err != nil {
|
|
return err
|
|
}
|
|
tx, err := conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
create := `CREATE TEMPORARY TABLE ymhut_connection_test (id INTEGER)`
|
|
if d.name == "postgres" {
|
|
create = `CREATE TEMP TABLE ymhut_connection_test (id INTEGER)`
|
|
}
|
|
if _, err := tx.ExecContext(ctx, d.rebind(create)); err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
_ = tx.Rollback()
|
|
return nil
|
|
}
|