package database import ( "fmt" "os" "path/filepath" "strings" "software-download-center/models" "software-download-center/utils" "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" gormLogger "gorm.io/gorm/logger" ) var ( DB *gorm.DB dbType string // "sqlite" 或 "mysql" dbLogger *utils.Logger ) // DatabaseConfig 数据库配置 type DatabaseConfig struct { Type string // "sqlite" 或 "mysql" DSN string // 数据库连接字符串 Host string // MySQL 主机 Port string // MySQL 端口 User string // MySQL 用户名 Password string // MySQL 密码 Database string // MySQL 数据库名 TablePrefix string // 表前缀 } // InitDB 初始化数据库(延迟初始化,允许失败) func InitDB() error { logger := utils.NewLogger() dbLogger = logger // 检测操作系统 osInfo := utils.DetectOS() logger.System(fmt.Sprintf("🖥️ 检测到操作系统: %s (%s)", osInfo.OS, osInfo.Arch)) // 检查是否已初始化 if IsDBInitialized() { // 从配置文件读取数据库配置 fileConfig, err := LoadDBConfig() if err != nil { logger.Warn(fmt.Sprintf("⚠️ 读取数据库配置失败: %s,使用环境变量", err.Error())) config := getDatabaseConfig(osInfo) return connectDB(config, logger) } config := &DatabaseConfig{ Type: fileConfig.Type, Host: fileConfig.Host, Port: fileConfig.Port, User: fileConfig.User, Password: fileConfig.Password, Database: fileConfig.Database, TablePrefix: fileConfig.TablePrefix, DSN: fileConfig.DSN, } return connectDB(config, logger) } // 未初始化,不强制连接 logger.System("ℹ️ 数据库未初始化,等待管理员配置") return nil } // connectDB 连接数据库 func connectDB(config *DatabaseConfig, logger *utils.Logger) error { // 确保 data 目录存在(仅 SQLite 需要) if config.Type == "sqlite" { if err := os.MkdirAll(config.DSN, 0755); err != nil { return fmt.Errorf("创建数据目录失败: %w", err) } } var err error dbType = config.Type if config.Type == "mysql" { // 使用 MySQL logger.System("📊 使用 MySQL 数据库") dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Host, config.Port, config.Database) DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) if err != nil { return fmt.Errorf("MySQL 连接失败: %w", err) } } else { // 使用 SQLite logger.System("📊 使用 SQLite 数据库") osInfo := utils.DetectOS() // 检查 CGO 支持 if !osInfo.IsCGO { logger.Warn("⚠️ 检测到 CGO 未启用,SQLite 可能需要 CGO 支持") } dbPath := filepath.Join(config.DSN, "app.db") logger.System(fmt.Sprintf("📁 数据库文件路径: %s", dbPath)) DB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) if err != nil { return fmt.Errorf("SQLite 连接失败: %w", err) } } if DB == nil { return fmt.Errorf("数据库连接失败") } // 设置表前缀 models.SetTablePrefix(config.TablePrefix) // 自动迁移 logger.System("🔄 开始数据库迁移...") if err := DB.AutoMigrate( &models.User{}, &models.Route{}, ); err != nil { return fmt.Errorf("数据库迁移失败: %w", err) } logger.System("✅ 数据库迁移完成") // 记录数据库信息 var userCount int64 DB.Model(&models.User{}).Count(&userCount) logger.System(fmt.Sprintf("📊 数据库类型: %s", strings.ToUpper(dbType))) logger.System(fmt.Sprintf("👥 当前用户数: %d", userCount)) return nil } // InitDBWithConfig 使用配置初始化数据库 func InitDBWithConfig(config *DatabaseConfig) error { logger := utils.NewLogger() dbLogger = logger // 连接数据库 if err := connectDB(config, logger); err != nil { return err } // 保存配置 fileConfig := &DBConfigFile{ Type: config.Type, Host: config.Host, Port: config.Port, User: config.User, Password: config.Password, Database: config.Database, TablePrefix: config.TablePrefix, DSN: config.DSN, Initialized: true, } if err := SaveDBConfig(fileConfig); err != nil { logger.Warn(fmt.Sprintf("⚠️ 保存数据库配置失败: %s", err.Error())) } return nil } // getDatabaseConfig 获取数据库配置(从环境变量) func getDatabaseConfig(osInfo *utils.OSInfo) *DatabaseConfig { config := &DatabaseConfig{ Type: getEnvOrDefault("DB_TYPE", "sqlite"), Host: getEnvOrDefault("DB_HOST", "localhost"), Port: getEnvOrDefault("DB_PORT", "3306"), User: getEnvOrDefault("DB_USER", "root"), Password: getEnvOrDefault("DB_PASSWORD", ""), Database: getEnvOrDefault("DB_NAME", "software_download_center"), TablePrefix: getEnvOrDefault("DB_TABLE_PREFIX", ""), } // 数据目录 config.DSN = osInfo.DataDir if config.DSN == "" { config.DSN = "data" } return config } // GetDatabaseConfigFromFile 从配置文件获取数据库配置 func GetDatabaseConfigFromFile() (*DatabaseConfig, error) { fileConfig, err := LoadDBConfig() if err != nil { return nil, err } return &DatabaseConfig{ Type: fileConfig.Type, Host: fileConfig.Host, Port: fileConfig.Port, User: fileConfig.User, Password: fileConfig.Password, Database: fileConfig.Database, TablePrefix: fileConfig.TablePrefix, DSN: fileConfig.DSN, }, nil } // GetDatabaseConfig 获取当前数据库配置(供外部调用) func GetDatabaseConfig() *DatabaseConfig { // 优先从配置文件读取 if config, err := GetDatabaseConfigFromFile(); err == nil && config != nil { return config } // 回退到环境变量 osInfo := utils.DetectOS() return getDatabaseConfig(osInfo) } // VerifyMySQLPassword 验证 MySQL 密码 func VerifyMySQLPassword(password string) error { config := GetDatabaseConfig() if config.Type != "mysql" { return fmt.Errorf("当前数据库类型不是 MySQL") } // 尝试使用提供的密码连接数据库 dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, password, config.Host, config.Port, config.Database) testDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) if err != nil { return fmt.Errorf("密码验证失败: %w", err) } // 关闭测试连接 sqlDB, _ := testDB.DB() if sqlDB != nil { sqlDB.Close() } return nil } // UpdateMySQLPassword 更新 MySQL root 密码 func UpdateMySQLPassword(newPassword string) error { config := GetDatabaseConfig() if config.Type != "mysql" { return fmt.Errorf("当前数据库类型不是 MySQL") } // 使用当前密码连接 currentDSN := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Host, config.Port, config.Database) db, err := gorm.Open(mysql.Open(currentDSN), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) if err != nil { return fmt.Errorf("连接数据库失败: %w", err) } sqlDB, _ := db.DB() defer sqlDB.Close() // 执行 ALTER USER 语句更新密码 // MySQL 8.0+ 使用 ALTER USER,旧版本使用 SET PASSWORD updateSQL := fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", config.User, "%", newPassword) // 尝试使用 ALTER USER(MySQL 5.7.6+ 和 8.0+) _, err = sqlDB.Exec(updateSQL) if err != nil { // 如果失败,尝试使用 SET PASSWORD(兼容旧版本) setPasswordSQL := fmt.Sprintf("SET PASSWORD FOR '%s'@'%s' = PASSWORD('%s')", config.User, "%", newPassword) _, err2 := sqlDB.Exec(setPasswordSQL) if err2 != nil { return fmt.Errorf("更新密码失败: %w (ALTER USER 失败: %v)", err2, err) } } // 刷新权限 _, err = sqlDB.Exec("FLUSH PRIVILEGES") if err != nil { // 刷新权限失败不影响密码更新,只记录警告 utils.NewLogger().Warn(fmt.Sprintf("刷新权限失败: %s", err.Error())) } return nil } // getEnvOrDefault 获取环境变量或返回默认值 func getEnvOrDefault(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } // GetDB 获取数据库实例 func GetDB() *gorm.DB { return DB } // GetDBType 获取数据库类型 func GetDBType() string { return dbType } // ConvertDatabase 转换数据库(MySQL <-> SQLite) func ConvertDatabase(targetType string, logger *utils.Logger) error { if dbType == targetType { return fmt.Errorf("数据库类型已经是 %s", targetType) } logger.System(fmt.Sprintf("🔄 开始数据库转换: %s -> %s", strings.ToUpper(dbType), strings.ToUpper(targetType))) // 导出当前数据库数据 users, routes, err := exportData() if err != nil { return fmt.Errorf("导出数据失败: %w", err) } logger.System(fmt.Sprintf("📤 已导出 %d 个用户, %d 个路由", len(users), len(routes))) // 关闭当前数据库连接 if DB != nil { sqlDB, _ := DB.DB() if sqlDB != nil { sqlDB.Close() } } // 初始化新数据库 osInfo := utils.DetectOS() config := getDatabaseConfig(osInfo) config.Type = targetType dbType = targetType var newDB *gorm.DB if targetType == "mysql" { dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Host, config.Port, config.Database) newDB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) } else { dbPath := filepath.Join(config.DSN, "app.db") newDB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) } if err != nil { return fmt.Errorf("连接新数据库失败: %w", err) } // 迁移表结构 if err := newDB.AutoMigrate(&models.User{}, &models.Route{}); err != nil { return fmt.Errorf("迁移表结构失败: %w", err) } // 导入数据 if err := importData(newDB, users, routes, logger); err != nil { return fmt.Errorf("导入数据失败: %w", err) } // 更新全局数据库实例 DB = newDB logger.System(fmt.Sprintf("✅ 数据库转换完成: %s", strings.ToUpper(targetType))) return nil } // exportData 导出数据 func exportData() ([]models.User, []models.Route, error) { var users []models.User var routes []models.Route if err := DB.Find(&users).Error; err != nil { return nil, nil, err } if err := DB.Find(&routes).Error; err != nil { return nil, nil, err } return users, routes, nil } // importData 导入数据 func importData(db *gorm.DB, users []models.User, routes []models.Route, logger *utils.Logger) error { // 导入用户 if len(users) > 0 { if err := db.Create(&users).Error; err != nil { logger.Warn(fmt.Sprintf("⚠️ 导入用户时出现错误: %s", err.Error())) // 尝试逐个导入 for _, user := range users { if err := db.Create(&user).Error; err != nil { logger.Warn(fmt.Sprintf("⚠️ 跳过用户 %s: %s", user.Username, err.Error())) } } } else { logger.System(fmt.Sprintf("✅ 成功导入 %d 个用户", len(users))) } } // 导入路由 if len(routes) > 0 { if err := db.Create(&routes).Error; err != nil { logger.Warn(fmt.Sprintf("⚠️ 导入路由时出现错误: %s", err.Error())) // 尝试逐个导入 for _, route := range routes { if err := db.Create(&route).Error; err != nil { logger.Warn(fmt.Sprintf("⚠️ 跳过路由 %s %s: %s", route.Method, route.Path, err.Error())) } } } else { logger.System(fmt.Sprintf("✅ 成功导入 %d 个路由", len(routes))) } } return nil }