@@ -0,0 +1,115 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DBConfigFile 数据库配置文件结构
|
||||
type DBConfigFile struct {
|
||||
Type string `json:"type"` // "sqlite" 或 "mysql"
|
||||
Host string `json:"host"` // MySQL 主机
|
||||
Port string `json:"port"` // MySQL 端口
|
||||
User string `json:"user"` // MySQL 用户名
|
||||
Password string `json:"password"` // MySQL 密码
|
||||
Database string `json:"database"` // MySQL 数据库名
|
||||
TablePrefix string `json:"table_prefix"` // 表前缀
|
||||
DSN string `json:"dsn"` // SQLite 数据目录
|
||||
Initialized bool `json:"initialized"` // 是否已初始化
|
||||
}
|
||||
|
||||
var (
|
||||
configFile *DBConfigFile
|
||||
configFileLock sync.RWMutex
|
||||
configFilePath = "data/db-config.json"
|
||||
)
|
||||
|
||||
// LoadDBConfig 加载数据库配置
|
||||
func LoadDBConfig() (*DBConfigFile, error) {
|
||||
configFileLock.RLock()
|
||||
if configFile != nil {
|
||||
configFileLock.RUnlock()
|
||||
return configFile, nil
|
||||
}
|
||||
configFileLock.RUnlock()
|
||||
|
||||
configFileLock.Lock()
|
||||
defer configFileLock.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if configFile != nil {
|
||||
return configFile, nil
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(configFilePath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建配置目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 读取配置文件
|
||||
data, err := os.ReadFile(configFilePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// 配置文件不存在,返回默认配置
|
||||
configFile = &DBConfigFile{
|
||||
Type: "mysql",
|
||||
Host: "localhost",
|
||||
Port: "3306",
|
||||
User: "root",
|
||||
Password: "",
|
||||
Database: "software_download_center",
|
||||
TablePrefix: "",
|
||||
DSN: "data",
|
||||
Initialized: false,
|
||||
}
|
||||
return configFile, nil
|
||||
}
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
configFile = &DBConfigFile{}
|
||||
if err := json.Unmarshal(data, configFile); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
return configFile, nil
|
||||
}
|
||||
|
||||
// SaveDBConfig 保存数据库配置
|
||||
func SaveDBConfig(config *DBConfigFile) error {
|
||||
configFileLock.Lock()
|
||||
defer configFileLock.Unlock()
|
||||
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(configFilePath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("创建配置目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 序列化配置
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(configFilePath, data, 0644); err != nil {
|
||||
return fmt.Errorf("写入配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
configFile = config
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDBInitialized 检查数据库是否已初始化
|
||||
func IsDBInitialized() bool {
|
||||
config, err := LoadDBConfig()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.Initialized
|
||||
}
|
||||
@@ -0,0 +1,430 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"software-download-center/models"
|
||||
"software-download-center/utils"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
gormLogger "gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
DB *gorm.DB
|
||||
dbType string // "sqlite" 或 "mysql"
|
||||
dbLogger *utils.Logger
|
||||
)
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Type string // "sqlite" 或 "mysql"
|
||||
DSN string // 数据库连接字符串
|
||||
Host string // MySQL 主机
|
||||
Port string // MySQL 端口
|
||||
User string // MySQL 用户名
|
||||
Password string // MySQL 密码
|
||||
Database string // MySQL 数据库名
|
||||
TablePrefix string // 表前缀
|
||||
}
|
||||
|
||||
// InitDB 初始化数据库(延迟初始化,允许失败)
|
||||
func InitDB() error {
|
||||
logger := utils.NewLogger()
|
||||
dbLogger = logger
|
||||
|
||||
// 检测操作系统
|
||||
osInfo := utils.DetectOS()
|
||||
logger.System(fmt.Sprintf("🖥️ 检测到操作系统: %s (%s)", osInfo.OS, osInfo.Arch))
|
||||
|
||||
// 检查是否已初始化
|
||||
if IsDBInitialized() {
|
||||
// 从配置文件读取数据库配置
|
||||
fileConfig, err := LoadDBConfig()
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("⚠️ 读取数据库配置失败: %s,使用环境变量", err.Error()))
|
||||
config := getDatabaseConfig(osInfo)
|
||||
return connectDB(config, logger)
|
||||
}
|
||||
|
||||
config := &DatabaseConfig{
|
||||
Type: fileConfig.Type,
|
||||
Host: fileConfig.Host,
|
||||
Port: fileConfig.Port,
|
||||
User: fileConfig.User,
|
||||
Password: fileConfig.Password,
|
||||
Database: fileConfig.Database,
|
||||
TablePrefix: fileConfig.TablePrefix,
|
||||
DSN: fileConfig.DSN,
|
||||
}
|
||||
return connectDB(config, logger)
|
||||
}
|
||||
|
||||
// 未初始化,不强制连接
|
||||
logger.System("ℹ️ 数据库未初始化,等待管理员配置")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectDB 连接数据库
|
||||
func connectDB(config *DatabaseConfig, logger *utils.Logger) error {
|
||||
|
||||
// 确保 data 目录存在(仅 SQLite 需要)
|
||||
if config.Type == "sqlite" {
|
||||
if err := os.MkdirAll(config.DSN, 0755); err != nil {
|
||||
return fmt.Errorf("创建数据目录失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
dbType = config.Type
|
||||
|
||||
if config.Type == "mysql" {
|
||||
// 使用 MySQL
|
||||
logger.System("📊 使用 MySQL 数据库")
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||||
|
||||
DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("MySQL 连接失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 使用 SQLite
|
||||
logger.System("📊 使用 SQLite 数据库")
|
||||
|
||||
osInfo := utils.DetectOS()
|
||||
// 检查 CGO 支持
|
||||
if !osInfo.IsCGO {
|
||||
logger.Warn("⚠️ 检测到 CGO 未启用,SQLite 可能需要 CGO 支持")
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(config.DSN, "app.db")
|
||||
logger.System(fmt.Sprintf("📁 数据库文件路径: %s", dbPath))
|
||||
|
||||
DB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("SQLite 连接失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if DB == nil {
|
||||
return fmt.Errorf("数据库连接失败")
|
||||
}
|
||||
|
||||
// 设置表前缀
|
||||
models.SetTablePrefix(config.TablePrefix)
|
||||
|
||||
// 自动迁移
|
||||
logger.System("🔄 开始数据库迁移...")
|
||||
if err := DB.AutoMigrate(
|
||||
&models.User{},
|
||||
&models.Route{},
|
||||
); err != nil {
|
||||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
logger.System("✅ 数据库迁移完成")
|
||||
|
||||
// 记录数据库信息
|
||||
var userCount int64
|
||||
DB.Model(&models.User{}).Count(&userCount)
|
||||
logger.System(fmt.Sprintf("📊 数据库类型: %s", strings.ToUpper(dbType)))
|
||||
logger.System(fmt.Sprintf("👥 当前用户数: %d", userCount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitDBWithConfig 使用配置初始化数据库
|
||||
func InitDBWithConfig(config *DatabaseConfig) error {
|
||||
logger := utils.NewLogger()
|
||||
dbLogger = logger
|
||||
|
||||
// 连接数据库
|
||||
if err := connectDB(config, logger); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
fileConfig := &DBConfigFile{
|
||||
Type: config.Type,
|
||||
Host: config.Host,
|
||||
Port: config.Port,
|
||||
User: config.User,
|
||||
Password: config.Password,
|
||||
Database: config.Database,
|
||||
TablePrefix: config.TablePrefix,
|
||||
DSN: config.DSN,
|
||||
Initialized: true,
|
||||
}
|
||||
|
||||
if err := SaveDBConfig(fileConfig); err != nil {
|
||||
logger.Warn(fmt.Sprintf("⚠️ 保存数据库配置失败: %s", err.Error()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDatabaseConfig 获取数据库配置(从环境变量)
|
||||
func getDatabaseConfig(osInfo *utils.OSInfo) *DatabaseConfig {
|
||||
config := &DatabaseConfig{
|
||||
Type: getEnvOrDefault("DB_TYPE", "sqlite"),
|
||||
Host: getEnvOrDefault("DB_HOST", "localhost"),
|
||||
Port: getEnvOrDefault("DB_PORT", "3306"),
|
||||
User: getEnvOrDefault("DB_USER", "root"),
|
||||
Password: getEnvOrDefault("DB_PASSWORD", ""),
|
||||
Database: getEnvOrDefault("DB_NAME", "software_download_center"),
|
||||
TablePrefix: getEnvOrDefault("DB_TABLE_PREFIX", ""),
|
||||
}
|
||||
|
||||
// 数据目录
|
||||
config.DSN = osInfo.DataDir
|
||||
if config.DSN == "" {
|
||||
config.DSN = "data"
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// GetDatabaseConfigFromFile 从配置文件获取数据库配置
|
||||
func GetDatabaseConfigFromFile() (*DatabaseConfig, error) {
|
||||
fileConfig, err := LoadDBConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &DatabaseConfig{
|
||||
Type: fileConfig.Type,
|
||||
Host: fileConfig.Host,
|
||||
Port: fileConfig.Port,
|
||||
User: fileConfig.User,
|
||||
Password: fileConfig.Password,
|
||||
Database: fileConfig.Database,
|
||||
TablePrefix: fileConfig.TablePrefix,
|
||||
DSN: fileConfig.DSN,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDatabaseConfig 获取当前数据库配置(供外部调用)
|
||||
func GetDatabaseConfig() *DatabaseConfig {
|
||||
// 优先从配置文件读取
|
||||
if config, err := GetDatabaseConfigFromFile(); err == nil && config != nil {
|
||||
return config
|
||||
}
|
||||
// 回退到环境变量
|
||||
osInfo := utils.DetectOS()
|
||||
return getDatabaseConfig(osInfo)
|
||||
}
|
||||
|
||||
// VerifyMySQLPassword 验证 MySQL 密码
|
||||
func VerifyMySQLPassword(password string) error {
|
||||
config := GetDatabaseConfig()
|
||||
if config.Type != "mysql" {
|
||||
return fmt.Errorf("当前数据库类型不是 MySQL")
|
||||
}
|
||||
|
||||
// 尝试使用提供的密码连接数据库
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
config.User, password, config.Host, config.Port, config.Database)
|
||||
|
||||
testDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 关闭测试连接
|
||||
sqlDB, _ := testDB.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateMySQLPassword 更新 MySQL root 密码
|
||||
func UpdateMySQLPassword(newPassword string) error {
|
||||
config := GetDatabaseConfig()
|
||||
if config.Type != "mysql" {
|
||||
return fmt.Errorf("当前数据库类型不是 MySQL")
|
||||
}
|
||||
|
||||
// 使用当前密码连接
|
||||
currentDSN := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||||
|
||||
db, err := gorm.Open(mysql.Open(currentDSN), &gorm.Config{
|
||||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
sqlDB, _ := db.DB()
|
||||
defer sqlDB.Close()
|
||||
|
||||
// 执行 ALTER USER 语句更新密码
|
||||
// MySQL 8.0+ 使用 ALTER USER,旧版本使用 SET PASSWORD
|
||||
updateSQL := fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", config.User, "%", newPassword)
|
||||
|
||||
// 尝试使用 ALTER USER(MySQL 5.7.6+ 和 8.0+)
|
||||
_, err = sqlDB.Exec(updateSQL)
|
||||
if err != nil {
|
||||
// 如果失败,尝试使用 SET PASSWORD(兼容旧版本)
|
||||
setPasswordSQL := fmt.Sprintf("SET PASSWORD FOR '%s'@'%s' = PASSWORD('%s')", config.User, "%", newPassword)
|
||||
_, err2 := sqlDB.Exec(setPasswordSQL)
|
||||
if err2 != nil {
|
||||
return fmt.Errorf("更新密码失败: %w (ALTER USER 失败: %v)", err2, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新权限
|
||||
_, err = sqlDB.Exec("FLUSH PRIVILEGES")
|
||||
if err != nil {
|
||||
// 刷新权限失败不影响密码更新,只记录警告
|
||||
utils.NewLogger().Warn(fmt.Sprintf("刷新权限失败: %s", err.Error()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEnvOrDefault 获取环境变量或返回默认值
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetDB 获取数据库实例
|
||||
func GetDB() *gorm.DB {
|
||||
return DB
|
||||
}
|
||||
|
||||
// GetDBType 获取数据库类型
|
||||
func GetDBType() string {
|
||||
return dbType
|
||||
}
|
||||
|
||||
// ConvertDatabase 转换数据库(MySQL <-> SQLite)
|
||||
func ConvertDatabase(targetType string, logger *utils.Logger) error {
|
||||
if dbType == targetType {
|
||||
return fmt.Errorf("数据库类型已经是 %s", targetType)
|
||||
}
|
||||
|
||||
logger.System(fmt.Sprintf("🔄 开始数据库转换: %s -> %s", strings.ToUpper(dbType), strings.ToUpper(targetType)))
|
||||
|
||||
// 导出当前数据库数据
|
||||
users, routes, err := exportData()
|
||||
if err != nil {
|
||||
return fmt.Errorf("导出数据失败: %w", err)
|
||||
}
|
||||
logger.System(fmt.Sprintf("📤 已导出 %d 个用户, %d 个路由", len(users), len(routes)))
|
||||
|
||||
// 关闭当前数据库连接
|
||||
if DB != nil {
|
||||
sqlDB, _ := DB.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化新数据库
|
||||
osInfo := utils.DetectOS()
|
||||
config := getDatabaseConfig(osInfo)
|
||||
config.Type = targetType
|
||||
dbType = targetType
|
||||
|
||||
var newDB *gorm.DB
|
||||
if targetType == "mysql" {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||||
newDB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||||
})
|
||||
} else {
|
||||
dbPath := filepath.Join(config.DSN, "app.db")
|
||||
newDB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接新数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 迁移表结构
|
||||
if err := newDB.AutoMigrate(&models.User{}, &models.Route{}); err != nil {
|
||||
return fmt.Errorf("迁移表结构失败: %w", err)
|
||||
}
|
||||
|
||||
// 导入数据
|
||||
if err := importData(newDB, users, routes, logger); err != nil {
|
||||
return fmt.Errorf("导入数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新全局数据库实例
|
||||
DB = newDB
|
||||
logger.System(fmt.Sprintf("✅ 数据库转换完成: %s", strings.ToUpper(targetType)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// exportData 导出数据
|
||||
func exportData() ([]models.User, []models.Route, error) {
|
||||
var users []models.User
|
||||
var routes []models.Route
|
||||
|
||||
if err := DB.Find(&users).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := DB.Find(&routes).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return users, routes, nil
|
||||
}
|
||||
|
||||
// importData 导入数据
|
||||
func importData(db *gorm.DB, users []models.User, routes []models.Route, logger *utils.Logger) error {
|
||||
// 导入用户
|
||||
if len(users) > 0 {
|
||||
if err := db.Create(&users).Error; err != nil {
|
||||
logger.Warn(fmt.Sprintf("⚠️ 导入用户时出现错误: %s", err.Error()))
|
||||
// 尝试逐个导入
|
||||
for _, user := range users {
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
logger.Warn(fmt.Sprintf("⚠️ 跳过用户 %s: %s", user.Username, err.Error()))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.System(fmt.Sprintf("✅ 成功导入 %d 个用户", len(users)))
|
||||
}
|
||||
}
|
||||
|
||||
// 导入路由
|
||||
if len(routes) > 0 {
|
||||
if err := db.Create(&routes).Error; err != nil {
|
||||
logger.Warn(fmt.Sprintf("⚠️ 导入路由时出现错误: %s", err.Error()))
|
||||
// 尝试逐个导入
|
||||
for _, route := range routes {
|
||||
if err := db.Create(&route).Error; err != nil {
|
||||
logger.Warn(fmt.Sprintf("⚠️ 跳过路由 %s %s: %s", route.Method, route.Path, err.Error()))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.System(fmt.Sprintf("✅ 成功导入 %d 个路由", len(routes)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user