431 lines
12 KiB
Go
431 lines
12 KiB
Go
package database
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
|
||
"software-download-center/models"
|
||
"software-download-center/utils"
|
||
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
gormLogger "gorm.io/gorm/logger"
|
||
)
|
||
|
||
var (
|
||
DB *gorm.DB
|
||
dbType string // "sqlite" 或 "mysql"
|
||
dbLogger *utils.Logger
|
||
)
|
||
|
||
// DatabaseConfig 数据库配置
|
||
type DatabaseConfig struct {
|
||
Type string // "sqlite" 或 "mysql"
|
||
DSN string // 数据库连接字符串
|
||
Host string // MySQL 主机
|
||
Port string // MySQL 端口
|
||
User string // MySQL 用户名
|
||
Password string // MySQL 密码
|
||
Database string // MySQL 数据库名
|
||
TablePrefix string // 表前缀
|
||
}
|
||
|
||
// InitDB 初始化数据库(延迟初始化,允许失败)
|
||
func InitDB() error {
|
||
logger := utils.NewLogger()
|
||
dbLogger = logger
|
||
|
||
// 检测操作系统
|
||
osInfo := utils.DetectOS()
|
||
logger.System(fmt.Sprintf("🖥️ 检测到操作系统: %s (%s)", osInfo.OS, osInfo.Arch))
|
||
|
||
// 检查是否已初始化
|
||
if IsDBInitialized() {
|
||
// 从配置文件读取数据库配置
|
||
fileConfig, err := LoadDBConfig()
|
||
if err != nil {
|
||
logger.Warn(fmt.Sprintf("⚠️ 读取数据库配置失败: %s,使用环境变量", err.Error()))
|
||
config := getDatabaseConfig(osInfo)
|
||
return connectDB(config, logger)
|
||
}
|
||
|
||
config := &DatabaseConfig{
|
||
Type: fileConfig.Type,
|
||
Host: fileConfig.Host,
|
||
Port: fileConfig.Port,
|
||
User: fileConfig.User,
|
||
Password: fileConfig.Password,
|
||
Database: fileConfig.Database,
|
||
TablePrefix: fileConfig.TablePrefix,
|
||
DSN: fileConfig.DSN,
|
||
}
|
||
return connectDB(config, logger)
|
||
}
|
||
|
||
// 未初始化,不强制连接
|
||
logger.System("ℹ️ 数据库未初始化,等待管理员配置")
|
||
return nil
|
||
}
|
||
|
||
// connectDB 连接数据库
|
||
func connectDB(config *DatabaseConfig, logger *utils.Logger) error {
|
||
|
||
// 确保 data 目录存在(仅 SQLite 需要)
|
||
if config.Type == "sqlite" {
|
||
if err := os.MkdirAll(config.DSN, 0755); err != nil {
|
||
return fmt.Errorf("创建数据目录失败: %w", err)
|
||
}
|
||
}
|
||
|
||
var err error
|
||
dbType = config.Type
|
||
|
||
if config.Type == "mysql" {
|
||
// 使用 MySQL
|
||
logger.System("📊 使用 MySQL 数据库")
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||
|
||
DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("MySQL 连接失败: %w", err)
|
||
}
|
||
} else {
|
||
// 使用 SQLite
|
||
logger.System("📊 使用 SQLite 数据库")
|
||
|
||
osInfo := utils.DetectOS()
|
||
// 检查 CGO 支持
|
||
if !osInfo.IsCGO {
|
||
logger.Warn("⚠️ 检测到 CGO 未启用,SQLite 可能需要 CGO 支持")
|
||
}
|
||
|
||
dbPath := filepath.Join(config.DSN, "app.db")
|
||
logger.System(fmt.Sprintf("📁 数据库文件路径: %s", dbPath))
|
||
|
||
DB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("SQLite 连接失败: %w", err)
|
||
}
|
||
}
|
||
|
||
if DB == nil {
|
||
return fmt.Errorf("数据库连接失败")
|
||
}
|
||
|
||
// 设置表前缀
|
||
models.SetTablePrefix(config.TablePrefix)
|
||
|
||
// 自动迁移
|
||
logger.System("🔄 开始数据库迁移...")
|
||
if err := DB.AutoMigrate(
|
||
&models.User{},
|
||
&models.Route{},
|
||
); err != nil {
|
||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||
}
|
||
logger.System("✅ 数据库迁移完成")
|
||
|
||
// 记录数据库信息
|
||
var userCount int64
|
||
DB.Model(&models.User{}).Count(&userCount)
|
||
logger.System(fmt.Sprintf("📊 数据库类型: %s", strings.ToUpper(dbType)))
|
||
logger.System(fmt.Sprintf("👥 当前用户数: %d", userCount))
|
||
|
||
return nil
|
||
}
|
||
|
||
// InitDBWithConfig 使用配置初始化数据库
|
||
func InitDBWithConfig(config *DatabaseConfig) error {
|
||
logger := utils.NewLogger()
|
||
dbLogger = logger
|
||
|
||
// 连接数据库
|
||
if err := connectDB(config, logger); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 保存配置
|
||
fileConfig := &DBConfigFile{
|
||
Type: config.Type,
|
||
Host: config.Host,
|
||
Port: config.Port,
|
||
User: config.User,
|
||
Password: config.Password,
|
||
Database: config.Database,
|
||
TablePrefix: config.TablePrefix,
|
||
DSN: config.DSN,
|
||
Initialized: true,
|
||
}
|
||
|
||
if err := SaveDBConfig(fileConfig); err != nil {
|
||
logger.Warn(fmt.Sprintf("⚠️ 保存数据库配置失败: %s", err.Error()))
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getDatabaseConfig 获取数据库配置(从环境变量)
|
||
func getDatabaseConfig(osInfo *utils.OSInfo) *DatabaseConfig {
|
||
config := &DatabaseConfig{
|
||
Type: getEnvOrDefault("DB_TYPE", "sqlite"),
|
||
Host: getEnvOrDefault("DB_HOST", "localhost"),
|
||
Port: getEnvOrDefault("DB_PORT", "3306"),
|
||
User: getEnvOrDefault("DB_USER", "root"),
|
||
Password: getEnvOrDefault("DB_PASSWORD", ""),
|
||
Database: getEnvOrDefault("DB_NAME", "software_download_center"),
|
||
TablePrefix: getEnvOrDefault("DB_TABLE_PREFIX", ""),
|
||
}
|
||
|
||
// 数据目录
|
||
config.DSN = osInfo.DataDir
|
||
if config.DSN == "" {
|
||
config.DSN = "data"
|
||
}
|
||
|
||
return config
|
||
}
|
||
|
||
// GetDatabaseConfigFromFile 从配置文件获取数据库配置
|
||
func GetDatabaseConfigFromFile() (*DatabaseConfig, error) {
|
||
fileConfig, err := LoadDBConfig()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &DatabaseConfig{
|
||
Type: fileConfig.Type,
|
||
Host: fileConfig.Host,
|
||
Port: fileConfig.Port,
|
||
User: fileConfig.User,
|
||
Password: fileConfig.Password,
|
||
Database: fileConfig.Database,
|
||
TablePrefix: fileConfig.TablePrefix,
|
||
DSN: fileConfig.DSN,
|
||
}, nil
|
||
}
|
||
|
||
// GetDatabaseConfig 获取当前数据库配置(供外部调用)
|
||
func GetDatabaseConfig() *DatabaseConfig {
|
||
// 优先从配置文件读取
|
||
if config, err := GetDatabaseConfigFromFile(); err == nil && config != nil {
|
||
return config
|
||
}
|
||
// 回退到环境变量
|
||
osInfo := utils.DetectOS()
|
||
return getDatabaseConfig(osInfo)
|
||
}
|
||
|
||
// VerifyMySQLPassword 验证 MySQL 密码
|
||
func VerifyMySQLPassword(password string) error {
|
||
config := GetDatabaseConfig()
|
||
if config.Type != "mysql" {
|
||
return fmt.Errorf("当前数据库类型不是 MySQL")
|
||
}
|
||
|
||
// 尝试使用提供的密码连接数据库
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||
config.User, password, config.Host, config.Port, config.Database)
|
||
|
||
testDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("密码验证失败: %w", err)
|
||
}
|
||
|
||
// 关闭测试连接
|
||
sqlDB, _ := testDB.DB()
|
||
if sqlDB != nil {
|
||
sqlDB.Close()
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateMySQLPassword 更新 MySQL root 密码
|
||
func UpdateMySQLPassword(newPassword string) error {
|
||
config := GetDatabaseConfig()
|
||
if config.Type != "mysql" {
|
||
return fmt.Errorf("当前数据库类型不是 MySQL")
|
||
}
|
||
|
||
// 使用当前密码连接
|
||
currentDSN := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||
|
||
db, err := gorm.Open(mysql.Open(currentDSN), &gorm.Config{
|
||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("连接数据库失败: %w", err)
|
||
}
|
||
|
||
sqlDB, _ := db.DB()
|
||
defer sqlDB.Close()
|
||
|
||
// 执行 ALTER USER 语句更新密码
|
||
// MySQL 8.0+ 使用 ALTER USER,旧版本使用 SET PASSWORD
|
||
updateSQL := fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", config.User, "%", newPassword)
|
||
|
||
// 尝试使用 ALTER USER(MySQL 5.7.6+ 和 8.0+)
|
||
_, err = sqlDB.Exec(updateSQL)
|
||
if err != nil {
|
||
// 如果失败,尝试使用 SET PASSWORD(兼容旧版本)
|
||
setPasswordSQL := fmt.Sprintf("SET PASSWORD FOR '%s'@'%s' = PASSWORD('%s')", config.User, "%", newPassword)
|
||
_, err2 := sqlDB.Exec(setPasswordSQL)
|
||
if err2 != nil {
|
||
return fmt.Errorf("更新密码失败: %w (ALTER USER 失败: %v)", err2, err)
|
||
}
|
||
}
|
||
|
||
// 刷新权限
|
||
_, err = sqlDB.Exec("FLUSH PRIVILEGES")
|
||
if err != nil {
|
||
// 刷新权限失败不影响密码更新,只记录警告
|
||
utils.NewLogger().Warn(fmt.Sprintf("刷新权限失败: %s", err.Error()))
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getEnvOrDefault 获取环境变量或返回默认值
|
||
func getEnvOrDefault(key, defaultValue string) string {
|
||
if value := os.Getenv(key); value != "" {
|
||
return value
|
||
}
|
||
return defaultValue
|
||
}
|
||
|
||
// GetDB 获取数据库实例
|
||
func GetDB() *gorm.DB {
|
||
return DB
|
||
}
|
||
|
||
// GetDBType 获取数据库类型
|
||
func GetDBType() string {
|
||
return dbType
|
||
}
|
||
|
||
// ConvertDatabase 转换数据库(MySQL <-> SQLite)
|
||
func ConvertDatabase(targetType string, logger *utils.Logger) error {
|
||
if dbType == targetType {
|
||
return fmt.Errorf("数据库类型已经是 %s", targetType)
|
||
}
|
||
|
||
logger.System(fmt.Sprintf("🔄 开始数据库转换: %s -> %s", strings.ToUpper(dbType), strings.ToUpper(targetType)))
|
||
|
||
// 导出当前数据库数据
|
||
users, routes, err := exportData()
|
||
if err != nil {
|
||
return fmt.Errorf("导出数据失败: %w", err)
|
||
}
|
||
logger.System(fmt.Sprintf("📤 已导出 %d 个用户, %d 个路由", len(users), len(routes)))
|
||
|
||
// 关闭当前数据库连接
|
||
if DB != nil {
|
||
sqlDB, _ := DB.DB()
|
||
if sqlDB != nil {
|
||
sqlDB.Close()
|
||
}
|
||
}
|
||
|
||
// 初始化新数据库
|
||
osInfo := utils.DetectOS()
|
||
config := getDatabaseConfig(osInfo)
|
||
config.Type = targetType
|
||
dbType = targetType
|
||
|
||
var newDB *gorm.DB
|
||
if targetType == "mysql" {
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||
newDB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||
})
|
||
} else {
|
||
dbPath := filepath.Join(config.DSN, "app.db")
|
||
newDB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||
})
|
||
}
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("连接新数据库失败: %w", err)
|
||
}
|
||
|
||
// 迁移表结构
|
||
if err := newDB.AutoMigrate(&models.User{}, &models.Route{}); err != nil {
|
||
return fmt.Errorf("迁移表结构失败: %w", err)
|
||
}
|
||
|
||
// 导入数据
|
||
if err := importData(newDB, users, routes, logger); err != nil {
|
||
return fmt.Errorf("导入数据失败: %w", err)
|
||
}
|
||
|
||
// 更新全局数据库实例
|
||
DB = newDB
|
||
logger.System(fmt.Sprintf("✅ 数据库转换完成: %s", strings.ToUpper(targetType)))
|
||
|
||
return nil
|
||
}
|
||
|
||
// exportData 导出数据
|
||
func exportData() ([]models.User, []models.Route, error) {
|
||
var users []models.User
|
||
var routes []models.Route
|
||
|
||
if err := DB.Find(&users).Error; err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
if err := DB.Find(&routes).Error; err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
return users, routes, nil
|
||
}
|
||
|
||
// importData 导入数据
|
||
func importData(db *gorm.DB, users []models.User, routes []models.Route, logger *utils.Logger) error {
|
||
// 导入用户
|
||
if len(users) > 0 {
|
||
if err := db.Create(&users).Error; err != nil {
|
||
logger.Warn(fmt.Sprintf("⚠️ 导入用户时出现错误: %s", err.Error()))
|
||
// 尝试逐个导入
|
||
for _, user := range users {
|
||
if err := db.Create(&user).Error; err != nil {
|
||
logger.Warn(fmt.Sprintf("⚠️ 跳过用户 %s: %s", user.Username, err.Error()))
|
||
}
|
||
}
|
||
} else {
|
||
logger.System(fmt.Sprintf("✅ 成功导入 %d 个用户", len(users)))
|
||
}
|
||
}
|
||
|
||
// 导入路由
|
||
if len(routes) > 0 {
|
||
if err := db.Create(&routes).Error; err != nil {
|
||
logger.Warn(fmt.Sprintf("⚠️ 导入路由时出现错误: %s", err.Error()))
|
||
// 尝试逐个导入
|
||
for _, route := range routes {
|
||
if err := db.Create(&route).Error; err != nil {
|
||
logger.Warn(fmt.Sprintf("⚠️ 跳过路由 %s %s: %s", route.Method, route.Path, err.Error()))
|
||
}
|
||
}
|
||
} else {
|
||
logger.System(fmt.Sprintf("✅ 成功导入 %d 个路由", len(routes)))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|