@@ -0,0 +1,173 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
)
|
||||
|
||||
type dialect struct {
|
||||
name string
|
||||
driver string
|
||||
}
|
||||
|
||||
func dialectFor(provider string) dialect {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "mysql":
|
||||
return dialect{name: "mysql", driver: "mysql"}
|
||||
default:
|
||||
return dialect{name: "sqlite", driver: "sqlite"}
|
||||
}
|
||||
}
|
||||
|
||||
func (d dialect) rebind(query string) string {
|
||||
return query
|
||||
}
|
||||
|
||||
func (d dialect) idType() string {
|
||||
if d.name == "mysql" {
|
||||
return "BIGINT PRIMARY KEY AUTO_INCREMENT"
|
||||
}
|
||||
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||
}
|
||||
|
||||
func (d dialect) boolExpr(value bool) int {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d dialect) upsert(table string, columns, conflict []string) string {
|
||||
placeholders := make([]string, len(columns))
|
||||
for i := range placeholders {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||
conflictSet := map[string]bool{}
|
||||
for _, column := range conflict {
|
||||
conflictSet[column] = true
|
||||
}
|
||||
updates := []string{}
|
||||
for _, column := range columns {
|
||||
if conflictSet[column] {
|
||||
continue
|
||||
}
|
||||
if d.name == "mysql" {
|
||||
updates = append(updates, fmt.Sprintf("%s = VALUES(%s)", column, column))
|
||||
} else {
|
||||
updates = append(updates, fmt.Sprintf("%s = excluded.%s", column, column))
|
||||
}
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
if d.name == "mysql" {
|
||||
return strings.Replace(base, "INSERT INTO", "INSERT IGNORE INTO", 1)
|
||||
}
|
||||
return strings.Replace(base, "INSERT INTO", "INSERT OR IGNORE INTO", 1)
|
||||
}
|
||||
if d.name == "mysql" {
|
||||
return base + " ON DUPLICATE KEY UPDATE " + strings.Join(updates, ", ")
|
||||
}
|
||||
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updates, ", ")
|
||||
}
|
||||
|
||||
func (d dialect) limitOffset(limit, offset int) string {
|
||||
return fmt.Sprintf(" LIMIT %d OFFSET %d", limit, offset)
|
||||
}
|
||||
|
||||
func openSQLDatabase(cfg config.DatabaseConfig) (*sql.DB, dialect, error) {
|
||||
d := dialectFor(cfg.Provider)
|
||||
dsn := strings.TrimSpace(cfg.SQLitePath)
|
||||
if d.name == "mysql" {
|
||||
dsn = strings.TrimSpace(cfg.MySQLDSN)
|
||||
if dsn == "" {
|
||||
return nil, d, errors.New("mysql_dsn is required")
|
||||
}
|
||||
} else {
|
||||
if dsn == "" {
|
||||
return nil, d, errors.New("sqlite path is required")
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(dsn), "file:") {
|
||||
if err := os.MkdirAll(filepath.Dir(dsn), 0o750); err != nil {
|
||||
return nil, d, err
|
||||
}
|
||||
}
|
||||
}
|
||||
conn, err := sql.Open(d.driver, dsn)
|
||||
if err != nil {
|
||||
return nil, d, err
|
||||
}
|
||||
conn.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
conn.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
conn.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetimeSeconds) * time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, d, err
|
||||
}
|
||||
return conn, d, nil
|
||||
}
|
||||
|
||||
func TestDatabase(cfg config.DatabaseConfig) error {
|
||||
if cfg.Provider == "" {
|
||||
cfg.Provider = "sqlite"
|
||||
}
|
||||
if cfg.MaxOpenConns <= 0 {
|
||||
cfg.MaxOpenConns = 1
|
||||
}
|
||||
if cfg.MaxIdleConns <= 0 {
|
||||
cfg.MaxIdleConns = 1
|
||||
}
|
||||
if cfg.ConnMaxLifetimeSeconds <= 0 {
|
||||
cfg.ConnMaxLifetimeSeconds = 60
|
||||
}
|
||||
conn, d, err := openSQLDatabase(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
create := "CREATE TEMPORARY TABLE ymhut_unified_connection_test (id INTEGER)"
|
||||
if _, err := tx.ExecContext(ctx, d.rebind(create)); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
_ = tx.Rollback()
|
||||
return nil
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
items := make([]string, n)
|
||||
for i := range items {
|
||||
items[i] = "?"
|
||||
}
|
||||
return strings.Join(items, ", ")
|
||||
}
|
||||
|
||||
func atoiDefault(value string, fallback int) int {
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,65 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
)
|
||||
|
||||
func TestOpenImportsJSONPrototypeIntoSQLite(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
path := filepath.Join(root, "unified.sqlite")
|
||||
prototype := state{
|
||||
Admins: []adminRow{{
|
||||
ID: 1,
|
||||
Username: "admin",
|
||||
PasswordHash: passwordHash("admin"),
|
||||
PasswordChanged: false,
|
||||
CreatedAt: "2026-01-01T00:00:00Z",
|
||||
UpdatedAt: "2026-01-01T00:00:00Z",
|
||||
}},
|
||||
Feedbacks: []Feedback{{Code: "FB-20260101-ABCDEF", Title: "Imported", Type: "issue", Severity: "normal", Body: "hello"}},
|
||||
Sources: []Source{{CategoryID: "ip", CategoryName: "IP", SourceID: "ip-demo", Name: "IP Demo", APIURL: "https://example.com/ip", Enabled: true, ClientVisible: true}},
|
||||
}
|
||||
data, err := json.Marshal(prototype)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o640); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
store, err := Open(&config.Config{
|
||||
StorageDir: root,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: path,
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
MaxOpenConns: 1,
|
||||
MaxIdleConns: 1,
|
||||
ConnMaxLifetimeSeconds: 60,
|
||||
},
|
||||
UploadGuard: config.UploadGuardConfig{MaxZipFiles: 80, MaxDecompressedBytes: 30 << 20, MaxSingleFileBytes: 8 << 20, MaxCompressionRatio: 120, MaxReadableTextBytes: 256 << 10, AllowUnexpectedZipFiles: true},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
if _, _, err := store.VerifyAdminPassword(context.Background(), "admin", "admin"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := store.GetFeedback("FB-20260101-ABCDEF"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count, err := store.CountSources(); err != nil || count != 1 {
|
||||
t.Fatalf("CountSources = %d, %v", count, err)
|
||||
}
|
||||
matches, _ := filepath.Glob(path + ".json-prototype-*.bak")
|
||||
if len(matches) != 1 {
|
||||
t.Fatalf("expected prototype backup, got %v", matches)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user