@@ -0,0 +1,294 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionCookie = "ymhut_unified_session"
|
||||
captchaTTL = 5 * time.Minute
|
||||
sessionTTL = 12 * time.Hour
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
store *db.Store
|
||||
mu sync.Mutex
|
||||
captchas map[string]captchaEntry
|
||||
sessions map[string]sessionEntry
|
||||
}
|
||||
|
||||
type captchaEntry struct {
|
||||
answer string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type sessionEntry struct {
|
||||
username string
|
||||
csrf string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type Captcha struct {
|
||||
ID string `json:"captchaId"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
func NewService(store *db.Store) *Service {
|
||||
return &Service{
|
||||
store: store,
|
||||
captchas: map[string]captchaEntry{},
|
||||
sessions: map[string]sessionEntry{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Bootstrap(ctx context.Context) (map[string]any, error) {
|
||||
isDefault, err := s.store.IsDefaultAdminPassword(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := map[string]any{
|
||||
"ok": true,
|
||||
"defaultUsername": "admin",
|
||||
"defaultPassword": "",
|
||||
"isDefaultPassword": isDefault,
|
||||
}
|
||||
if isDefault {
|
||||
payload["defaultPassword"] = "admin"
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewCaptcha() (Captcha, error) {
|
||||
answer := randomDigits(5)
|
||||
id := randomToken(16)
|
||||
imageBytes, err := renderCaptcha(answer)
|
||||
if err != nil {
|
||||
return Captcha{}, err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.cleanupLocked()
|
||||
s.captchas[id] = captchaEntry{answer: answer, expiresAt: time.Now().Add(captchaTTL)}
|
||||
s.mu.Unlock()
|
||||
return Captcha{
|
||||
ID: id,
|
||||
Image: "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, username, password, captchaID, captcha string) (string, string, bool, error) {
|
||||
if !s.consumeCaptcha(captchaID, captcha) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
user, ok, err := s.store.VerifyAdminPassword(ctx, username, password)
|
||||
if err != nil || !ok {
|
||||
return "", "", false, err
|
||||
}
|
||||
sessionID := randomToken(32)
|
||||
csrf := randomToken(32)
|
||||
s.mu.Lock()
|
||||
s.cleanupLocked()
|
||||
s.sessions[sessionID] = sessionEntry{username: user.Username, csrf: csrf, expiresAt: time.Now().Add(sessionTTL)}
|
||||
s.mu.Unlock()
|
||||
return sessionID, csrf, true, nil
|
||||
}
|
||||
|
||||
func (s *Service) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
if cookie, err := r.Cookie(SessionCookie); err == nil {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, cookie.Value)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
clearCookie(w)
|
||||
}
|
||||
|
||||
func (s *Service) UserForRequest(r *http.Request) (string, string, bool) {
|
||||
cookie, err := r.Cookie(SessionCookie)
|
||||
if err != nil || cookie.Value == "" {
|
||||
return "", "", false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cleanupLocked()
|
||||
session, ok := s.sessions[cookie.Value]
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.username, session.csrf, true
|
||||
}
|
||||
|
||||
func (s *Service) Require(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, csrf, ok := s.UserForRequest(r)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"ok": false, "error": "UNAUTHORIZED", "message": "Login required"})
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
|
||||
actual := r.Header.Get("X-CSRF-Token")
|
||||
if actual == "" || subtle.ConstantTimeCompare([]byte(csrf), []byte(actual)) != 1 {
|
||||
writeJSON(w, http.StatusForbidden, map[string]any{"ok": false, "error": "CSRF_INVALID", "message": "Invalid CSRF token"})
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func SetSessionCookie(w http.ResponseWriter, sessionID string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookie,
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: int(sessionTTL.Seconds()),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func clearCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{Name: SessionCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: true, SameSite: http.SameSiteLaxMode})
|
||||
}
|
||||
|
||||
func (s *Service) consumeCaptcha(id, answer string) bool {
|
||||
id = strings.TrimSpace(id)
|
||||
answer = strings.TrimSpace(answer)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cleanupLocked()
|
||||
entry, ok := s.captchas[id]
|
||||
if ok {
|
||||
delete(s.captchas, id)
|
||||
}
|
||||
if !ok || time.Now().After(entry.expiresAt) {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(strings.ToLower(entry.answer)), []byte(strings.ToLower(answer))) == 1
|
||||
}
|
||||
|
||||
func (s *Service) cleanupLocked() {
|
||||
now := time.Now()
|
||||
for id, entry := range s.captchas {
|
||||
if now.After(entry.expiresAt) {
|
||||
delete(s.captchas, id)
|
||||
}
|
||||
}
|
||||
for id, entry := range s.sessions {
|
||||
if now.After(entry.expiresAt) {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func randomDigits(count int) string {
|
||||
data := make([]byte, count)
|
||||
if _, err := rand.Read(data); err != nil {
|
||||
return "12345"
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, value := range data {
|
||||
builder.WriteByte('0' + value%10)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func randomToken(bytesLen int) string {
|
||||
data := make([]byte, bytesLen)
|
||||
if _, err := rand.Read(data); err != nil {
|
||||
return hex.EncodeToString([]byte(time.Now().Format(time.RFC3339Nano)))
|
||||
}
|
||||
return hex.EncodeToString(data)
|
||||
}
|
||||
|
||||
func renderCaptcha(answer string) ([]byte, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, 180, 64))
|
||||
draw.Draw(img, img.Bounds(), &image.Uniform{color.RGBA{246, 248, 244, 255}}, image.Point{}, draw.Src)
|
||||
for i := 0; i < 26; i++ {
|
||||
x := (i*37 + 13) % 180
|
||||
y := (i*19 + 7) % 64
|
||||
img.Set(x, y, color.RGBA{111, 119, 130, 255})
|
||||
}
|
||||
for index, digit := range answer {
|
||||
drawDigit(img, int(digit-'0'), 18+index*32, 13, color.RGBA{28, 61, 89, 255})
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
if err := png.Encode(&buffer, img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
var segments = [10][7]bool{
|
||||
{true, true, true, true, true, true, false},
|
||||
{false, true, true, false, false, false, false},
|
||||
{true, true, false, true, true, false, true},
|
||||
{true, true, true, true, false, false, true},
|
||||
{false, true, true, false, false, true, true},
|
||||
{true, false, true, true, false, true, true},
|
||||
{true, false, true, true, true, true, true},
|
||||
{true, true, true, false, false, false, false},
|
||||
{true, true, true, true, true, true, true},
|
||||
{true, true, true, true, false, true, true},
|
||||
}
|
||||
|
||||
func drawDigit(img *image.RGBA, digit, x, y int, col color.Color) {
|
||||
if digit < 0 || digit > 9 {
|
||||
return
|
||||
}
|
||||
thick := 4
|
||||
width := 22
|
||||
height := 36
|
||||
drawSegment := func(rect image.Rectangle) {
|
||||
draw.Draw(img, rect, &image.Uniform{col}, image.Point{}, draw.Src)
|
||||
}
|
||||
if segments[digit][0] {
|
||||
drawSegment(image.Rect(x+thick, y, x+width-thick, y+thick))
|
||||
}
|
||||
if segments[digit][1] {
|
||||
drawSegment(image.Rect(x+width-thick, y+thick, x+width, y+height/2))
|
||||
}
|
||||
if segments[digit][2] {
|
||||
drawSegment(image.Rect(x+width-thick, y+height/2, x+width, y+height-thick))
|
||||
}
|
||||
if segments[digit][3] {
|
||||
drawSegment(image.Rect(x+thick, y+height-thick, x+width-thick, y+height))
|
||||
}
|
||||
if segments[digit][4] {
|
||||
drawSegment(image.Rect(x, y+height/2, x+thick, y+height-thick))
|
||||
}
|
||||
if segments[digit][5] {
|
||||
drawSegment(image.Rect(x, y+thick, x+thick, y+height/2))
|
||||
}
|
||||
if segments[digit][6] {
|
||||
drawSegment(image.Rect(x+thick, y+height/2-thick/2, x+width-thick, y+height/2+thick/2))
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload map[string]any) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
_ = jsonNewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
type jsonEncoder interface {
|
||||
Encode(v any) error
|
||||
}
|
||||
|
||||
func jsonNewEncoder(w http.ResponseWriter) jsonEncoder {
|
||||
return json.NewEncoder(w)
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
func TestBootstrapShowsDefaultPasswordOnlyBeforeChange(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
StorageDir: root,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: root + "/test.sqlite",
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
},
|
||||
}
|
||||
store, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
if err := store.EnsureDefaultAdmin(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
service := NewService(store)
|
||||
payload, err := service.Bootstrap(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if payload["isDefaultPassword"] != true || payload["defaultPassword"] != "admin" {
|
||||
t.Fatalf("unexpected bootstrap payload: %#v", payload)
|
||||
}
|
||||
if err := store.ChangeAdminPassword(context.Background(), "admin", "admin", "changed"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payload, err = service.Bootstrap(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if payload["isDefaultPassword"] != false || payload["defaultPassword"] != "" {
|
||||
t.Fatalf("default password leaked after change: %#v", payload)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,360 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const DefaultListen = ":33550"
|
||||
|
||||
var Version = "0.1.0"
|
||||
|
||||
type Config struct {
|
||||
BaseDir string `json:"base_dir"`
|
||||
ConfigPath string `json:"-"`
|
||||
Initialized bool `json:"initialized"`
|
||||
Listen string `json:"listen"`
|
||||
BaseURL string `json:"base_url"`
|
||||
StorageDir string `json:"storage_dir"`
|
||||
DataDir string `json:"data_dir"`
|
||||
UpdatePublicDir string `json:"update_public_dir"`
|
||||
UpdateNoticeDir string `json:"update_notice_dir"`
|
||||
DownloadsDir string `json:"downloads_dir"`
|
||||
AdminWebDir string `json:"admin_web_dir"`
|
||||
PortalWebDir string `json:"portal_web_dir"`
|
||||
SetupWebDir string `json:"setup_web_dir"`
|
||||
LegacyUpdateDir string `json:"legacy_update_dir"`
|
||||
LegacyFeedbackDir string `json:"legacy_feedback_dir"`
|
||||
LegacyUpdateNoticeDir string `json:"legacy_update_notice_dir"`
|
||||
ClientSignatureKey string `json:"client_signature_key"`
|
||||
PackageEncryptionKey string `json:"package_encryption_key"`
|
||||
TimestampWindowSeconds int64 `json:"timestamp_window_seconds"`
|
||||
MaxRequestBytes int64 `json:"max_request_bytes"`
|
||||
MaxPackageBytes int64 `json:"max_package_bytes"`
|
||||
Database DatabaseConfig `json:"database"`
|
||||
UploadGuard UploadGuardConfig `json:"upload_guard"`
|
||||
SourceCheckSeconds int `json:"source_check_seconds"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Provider string `json:"provider"`
|
||||
SQLitePath string `json:"sqlite_path"`
|
||||
MySQLDSN string `json:"mysql_dsn"`
|
||||
FailoverEnabled bool `json:"failover_enabled"`
|
||||
HotSyncEnabled bool `json:"hot_sync_enabled"`
|
||||
HealthIntervalSec int `json:"health_interval_sec"`
|
||||
MaxOpenConns int `json:"max_open_conns"`
|
||||
MaxIdleConns int `json:"max_idle_conns"`
|
||||
ConnMaxLifetimeSeconds int `json:"conn_max_lifetime_seconds"`
|
||||
}
|
||||
|
||||
type UploadGuardConfig struct {
|
||||
MaxZipFiles int `json:"max_zip_files"`
|
||||
MaxDecompressedBytes int64 `json:"max_decompressed_bytes"`
|
||||
MaxSingleFileBytes int64 `json:"max_single_file_bytes"`
|
||||
MaxCompressionRatio float64 `json:"max_compression_ratio"`
|
||||
MaxReadableTextBytes int64 `json:"max_readable_text_bytes"`
|
||||
AllowUnexpectedZipFiles bool `json:"allow_unexpected_zip_files"`
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
root, err := ResolveBaseDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := defaults(root)
|
||||
path := firstNonEmpty(os.Getenv("YMHUT_UNIFIED_CONFIG"), filepath.Join(root, "config.json"))
|
||||
if data, err := os.ReadFile(path); err == nil {
|
||||
if err := json.Unmarshal(data, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.Initialized = true
|
||||
}
|
||||
cfg.BaseDir = root
|
||||
cfg.ConfigPath = path
|
||||
applyEnv(cfg)
|
||||
normalize(root, cfg)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func defaults(root string) *Config {
|
||||
return &Config{
|
||||
BaseDir: root,
|
||||
ConfigPath: filepath.Join(root, "config.json"),
|
||||
Initialized: false,
|
||||
Listen: DefaultListen,
|
||||
BaseURL: "https://update.ymhut.cn",
|
||||
StorageDir: filepath.Join(root, "storage"),
|
||||
DataDir: filepath.Join(root, "data"),
|
||||
UpdatePublicDir: filepath.Join(root, "data", "update", "public"),
|
||||
UpdateNoticeDir: filepath.Join(root, "data", "update-notice"),
|
||||
DownloadsDir: filepath.Join(root, "data", "update", "public", "downloads"),
|
||||
AdminWebDir: filepath.Join(root, "web", "admin", "dist"),
|
||||
PortalWebDir: filepath.Join(root, "web", "portal", "dist"),
|
||||
SetupWebDir: filepath.Join(root, "web", "setup", "dist"),
|
||||
LegacyUpdateDir: filepath.Clean(filepath.Join(root, "..", "update")),
|
||||
LegacyFeedbackDir: filepath.Clean(filepath.Join(root, "..", "feedback-mailer")),
|
||||
LegacyUpdateNoticeDir: filepath.Clean(filepath.Join(root, "..", "..", "update-notice")),
|
||||
ClientSignatureKey: "ymhut-box-feedback-client-v1",
|
||||
PackageEncryptionKey: "ymhut-box-feedback-package-v1",
|
||||
TimestampWindowSeconds: 600,
|
||||
MaxRequestBytes: 12 * 1024 * 1024,
|
||||
MaxPackageBytes: 10 * 1024 * 1024,
|
||||
SourceCheckSeconds: 300,
|
||||
Database: DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: filepath.Join(root, "storage", "unified.sqlite"),
|
||||
FailoverEnabled: true,
|
||||
HotSyncEnabled: true,
|
||||
HealthIntervalSec: 30,
|
||||
MaxOpenConns: 10,
|
||||
MaxIdleConns: 4,
|
||||
ConnMaxLifetimeSeconds: 300,
|
||||
},
|
||||
UploadGuard: UploadGuardConfig{
|
||||
MaxZipFiles: 80,
|
||||
MaxDecompressedBytes: 30 * 1024 * 1024,
|
||||
MaxSingleFileBytes: 8 * 1024 * 1024,
|
||||
MaxCompressionRatio: 120,
|
||||
MaxReadableTextBytes: 256 * 1024,
|
||||
AllowUnexpectedZipFiles: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyEnv(cfg *Config) {
|
||||
if value := os.Getenv("YMHUT_BASE_DIR"); value != "" {
|
||||
cfg.BaseDir = value
|
||||
}
|
||||
if value := os.Getenv("PORT"); value != "" {
|
||||
cfg.Listen = ":" + value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_LISTEN"); value != "" {
|
||||
cfg.Listen = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_BASE_URL"); value != "" {
|
||||
cfg.BaseURL = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_STORAGE_DIR"); value != "" {
|
||||
cfg.StorageDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_DATA_DIR"); value != "" {
|
||||
cfg.DataDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_UPDATE_PUBLIC_DIR"); value != "" {
|
||||
cfg.UpdatePublicDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_UPDATE_NOTICE_DIR"); value != "" {
|
||||
cfg.UpdateNoticeDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_DOWNLOADS_DIR"); value != "" {
|
||||
cfg.DownloadsDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_LEGACY_UPDATE_DIR"); value != "" {
|
||||
cfg.LegacyUpdateDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_LEGACY_FEEDBACK_DIR"); value != "" {
|
||||
cfg.LegacyFeedbackDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_LEGACY_UPDATE_NOTICE_DIR"); value != "" {
|
||||
cfg.LegacyUpdateNoticeDir = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_DB_PROVIDER"); value != "" {
|
||||
cfg.Database.Provider = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_SQLITE_PATH"); value != "" {
|
||||
cfg.Database.SQLitePath = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_MYSQL_DSN"); value != "" {
|
||||
cfg.Database.MySQLDSN = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_CLIENT_SIGNATURE_KEY"); value != "" {
|
||||
cfg.ClientSignatureKey = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_PACKAGE_ENCRYPTION_KEY"); value != "" {
|
||||
cfg.PackageEncryptionKey = value
|
||||
}
|
||||
if value := os.Getenv("YMHUT_TIMESTAMP_WINDOW_SECONDS"); value != "" {
|
||||
if parsed, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
cfg.TimestampWindowSeconds = parsed
|
||||
}
|
||||
}
|
||||
if value := os.Getenv("YMHUT_MAX_REQUEST_BYTES"); value != "" {
|
||||
if parsed, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
cfg.MaxRequestBytes = parsed
|
||||
}
|
||||
}
|
||||
if value := os.Getenv("YMHUT_MAX_PACKAGE_BYTES"); value != "" {
|
||||
if parsed, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
cfg.MaxPackageBytes = parsed
|
||||
}
|
||||
}
|
||||
if value := os.Getenv("YMHUT_SOURCE_CHECK_SECONDS"); value != "" {
|
||||
if parsed, err := strconv.Atoi(value); err == nil {
|
||||
cfg.SourceCheckSeconds = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalize(root string, cfg *Config) {
|
||||
cfg.BaseDir = absPath(root, firstNonEmpty(cfg.BaseDir, root))
|
||||
if cfg.ConfigPath == "" {
|
||||
cfg.ConfigPath = filepath.Join(cfg.BaseDir, "config.json")
|
||||
}
|
||||
if cfg.Listen == "" {
|
||||
cfg.Listen = DefaultListen
|
||||
}
|
||||
if cfg.StorageDir == "" {
|
||||
cfg.StorageDir = filepath.Join(cfg.BaseDir, "storage")
|
||||
}
|
||||
cfg.StorageDir = absPath(cfg.BaseDir, cfg.StorageDir)
|
||||
if cfg.DataDir == "" {
|
||||
cfg.DataDir = filepath.Join(cfg.BaseDir, "data")
|
||||
}
|
||||
cfg.DataDir = absPath(cfg.BaseDir, cfg.DataDir)
|
||||
if cfg.UpdatePublicDir == "" {
|
||||
cfg.UpdatePublicDir = filepath.Join(cfg.DataDir, "update", "public")
|
||||
}
|
||||
cfg.UpdatePublicDir = absPath(cfg.BaseDir, cfg.UpdatePublicDir)
|
||||
if cfg.UpdateNoticeDir == "" {
|
||||
cfg.UpdateNoticeDir = filepath.Join(cfg.DataDir, "update-notice")
|
||||
}
|
||||
cfg.UpdateNoticeDir = absPath(cfg.BaseDir, cfg.UpdateNoticeDir)
|
||||
if cfg.DownloadsDir == "" {
|
||||
cfg.DownloadsDir = filepath.Join(cfg.UpdatePublicDir, "downloads")
|
||||
}
|
||||
cfg.DownloadsDir = absPath(cfg.BaseDir, cfg.DownloadsDir)
|
||||
if cfg.AdminWebDir == "" {
|
||||
cfg.AdminWebDir = filepath.Join(cfg.BaseDir, "web", "admin", "dist")
|
||||
}
|
||||
cfg.AdminWebDir = absPath(cfg.BaseDir, cfg.AdminWebDir)
|
||||
if cfg.PortalWebDir == "" {
|
||||
cfg.PortalWebDir = filepath.Join(cfg.BaseDir, "web", "portal", "dist")
|
||||
}
|
||||
cfg.PortalWebDir = absPath(cfg.BaseDir, cfg.PortalWebDir)
|
||||
if cfg.SetupWebDir == "" {
|
||||
cfg.SetupWebDir = filepath.Join(cfg.BaseDir, "web", "setup", "dist")
|
||||
}
|
||||
cfg.SetupWebDir = absPath(cfg.BaseDir, cfg.SetupWebDir)
|
||||
if cfg.LegacyUpdateDir == "" {
|
||||
cfg.LegacyUpdateDir = filepath.Clean(filepath.Join(cfg.BaseDir, "..", "update"))
|
||||
}
|
||||
cfg.LegacyUpdateDir = absPath(cfg.BaseDir, cfg.LegacyUpdateDir)
|
||||
if cfg.LegacyFeedbackDir == "" {
|
||||
cfg.LegacyFeedbackDir = filepath.Clean(filepath.Join(cfg.BaseDir, "..", "feedback-mailer"))
|
||||
}
|
||||
cfg.LegacyFeedbackDir = absPath(cfg.BaseDir, cfg.LegacyFeedbackDir)
|
||||
if cfg.LegacyUpdateNoticeDir == "" {
|
||||
cfg.LegacyUpdateNoticeDir = filepath.Clean(filepath.Join(cfg.BaseDir, "..", "..", "update-notice"))
|
||||
}
|
||||
cfg.LegacyUpdateNoticeDir = absPath(cfg.BaseDir, cfg.LegacyUpdateNoticeDir)
|
||||
if cfg.Database.Provider == "" {
|
||||
cfg.Database.Provider = "sqlite"
|
||||
}
|
||||
if cfg.Database.SQLitePath == "" {
|
||||
cfg.Database.SQLitePath = filepath.Join(cfg.StorageDir, "unified.sqlite")
|
||||
}
|
||||
cfg.Database.SQLitePath = absPath(cfg.BaseDir, cfg.Database.SQLitePath)
|
||||
if cfg.Database.HealthIntervalSec <= 0 {
|
||||
cfg.Database.HealthIntervalSec = 30
|
||||
}
|
||||
if cfg.Database.MaxOpenConns <= 0 {
|
||||
cfg.Database.MaxOpenConns = 10
|
||||
}
|
||||
if cfg.Database.MaxIdleConns <= 0 {
|
||||
cfg.Database.MaxIdleConns = 4
|
||||
}
|
||||
if cfg.Database.ConnMaxLifetimeSeconds <= 0 {
|
||||
cfg.Database.ConnMaxLifetimeSeconds = 300
|
||||
}
|
||||
if cfg.ClientSignatureKey == "" {
|
||||
cfg.ClientSignatureKey = "ymhut-box-feedback-client-v1"
|
||||
}
|
||||
if cfg.PackageEncryptionKey == "" {
|
||||
cfg.PackageEncryptionKey = "ymhut-box-feedback-package-v1"
|
||||
}
|
||||
if cfg.TimestampWindowSeconds <= 0 {
|
||||
cfg.TimestampWindowSeconds = 600
|
||||
}
|
||||
if cfg.MaxRequestBytes <= 0 {
|
||||
cfg.MaxRequestBytes = 12 * 1024 * 1024
|
||||
}
|
||||
if cfg.MaxPackageBytes <= 0 {
|
||||
cfg.MaxPackageBytes = 10 * 1024 * 1024
|
||||
}
|
||||
if cfg.UploadGuard.MaxZipFiles <= 0 {
|
||||
cfg.UploadGuard.MaxZipFiles = 80
|
||||
}
|
||||
if cfg.UploadGuard.MaxDecompressedBytes <= 0 {
|
||||
cfg.UploadGuard.MaxDecompressedBytes = 30 * 1024 * 1024
|
||||
}
|
||||
if cfg.UploadGuard.MaxSingleFileBytes <= 0 {
|
||||
cfg.UploadGuard.MaxSingleFileBytes = 8 * 1024 * 1024
|
||||
}
|
||||
if cfg.UploadGuard.MaxCompressionRatio <= 0 {
|
||||
cfg.UploadGuard.MaxCompressionRatio = 120
|
||||
}
|
||||
if cfg.UploadGuard.MaxReadableTextBytes <= 0 {
|
||||
cfg.UploadGuard.MaxReadableTextBytes = 256 * 1024
|
||||
}
|
||||
if cfg.SourceCheckSeconds <= 0 {
|
||||
cfg.SourceCheckSeconds = 300
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveBaseDir() (string, error) {
|
||||
if value := os.Getenv("YMHUT_BASE_DIR"); value != "" {
|
||||
return filepath.Abs(value)
|
||||
}
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
if filepath.Base(cwd) == "unified-management" {
|
||||
return filepath.Abs(cwd)
|
||||
}
|
||||
candidate := filepath.Join(cwd, "server", "unified-management")
|
||||
if info, err := os.Stat(candidate); err == nil && info.IsDir() {
|
||||
return filepath.Abs(candidate)
|
||||
}
|
||||
}
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return os.Getwd()
|
||||
}
|
||||
return filepath.Abs(filepath.Dir(exe))
|
||||
}
|
||||
|
||||
func Save(cfg *Config) error {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
normalize(firstNonEmpty(cfg.BaseDir, "."), cfg)
|
||||
if err := os.MkdirAll(filepath.Dir(cfg.ConfigPath), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(cfg.ConfigPath, data, 0o600)
|
||||
}
|
||||
|
||||
func absPath(base, value string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return value
|
||||
}
|
||||
if filepath.IsAbs(value) || strings.HasPrefix(strings.ToLower(value), "file:") {
|
||||
return filepath.Clean(value)
|
||||
}
|
||||
return filepath.Clean(filepath.Join(base, value))
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
webassets "ymhut-box/server/unified-management/web"
|
||||
)
|
||||
|
||||
type Check struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func Preflight(cfg *Config) []Check {
|
||||
checks := []Check{
|
||||
checkDir("storage", cfg.StorageDir, true),
|
||||
checkParent("sqlite", cfg.Database.SQLitePath),
|
||||
checkDir("update public", cfg.UpdatePublicDir, false),
|
||||
checkDir("update notice", cfg.UpdateNoticeDir, false),
|
||||
checkDir("downloads", cfg.DownloadsDir, false),
|
||||
checkFile("legacy update-info", filepath.Join(cfg.UpdatePublicDir, "update-info.json"), false),
|
||||
checkFile("legacy media-types", filepath.Join(cfg.UpdatePublicDir, "media-types.json"), false),
|
||||
checkFile("version notice index", filepath.Join(cfg.UpdateNoticeDir, "total.json"), false),
|
||||
checkWebBuild("admin web dist", cfg.AdminWebDir, "admin/dist"),
|
||||
checkWebBuild("portal web dist", cfg.PortalWebDir, "portal/dist"),
|
||||
checkWebBuild("setup web dist", cfg.SetupWebDir, "setup/dist"),
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
func checkDir(name, path string, create bool) Check {
|
||||
if create {
|
||||
if err := os.MkdirAll(path, 0o750); err != nil {
|
||||
return Check{Name: name, Status: "error", Path: path, Message: err.Error()}
|
||||
}
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return Check{Name: name, Status: "missing", Path: path, Message: "directory not found"}
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return Check{Name: name, Status: "error", Path: path, Message: "path is not a directory"}
|
||||
}
|
||||
return Check{Name: name, Status: "ok", Path: path}
|
||||
}
|
||||
|
||||
func checkParent(name, path string) Check {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return Check{Name: name, Status: "error", Path: path, Message: err.Error()}
|
||||
}
|
||||
return Check{Name: name, Status: "ok", Path: path}
|
||||
}
|
||||
|
||||
func checkFile(name, path string, required bool) Check {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
status := "missing"
|
||||
if required {
|
||||
status = "error"
|
||||
}
|
||||
return Check{Name: name, Status: status, Path: path, Message: "file not found"}
|
||||
}
|
||||
if info.IsDir() {
|
||||
return Check{Name: name, Status: "error", Path: path, Message: "path is a directory"}
|
||||
}
|
||||
return Check{Name: name, Status: "ok", Path: path}
|
||||
}
|
||||
|
||||
func checkWebBuild(name, path, embedRoot string) Check {
|
||||
dir := checkDir(name, path, false)
|
||||
if dir.Status != "ok" {
|
||||
if embeddedWebBuildOK(embedRoot) {
|
||||
return Check{Name: name, Status: "ok", Path: path, Message: "using embedded frontend assets"}
|
||||
}
|
||||
dir.Message = "frontend dist missing; run npm install && npm run build, or build release with -tags embed_web"
|
||||
return dir
|
||||
}
|
||||
index := filepath.Join(path, "index.html")
|
||||
if file := checkFile(name+" index", index, true); file.Status != "ok" {
|
||||
if embeddedWebBuildOK(embedRoot) {
|
||||
return Check{Name: name, Status: "ok", Path: path, Message: "disk index missing; using embedded frontend assets"}
|
||||
}
|
||||
return Check{Name: name, Status: "missing", Path: index, Message: "index.html missing; run npm run build"}
|
||||
}
|
||||
assets := filepath.Join(path, "assets")
|
||||
if assetDir := checkDir(name+" assets", assets, false); assetDir.Status != "ok" {
|
||||
if embeddedWebBuildOK(embedRoot) {
|
||||
return Check{Name: name, Status: "ok", Path: path, Message: "disk assets missing; using embedded frontend assets"}
|
||||
}
|
||||
return Check{Name: name, Status: "missing", Path: assets, Message: "assets directory missing; run npm run build"}
|
||||
}
|
||||
return Check{Name: name, Status: "ok", Path: path}
|
||||
}
|
||||
|
||||
func embeddedWebBuildOK(embedRoot string) bool {
|
||||
if !webassets.Embedded {
|
||||
return false
|
||||
}
|
||||
if _, err := webassets.ReadFile(filepath.ToSlash(filepath.Join(embedRoot, "index.html"))); err != nil {
|
||||
return false
|
||||
}
|
||||
entries, err := webassets.ReadDir(filepath.ToSlash(filepath.Join(embedRoot, "assets")))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func FormatPreflight(checks []Check) []string {
|
||||
lines := make([]string, 0, len(checks))
|
||||
for _, check := range checks {
|
||||
line := fmt.Sprintf("[%s] %s", check.Status, check.Name)
|
||||
if check.Path != "" {
|
||||
line += " -> " + check.Path
|
||||
}
|
||||
if check.Message != "" {
|
||||
line += " (" + check.Message + ")"
|
||||
}
|
||||
lines = append(lines, line)
|
||||
}
|
||||
return lines
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
)
|
||||
|
||||
type dialect struct {
|
||||
name string
|
||||
driver string
|
||||
}
|
||||
|
||||
func dialectFor(provider string) dialect {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "mysql":
|
||||
return dialect{name: "mysql", driver: "mysql"}
|
||||
default:
|
||||
return dialect{name: "sqlite", driver: "sqlite"}
|
||||
}
|
||||
}
|
||||
|
||||
func (d dialect) rebind(query string) string {
|
||||
return query
|
||||
}
|
||||
|
||||
func (d dialect) idType() string {
|
||||
if d.name == "mysql" {
|
||||
return "BIGINT PRIMARY KEY AUTO_INCREMENT"
|
||||
}
|
||||
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||
}
|
||||
|
||||
func (d dialect) boolExpr(value bool) int {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d dialect) upsert(table string, columns, conflict []string) string {
|
||||
placeholders := make([]string, len(columns))
|
||||
for i := range placeholders {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||
conflictSet := map[string]bool{}
|
||||
for _, column := range conflict {
|
||||
conflictSet[column] = true
|
||||
}
|
||||
updates := []string{}
|
||||
for _, column := range columns {
|
||||
if conflictSet[column] {
|
||||
continue
|
||||
}
|
||||
if d.name == "mysql" {
|
||||
updates = append(updates, fmt.Sprintf("%s = VALUES(%s)", column, column))
|
||||
} else {
|
||||
updates = append(updates, fmt.Sprintf("%s = excluded.%s", column, column))
|
||||
}
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
if d.name == "mysql" {
|
||||
return strings.Replace(base, "INSERT INTO", "INSERT IGNORE INTO", 1)
|
||||
}
|
||||
return strings.Replace(base, "INSERT INTO", "INSERT OR IGNORE INTO", 1)
|
||||
}
|
||||
if d.name == "mysql" {
|
||||
return base + " ON DUPLICATE KEY UPDATE " + strings.Join(updates, ", ")
|
||||
}
|
||||
return base + " ON CONFLICT (" + strings.Join(conflict, ", ") + ") DO UPDATE SET " + strings.Join(updates, ", ")
|
||||
}
|
||||
|
||||
func (d dialect) limitOffset(limit, offset int) string {
|
||||
return fmt.Sprintf(" LIMIT %d OFFSET %d", limit, offset)
|
||||
}
|
||||
|
||||
func openSQLDatabase(cfg config.DatabaseConfig) (*sql.DB, dialect, error) {
|
||||
d := dialectFor(cfg.Provider)
|
||||
dsn := strings.TrimSpace(cfg.SQLitePath)
|
||||
if d.name == "mysql" {
|
||||
dsn = strings.TrimSpace(cfg.MySQLDSN)
|
||||
if dsn == "" {
|
||||
return nil, d, errors.New("mysql_dsn is required")
|
||||
}
|
||||
} else {
|
||||
if dsn == "" {
|
||||
return nil, d, errors.New("sqlite path is required")
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(dsn), "file:") {
|
||||
if err := os.MkdirAll(filepath.Dir(dsn), 0o750); err != nil {
|
||||
return nil, d, err
|
||||
}
|
||||
}
|
||||
}
|
||||
conn, err := sql.Open(d.driver, dsn)
|
||||
if err != nil {
|
||||
return nil, d, err
|
||||
}
|
||||
conn.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
conn.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
conn.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetimeSeconds) * time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, d, err
|
||||
}
|
||||
return conn, d, nil
|
||||
}
|
||||
|
||||
func TestDatabase(cfg config.DatabaseConfig) error {
|
||||
if cfg.Provider == "" {
|
||||
cfg.Provider = "sqlite"
|
||||
}
|
||||
if cfg.MaxOpenConns <= 0 {
|
||||
cfg.MaxOpenConns = 1
|
||||
}
|
||||
if cfg.MaxIdleConns <= 0 {
|
||||
cfg.MaxIdleConns = 1
|
||||
}
|
||||
if cfg.ConnMaxLifetimeSeconds <= 0 {
|
||||
cfg.ConnMaxLifetimeSeconds = 60
|
||||
}
|
||||
conn, d, err := openSQLDatabase(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
create := "CREATE TEMPORARY TABLE ymhut_unified_connection_test (id INTEGER)"
|
||||
if _, err := tx.ExecContext(ctx, d.rebind(create)); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
_ = tx.Rollback()
|
||||
return nil
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
items := make([]string, n)
|
||||
for i := range items {
|
||||
items[i] = "?"
|
||||
}
|
||||
return strings.Join(items, ", ")
|
||||
}
|
||||
|
||||
func atoiDefault(value string, fallback int) int {
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,65 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
)
|
||||
|
||||
func TestOpenImportsJSONPrototypeIntoSQLite(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
path := filepath.Join(root, "unified.sqlite")
|
||||
prototype := state{
|
||||
Admins: []adminRow{{
|
||||
ID: 1,
|
||||
Username: "admin",
|
||||
PasswordHash: passwordHash("admin"),
|
||||
PasswordChanged: false,
|
||||
CreatedAt: "2026-01-01T00:00:00Z",
|
||||
UpdatedAt: "2026-01-01T00:00:00Z",
|
||||
}},
|
||||
Feedbacks: []Feedback{{Code: "FB-20260101-ABCDEF", Title: "Imported", Type: "issue", Severity: "normal", Body: "hello"}},
|
||||
Sources: []Source{{CategoryID: "ip", CategoryName: "IP", SourceID: "ip-demo", Name: "IP Demo", APIURL: "https://example.com/ip", Enabled: true, ClientVisible: true}},
|
||||
}
|
||||
data, err := json.Marshal(prototype)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o640); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
store, err := Open(&config.Config{
|
||||
StorageDir: root,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: path,
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
MaxOpenConns: 1,
|
||||
MaxIdleConns: 1,
|
||||
ConnMaxLifetimeSeconds: 60,
|
||||
},
|
||||
UploadGuard: config.UploadGuardConfig{MaxZipFiles: 80, MaxDecompressedBytes: 30 << 20, MaxSingleFileBytes: 8 << 20, MaxCompressionRatio: 120, MaxReadableTextBytes: 256 << 10, AllowUnexpectedZipFiles: true},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
if _, _, err := store.VerifyAdminPassword(context.Background(), "admin", "admin"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := store.GetFeedback("FB-20260101-ABCDEF"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count, err := store.CountSources(); err != nil || count != 1 {
|
||||
t.Fatalf("CountSources = %d, %v", count, err)
|
||||
}
|
||||
matches, _ := filepath.Glob(path + ".json-prototype-*.bak")
|
||||
if len(matches) != 1 {
|
||||
t.Fatalf("expected prototype backup, got %v", matches)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,466 @@
|
||||
package feedback
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
const PackageMagic = "YMHUTFB1"
|
||||
|
||||
var feedbackCodePattern = regexp.MustCompile(`^FB-[0-9]{8}-[A-F0-9]{6}$`)
|
||||
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *db.Store
|
||||
}
|
||||
|
||||
type submissionPayload struct {
|
||||
FeedbackCode string `json:"feedbackCode"`
|
||||
Title string `json:"title"`
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"`
|
||||
Contact string `json:"contact"`
|
||||
BodyLength int `json:"bodyLength"`
|
||||
PackageEncrypted bool `json:"packageEncrypted"`
|
||||
Encryption string `json:"encryption"`
|
||||
PackageBytes int64 `json:"packageBytes"`
|
||||
PackageSha256 string `json:"packageSha256"`
|
||||
PlainPackageBytes int64 `json:"plainPackageBytes"`
|
||||
PlainPackageSha256 string `json:"plainPackageSha256"`
|
||||
CreatedAt json.RawMessage `json:"createdAt"`
|
||||
}
|
||||
|
||||
type packageInfo struct {
|
||||
Request map[string]any
|
||||
Summary string
|
||||
Files []string
|
||||
}
|
||||
|
||||
func NewService(cfg *config.Config, store *db.Store) *Service {
|
||||
return &Service{cfg: cfg, store: store}
|
||||
}
|
||||
|
||||
func (s *Service) Submit(r *http.Request) (db.Feedback, error) {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
if item, err := s.submitMultipart(r); err == nil {
|
||||
return item, nil
|
||||
} else if hasSignedFields(r) {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
}
|
||||
return s.submitSimple(r)
|
||||
}
|
||||
|
||||
func (s *Service) submitSimple(r *http.Request) (db.Feedback, error) {
|
||||
if strings.Contains(r.Header.Get("Content-Type"), "application/json") {
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
item := db.Feedback{
|
||||
Code: db.NewFeedbackCode(),
|
||||
Title: value(payload, "title", "客户端反馈"),
|
||||
Type: value(payload, "type", "issue"),
|
||||
Severity: value(payload, "severity", "normal"),
|
||||
Contact: value(payload, "contact", ""),
|
||||
Body: value(payload, "body", value(payload, "message", "")),
|
||||
Status: "new",
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
}
|
||||
if strings.TrimSpace(item.Body) == "" {
|
||||
item.Body = "No feedback body provided."
|
||||
}
|
||||
return item, s.store.InsertFeedback(item)
|
||||
}
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
_ = r.ParseForm()
|
||||
}
|
||||
item := db.Feedback{
|
||||
Code: db.NewFeedbackCode(),
|
||||
Title: firstNonEmpty(r.FormValue("title"), r.FormValue("subject"), "客户端反馈"),
|
||||
Type: firstNonEmpty(r.FormValue("type"), r.FormValue("category"), "issue"),
|
||||
Severity: firstNonEmpty(r.FormValue("severity"), r.FormValue("priority"), "normal"),
|
||||
Contact: firstNonEmpty(r.FormValue("contact"), r.FormValue("email")),
|
||||
Body: firstNonEmpty(r.FormValue("body"), r.FormValue("message"), r.FormValue("description")),
|
||||
Status: "new",
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
}
|
||||
if strings.TrimSpace(item.Body) == "" {
|
||||
item.Body = "No feedback body provided."
|
||||
}
|
||||
return item, s.store.InsertFeedback(item)
|
||||
}
|
||||
|
||||
func (s *Service) submitMultipart(r *http.Request) (db.Feedback, error) {
|
||||
if s.cfg.MaxRequestBytes > 0 {
|
||||
if r.ContentLength > s.cfg.MaxRequestBytes {
|
||||
return db.Feedback{}, errors.New("request is too large")
|
||||
}
|
||||
r.Body = http.MaxBytesReader(nilResponseWriter{}, r.Body, s.cfg.MaxRequestBytes)
|
||||
}
|
||||
if err := r.ParseMultipartForm(s.cfg.MaxRequestBytes); err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
payloadText := strings.TrimSpace(r.FormValue("payload"))
|
||||
timestamp := strings.TrimSpace(r.FormValue("timestamp"))
|
||||
nonce := strings.TrimSpace(r.FormValue("nonce"))
|
||||
packageSha256 := strings.ToLower(strings.TrimSpace(r.FormValue("packageSha256")))
|
||||
signature := strings.ToLower(strings.TrimSpace(r.FormValue("signature")))
|
||||
if payloadText == "" || timestamp == "" || nonce == "" || packageSha256 == "" || signature == "" {
|
||||
return db.Feedback{}, errors.New("signed multipart fields are required")
|
||||
}
|
||||
if !validTimestamp(timestamp, s.cfg.TimestampWindowSeconds) {
|
||||
return db.Feedback{}, errors.New("timestamp outside accepted window")
|
||||
}
|
||||
if !isHexSHA256(packageSha256) {
|
||||
return db.Feedback{}, errors.New("invalid package hash")
|
||||
}
|
||||
expected := SignWithKey(s.cfg.ClientSignatureKey, timestamp, nonce, packageSha256, payloadText)
|
||||
if !hmac.Equal([]byte(expected), []byte(signature)) {
|
||||
return db.Feedback{}, errors.New("invalid request signature")
|
||||
}
|
||||
var payload submissionPayload
|
||||
if err := json.Unmarshal([]byte(payloadText), &payload); err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
if err := validatePayload(payload); err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
code := NormalizeCode(payload.FeedbackCode)
|
||||
if code == "" {
|
||||
code = db.NewFeedbackCode()
|
||||
}
|
||||
if existing, err := s.store.GetFeedback(code); err == nil {
|
||||
return existing, nil
|
||||
}
|
||||
file, _, err := r.FormFile("package")
|
||||
if err != nil {
|
||||
return db.Feedback{}, errors.New("missing package file")
|
||||
}
|
||||
defer file.Close()
|
||||
data, err := readUploadedPackage(file, s.cfg.MaxPackageBytes)
|
||||
if err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
if !bytes.HasPrefix(data, []byte(PackageMagic)) {
|
||||
return db.Feedback{}, errors.New("encrypted package format is invalid")
|
||||
}
|
||||
if !hmac.Equal([]byte(sha256Hex(data)), []byte(packageSha256)) {
|
||||
return db.Feedback{}, errors.New("package hash mismatch")
|
||||
}
|
||||
plain, err := DecryptPackage(data, s.cfg.PackageEncryptionKey)
|
||||
if err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
if !isZipBytes(plain) {
|
||||
return db.Feedback{}, errors.New("decrypted package is not a zip")
|
||||
}
|
||||
if payload.PlainPackageSha256 != "" && isHexSHA256(payload.PlainPackageSha256) {
|
||||
if !hmac.Equal([]byte(sha256Hex(plain)), []byte(strings.ToLower(payload.PlainPackageSha256))) {
|
||||
return db.Feedback{}, errors.New("decrypted package hash mismatch")
|
||||
}
|
||||
}
|
||||
info, err := ReadFeedbackPackageWithGuard(plain, s.cfg.UploadGuard)
|
||||
if err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
dir := filepath.Join(s.cfg.StorageDir, "feedback")
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
encryptedPath := filepath.Join(dir, code+".ymfb")
|
||||
packagePath := filepath.Join(dir, code+".zip")
|
||||
if err := os.WriteFile(encryptedPath, data, 0o640); err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
if err := os.WriteFile(packagePath, plain, 0o640); err != nil {
|
||||
return db.Feedback{}, err
|
||||
}
|
||||
item := buildRecord(code, payload, info, encryptedPath, packagePath, packageSha256, strings.ToLower(payload.PlainPackageSha256), r.RemoteAddr)
|
||||
return item, s.store.InsertFeedback(item)
|
||||
}
|
||||
|
||||
func hasSignedFields(r *http.Request) bool {
|
||||
if r.MultipartForm == nil {
|
||||
return false
|
||||
}
|
||||
for _, key := range []string{"payload", "timestamp", "nonce", "packageSha256", "signature"} {
|
||||
if strings.TrimSpace(r.FormValue(key)) == "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func NormalizeCode(code string) string {
|
||||
code = strings.ToUpper(strings.TrimSpace(code))
|
||||
if feedbackCodePattern.MatchString(code) {
|
||||
return code
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SignWithKey(key, timestamp, nonce, packageSha256, payload string) string {
|
||||
material := timestamp + "\n" + nonce + "\n" + packageSha256 + "\n" + payload
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
_, _ = mac.Write([]byte(material))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
func validTimestamp(value string, windowSeconds int64) bool {
|
||||
if !regexp.MustCompile(`^[0-9]{10,}$`).MatchString(value) {
|
||||
return false
|
||||
}
|
||||
seconds, err := time.ParseDuration(value + "s")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
delta := time.Now().Unix() - int64(seconds.Seconds())
|
||||
if delta < 0 {
|
||||
delta = -delta
|
||||
}
|
||||
return delta <= windowSeconds
|
||||
}
|
||||
|
||||
func validatePayload(payload submissionPayload) error {
|
||||
if payload.Title == "" || payload.Type == "" || payload.Severity == "" {
|
||||
return errors.New("payload title, type and severity are required")
|
||||
}
|
||||
if payload.PackageBytes <= 0 || payload.PackageSha256 == "" || payload.PlainPackageSha256 == "" {
|
||||
return errors.New("payload package hashes are required")
|
||||
}
|
||||
if !payload.PackageEncrypted || payload.Encryption != PackageMagic {
|
||||
return errors.New("encrypted package is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readUploadedPackage(file multipart.File, maxBytes int64) ([]byte, error) {
|
||||
limit := maxBytes + 1
|
||||
if limit <= 1 {
|
||||
limit = 10*1024*1024 + 1
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(file, limit))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if maxBytes > 0 && int64(len(data)) > maxBytes {
|
||||
return nil, errors.New("feedback package is too large")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func DecryptPackage(data []byte, keyMaterial string) ([]byte, error) {
|
||||
if len(data) < len(PackageMagic)+12+16 || !bytes.HasPrefix(data, []byte(PackageMagic)) {
|
||||
return nil, errors.New("encrypted package format is invalid")
|
||||
}
|
||||
if keyMaterial == "" {
|
||||
keyMaterial = "ymhut-box-feedback-package-v1"
|
||||
}
|
||||
key := sha256.Sum256([]byte(keyMaterial))
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset := len(PackageMagic)
|
||||
nonce := data[offset : offset+12]
|
||||
offset += 12
|
||||
tag := data[offset : offset+16]
|
||||
offset += 16
|
||||
ciphertext := data[offset:]
|
||||
combined := append(append([]byte{}, ciphertext...), tag...)
|
||||
return gcm.Open(nil, nonce, combined, []byte(PackageMagic))
|
||||
}
|
||||
|
||||
func ReadFeedbackPackageWithGuard(plain []byte, guard config.UploadGuardConfig) (packageInfo, error) {
|
||||
reader, err := zip.NewReader(bytes.NewReader(plain), int64(len(plain)))
|
||||
if err != nil {
|
||||
return packageInfo{}, err
|
||||
}
|
||||
files := []string{}
|
||||
texts := map[string]string{}
|
||||
var total uint64
|
||||
for _, entry := range reader.File {
|
||||
if entry.FileInfo().IsDir() {
|
||||
continue
|
||||
}
|
||||
cleanName, err := safeZipName(entry.Name)
|
||||
if err != nil {
|
||||
return packageInfo{}, err
|
||||
}
|
||||
if len(files)+1 > guard.MaxZipFiles {
|
||||
return packageInfo{}, errors.New("zip contains too many files")
|
||||
}
|
||||
if entry.UncompressedSize64 > uint64(guard.MaxSingleFileBytes) {
|
||||
return packageInfo{}, errors.New("zip entry is too large")
|
||||
}
|
||||
total += entry.UncompressedSize64
|
||||
if total > uint64(guard.MaxDecompressedBytes) {
|
||||
return packageInfo{}, errors.New("zip decompressed size is too large")
|
||||
}
|
||||
if entry.CompressedSize64 == 0 && entry.UncompressedSize64 > 0 {
|
||||
return packageInfo{}, errors.New("zip entry has invalid compression metadata")
|
||||
}
|
||||
if entry.CompressedSize64 > 0 && float64(entry.UncompressedSize64)/float64(entry.CompressedSize64) > guard.MaxCompressionRatio {
|
||||
return packageInfo{}, errors.New("zip compression ratio is suspicious")
|
||||
}
|
||||
files = append(files, cleanName)
|
||||
if cleanName != "feedback.json" && cleanName != "summary.txt" {
|
||||
if !guard.AllowUnexpectedZipFiles && !strings.HasPrefix(cleanName, "attachments/") {
|
||||
return packageInfo{}, errors.New("zip contains unexpected file")
|
||||
}
|
||||
continue
|
||||
}
|
||||
text, err := readZipText(entry, guard.MaxReadableTextBytes)
|
||||
if err != nil {
|
||||
return packageInfo{}, err
|
||||
}
|
||||
texts[cleanName] = text
|
||||
}
|
||||
request := map[string]any{}
|
||||
if raw := texts["feedback.json"]; raw != "" {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err == nil {
|
||||
if nested, ok := parsed["request"].(map[string]any); ok {
|
||||
request = nested
|
||||
} else {
|
||||
request = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(request) == 0 && texts["feedback.json"] == "" {
|
||||
return packageInfo{}, errors.New("feedback.json is missing")
|
||||
}
|
||||
return packageInfo{Request: request, Summary: texts["summary.txt"], Files: files}, nil
|
||||
}
|
||||
|
||||
func safeZipName(name string) (string, error) {
|
||||
name = strings.ReplaceAll(name, "\\", "/")
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" || strings.Contains(name, "\x00") || strings.HasPrefix(name, "/") {
|
||||
return "", errors.New("unsafe zip entry name")
|
||||
}
|
||||
clean := path.Clean(name)
|
||||
if clean == "." || clean == ".." || strings.HasPrefix(clean, "../") {
|
||||
return "", errors.New("unsafe zip entry path")
|
||||
}
|
||||
return clean, nil
|
||||
}
|
||||
|
||||
func readZipText(entry *zip.File, maxBytes int64) (string, error) {
|
||||
if int64(entry.UncompressedSize64) > maxBytes {
|
||||
return "", nil
|
||||
}
|
||||
reader, err := entry.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer reader.Close()
|
||||
data, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return "", nil
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func buildRecord(code string, payload submissionPayload, info packageInfo, encryptedPath, packagePath, packageSha256, plainPackageSha256, remoteAddr string) db.Feedback {
|
||||
title := firstNonEmpty(textFromMap(info.Request, "title"), payload.Title, "未命名反馈")
|
||||
typ := firstNonEmpty(textFromMap(info.Request, "type"), payload.Type, "issue")
|
||||
severity := firstNonEmpty(textFromMap(info.Request, "severity"), payload.Severity, "normal")
|
||||
contact := firstNonEmpty(textFromMap(info.Request, "contact"), payload.Contact)
|
||||
body := firstNonEmpty(textFromMap(info.Request, "body"), info.Summary)
|
||||
return db.Feedback{
|
||||
Code: code,
|
||||
Title: title,
|
||||
Type: typ,
|
||||
Severity: severity,
|
||||
Contact: contact,
|
||||
Body: body,
|
||||
Status: "new",
|
||||
StatusDetail: "反馈已接收,等待后台处理。",
|
||||
SourceChannel: "winui",
|
||||
PackagePath: packagePath,
|
||||
EncryptedPackagePath: encryptedPath,
|
||||
PackageSha256: packageSha256,
|
||||
PlainPackageSha256: plainPackageSha256,
|
||||
SummaryText: info.Summary,
|
||||
IncludedFiles: strings.Join(info.Files, ", "),
|
||||
RemoteAddr: remoteAddr,
|
||||
}
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func value(payload map[string]any, key, fallback string) string {
|
||||
if raw, ok := payload[key].(string); ok && strings.TrimSpace(raw) != "" {
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func textFromMap(values map[string]any, key string) string {
|
||||
if value, ok := values[key].(string); ok {
|
||||
return value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isHexSHA256(value string) bool {
|
||||
value = strings.ToLower(strings.TrimSpace(value))
|
||||
if len(value) != 64 {
|
||||
return false
|
||||
}
|
||||
_, err := hex.DecodeString(value)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func sha256Hex(data []byte) string {
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func isZipBytes(data []byte) bool {
|
||||
return bytes.HasPrefix(data, []byte("PK\x03\x04")) ||
|
||||
bytes.HasPrefix(data, []byte("PK\x05\x06")) ||
|
||||
bytes.HasPrefix(data, []byte("PK\x07\x08"))
|
||||
}
|
||||
|
||||
type nilResponseWriter struct{}
|
||||
|
||||
func (nilResponseWriter) Header() http.Header { return http.Header{} }
|
||||
func (nilResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (nilResponseWriter) WriteHeader(int) {}
|
||||
@@ -0,0 +1,158 @@
|
||||
package feedback
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
func TestSignedMultipartSubmission(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
store, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
service := NewService(cfg, store)
|
||||
plain := zipBytes(t, map[string]string{
|
||||
"feedback.json": `{"request":{"title":"Crash on launch","type":"issue","severity":"major","contact":"user@example.com","body":"It crashes."}}`,
|
||||
"summary.txt": "launch failure",
|
||||
})
|
||||
encrypted := encryptPackageForTest(t, plain, cfg.PackageEncryptionKey)
|
||||
encryptedHash := sha256HexTest(encrypted)
|
||||
plainHash := sha256HexTest(plain)
|
||||
timestamp := itoa(int(time.Now().Unix()))
|
||||
payload := `{"feedbackCode":"FB-20260625-ABCDEF","title":"Crash on launch","type":"issue","severity":"major","contact":"user@example.com","bodyLength":11,"packageEncrypted":true,"encryption":"YMHUTFB1","packageBytes":` + itoa(len(encrypted)) + `,"packageSha256":"` + encryptedHash + `","plainPackageBytes":` + itoa(len(plain)) + `,"plainPackageSha256":"` + plainHash + `","createdAt":"2026-06-25T00:00:00Z"}`
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
_ = writer.WriteField("payload", payload)
|
||||
_ = writer.WriteField("timestamp", timestamp)
|
||||
_ = writer.WriteField("nonce", "abc123")
|
||||
_ = writer.WriteField("packageSha256", encryptedHash)
|
||||
_ = writer.WriteField("signature", SignWithKey(cfg.ClientSignatureKey, timestamp, "abc123", encryptedHash, payload))
|
||||
part, err := writer.CreateFormFile("package", "feedback.ymfb")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = io.Copy(part, bytes.NewReader(encrypted))
|
||||
_ = writer.Close()
|
||||
req := httptest.NewRequest("POST", "/", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
item, err := service.Submit(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if item.Code != "FB-20260625-ABCDEF" || !strings.Contains(item.IncludedFiles, "feedback.json") {
|
||||
t.Fatalf("unexpected item: %#v", item)
|
||||
}
|
||||
if _, err := store.GetFeedback(item.Code); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZipPathEscapeRejected(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
cfg := testConfig(root)
|
||||
store, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
plain := zipBytes(t, map[string]string{"../escape.txt": "bad", "feedback.json": `{}`})
|
||||
if _, err := ReadFeedbackPackageWithGuard(plain, cfg.UploadGuard); err == nil {
|
||||
t.Fatal("expected unsafe zip entry error")
|
||||
}
|
||||
}
|
||||
|
||||
func testConfig(root string) *config.Config {
|
||||
return &config.Config{
|
||||
StorageDir: filepath.Join(root, "storage"),
|
||||
ClientSignatureKey: "ymhut-box-feedback-client-v1",
|
||||
PackageEncryptionKey: "ymhut-box-feedback-package-v1",
|
||||
TimestampWindowSeconds: 600,
|
||||
MaxRequestBytes: 12 << 20,
|
||||
MaxPackageBytes: 10 << 20,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: filepath.Join(root, "storage", "unified.sqlite"),
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
MaxOpenConns: 1,
|
||||
MaxIdleConns: 1,
|
||||
ConnMaxLifetimeSeconds: 60,
|
||||
},
|
||||
UploadGuard: config.UploadGuardConfig{MaxZipFiles: 80, MaxDecompressedBytes: 30 << 20, MaxSingleFileBytes: 8 << 20, MaxCompressionRatio: 120, MaxReadableTextBytes: 256 << 10, AllowUnexpectedZipFiles: true},
|
||||
}
|
||||
}
|
||||
|
||||
func zipBytes(t *testing.T, files map[string]string) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
writer := zip.NewWriter(&buf)
|
||||
for name, body := range files {
|
||||
entry, err := writer.Create(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = entry.Write([]byte(body))
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func encryptPackageForTest(t *testing.T, plain []byte, keyMaterial string) []byte {
|
||||
t.Helper()
|
||||
key := sha256.Sum256([]byte(keyMaterial))
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nonce := []byte("123456789012")
|
||||
sealed := gcm.Seal(nil, nonce, plain, []byte(PackageMagic))
|
||||
ciphertext := sealed[:len(sealed)-gcm.Overhead()]
|
||||
tag := sealed[len(sealed)-gcm.Overhead():]
|
||||
out := []byte(PackageMagic)
|
||||
out = append(out, nonce...)
|
||||
out = append(out, tag...)
|
||||
out = append(out, ciphertext...)
|
||||
return out
|
||||
}
|
||||
|
||||
func sha256HexTest(data []byte) string {
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func itoa(value int) string {
|
||||
if value == 0 {
|
||||
return "0"
|
||||
}
|
||||
var buf [20]byte
|
||||
i := len(buf)
|
||||
for value > 0 {
|
||||
i--
|
||||
buf[i] = byte('0' + value%10)
|
||||
value /= 10
|
||||
}
|
||||
return string(buf[i:])
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
func Snapshot(cfg *config.Config, store *db.Store) map[string]any {
|
||||
return map[string]any{
|
||||
"ok": true,
|
||||
"version": config.Version,
|
||||
"service": map[string]any{
|
||||
"name": "YMhut Unified Management",
|
||||
"baseUrl": cfg.BaseURL,
|
||||
},
|
||||
"database": store.Status(),
|
||||
"preflight": config.Preflight(cfg),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
package legacy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *db.Store
|
||||
}
|
||||
|
||||
type Document struct {
|
||||
Name string `json:"name"`
|
||||
Raw string `json:"raw"`
|
||||
Parsed map[string]any `json:"parsed"`
|
||||
Path string `json:"path"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
Revisions []db.LegacyJsonRevision `json:"revisions"`
|
||||
}
|
||||
|
||||
type SaveRequest struct {
|
||||
Raw string `json:"raw"`
|
||||
Parsed map[string]any `json:"parsed"`
|
||||
Note string `json:"note"`
|
||||
}
|
||||
|
||||
func NewService(cfg *config.Config, store *db.Store) *Service {
|
||||
return &Service{cfg: cfg, store: store}
|
||||
}
|
||||
|
||||
func (s *Service) Get(ctx context.Context, name string) (Document, error) {
|
||||
fileName, err := fileNameFor(name)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
path := filepath.Join(s.cfg.UpdatePublicDir, fileName)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
parsed, formatted, err := parseAndFormat(name, data)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
revisions, _ := s.store.ListLegacyRevisions(name, 20)
|
||||
updatedAt := ""
|
||||
if info, err := os.Stat(path); err == nil {
|
||||
updatedAt = info.ModTime().UTC().Format(time.RFC3339)
|
||||
}
|
||||
return Document{Name: name, Raw: formatted, Parsed: parsed, Path: path, UpdatedAt: updatedAt, Revisions: revisions}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Validate(ctx context.Context, name string, req SaveRequest) (Document, error) {
|
||||
raw, err := requestRaw(req)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
parsed, formatted, err := parseAndFormat(name, []byte(raw))
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
return Document{Name: name, Raw: formatted, Parsed: parsed}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Save(ctx context.Context, name string, req SaveRequest, actor string) (Document, error) {
|
||||
fileName, err := fileNameFor(name)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
raw, err := requestRaw(req)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
parsed, formatted, err := parseAndFormat(name, []byte(raw))
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
path := filepath.Join(s.cfg.UpdatePublicDir, fileName)
|
||||
current, _ := os.ReadFile(path)
|
||||
if len(current) > 0 {
|
||||
_, _ = s.store.SaveLegacyRevision(name, string(current), "auto backup before save", actor)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
if err := atomicWrite(path, []byte(formatted)); err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
_, _ = s.store.SaveLegacyRevision(name, formatted, req.Note, actor)
|
||||
_ = s.store.InsertAudit(db.AuditLog{Actor: actor, Type: "legacy_json.saved", Target: name, Message: "Legacy JSON saved"})
|
||||
revisions, _ := s.store.ListLegacyRevisions(name, 20)
|
||||
return Document{Name: name, Raw: formatted, Parsed: parsed, Path: path, UpdatedAt: db.Now(), Revisions: revisions}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Restore(ctx context.Context, name string, revisionID int64, actor string) (Document, error) {
|
||||
fileName, err := fileNameFor(name)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
revision, err := s.store.GetLegacyRevision(name, revisionID)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
parsed, formatted, err := parseAndFormat(name, []byte(revision.Raw))
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
path := filepath.Join(s.cfg.UpdatePublicDir, fileName)
|
||||
current, _ := os.ReadFile(path)
|
||||
if len(current) > 0 {
|
||||
_, _ = s.store.SaveLegacyRevision(name, string(current), "auto backup before restore", actor)
|
||||
}
|
||||
if err := atomicWrite(path, []byte(formatted)); err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
_, _ = s.store.SaveLegacyRevision(name, formatted, "restored revision", actor)
|
||||
_ = s.store.InsertAudit(db.AuditLog{Actor: actor, Type: "legacy_json.restored", Target: name, Message: "Legacy JSON restored"})
|
||||
revisions, _ := s.store.ListLegacyRevisions(name, 20)
|
||||
return Document{Name: name, Raw: formatted, Parsed: parsed, Path: path, UpdatedAt: db.Now(), Revisions: revisions}, nil
|
||||
}
|
||||
|
||||
func fileNameFor(name string) (string, error) {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "update-info":
|
||||
return "update-info.json", nil
|
||||
case "media-types":
|
||||
return "media-types.json", nil
|
||||
default:
|
||||
return "", errors.New("unsupported legacy document")
|
||||
}
|
||||
}
|
||||
|
||||
func requestRaw(req SaveRequest) (string, error) {
|
||||
if strings.TrimSpace(req.Raw) != "" {
|
||||
return req.Raw, nil
|
||||
}
|
||||
if req.Parsed == nil {
|
||||
return "", errors.New("raw or parsed JSON is required")
|
||||
}
|
||||
data, err := json.Marshal(req.Parsed)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func parseAndFormat(name string, data []byte) (map[string]any, string, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(data))
|
||||
decoder.UseNumber()
|
||||
var parsed map[string]any
|
||||
if err := decoder.Decode(&parsed); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := validate(name, parsed); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
out, err := json.MarshalIndent(parsed, "", " ")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return parsed, string(out) + "\n", nil
|
||||
}
|
||||
|
||||
func validate(name string, parsed map[string]any) error {
|
||||
switch name {
|
||||
case "update-info":
|
||||
if _, ok := parsed["app_version"]; !ok {
|
||||
if _, ok := parsed["title"]; !ok {
|
||||
return errors.New("update-info requires app_version or title")
|
||||
}
|
||||
}
|
||||
case "media-types":
|
||||
if _, ok := parsed["categories"].([]any); !ok {
|
||||
return errors.New("media-types requires categories array")
|
||||
}
|
||||
if _, ok := parsed["layout_version"]; !ok {
|
||||
parsed["layout_version"] = "1.0.0"
|
||||
}
|
||||
if _, ok := parsed["last_updated"]; !ok {
|
||||
parsed["last_updated"] = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func atomicWrite(path string, data []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpName := tmp.Name()
|
||||
defer os.Remove(tmpName)
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Chmod(tmpName, 0o640); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmpName, path)
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package legacy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
func TestSaveValidateAndRestoreLegacyJSON(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
public := filepath.Join(root, "public")
|
||||
if err := os.MkdirAll(public, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
path := filepath.Join(public, "media-types.json")
|
||||
if err := os.WriteFile(path, []byte(`{"layout_version":"1","categories":[]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cfg := &config.Config{
|
||||
StorageDir: filepath.Join(root, "storage"),
|
||||
UpdatePublicDir: public,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: filepath.Join(root, "storage", "unified.sqlite"),
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
MaxOpenConns: 1,
|
||||
MaxIdleConns: 1,
|
||||
ConnMaxLifetimeSeconds: 60,
|
||||
},
|
||||
}
|
||||
store, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
service := NewService(cfg, store)
|
||||
if _, err := service.Validate(context.Background(), "media-types", SaveRequest{Raw: `{"not_categories":[]}`}); err == nil {
|
||||
t.Fatal("expected validation failure")
|
||||
}
|
||||
saved, err := service.Save(context.Background(), "media-types", SaveRequest{Raw: `{"categories":[{"id":"image","name":"Image","subcategories":[]}]}`, Note: "test"}, "admin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if saved.Parsed["layout_version"] == nil {
|
||||
t.Fatal("layout_version was not filled")
|
||||
}
|
||||
revisions, err := store.ListLegacyRevisions("media-types", 10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(revisions) < 2 {
|
||||
t.Fatalf("expected auto backup and saved revision, got %d", len(revisions))
|
||||
}
|
||||
restored, err := service.Restore(context.Background(), "media-types", revisions[len(revisions)-1].ID, "admin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if restored.Parsed["categories"] == nil {
|
||||
t.Fatal("restored document missing categories")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,497 @@
|
||||
package notices
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *db.Store
|
||||
}
|
||||
|
||||
type Document struct {
|
||||
Notice db.ReleaseNotice `json:"notice"`
|
||||
Raw string `json:"raw"`
|
||||
Parsed map[string]any `json:"parsed"`
|
||||
Path string `json:"path"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
Revisions []db.ReleaseNoticeRevision `json:"revisions"`
|
||||
}
|
||||
|
||||
type SaveRequest struct {
|
||||
Raw string `json:"raw"`
|
||||
Parsed map[string]any `json:"parsed"`
|
||||
Note string `json:"note"`
|
||||
}
|
||||
|
||||
func NewService(cfg *config.Config, store *db.Store) *Service {
|
||||
return &Service{cfg: cfg, store: store}
|
||||
}
|
||||
|
||||
func (s *Service) Import(ctx context.Context) error {
|
||||
if strings.TrimSpace(s.cfg.UpdateNoticeDir) == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := os.Stat(s.cfg.UpdateNoticeDir); errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err := s.importTotalIndex(); err != nil {
|
||||
return err
|
||||
}
|
||||
entries, err := os.ReadDir(s.cfg.UpdateNoticeDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(strings.ToLower(entry.Name()), ".json") || strings.EqualFold(entry.Name(), "total.json") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(s.cfg.UpdateNoticeDir, entry.Name())
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item, _, _, err := parseNotice(data, versionFromFile(entry.Name()), entry.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.store.UpsertReleaseNotice(item); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) List(limit int) ([]db.ReleaseNotice, error) {
|
||||
return s.store.ListReleaseNotices(limit)
|
||||
}
|
||||
|
||||
func (s *Service) Get(version string) (Document, error) {
|
||||
item, err := s.store.GetReleaseNotice(version)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
raw := item.RawJSON
|
||||
parsed, formatted, err := parseAndFormat([]byte(raw), version, item.NoticeFile)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
revisions, _ := s.store.ListReleaseNoticeRevisions(version, 20)
|
||||
updatedAt := item.UpdatedAt
|
||||
path := filepath.Join(s.cfg.UpdateNoticeDir, firstNonEmpty(item.NoticeFile, version+".json"))
|
||||
if info, err := os.Stat(path); err == nil {
|
||||
updatedAt = info.ModTime().UTC().Format(time.RFC3339)
|
||||
}
|
||||
return Document{Notice: item, Raw: formatted, Parsed: parsed, Path: path, UpdatedAt: updatedAt, Revisions: revisions}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Validate(ctx context.Context, version string, req SaveRequest) (Document, error) {
|
||||
raw, err := requestRaw(req)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
item, parsed, formatted, err := parseNotice([]byte(raw), version, version+".json")
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
return Document{Notice: item, Raw: formatted, Parsed: parsed}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Save(ctx context.Context, version string, req SaveRequest, actor string) (Document, error) {
|
||||
raw, err := requestRaw(req)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
item, parsed, formatted, err := parseNotice([]byte(raw), version, version+".json")
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
item.RawJSON = formatted
|
||||
if current, err := s.store.GetReleaseNotice(item.Version); err == nil && current.RawJSON != "" {
|
||||
_, _ = s.store.SaveReleaseNoticeRevision(item.Version, current.RawJSON, "auto backup before save", actor)
|
||||
}
|
||||
saved, err := s.store.UpsertReleaseNotice(item)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
_, _ = s.store.SaveReleaseNoticeRevision(saved.Version, formatted, req.Note, actor)
|
||||
if err := s.writeNoticeFile(saved, formatted); err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
if err := s.writeTotalIndex(saved, parsed); err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
if err := s.syncLegacyUpdateInfo(saved, parsed); err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
_ = s.store.InsertAudit(db.AuditLog{Actor: actor, Type: "release_notice.saved", Target: saved.Version, Message: "Release notice saved"})
|
||||
return s.Get(saved.Version)
|
||||
}
|
||||
|
||||
func (s *Service) Restore(ctx context.Context, version string, revisionID int64, actor string) (Document, error) {
|
||||
revision, err := s.store.GetReleaseNoticeRevision(version, revisionID)
|
||||
if err != nil {
|
||||
return Document{}, err
|
||||
}
|
||||
return s.Save(ctx, version, SaveRequest{Raw: revision.RawJSON, Note: "restored revision"}, actor)
|
||||
}
|
||||
|
||||
func (s *Service) importTotalIndex() error {
|
||||
path := filepath.Join(s.cfg.UpdateNoticeDir, "total.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var root map[string]any
|
||||
if err := json.Unmarshal(data, &root); err != nil {
|
||||
return err
|
||||
}
|
||||
if latest, ok := root["latest"].(map[string]any); ok {
|
||||
raw, _ := json.MarshalIndent(latest, "", " ")
|
||||
item, _, _, err := parseNotice(raw, stringValue(root, "latest_version"), stringValue(root, "latest_notice_file"))
|
||||
if err == nil {
|
||||
_, _ = s.store.UpsertReleaseNotice(item)
|
||||
}
|
||||
}
|
||||
for _, entry := range arrayValue(root, "versions") {
|
||||
version := stringValue(entry, "version")
|
||||
if version == "" {
|
||||
continue
|
||||
}
|
||||
raw, _ := json.MarshalIndent(entry, "", " ")
|
||||
item, _, _, err := parseNotice(raw, version, stringValue(entry, "notice_file"))
|
||||
if err == nil {
|
||||
_, _ = s.store.UpsertReleaseNotice(item)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) writeNoticeFile(item db.ReleaseNotice, raw string) error {
|
||||
if err := os.MkdirAll(s.cfg.UpdateNoticeDir, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
return atomicWrite(filepath.Join(s.cfg.UpdateNoticeDir, firstNonEmpty(item.NoticeFile, item.Version+".json")), []byte(raw))
|
||||
}
|
||||
|
||||
func (s *Service) writeTotalIndex(item db.ReleaseNotice, parsed map[string]any) error {
|
||||
path := filepath.Join(s.cfg.UpdateNoticeDir, "total.json")
|
||||
root := map[string]any{"schema_version": 1, "product": "YMhut Box", "versions": []any{}}
|
||||
if data, err := os.ReadFile(path); err == nil {
|
||||
_ = json.Unmarshal(data, &root)
|
||||
}
|
||||
root["latest_version"] = newestVersion(stringValue(root, "latest_version"), item.Version)
|
||||
if root["latest_version"] == item.Version {
|
||||
root["latest_notice_file"] = item.NoticeFile
|
||||
root["latest"] = latestMap(item, parsed)
|
||||
}
|
||||
root["last_updated"] = db.Now()
|
||||
versions := arrayValue(root, "versions")
|
||||
next := make([]any, 0, len(versions)+1)
|
||||
found := false
|
||||
for _, entry := range versions {
|
||||
if stringValue(entry, "version") == item.Version {
|
||||
next = append(next, summaryMap(item, parsed))
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
next = append(next, entry)
|
||||
}
|
||||
if !found {
|
||||
next = append(next, summaryMap(item, parsed))
|
||||
}
|
||||
sort.SliceStable(next, func(i, j int) bool {
|
||||
return compareVersion(stringValue(next[i], "version"), stringValue(next[j], "version")) > 0
|
||||
})
|
||||
root["versions"] = next
|
||||
data, err := json.MarshalIndent(root, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return atomicWrite(path, append(data, '\n'))
|
||||
}
|
||||
|
||||
func (s *Service) syncLegacyUpdateInfo(item db.ReleaseNotice, parsed map[string]any) error {
|
||||
path := filepath.Join(s.cfg.UpdatePublicDir, "update-info.json")
|
||||
payload := map[string]any{}
|
||||
if data, err := os.ReadFile(path); err == nil {
|
||||
_ = json.Unmarshal(data, &payload)
|
||||
}
|
||||
payload["app_version"] = item.Version
|
||||
setNonEmpty(payload, "build", item.Build)
|
||||
setNonEmpty(payload, "channel", item.Channel)
|
||||
setNonEmpty(payload, "title", item.Title)
|
||||
setNonEmpty(payload, "message", item.Message)
|
||||
setNonEmpty(payload, "message_md", item.MessageMD)
|
||||
setNonEmpty(payload, "release_notes", item.ReleaseNotes)
|
||||
setNonEmpty(payload, "release_notes_md", item.ReleaseNotesMD)
|
||||
setNonEmpty(payload, "download_url", item.DownloadURL)
|
||||
payload["last_updated"] = firstNonEmpty(item.PublishedAt, db.Now())
|
||||
for _, key := range []string{"update_notes", "last_update_notes", "download_mirrors", "detected_packages", "detected_product", "category_list", "home_notes", "tool_metadata", "api_keys"} {
|
||||
if value, ok := parsed[key]; ok {
|
||||
payload[key] = value
|
||||
}
|
||||
}
|
||||
data, err := json.MarshalIndent(payload, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
return atomicWrite(path, append(data, '\n'))
|
||||
}
|
||||
|
||||
func parseAndFormat(data []byte, fallbackVersion, noticeFile string) (map[string]any, string, error) {
|
||||
_, parsed, formatted, err := parseNotice(data, fallbackVersion, noticeFile)
|
||||
return parsed, formatted, err
|
||||
}
|
||||
|
||||
func parseNotice(data []byte, fallbackVersion, noticeFile string) (db.ReleaseNotice, map[string]any, string, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(data))
|
||||
decoder.UseNumber()
|
||||
var parsed map[string]any
|
||||
if err := decoder.Decode(&parsed); err != nil {
|
||||
return db.ReleaseNotice{}, nil, "", err
|
||||
}
|
||||
version := firstNonEmpty(stringValue(parsed, "app_version"), stringValue(parsed, "version"), fallbackVersion)
|
||||
if version == "" {
|
||||
return db.ReleaseNotice{}, nil, "", errors.New("version or app_version is required")
|
||||
}
|
||||
if noticeFile == "" {
|
||||
noticeFile = version + ".json"
|
||||
}
|
||||
formattedBytes, err := json.MarshalIndent(parsed, "", " ")
|
||||
if err != nil {
|
||||
return db.ReleaseNotice{}, nil, "", err
|
||||
}
|
||||
formatted := string(formattedBytes) + "\n"
|
||||
item := db.ReleaseNotice{
|
||||
Version: version,
|
||||
Build: stringValue(parsed, "build"),
|
||||
Channel: firstNonEmpty(stringValue(parsed, "channel"), "stable"),
|
||||
Title: firstNonEmpty(stringValue(parsed, "title"), "YMhut Box "+version),
|
||||
Message: firstNonEmpty(stringValue(parsed, "message"), stringValue(parsed, "summary"), stringValue(parsed, "home_notes")),
|
||||
ReleaseNotes: firstNonEmpty(stringValue(parsed, "release_notes"), stringValue(parsed, "summary")),
|
||||
MessageMD: stringValue(parsed, "message_md"),
|
||||
ReleaseNotesMD: stringValue(parsed, "release_notes_md"),
|
||||
DownloadURL: stringValue(parsed, "download_url"),
|
||||
NoticeFile: noticeFile,
|
||||
RawJSON: formatted,
|
||||
PublishedAt: normalizeTime(firstNonEmpty(stringValue(parsed, "published_at"), stringValue(parsed, "release_date"), stringValue(parsed, "last_updated"))),
|
||||
}
|
||||
return item, parsed, formatted, nil
|
||||
}
|
||||
|
||||
func requestRaw(req SaveRequest) (string, error) {
|
||||
if strings.TrimSpace(req.Raw) != "" {
|
||||
return req.Raw, nil
|
||||
}
|
||||
if req.Parsed == nil {
|
||||
return "", errors.New("raw or parsed JSON is required")
|
||||
}
|
||||
data, err := json.Marshal(req.Parsed)
|
||||
return string(data), err
|
||||
}
|
||||
|
||||
func latestMap(item db.ReleaseNotice, parsed map[string]any) map[string]any {
|
||||
out := summaryMap(item, parsed)
|
||||
out["title"] = item.Title
|
||||
out["message"] = item.Message
|
||||
out["download_url"] = item.DownloadURL
|
||||
out["release_notes"] = item.ReleaseNotes
|
||||
out["message_md"] = item.MessageMD
|
||||
out["release_notes_md"] = item.ReleaseNotesMD
|
||||
return out
|
||||
}
|
||||
|
||||
func summaryMap(item db.ReleaseNotice, parsed map[string]any) map[string]any {
|
||||
out := map[string]any{
|
||||
"version": item.Version,
|
||||
"build": item.Build,
|
||||
"channel": item.Channel,
|
||||
"release_date": dateOnly(item.PublishedAt),
|
||||
"notice_file": item.NoticeFile,
|
||||
"summary": firstNonEmpty(item.Message, item.ReleaseNotes),
|
||||
}
|
||||
if value, ok := parsed["highlights"]; ok {
|
||||
out["highlights"] = value
|
||||
}
|
||||
if value, ok := parsed["categories"]; ok {
|
||||
out["categories"] = value
|
||||
} else if value, ok := parsed["update_notes"]; ok {
|
||||
out["categories"] = value
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func PublicNotice(item db.ReleaseNotice) map[string]any {
|
||||
return map[string]any{
|
||||
"version": item.Version,
|
||||
"build": item.Build,
|
||||
"channel": item.Channel,
|
||||
"title": item.Title,
|
||||
"message": item.Message,
|
||||
"release_notes": item.ReleaseNotes,
|
||||
"message_md": item.MessageMD,
|
||||
"release_notes_md": item.ReleaseNotesMD,
|
||||
"download_url": item.DownloadURL,
|
||||
"notice_file": item.NoticeFile,
|
||||
"published_at": item.PublishedAt,
|
||||
"updated_at": item.UpdatedAt,
|
||||
"releaseNotes": item.ReleaseNotes,
|
||||
"messageMarkdown": item.MessageMD,
|
||||
"releaseNotesMarkdown": item.ReleaseNotesMD,
|
||||
}
|
||||
}
|
||||
|
||||
func PublicList(items []db.ReleaseNotice) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
out = append(out, PublicNotice(item))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func arrayValue(root map[string]any, key string) []any {
|
||||
if values, ok := root[key].([]any); ok {
|
||||
return values
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
|
||||
func stringValue(root any, key string) string {
|
||||
obj, ok := root.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
switch value := obj[key].(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(value)
|
||||
case json.Number:
|
||||
return value.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func setNonEmpty(target map[string]any, key, value string) {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
target[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeTime(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, value); err == nil {
|
||||
return t.UTC().Format(time.RFC3339)
|
||||
}
|
||||
if t, err := time.Parse("2006-01-02", value); err == nil {
|
||||
return t.UTC().Format(time.RFC3339)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func dateOnly(value string) string {
|
||||
if t, err := time.Parse(time.RFC3339, value); err == nil {
|
||||
return t.UTC().Format("2006-01-02")
|
||||
}
|
||||
if len(value) >= 10 {
|
||||
return value[:10]
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func newestVersion(current, candidate string) string {
|
||||
if current == "" || compareVersion(candidate, current) >= 0 {
|
||||
return candidate
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func compareVersion(a, b string) int {
|
||||
as := strings.Split(a, ".")
|
||||
bs := strings.Split(b, ".")
|
||||
for len(as) < 4 {
|
||||
as = append(as, "0")
|
||||
}
|
||||
for len(bs) < 4 {
|
||||
bs = append(bs, "0")
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
ai := parseInt(as[i])
|
||||
bi := parseInt(bs[i])
|
||||
if ai > bi {
|
||||
return 1
|
||||
}
|
||||
if ai < bi {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseInt(value string) int {
|
||||
n := 0
|
||||
for _, r := range value {
|
||||
if r < '0' || r > '9' {
|
||||
break
|
||||
}
|
||||
n = n*10 + int(r-'0')
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func versionFromFile(name string) string {
|
||||
return strings.TrimSuffix(filepath.Base(name), filepath.Ext(name))
|
||||
}
|
||||
|
||||
func atomicWrite(path string, data []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpName := tmp.Name()
|
||||
defer os.Remove(tmpName)
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Chmod(tmpName, 0o640); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmpName, path)
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package notices
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
func TestSaveNoticeSyncsFilesAndLegacyUpdateInfo(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
public := filepath.Join(root, "public")
|
||||
noticeDir := filepath.Join(root, "update-notice")
|
||||
if err := os.MkdirAll(public, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(noticeDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
writeJSON(t, filepath.Join(public, "update-info.json"), map[string]any{"app_version": "1.0.0"})
|
||||
writeJSON(t, filepath.Join(noticeDir, "total.json"), map[string]any{"schema_version": 1, "versions": []any{}})
|
||||
|
||||
store, err := db.Open(&config.Config{
|
||||
StorageDir: filepath.Join(root, "storage"),
|
||||
UpdatePublicDir: public,
|
||||
UpdateNoticeDir: noticeDir,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: filepath.Join(root, "storage", "unified.sqlite"),
|
||||
FailoverEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
MaxOpenConns: 1,
|
||||
MaxIdleConns: 1,
|
||||
ConnMaxLifetimeSeconds: 60,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
service := NewService(&config.Config{UpdatePublicDir: public, UpdateNoticeDir: noticeDir}, store)
|
||||
raw := `{"app_version":"2.0.1","title":"YMhut Box 2.0.1","message":"hello","release_notes":"notes","release_notes_md":"## Notes","download_url":"https://update.ymhut.cn/downloads/app.exe","update_notes":{"发布":"说明"}}`
|
||||
doc, err := service.Save(context.Background(), "2.0.1", SaveRequest{Raw: raw, Note: "test"}, "admin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if doc.Notice.Version != "2.0.1" {
|
||||
t.Fatalf("unexpected version %q", doc.Notice.Version)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(noticeDir, "2.0.1.json")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updateInfo := readJSONFile(t, filepath.Join(public, "update-info.json"))
|
||||
if updateInfo["app_version"] != "2.0.1" || updateInfo["release_notes"] != "notes" {
|
||||
t.Fatalf("legacy update-info not synced: %#v", updateInfo)
|
||||
}
|
||||
total := readJSONFile(t, filepath.Join(noticeDir, "total.json"))
|
||||
if total["latest_version"] != "2.0.1" {
|
||||
t.Fatalf("total index not synced: %#v", total)
|
||||
}
|
||||
revisions, err := store.ListReleaseNoticeRevisions("2.0.1", 10)
|
||||
if err != nil || len(revisions) == 0 {
|
||||
t.Fatalf("expected revision, got %d, %v", len(revisions), err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, path string, payload any) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func readJSONFile(t *testing.T, path string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
package releases
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
"ymhut-box/server/unified-management/internal/notices"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *db.Store
|
||||
notices *notices.Service
|
||||
}
|
||||
|
||||
type Package struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Platform string `json:"platform"`
|
||||
Arch string `json:"arch"`
|
||||
URL string `json:"url"`
|
||||
SHA256 string `json:"sha256"`
|
||||
Size int64 `json:"size"`
|
||||
Required bool `json:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
FileName string `json:"fileName"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func NewService(cfg *config.Config, store *db.Store, noticeService ...*notices.Service) *Service {
|
||||
service := &Service{cfg: cfg, store: store}
|
||||
if len(noticeService) > 0 {
|
||||
service.notices = noticeService[0]
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *Service) LegacyUpdateInfo(r *http.Request) map[string]any {
|
||||
payload := readJSON(filepath.Join(s.cfg.UpdatePublicDir, "update-info.json"))
|
||||
manifest := s.Manifest(r)
|
||||
for _, key := range []string{"app_version", "download_url", "download_mirrors", "detected_product", "detected_packages", "packages", "modules", "manifest_version", "release_notes", "release_notes_md", "message", "message_md", "notices", "latest_notice"} {
|
||||
if value, ok := manifest[key]; ok {
|
||||
payload[key] = value
|
||||
}
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func (s *Service) Manifest(r *http.Request) map[string]any {
|
||||
payload := readJSON(filepath.Join(s.cfg.UpdatePublicDir, "update-info.json"))
|
||||
packages := s.ScanPackages(r)
|
||||
modules := readJSON(filepath.Join(s.cfg.UpdatePublicDir, "modules.json"))["modules"]
|
||||
if modules == nil {
|
||||
modules = []any{}
|
||||
}
|
||||
payload["manifest_version"] = 2
|
||||
payload["service_version"] = config.Version
|
||||
payload["packages"] = packages
|
||||
payload["modules"] = modules
|
||||
payload["assets"] = []any{}
|
||||
payload["generated_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
if s.notices != nil {
|
||||
if items, err := s.notices.List(50); err == nil && len(items) > 0 {
|
||||
publicNotices := notices.PublicList(items)
|
||||
payload["notices"] = publicNotices
|
||||
payload["latest_notice"] = publicNotices[0]
|
||||
latestNotice := items[0]
|
||||
setIfMissing(payload, "app_version", latestNotice.Version)
|
||||
setIfMissing(payload, "title", latestNotice.Title)
|
||||
setIfMissing(payload, "message", latestNotice.Message)
|
||||
setIfMissing(payload, "message_md", latestNotice.MessageMD)
|
||||
setIfMissing(payload, "release_notes", latestNotice.ReleaseNotes)
|
||||
setIfMissing(payload, "release_notes_md", latestNotice.ReleaseNotesMD)
|
||||
setIfMissing(payload, "download_url", latestNotice.DownloadURL)
|
||||
}
|
||||
}
|
||||
if len(packages) > 0 {
|
||||
latest := packages[0]
|
||||
payload["app_version"] = latest.Version
|
||||
payload["download_url"] = latest.URL
|
||||
payload["download_mirrors"] = []map[string]any{{
|
||||
"id": "primary",
|
||||
"name": "官方直连",
|
||||
"url": latest.URL,
|
||||
"type": "direct",
|
||||
"sha256": latest.SHA256,
|
||||
"enabled": true,
|
||||
}}
|
||||
payload["detected_product"] = latest.Name
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func setIfMissing(payload map[string]any, key, value string) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return
|
||||
}
|
||||
if existing, ok := payload[key].(string); !ok || strings.TrimSpace(existing) == "" {
|
||||
payload[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ScanPackages(r *http.Request) []Package {
|
||||
entries, err := os.ReadDir(s.cfg.DownloadsDir)
|
||||
if err != nil {
|
||||
return []Package{}
|
||||
}
|
||||
base := requestBaseURL(r, s.cfg.BaseURL)
|
||||
items := []Package{}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
lower := strings.ToLower(name)
|
||||
if !(strings.HasSuffix(lower, ".exe") || strings.HasSuffix(lower, ".msix") || strings.HasSuffix(lower, ".appinstaller") || strings.HasSuffix(lower, ".msi")) {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
version := detectVersion(name)
|
||||
platform, arch := detectPlatform(name)
|
||||
product := detectProduct(name)
|
||||
url := base + "/downloads/" + name
|
||||
items = append(items, Package{
|
||||
ID: strings.ToLower(strings.ReplaceAll(product+"-"+platform+"-"+arch+"-"+version, " ", "-")),
|
||||
Name: product,
|
||||
Version: version,
|
||||
Platform: platform,
|
||||
Arch: arch,
|
||||
URL: url,
|
||||
SHA256: sha256File(filepath.Join(s.cfg.DownloadsDir, name)),
|
||||
Size: info.Size(),
|
||||
Required: strings.Contains(strings.ToLower(product), "ymhut"),
|
||||
Enabled: true,
|
||||
FileName: name,
|
||||
UpdatedAt: info.ModTime().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return compareVersion(items[i].Version, items[j].Version) > 0
|
||||
})
|
||||
return items
|
||||
}
|
||||
|
||||
func (s *Service) StaticJSON(name string) map[string]any {
|
||||
return readJSON(filepath.Join(s.cfg.UpdatePublicDir, name))
|
||||
}
|
||||
|
||||
func readJSON(path string) map[string]any {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func requestBaseURL(r *http.Request, fallback string) string {
|
||||
if r != nil {
|
||||
scheme := r.Header.Get("X-Forwarded-Proto")
|
||||
if scheme == "" {
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
if r.Host != "" {
|
||||
return scheme + "://" + r.Host
|
||||
}
|
||||
}
|
||||
return strings.TrimRight(fallback, "/")
|
||||
}
|
||||
|
||||
var versionPattern = regexp.MustCompile(`\d+\.\d+\.\d+(?:\.\d+)?`)
|
||||
|
||||
func detectVersion(name string) string {
|
||||
match := versionPattern.FindString(name)
|
||||
if match == "" {
|
||||
return "0.0.0"
|
||||
}
|
||||
return match
|
||||
}
|
||||
|
||||
func detectPlatform(name string) (string, string) {
|
||||
lower := strings.ToLower(name)
|
||||
platform := "windows"
|
||||
if strings.Contains(lower, "appinstaller") || strings.HasSuffix(lower, ".msix") || strings.HasSuffix(lower, ".exe") || strings.HasSuffix(lower, ".msi") {
|
||||
platform = "windows"
|
||||
}
|
||||
arch := "x64"
|
||||
if strings.Contains(lower, "arm64") {
|
||||
arch = "arm64"
|
||||
} else if strings.Contains(lower, "x86") && !strings.Contains(lower, "x64") {
|
||||
arch = "x86"
|
||||
}
|
||||
return platform, arch
|
||||
}
|
||||
|
||||
func detectProduct(name string) string {
|
||||
if strings.Contains(strings.ToLower(name), "ymhut") {
|
||||
return "YMhut Box"
|
||||
}
|
||||
return "YMhut Package"
|
||||
}
|
||||
|
||||
func compareVersion(a, b string) int {
|
||||
as := strings.Split(a, ".")
|
||||
bs := strings.Split(b, ".")
|
||||
for len(as) < 4 {
|
||||
as = append(as, "0")
|
||||
}
|
||||
for len(bs) < 4 {
|
||||
bs = append(bs, "0")
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
ai, _ := strconv.Atoi(as[i])
|
||||
bi, _ := strconv.Atoi(bs[i])
|
||||
if ai > bi {
|
||||
return 1
|
||||
}
|
||||
if ai < bi {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func sha256File(path string) string {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer file.Close()
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, file); err != nil {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package releases
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCompareVersion(t *testing.T) {
|
||||
cases := []struct {
|
||||
a string
|
||||
b string
|
||||
want int
|
||||
}{
|
||||
{"2.0.6.31", "2.0.6.2", 1},
|
||||
{"2.0.10", "2.0.9", 1},
|
||||
{"2.0.6.2", "2.0.6.31", -1},
|
||||
{"2.0.6", "2.0.6.0", 0},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := compareVersion(tc.a, tc.b); got != tc.want {
|
||||
t.Fatalf("compareVersion(%q, %q) = %d, want %d", tc.a, tc.b, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectPackageMetadata(t *testing.T) {
|
||||
platform, arch := detectPlatform("YMhutBox_2.0.6.31_x64.msix")
|
||||
if platform != "windows" || arch != "x64" {
|
||||
t.Fatalf("detectPlatform returned %s/%s", platform, arch)
|
||||
}
|
||||
if version := detectVersion("YMhut_Box_WinUI_Setup_2.0.6.31.exe"); version != "2.0.6.31" {
|
||||
t.Fatalf("detectVersion returned %q", version)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *db.Store
|
||||
client *http.Client
|
||||
stop chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
type legacyMedia struct {
|
||||
Categories []legacyCategory `json:"categories"`
|
||||
}
|
||||
|
||||
type legacyCategory struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Subcategories []legacySubcategory `json:"subcategories"`
|
||||
}
|
||||
|
||||
type legacySubcategory struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
APIURL string `json:"api_url"`
|
||||
ThumbnailURL string `json:"thumbnail_url"`
|
||||
RefreshInterval int `json:"refresh_interval"`
|
||||
SupportedFormats []string `json:"supported_formats"`
|
||||
Downloadable bool `json:"downloadable"`
|
||||
}
|
||||
|
||||
func NewService(cfg *config.Config, store *db.Store) *Service {
|
||||
return &Service{
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Start(ctx context.Context) {
|
||||
s.once.Do(func() {
|
||||
go s.loop()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Stop() {
|
||||
close(s.stop)
|
||||
}
|
||||
|
||||
func (s *Service) loop() {
|
||||
ticker := time.NewTicker(time.Duration(s.cfg.SourceCheckSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
s.CheckDue(context.Background())
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.CheckDue(context.Background())
|
||||
case <-s.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ImportLegacyMediaTypesIfEmpty(ctx context.Context) error {
|
||||
count, err := s.store.CountSources()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
return s.ImportLegacyMediaTypes(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) ImportLegacyMediaTypes(ctx context.Context) error {
|
||||
data, err := os.ReadFile(filepath.Join(s.cfg.UpdatePublicDir, "media-types.json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var legacy legacyMedia
|
||||
if err := json.Unmarshal(data, &legacy); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, category := range legacy.Categories {
|
||||
for _, sub := range category.Subcategories {
|
||||
if strings.TrimSpace(sub.APIURL) == "" {
|
||||
continue
|
||||
}
|
||||
formats, _ := json.Marshal(sub.SupportedFormats)
|
||||
_, err := s.store.UpsertSource(db.Source{
|
||||
CategoryID: defaultString(category.ID, "media"),
|
||||
CategoryName: defaultString(category.Name, category.ID),
|
||||
SourceID: defaultString(sub.ID, category.ID+"-"+sub.Name),
|
||||
Name: defaultString(sub.Name, sub.ID),
|
||||
Description: sub.Description,
|
||||
Method: "GET",
|
||||
APIURL: sub.APIURL,
|
||||
ThumbnailURL: sub.ThumbnailURL,
|
||||
ProxyMode: "client_direct",
|
||||
TimeoutMS: 8000,
|
||||
RetryCount: 1,
|
||||
CheckIntervalSec: maxInt(sub.RefreshInterval, 300),
|
||||
Enabled: legacyEnabled(category.Enabled),
|
||||
ClientVisible: true,
|
||||
SupportedFormats: string(formats),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Catalog(includeHidden bool) (map[string]any, error) {
|
||||
items, err := s.store.ListSources(includeHidden)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
categories := map[string]map[string]any{}
|
||||
for _, item := range items {
|
||||
cat, ok := categories[item.CategoryID]
|
||||
if !ok {
|
||||
cat = map[string]any{
|
||||
"id": item.CategoryID,
|
||||
"name": item.CategoryName,
|
||||
"enabled": true,
|
||||
"subcategories": []map[string]any{},
|
||||
}
|
||||
categories[item.CategoryID] = cat
|
||||
}
|
||||
var formats []string
|
||||
_ = json.Unmarshal([]byte(item.SupportedFormats), &formats)
|
||||
sub := map[string]any{
|
||||
"id": item.SourceID,
|
||||
"name": item.Name,
|
||||
"description": item.Description,
|
||||
"api_url": item.APIURL,
|
||||
"urlTemplate": firstNonEmpty(item.URLTemplate, item.APIURL),
|
||||
"thumbnail_url": item.ThumbnailURL,
|
||||
"method": item.Method,
|
||||
"proxy_mode": item.ProxyMode,
|
||||
"proxyMode": item.ProxyMode,
|
||||
"refresh_interval": item.CheckIntervalSec,
|
||||
"cacheSeconds": item.CacheSeconds,
|
||||
"supported_formats": formats,
|
||||
"downloadable": true,
|
||||
"health": map[string]any{
|
||||
"status": item.LastStatus,
|
||||
"latency_ms": item.LastLatencyMS,
|
||||
"last_checked_at": item.LastCheckedAt,
|
||||
"last_error": item.LastError,
|
||||
"consecutiveFailure": item.ConsecutiveFailure,
|
||||
},
|
||||
}
|
||||
cat["subcategories"] = append(cat["subcategories"].([]map[string]any), sub)
|
||||
}
|
||||
out := []map[string]any{}
|
||||
for _, cat := range categories {
|
||||
out = append(out, cat)
|
||||
}
|
||||
return map[string]any{
|
||||
"layout_version": "2.0.0",
|
||||
"last_updated": time.Now().UTC().Format(time.RFC3339),
|
||||
"categories": out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Endpoints(includeHidden bool) ([]map[string]any, error) {
|
||||
items, err := s.store.ListSources(includeHidden)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := []map[string]any{}
|
||||
for _, item := range items {
|
||||
var formats []string
|
||||
_ = json.Unmarshal([]byte(item.SupportedFormats), &formats)
|
||||
out = append(out, map[string]any{
|
||||
"id": item.SourceID,
|
||||
"category": item.CategoryID,
|
||||
"name": item.Name,
|
||||
"method": item.Method,
|
||||
"urlTemplate": firstNonEmpty(item.URLTemplate, item.APIURL),
|
||||
"proxyMode": item.ProxyMode,
|
||||
"clientVisible": item.ClientVisible,
|
||||
"enabled": item.Enabled,
|
||||
"cacheSeconds": item.CacheSeconds,
|
||||
"supportedFormats": formats,
|
||||
"health": map[string]any{
|
||||
"status": item.LastStatus,
|
||||
"latencyMs": item.LastLatencyMS,
|
||||
"lastCheckedAt": item.LastCheckedAt,
|
||||
"lastError": item.LastError,
|
||||
"consecutiveFailure": item.ConsecutiveFailure,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckDue(ctx context.Context) {
|
||||
items, err := s.store.ListSources(true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
for _, item := range items {
|
||||
if !item.Enabled {
|
||||
continue
|
||||
}
|
||||
if item.LastCheckedAt != "" {
|
||||
if last, err := time.Parse(time.RFC3339, item.LastCheckedAt); err == nil && now.Sub(last) < time.Duration(item.CheckIntervalSec)*time.Second {
|
||||
continue
|
||||
}
|
||||
}
|
||||
_ = s.CheckOne(ctx, item)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) CheckSourceID(ctx context.Context, sourceID string) (db.Source, error) {
|
||||
item, err := s.store.GetSourceBySourceID(sourceID)
|
||||
if err != nil {
|
||||
return db.Source{}, err
|
||||
}
|
||||
return item, s.CheckOne(ctx, item)
|
||||
}
|
||||
|
||||
func (s *Service) CheckOne(ctx context.Context, item db.Source) error {
|
||||
if strings.TrimSpace(item.APIURL) == "" {
|
||||
return errors.New("source api_url is empty")
|
||||
}
|
||||
timeout := time.Duration(item.TimeoutMS) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 8 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, item.Method, item.APIURL, nil)
|
||||
if err != nil {
|
||||
_ = s.store.RecordSourceCheck(item.ID, "error", 0, err.Error())
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := s.client.Do(req)
|
||||
latency := int(time.Since(start).Milliseconds())
|
||||
if err != nil {
|
||||
_ = s.store.RecordSourceCheck(item.ID, "error", latency, err.Error())
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
status := "ok"
|
||||
message := ""
|
||||
if resp.StatusCode >= 400 {
|
||||
status = "degraded"
|
||||
message = resp.Status
|
||||
}
|
||||
return s.store.RecordSourceCheck(item.ID, status, latency, message)
|
||||
}
|
||||
|
||||
func defaultString(value, fallback string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fallback
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func legacyEnabled(value *bool) bool {
|
||||
if value == nil {
|
||||
return true
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func maxInt(value, fallback int) int {
|
||||
if value > 0 {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
package synclegacy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
"ymhut-box/server/unified-management/internal/notices"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *db.Store
|
||||
notices *notices.Service
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Ok bool `json:"ok"`
|
||||
DryRun bool `json:"dryRun"`
|
||||
Paths map[string]any `json:"paths"`
|
||||
Stats map[string]int `json:"stats"`
|
||||
Warnings []string `json:"warnings"`
|
||||
Errors []string `json:"errors"`
|
||||
Started string `json:"startedAt"`
|
||||
Finished string `json:"finishedAt"`
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, store *db.Store, noticeService *notices.Service) *Service {
|
||||
return &Service{cfg: cfg, store: store, notices: noticeService}
|
||||
}
|
||||
|
||||
func (s *Service) Preview(ctx context.Context) Result {
|
||||
return s.run(ctx, true)
|
||||
}
|
||||
|
||||
func (s *Service) Run(ctx context.Context) Result {
|
||||
return s.run(ctx, false)
|
||||
}
|
||||
|
||||
func (s *Service) run(ctx context.Context, dryRun bool) Result {
|
||||
result := Result{
|
||||
Ok: true,
|
||||
DryRun: dryRun,
|
||||
Paths: map[string]any{
|
||||
"legacyUpdateDir": s.cfg.LegacyUpdateDir,
|
||||
"legacyFeedbackDir": s.cfg.LegacyFeedbackDir,
|
||||
"legacyUpdateNoticeDir": s.cfg.LegacyUpdateNoticeDir,
|
||||
"updatePublicDir": s.cfg.UpdatePublicDir,
|
||||
"updateNoticeDir": s.cfg.UpdateNoticeDir,
|
||||
},
|
||||
Stats: map[string]int{},
|
||||
Started: db.Now(),
|
||||
}
|
||||
defer func() {
|
||||
result.Finished = db.Now()
|
||||
if len(result.Errors) > 0 {
|
||||
result.Ok = false
|
||||
}
|
||||
}()
|
||||
s.previewPath(&result, "legacy_update", s.cfg.LegacyUpdateDir)
|
||||
s.previewPath(&result, "legacy_feedback", s.cfg.LegacyFeedbackDir)
|
||||
s.previewPath(&result, "legacy_update_notice", s.cfg.LegacyUpdateNoticeDir)
|
||||
if dryRun {
|
||||
return result
|
||||
}
|
||||
if err := s.backupCurrent(); err != nil {
|
||||
result.Errors = append(result.Errors, err.Error())
|
||||
return result
|
||||
}
|
||||
s.syncUpdatePublic(&result)
|
||||
s.syncNotices(ctx, &result)
|
||||
s.syncFeedbackSQLite(&result)
|
||||
_ = s.store.InsertAudit(db.AuditLog{Actor: "admin", Type: "legacy.sync", Target: "legacy-projects", Message: fmt.Sprintf("Legacy sync finished: ok=%v copied=%d imported=%d errors=%d", result.Ok, result.Stats["copiedFiles"], result.Stats["importedRows"], len(result.Errors))})
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Service) previewPath(result *Result, key, path string) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, key+": "+err.Error())
|
||||
result.Stats["missingPaths"]++
|
||||
return
|
||||
}
|
||||
if !info.IsDir() {
|
||||
result.Warnings = append(result.Warnings, key+": path is not a directory")
|
||||
result.Stats["missingPaths"]++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) backupCurrent() error {
|
||||
backupRoot := filepath.Join(s.cfg.StorageDir, "backups", "legacy-sync-"+time.Now().UTC().Format("20060102-150405"))
|
||||
for _, path := range []string{s.cfg.UpdatePublicDir, s.cfg.UpdateNoticeDir} {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
target := filepath.Join(backupRoot, filepath.Base(path))
|
||||
if err := copyDir(path, target); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) syncUpdatePublic(result *Result) {
|
||||
sourcePublic := filepath.Join(s.cfg.LegacyUpdateDir, "public")
|
||||
if _, err := os.Stat(sourcePublic); err != nil {
|
||||
result.Warnings = append(result.Warnings, "update public not found: "+err.Error())
|
||||
return
|
||||
}
|
||||
for _, name := range []string{"update-info.json", "media-types.json", "tool-status.json", "modules.json"} {
|
||||
source := filepath.Join(sourcePublic, name)
|
||||
if _, err := os.Stat(source); err == nil {
|
||||
if err := copyFile(source, filepath.Join(s.cfg.UpdatePublicDir, name)); err != nil {
|
||||
result.Errors = append(result.Errors, err.Error())
|
||||
} else {
|
||||
result.Stats["copiedFiles"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
sourceDownloads := filepath.Join(sourcePublic, "downloads")
|
||||
if _, err := os.Stat(sourceDownloads); err == nil {
|
||||
if err := copyDir(sourceDownloads, s.cfg.DownloadsDir); err != nil {
|
||||
result.Errors = append(result.Errors, err.Error())
|
||||
} else {
|
||||
result.Stats["copiedDirectories"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) syncNotices(ctx context.Context, result *Result) {
|
||||
for _, source := range []string{filepath.Join(s.cfg.LegacyUpdateDir, "update-notice"), s.cfg.LegacyUpdateNoticeDir} {
|
||||
if _, err := os.Stat(source); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := copyDir(source, s.cfg.UpdateNoticeDir); err != nil {
|
||||
result.Errors = append(result.Errors, err.Error())
|
||||
continue
|
||||
}
|
||||
result.Stats["copiedDirectories"]++
|
||||
}
|
||||
if s.notices != nil {
|
||||
if err := s.notices.Import(ctx); err != nil {
|
||||
result.Errors = append(result.Errors, "notice import: "+err.Error())
|
||||
} else {
|
||||
result.Stats["noticeImports"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) syncFeedbackSQLite(result *Result) {
|
||||
path := filepath.Join(s.cfg.LegacyFeedbackDir, "storage", "feedback.sqlite")
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
result.Warnings = append(result.Warnings, "feedback sqlite not found: "+err.Error())
|
||||
return
|
||||
}
|
||||
oldDB, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, err.Error())
|
||||
return
|
||||
}
|
||||
defer oldDB.Close()
|
||||
s.importOldFeedbacks(oldDB, result)
|
||||
s.importOldComments(oldDB, result)
|
||||
s.importOldEvents(oldDB, result)
|
||||
s.importOldTags(oldDB, result)
|
||||
s.importOldMail(oldDB, result)
|
||||
s.importOldWebhooks(oldDB, result)
|
||||
}
|
||||
|
||||
func (s *Service) importOldFeedbacks(oldDB *sql.DB, result *Result) {
|
||||
rows, err := oldDB.Query(`SELECT code, received_at, title, type, severity, category, priority, contact, body, status, status_detail, note, public_reply, handled_by, assignee, due_at, resolved_at, archived_at, sla_level, source_channel, risk_score, resolution, package_path, encrypted_package_path, package_sha256, plain_package_sha256, remote_addr, summary_text, included_files, mail_sent, updated_at, last_activity_at FROM feedbacks`)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, "feedbacks: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var item db.Feedback
|
||||
var mailSent int
|
||||
if err := rows.Scan(&item.Code, &item.CreatedAt, &item.Title, &item.Type, &item.Severity, &item.Category, &item.Priority, &item.Contact, &item.Body, &item.Status, &item.StatusDetail, &item.Note, &item.PublicReply, &item.HandledBy, &item.Assignee, &item.DueAt, &item.ResolvedAt, &item.ArchivedAt, &item.SLALevel, &item.SourceChannel, &item.RiskScore, &item.Resolution, &item.PackagePath, &item.EncryptedPackagePath, &item.PackageSha256, &item.PlainPackageSha256, &item.RemoteAddr, &item.SummaryText, &item.IncludedFiles, &mailSent, &item.UpdatedAt, &item.LastActivityAt); err != nil {
|
||||
result.Errors = append(result.Errors, "feedback scan: "+err.Error())
|
||||
continue
|
||||
}
|
||||
item.MailSent = mailSent == 1
|
||||
item.PackagePath = s.copyLegacyFeedbackFile(item.PackagePath, item.Code, result)
|
||||
item.EncryptedPackagePath = s.copyLegacyFeedbackFile(item.EncryptedPackagePath, item.Code, result)
|
||||
if err := s.store.InsertFeedback(item); err != nil && !strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
result.Errors = append(result.Errors, "feedback import "+item.Code+": "+err.Error())
|
||||
} else {
|
||||
result.Stats["importedRows"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) importOldComments(oldDB *sql.DB, result *Result) {
|
||||
rows, err := oldDB.Query(`SELECT id, feedback_code, author, body, internal, created_at FROM feedback_comments`)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, "comments: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var comment db.FeedbackComment
|
||||
var internal int
|
||||
var oldID int64
|
||||
if err := rows.Scan(&oldID, &comment.Code, &comment.Author, &comment.Body, &internal, &comment.CreatedAt); err != nil {
|
||||
result.Errors = append(result.Errors, "comment scan: "+err.Error())
|
||||
continue
|
||||
}
|
||||
comment.Internal = internal == 1
|
||||
if _, err := s.store.InsertFeedbackComment(comment); err == nil {
|
||||
result.Stats["importedRows"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) importOldEvents(oldDB *sql.DB, result *Result) {
|
||||
rows, err := oldDB.Query(`SELECT id, feedback_code, event_type, actor, from_value, to_value, message, created_at FROM feedback_events`)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, "events: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var event db.LegacyFeedbackEvent
|
||||
if err := rows.Scan(&event.ID, &event.FeedbackCode, &event.EventType, &event.Actor, &event.FromValue, &event.ToValue, &event.Message, &event.CreatedAt); err != nil {
|
||||
result.Errors = append(result.Errors, "event scan: "+err.Error())
|
||||
continue
|
||||
}
|
||||
if err := s.store.UpsertFeedbackEvent(event); err == nil {
|
||||
result.Stats["importedRows"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) importOldTags(oldDB *sql.DB, result *Result) {
|
||||
rows, err := oldDB.Query(`SELECT feedback_code, tag, created_at FROM feedback_tags`)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, "tags: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var code, tag, createdAt string
|
||||
if err := rows.Scan(&code, &tag, &createdAt); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := s.store.UpsertFeedbackTag(code, tag, createdAt); err == nil {
|
||||
result.Stats["importedRows"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) importOldMail(oldDB *sql.DB, result *Result) {
|
||||
rows, err := oldDB.Query(`SELECT id, feedback_code, kind, status, to_address, subject, attachment_path, attachment_name, error_message, created_at, sent_at FROM mail_records`)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, "mail_records: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var item db.LegacyMailRecord
|
||||
if err := rows.Scan(&item.ID, &item.FeedbackCode, &item.Kind, &item.Status, &item.ToAddress, &item.Subject, &item.AttachmentPath, &item.AttachmentName, &item.ErrorMessage, &item.CreatedAt, &item.SentAt); err != nil {
|
||||
result.Errors = append(result.Errors, "mail scan: "+err.Error())
|
||||
continue
|
||||
}
|
||||
item.AttachmentPath = s.copyLegacyFeedbackFile(item.AttachmentPath, item.FeedbackCode, result)
|
||||
if err := s.store.UpsertMailRecord(item); err == nil {
|
||||
result.Stats["importedRows"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) importOldWebhooks(oldDB *sql.DB, result *Result) {
|
||||
rows, err := oldDB.Query(`SELECT id, webhook_name, event, status, attempts, response_code, error_message, payload_sha256, created_at, finished_at FROM webhook_deliveries`)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, "webhook_deliveries: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var name, event, status, message, payload, createdAt, finishedAt string
|
||||
var attempts, response int
|
||||
if err := rows.Scan(&id, &name, &event, &status, &attempts, &response, &message, &payload, &createdAt, &finishedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
_ = s.store.InsertAudit(db.AuditLog{Actor: "legacy", Type: "webhook." + status, Target: name, Message: event + " " + message, CreatedAt: firstNonEmpty(createdAt, finishedAt, db.Now())})
|
||||
result.Stats["importedRows"]++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) copyLegacyFeedbackFile(path, code string, result *Result) string {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return ""
|
||||
}
|
||||
source := path
|
||||
if !filepath.IsAbs(source) {
|
||||
source = filepath.Join(s.cfg.LegacyFeedbackDir, source)
|
||||
}
|
||||
info, err := os.Stat(source)
|
||||
if err != nil || info.IsDir() {
|
||||
return path
|
||||
}
|
||||
target := filepath.Join(s.cfg.StorageDir, "legacy-feedback", safeName(code), filepath.Base(source))
|
||||
if err := copyFile(source, target); err != nil {
|
||||
result.Warnings = append(result.Warnings, "copy attachment: "+err.Error())
|
||||
return path
|
||||
}
|
||||
result.Stats["copiedFiles"]++
|
||||
return target
|
||||
}
|
||||
|
||||
func copyDir(source, target string) error {
|
||||
sourceInfo, err := os.Stat(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !sourceInfo.IsDir() {
|
||||
return errors.New(source + " is not a directory")
|
||||
}
|
||||
return filepath.WalkDir(source, func(path string, entry os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rel, err := filepath.Rel(source, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest := filepath.Join(target, rel)
|
||||
if entry.IsDir() {
|
||||
return os.MkdirAll(dest, 0o750)
|
||||
}
|
||||
return copyFile(path, dest)
|
||||
})
|
||||
}
|
||||
|
||||
func copyFile(source, target string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
in, err := os.Open(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
_, err = io.Copy(out, in)
|
||||
return err
|
||||
}
|
||||
|
||||
func safeName(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return strings.Map(func(r rune) rune {
|
||||
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' {
|
||||
return r
|
||||
}
|
||||
return '-'
|
||||
}, value)
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,979 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/auth"
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
"ymhut-box/server/unified-management/internal/feedback"
|
||||
"ymhut-box/server/unified-management/internal/health"
|
||||
"ymhut-box/server/unified-management/internal/legacy"
|
||||
"ymhut-box/server/unified-management/internal/notices"
|
||||
"ymhut-box/server/unified-management/internal/releases"
|
||||
"ymhut-box/server/unified-management/internal/sources"
|
||||
"ymhut-box/server/unified-management/internal/synclegacy"
|
||||
webassets "ymhut-box/server/unified-management/web"
|
||||
)
|
||||
|
||||
type router struct {
|
||||
cfg *config.Config
|
||||
store *db.Store
|
||||
auth *auth.Service
|
||||
feedback *feedback.Service
|
||||
releases *releases.Service
|
||||
sources *sources.Service
|
||||
legacy *legacy.Service
|
||||
notices *notices.Service
|
||||
syncer *synclegacy.Service
|
||||
}
|
||||
|
||||
func NewRouter(cfg *config.Config, store *db.Store, authService *auth.Service, feedbackService *feedback.Service, releaseService *releases.Service, sourceService *sources.Service, legacyService *legacy.Service, optional ...any) http.Handler {
|
||||
r := &router{
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
auth: authService,
|
||||
feedback: feedbackService,
|
||||
releases: releaseService,
|
||||
sources: sourceService,
|
||||
legacy: legacyService,
|
||||
}
|
||||
for _, item := range optional {
|
||||
switch typed := item.(type) {
|
||||
case *notices.Service:
|
||||
r.notices = typed
|
||||
case *synclegacy.Service:
|
||||
r.syncer = typed
|
||||
}
|
||||
}
|
||||
return withSecurity(r)
|
||||
}
|
||||
|
||||
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
switch {
|
||||
case path == "/" && req.Method == http.MethodPost:
|
||||
r.handleFeedbackSubmit(w, req)
|
||||
case path == "/" && req.URL.Query().Get("api") == "status":
|
||||
r.handleFeedbackStatus(w, req)
|
||||
case isPortalRoute(path):
|
||||
r.servePortal(w, req)
|
||||
case path == "/api/auth/bootstrap" || path == "/api/admin/auth/bootstrap":
|
||||
r.handleAuthBootstrap(w, req)
|
||||
case path == "/api/auth/captcha" || path == "/api/admin/auth/captcha":
|
||||
r.handleCaptcha(w, req)
|
||||
case path == "/api/auth/login" || path == "/api/admin/auth/login":
|
||||
r.handleLogin(w, req)
|
||||
case path == "/api/auth/logout" || path == "/api/admin/auth/logout":
|
||||
r.auth.Require(http.HandlerFunc(r.handleLogout)).ServeHTTP(w, req)
|
||||
case path == "/api/admin/auth/password":
|
||||
r.auth.Require(http.HandlerFunc(r.handleChangePassword)).ServeHTTP(w, req)
|
||||
case path == "/api/client/bootstrap":
|
||||
r.handleClientBootstrap(w, req)
|
||||
case path == "/api/client/releases" || path == "/api/releases" || path == "/api/update-info":
|
||||
writeJSON(w, http.StatusOK, r.releases.Manifest(req))
|
||||
case path == "/api/client/sources":
|
||||
r.handleClientSources(w, req)
|
||||
case path == "/api/client/endpoints":
|
||||
r.handleClientEndpoints(w, req)
|
||||
case path == "/api/client/notices" || strings.HasPrefix(path, "/api/client/notices/"):
|
||||
r.handleClientNotices(w, req)
|
||||
case path == "/api/client/endpoint-calls" || path == "/api/client/source-calls":
|
||||
r.handleSourceCall(w, req)
|
||||
case path == "/update-info.json" || path == "/update-info":
|
||||
writeJSON(w, http.StatusOK, r.releases.LegacyUpdateInfo(req))
|
||||
case path == "/tool-status.json" || path == "/tool-status":
|
||||
writeJSON(w, http.StatusOK, r.releases.StaticJSON("tool-status.json"))
|
||||
case path == "/modules.json" || path == "/modules" || path == "/api/modules":
|
||||
writeJSON(w, http.StatusOK, r.releases.StaticJSON("modules.json"))
|
||||
case path == "/media-types.json" || path == "/media-types":
|
||||
r.handleLegacyMediaTypes(w, req)
|
||||
case strings.HasPrefix(path, "/downloads/"):
|
||||
r.handleDownload(w, req)
|
||||
case strings.HasPrefix(path, "/admin/assets/"):
|
||||
serveStaticAsset(w, req, r.cfg.AdminWebDir, "admin/dist", strings.TrimPrefix(path, "/admin/"))
|
||||
case strings.HasPrefix(path, "/assets/"):
|
||||
serveStaticAsset(w, req, r.cfg.PortalWebDir, "portal/dist", strings.TrimPrefix(path, "/"))
|
||||
case strings.HasPrefix(path, "/api/admin/feedbacks"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminFeedbacks)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/dashboard"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminDashboard)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/sync"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminSync)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/releases"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminReleases)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/sources"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminSources)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/endpoints"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminEndpoints)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/legacy"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminLegacy)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/database"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminDatabase)).ServeHTTP(w, req)
|
||||
case strings.HasPrefix(path, "/api/admin/system"):
|
||||
r.auth.Require(http.HandlerFunc(r.handleAdminSystem)).ServeHTTP(w, req)
|
||||
case path == "/admin" || path == "/admin/":
|
||||
http.Redirect(w, req, "/admin/dashboard", http.StatusFound)
|
||||
case path == "/admin/login" || strings.HasPrefix(path, "/admin/"):
|
||||
r.serveAdmin(w, req)
|
||||
default:
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) handleAuthBootstrap(w http.ResponseWriter, req *http.Request) {
|
||||
payload, err := r.auth.Bootstrap(req.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "BOOTSTRAP_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, payload)
|
||||
}
|
||||
|
||||
func (r *router) handleCaptcha(w http.ResponseWriter, req *http.Request) {
|
||||
captcha, err := r.auth.NewCaptcha()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "CAPTCHA_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "captchaId": captcha.ID, "image": captcha.Image})
|
||||
}
|
||||
|
||||
func (r *router) handleLogin(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", errors.New("POST required"))
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
CaptchaID string `json:"captchaId"`
|
||||
Captcha string `json:"captcha"`
|
||||
}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
if body.Username == "" {
|
||||
body.Username = "admin"
|
||||
}
|
||||
sessionID, csrf, ok, err := r.auth.Login(req.Context(), body.Username, body.Password, body.CaptchaID, body.Captcha)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "LOGIN_FAILED", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "LOGIN_FAILED", errors.New("invalid password or captcha"))
|
||||
return
|
||||
}
|
||||
auth.SetSessionCookie(w, sessionID)
|
||||
_ = r.store.InsertAudit(db.AuditLog{Actor: body.Username, Type: "auth.login", Target: "admin", Message: "Admin login", IP: req.RemoteAddr, UserAgent: req.UserAgent()})
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "csrfToken": csrf, "user": map[string]any{"username": body.Username}})
|
||||
}
|
||||
|
||||
func (r *router) handleLogout(w http.ResponseWriter, req *http.Request) {
|
||||
r.auth.Logout(w, req)
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (r *router) handleChangePassword(w http.ResponseWriter, req *http.Request) {
|
||||
var body struct {
|
||||
CurrentPassword string `json:"currentPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
if err := r.store.ChangeAdminPassword(req.Context(), "admin", body.CurrentPassword, body.NewPassword); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "PASSWORD_CHANGE_FAILED", err)
|
||||
return
|
||||
}
|
||||
_ = r.store.InsertAudit(db.AuditLog{Actor: "admin", Type: "auth.password_changed", Target: "admin", Message: "Admin password changed", IP: req.RemoteAddr, UserAgent: req.UserAgent()})
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (r *router) handleClientBootstrap(w http.ResponseWriter, req *http.Request) {
|
||||
release := r.releases.Manifest(req)
|
||||
sourceCatalog, _ := r.sources.Catalog(false)
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"ok": true,
|
||||
"serviceVersion": config.Version,
|
||||
"baseUrl": requestBaseURL(req, r.cfg.BaseURL),
|
||||
"capabilities": map[string]bool{
|
||||
"dynamicSources": true,
|
||||
"sourceHealth": true,
|
||||
"feedbackStatus": true,
|
||||
"releaseManifest": true,
|
||||
"endpointCalls": true,
|
||||
"legacyJson": true,
|
||||
},
|
||||
"endpoints": map[string]string{
|
||||
"releases": "/api/client/releases",
|
||||
"sources": "/api/client/sources",
|
||||
"clientEndpoints": "/api/client/endpoints",
|
||||
"endpointCalls": "/api/client/endpoint-calls",
|
||||
"notices": "/api/client/notices",
|
||||
"feedback": "/",
|
||||
},
|
||||
"cache": map[string]int{
|
||||
"bootstrapSeconds": 300,
|
||||
"releasesSeconds": 300,
|
||||
"sourcesSeconds": 600,
|
||||
"healthSeconds": 300,
|
||||
},
|
||||
"legacyRoutes": []string{"/update-info.json", "/update-info", "/api/update-info", "/api/releases", "/tool-status.json", "/media-types.json", "/modules.json", "/downloads/:filename"},
|
||||
"release": release,
|
||||
"sources": sourceCatalog,
|
||||
"feedback": map[string]any{"submit": "/", "status": "/?api=status&code=:code"},
|
||||
"health": health.Snapshot(r.cfg, r.store),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *router) handleClientSources(w http.ResponseWriter, req *http.Request) {
|
||||
catalog, err := r.sources.Catalog(false)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "SOURCES_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, catalog)
|
||||
}
|
||||
|
||||
func (r *router) handleClientEndpoints(w http.ResponseWriter, req *http.Request) {
|
||||
items, err := r.sources.Endpoints(false)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "ENDPOINTS_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": items})
|
||||
}
|
||||
|
||||
func (r *router) handleClientNotices(w http.ResponseWriter, req *http.Request) {
|
||||
if r.notices == nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": []any{}})
|
||||
return
|
||||
}
|
||||
path := cleanPath(req.URL.Path)
|
||||
if path == "/api/client/notices" {
|
||||
items, err := r.notices.List(100)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "NOTICES_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": notices.PublicList(items)})
|
||||
return
|
||||
}
|
||||
version := strings.TrimPrefix(path, "/api/client/notices/")
|
||||
if version == "" {
|
||||
http.NotFound(w, req)
|
||||
return
|
||||
}
|
||||
doc, err := r.notices.Get(version)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOTICE_NOT_FOUND", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "notice": notices.PublicNotice(doc.Notice), "raw": doc.Parsed})
|
||||
}
|
||||
|
||||
func (r *router) handleLegacyMediaTypes(w http.ResponseWriter, req *http.Request) {
|
||||
catalog, err := r.sources.Catalog(false)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "MEDIA_TYPES_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, catalog)
|
||||
}
|
||||
|
||||
func (r *router) handleSourceCall(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", errors.New("POST required"))
|
||||
return
|
||||
}
|
||||
var body db.SourceCall
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
body.Client = firstNonEmpty(body.Client, req.UserAgent())
|
||||
if err := r.store.RecordSourceCall(body); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "SOURCE_CALL_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (r *router) handleFeedbackSubmit(w http.ResponseWriter, req *http.Request) {
|
||||
item, err := r.feedback.Submit(req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "FEEDBACK_FAILED", err)
|
||||
return
|
||||
}
|
||||
_ = r.store.InsertAudit(db.AuditLog{Actor: "client", Type: "feedback.created", Target: item.Code, Message: item.Title, IP: req.RemoteAddr, UserAgent: req.UserAgent()})
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "code": item.Code})
|
||||
}
|
||||
|
||||
func (r *router) handleFeedbackStatus(w http.ResponseWriter, req *http.Request) {
|
||||
code := strings.TrimSpace(req.URL.Query().Get("code"))
|
||||
if code == "" {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_CODE", errors.New("code is required"))
|
||||
return
|
||||
}
|
||||
item, err := r.store.GetFeedback(code)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOT_FOUND", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "feedback": item})
|
||||
}
|
||||
|
||||
func (r *router) handleAdminFeedbacks(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
if req.Method == http.MethodGet && path == "/api/admin/feedbacks" {
|
||||
if req.URL.Query().Get("page") != "" {
|
||||
page, _ := strconv.Atoi(req.URL.Query().Get("page"))
|
||||
perPage, _ := strconv.Atoi(req.URL.Query().Get("perPage"))
|
||||
items, total, err := r.store.ListFeedbacksFiltered(page, perPage, db.FeedbackFilters{
|
||||
Status: req.URL.Query().Get("status"),
|
||||
Category: req.URL.Query().Get("category"),
|
||||
Priority: req.URL.Query().Get("priority"),
|
||||
Query: req.URL.Query().Get("q"),
|
||||
Assignee: req.URL.Query().Get("assignee"),
|
||||
Sort: req.URL.Query().Get("sort"),
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "FEEDBACK_LIST_FAILED", err)
|
||||
return
|
||||
}
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if perPage <= 0 {
|
||||
perPage = 20
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "page": map[string]any{"items": items, "total": total, "page": page, "perPage": perPage}})
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(req.URL.Query().Get("limit"))
|
||||
items, err := r.store.ListFeedbacks(limit)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "FEEDBACK_LIST_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": items})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodGet && path == "/api/admin/feedbacks/export" {
|
||||
items, _, err := r.store.ListFeedbacksFiltered(1, 100, db.FeedbackFilters{
|
||||
Status: req.URL.Query().Get("status"),
|
||||
Category: req.URL.Query().Get("category"),
|
||||
Priority: req.URL.Query().Get("priority"),
|
||||
Query: req.URL.Query().Get("q"),
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "EXPORT_FAILED", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="feedbacks.csv"`)
|
||||
writer := csv.NewWriter(w)
|
||||
_ = writer.Write([]string{"code", "created_at", "title", "status", "category", "priority", "contact", "status_detail", "public_reply"})
|
||||
for _, item := range items {
|
||||
_ = writer.Write([]string{item.Code, item.CreatedAt, item.Title, item.Status, item.Category, item.Priority, item.Contact, item.StatusDetail, item.PublicReply})
|
||||
}
|
||||
writer.Flush()
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodGet && strings.HasPrefix(path, "/api/admin/feedbacks/") {
|
||||
code := strings.TrimPrefix(path, "/api/admin/feedbacks/")
|
||||
detail, err := r.store.GetFeedbackDetail(code)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOT_FOUND", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "feedback": detail})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPatch && path == "/api/admin/feedbacks/bulk" {
|
||||
var body struct {
|
||||
Codes []string `json:"codes"`
|
||||
Status string `json:"status"`
|
||||
StatusDetail string `json:"statusDetail"`
|
||||
PublicReply string `json:"publicReply"`
|
||||
Assignee string `json:"assignee"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil || len(body.Codes) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", errors.New("codes are required"))
|
||||
return
|
||||
}
|
||||
if err := r.store.BulkUpdateFeedback(body.Codes, db.FeedbackUpdate{Status: body.Status, StatusDetail: body.StatusDetail, PublicReply: body.PublicReply, Assignee: body.Assignee, Actor: "admin", Tags: body.Tags}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "BULK_UPDATE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "updated": len(body.Codes)})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPost && strings.HasPrefix(path, "/api/admin/feedbacks/") && strings.HasSuffix(path, "/comments") {
|
||||
code := strings.TrimSuffix(strings.TrimPrefix(path, "/api/admin/feedbacks/"), "/comments")
|
||||
var body struct {
|
||||
Author string `json:"author"`
|
||||
Body string `json:"body"`
|
||||
Internal bool `json:"internal"`
|
||||
}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
comment, err := r.store.InsertFeedbackComment(db.FeedbackComment{Code: code, Author: firstNonEmpty(body.Author, "admin"), Body: body.Body, Internal: body.Internal})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "COMMENT_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "comment": comment})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPatch && strings.HasPrefix(path, "/api/admin/feedbacks/") {
|
||||
code := strings.TrimPrefix(path, "/api/admin/feedbacks/")
|
||||
var body struct {
|
||||
Status string `json:"status"`
|
||||
StatusDetail string `json:"statusDetail"`
|
||||
PublicReply string `json:"publicReply"`
|
||||
}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
if err := r.store.UpdateFeedbackTicket(code, db.FeedbackUpdate{Status: firstNonEmpty(body.Status, "new"), StatusDetail: body.StatusDetail, PublicReply: body.PublicReply, Actor: "admin"}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "FEEDBACK_UPDATE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
|
||||
func (r *router) handleAdminLegacy(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
name := ""
|
||||
switch {
|
||||
case strings.HasPrefix(path, "/api/admin/legacy/update-info"):
|
||||
name = "update-info"
|
||||
case strings.HasPrefix(path, "/api/admin/legacy/media-types"):
|
||||
name = "media-types"
|
||||
default:
|
||||
parts := strings.Split(strings.TrimPrefix(path, "/api/admin/legacy/"), "/")
|
||||
if len(parts) > 0 {
|
||||
name = parts[0]
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
http.NotFound(w, req)
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodGet && (path == "/api/admin/legacy/update-info" || path == "/api/admin/legacy/media-types") {
|
||||
doc, err := r.legacy.Get(req.Context(), name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "LEGACY_GET_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPut && (path == "/api/admin/legacy/update-info" || path == "/api/admin/legacy/media-types") {
|
||||
var body legacy.SaveRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
doc, err := r.legacy.Save(req.Context(), name, body, "admin")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "LEGACY_SAVE_FAILED", err)
|
||||
return
|
||||
}
|
||||
if name == "media-types" {
|
||||
_ = r.sources.ImportLegacyMediaTypes(req.Context())
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPost && strings.HasSuffix(path, "/validate") {
|
||||
var body legacy.SaveRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
doc, err := r.legacy.Validate(req.Context(), name, body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "LEGACY_VALIDATE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPost && strings.HasSuffix(path, "/restore") {
|
||||
var body struct {
|
||||
RevisionID int64 `json:"revisionId"`
|
||||
}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil || body.RevisionID <= 0 {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", errors.New("revisionId is required"))
|
||||
return
|
||||
}
|
||||
doc, err := r.legacy.Restore(req.Context(), name, body.RevisionID, "admin")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "LEGACY_RESTORE_FAILED", err)
|
||||
return
|
||||
}
|
||||
if name == "media-types" {
|
||||
_ = r.sources.ImportLegacyMediaTypes(req.Context())
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
|
||||
func (r *router) handleAdminDatabase(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
switch {
|
||||
case req.Method == http.MethodGet && path == "/api/admin/database/status":
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "database": r.store.Status()})
|
||||
case req.Method == http.MethodPost && path == "/api/admin/database/test":
|
||||
var body config.DatabaseConfig
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
if body.Provider == "" {
|
||||
body.Provider = r.cfg.Database.Provider
|
||||
}
|
||||
if body.SQLitePath == "" {
|
||||
body.SQLitePath = r.cfg.Database.SQLitePath
|
||||
}
|
||||
if body.MySQLDSN == "" {
|
||||
body.MySQLDSN = r.cfg.Database.MySQLDSN
|
||||
}
|
||||
if err := db.TestDatabase(body); err != nil {
|
||||
writeError(w, http.StatusBadGateway, "DATABASE_TEST_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
case req.Method == http.MethodPost && path == "/api/admin/database/import-sqlite":
|
||||
result, err := r.store.ImportSQLiteToRemote()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadGateway, "DATABASE_IMPORT_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "result": result})
|
||||
case req.Method == http.MethodPost && path == "/api/admin/database/sync":
|
||||
result, err := r.store.SyncNow()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadGateway, "DATABASE_SYNC_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "result": result})
|
||||
default:
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) handleAdminDashboard(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
if req.Method != http.MethodGet || path != "/api/admin/dashboard/overview" {
|
||||
http.NotFound(w, req)
|
||||
return
|
||||
}
|
||||
overview, err := r.store.DashboardOverview(80)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "DASHBOARD_FAILED", err)
|
||||
return
|
||||
}
|
||||
overview["health"] = health.Snapshot(r.cfg, r.store)
|
||||
writeJSON(w, http.StatusOK, overview)
|
||||
}
|
||||
|
||||
func (r *router) handleAdminSync(w http.ResponseWriter, req *http.Request) {
|
||||
if r.syncer == nil {
|
||||
writeError(w, http.StatusNotFound, "SYNC_DISABLED", errors.New("legacy sync service is not configured"))
|
||||
return
|
||||
}
|
||||
path := cleanPath(req.URL.Path)
|
||||
switch {
|
||||
case req.Method == http.MethodGet && path == "/api/admin/sync/legacy/preview":
|
||||
writeJSON(w, http.StatusOK, r.syncer.Preview(req.Context()))
|
||||
case req.Method == http.MethodPost && path == "/api/admin/sync/legacy/run":
|
||||
writeJSON(w, http.StatusOK, r.syncer.Run(req.Context()))
|
||||
default:
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) handleAdminEndpoints(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodGet {
|
||||
http.NotFound(w, req)
|
||||
return
|
||||
}
|
||||
items, err := r.sources.Endpoints(true)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "ENDPOINTS_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": items})
|
||||
}
|
||||
|
||||
func (r *router) handleAdminReleases(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
if strings.HasPrefix(path, "/api/admin/releases/notices") {
|
||||
r.handleAdminReleaseNotices(w, req)
|
||||
return
|
||||
}
|
||||
switch path {
|
||||
case "/api/admin/releases":
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "manifest": r.releases.Manifest(req)})
|
||||
case "/api/admin/releases/legacy-preview":
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "updateInfo": r.releases.LegacyUpdateInfo(req), "toolStatus": r.releases.StaticJSON("tool-status.json")})
|
||||
default:
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) handleAdminReleaseNotices(w http.ResponseWriter, req *http.Request) {
|
||||
if r.notices == nil {
|
||||
writeError(w, http.StatusNotFound, "NOTICES_DISABLED", errors.New("release notices are not configured"))
|
||||
return
|
||||
}
|
||||
path := cleanPath(req.URL.Path)
|
||||
if req.Method == http.MethodPost && path == "/api/admin/releases/notices/import" {
|
||||
if err := r.notices.Import(req.Context()); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "NOTICE_IMPORT_FAILED", err)
|
||||
return
|
||||
}
|
||||
items, _ := r.notices.List(100)
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": items})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodGet && path == "/api/admin/releases/notices" {
|
||||
items, err := r.notices.List(100)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "NOTICE_LIST_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": items})
|
||||
return
|
||||
}
|
||||
rest := strings.TrimPrefix(path, "/api/admin/releases/notices/")
|
||||
if rest == "" || rest == path {
|
||||
http.NotFound(w, req)
|
||||
return
|
||||
}
|
||||
parts := strings.Split(rest, "/")
|
||||
version := parts[0]
|
||||
if req.Method == http.MethodGet && len(parts) == 1 {
|
||||
doc, err := r.notices.Get(version)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "NOTICE_NOT_FOUND", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPut && len(parts) == 1 {
|
||||
var body notices.SaveRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
doc, err := r.notices.Save(req.Context(), version, body, "admin")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "NOTICE_SAVE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPost && len(parts) == 2 && parts[1] == "validate" {
|
||||
var body notices.SaveRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
doc, err := r.notices.Validate(req.Context(), version, body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "NOTICE_VALIDATE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPost && len(parts) == 2 && parts[1] == "restore" {
|
||||
var body struct {
|
||||
RevisionID int64 `json:"revisionId"`
|
||||
}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil || body.RevisionID <= 0 {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", errors.New("revisionId is required"))
|
||||
return
|
||||
}
|
||||
doc, err := r.notices.Restore(req.Context(), version, body.RevisionID, "admin")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "NOTICE_RESTORE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "document": doc})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
|
||||
func (r *router) handleAdminSources(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
switch {
|
||||
case req.Method == http.MethodGet && path == "/api/admin/sources":
|
||||
catalog, err := r.sources.Catalog(true)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "SOURCES_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "catalog": catalog})
|
||||
case req.Method == http.MethodPost && path == "/api/admin/sources/import-media-types":
|
||||
if err := r.sources.ImportLegacyMediaTypes(req.Context()); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "IMPORT_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
case req.Method == http.MethodPost && path == "/api/admin/sources/check":
|
||||
go r.sources.CheckDue(req.Context())
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "queued": true})
|
||||
case req.Method == http.MethodPost && strings.HasPrefix(path, "/api/admin/sources/") && strings.HasSuffix(path, "/check"):
|
||||
sourceID := strings.TrimSuffix(strings.TrimPrefix(path, "/api/admin/sources/"), "/check")
|
||||
item, err := r.sources.CheckSourceID(req.Context(), sourceID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "CHECK_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "source": item})
|
||||
case (req.Method == http.MethodPost || req.Method == http.MethodPut) && path == "/api/admin/sources":
|
||||
var item db.Source
|
||||
if err := json.NewDecoder(req.Body).Decode(&item); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
saved, err := r.store.UpsertSource(item)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "SOURCE_SAVE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "source": saved})
|
||||
case req.Method == http.MethodDelete && strings.HasPrefix(path, "/api/admin/sources/"):
|
||||
sourceID := strings.TrimPrefix(path, "/api/admin/sources/")
|
||||
if err := r.store.DeleteSource(sourceID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "SOURCE_DELETE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
default:
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) handleAdminSystem(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
switch path {
|
||||
case "/api/admin/system/health":
|
||||
writeJSON(w, http.StatusOK, health.Snapshot(r.cfg, r.store))
|
||||
case "/api/admin/system/audit":
|
||||
items, err := r.store.ListAuditLogs(100)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "AUDIT_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "items": items})
|
||||
case "/api/admin/system/database/sync":
|
||||
if req.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", errors.New("POST required"))
|
||||
return
|
||||
}
|
||||
finishedAt, err := r.store.CopySQLiteToRemote()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "SYNC_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "finishedAt": finishedAt})
|
||||
default:
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) handleDownload(w http.ResponseWriter, req *http.Request) {
|
||||
name := strings.TrimPrefix(cleanPath(req.URL.Path), "/downloads/")
|
||||
if name == "" || strings.Contains(name, "..") || strings.ContainsAny(name, `/\`) {
|
||||
writeError(w, http.StatusForbidden, "FORBIDDEN", errors.New("invalid filename"))
|
||||
return
|
||||
}
|
||||
path := filepath.Join(r.cfg.DownloadsDir, name)
|
||||
resolved, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "PATH_FAILED", err)
|
||||
return
|
||||
}
|
||||
base, _ := filepath.Abs(r.cfg.DownloadsDir)
|
||||
if !strings.HasPrefix(resolved, base) {
|
||||
writeError(w, http.StatusForbidden, "FORBIDDEN", errors.New("path escape rejected"))
|
||||
return
|
||||
}
|
||||
http.ServeFile(w, req, resolved)
|
||||
}
|
||||
|
||||
func serveStaticAsset(w http.ResponseWriter, req *http.Request, root, embedRoot, assetPath string) {
|
||||
if strings.Contains(assetPath, "..") || strings.ContainsAny(assetPath, `\`) {
|
||||
writeError(w, http.StatusForbidden, "FORBIDDEN", errors.New("invalid asset path"))
|
||||
return
|
||||
}
|
||||
if tryServeDiskFile(w, req, root, assetPath) {
|
||||
return
|
||||
}
|
||||
if serveEmbeddedFile(w, req, embedRoot+"/"+filepath.ToSlash(assetPath)) {
|
||||
return
|
||||
}
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
|
||||
func tryServeDiskFile(w http.ResponseWriter, req *http.Request, root, assetPath string) bool {
|
||||
path := filepath.Join(root, filepath.FromSlash(assetPath))
|
||||
resolved, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "PATH_FAILED", err)
|
||||
return true
|
||||
}
|
||||
base, _ := filepath.Abs(root)
|
||||
if resolved != base && !strings.HasPrefix(resolved, base+string(os.PathSeparator)) {
|
||||
writeError(w, http.StatusForbidden, "FORBIDDEN", errors.New("path escape rejected"))
|
||||
return true
|
||||
}
|
||||
info, err := os.Stat(resolved)
|
||||
if err != nil || info.IsDir() {
|
||||
return false
|
||||
}
|
||||
http.ServeFile(w, req, resolved)
|
||||
return true
|
||||
}
|
||||
|
||||
func serveEmbeddedFile(w http.ResponseWriter, req *http.Request, name string) bool {
|
||||
if strings.Contains(name, "..") || strings.ContainsAny(name, `\`) {
|
||||
writeError(w, http.StatusForbidden, "FORBIDDEN", errors.New("invalid embedded asset path"))
|
||||
return true
|
||||
}
|
||||
data, err := webassets.ReadFile(name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if contentType := mime.TypeByExtension(filepath.Ext(name)); contentType != "" {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
}
|
||||
http.ServeContent(w, req, filepath.Base(name), time.Time{}, bytes.NewReader(data))
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *router) servePortal(w http.ResponseWriter, req *http.Request) {
|
||||
index := filepath.Join(r.cfg.PortalWebDir, "index.html")
|
||||
if _, err := os.Stat(index); err == nil {
|
||||
http.ServeFile(w, req, index)
|
||||
return
|
||||
}
|
||||
if serveEmbeddedFile(w, req, "portal/dist/index.html") {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, _ = w.Write([]byte(`<!doctype html><html><head><meta charset="utf-8"><title>YMhut Box</title></head><body><main><h1>YMhut Box</h1><p>Unified management service is running.</p><p><a href="/api/client/bootstrap">Client bootstrap</a> | <a href="/admin/login">Admin</a></p></main></body></html>`))
|
||||
}
|
||||
|
||||
func (r *router) serveAdmin(w http.ResponseWriter, req *http.Request) {
|
||||
index := filepath.Join(r.cfg.AdminWebDir, "index.html")
|
||||
if _, err := os.Stat(index); err == nil {
|
||||
http.ServeFile(w, req, index)
|
||||
return
|
||||
}
|
||||
if serveEmbeddedFile(w, req, "admin/dist/index.html") {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, _ = w.Write([]byte(`<!doctype html><html><head><meta charset="utf-8"><title>YMhut Admin</title></head><body><main><h1>YMhut Admin</h1><p>Build web/admin to enable the Vue console.</p></main></body></html>`))
|
||||
}
|
||||
|
||||
func isPortalRoute(path string) bool {
|
||||
switch path {
|
||||
case "/", "/releases", "/sources", "/feedback", "/compatibility":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func withSecurity(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("Referrer-Policy", "same-origin")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, code string, err error) {
|
||||
message := ""
|
||||
if err != nil {
|
||||
message = err.Error()
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"ok": false, "error": code, "message": message})
|
||||
}
|
||||
|
||||
func cleanPath(path string) string {
|
||||
if path == "" {
|
||||
return "/"
|
||||
}
|
||||
if path != "/" {
|
||||
path = strings.TrimRight(path, "/")
|
||||
}
|
||||
if path == "" {
|
||||
return "/"
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func requestBaseURL(r *http.Request, fallback string) string {
|
||||
scheme := r.Header.Get("X-Forwarded-Proto")
|
||||
if scheme == "" {
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
if r.Host != "" {
|
||||
return scheme + "://" + r.Host
|
||||
}
|
||||
return strings.TrimRight(fallback, "/")
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/auth"
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
"ymhut-box/server/unified-management/internal/feedback"
|
||||
"ymhut-box/server/unified-management/internal/legacy"
|
||||
"ymhut-box/server/unified-management/internal/notices"
|
||||
"ymhut-box/server/unified-management/internal/releases"
|
||||
"ymhut-box/server/unified-management/internal/sources"
|
||||
)
|
||||
|
||||
func TestCompatibilityRoutes(t *testing.T) {
|
||||
handler, cleanup := testRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
for _, path := range []string{"/api/client/bootstrap", "/update-info.json", "/media-types.json", "/modules.json"} {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusOK {
|
||||
t.Fatalf("%s returned %d: %s", path, res.Code, res.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(res.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("%s did not return JSON: %v", path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientBootstrapAndEndpointsShape(t *testing.T) {
|
||||
handler, cleanup := testRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
for _, path := range []string{"/api/client/bootstrap", "/api/client/endpoints", "/api/client/sources", "/api/client/notices"} {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusOK {
|
||||
t.Fatalf("%s returned %d: %s", path, res.Code, res.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(res.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if path == "/api/client/sources" {
|
||||
if payload["categories"] == nil {
|
||||
t.Fatalf("%s missing categories: %#v", path, payload)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if payload["ok"] != true {
|
||||
t.Fatalf("%s missing ok=true: %#v", path, payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltFrontendAssetsAreServed(t *testing.T) {
|
||||
handler, cleanup := testRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
for _, item := range []struct {
|
||||
Path string
|
||||
ContentTypes []string
|
||||
}{
|
||||
{Path: "/assets/portal.css", ContentTypes: []string{"text/css"}},
|
||||
{Path: "/assets/portal.js", ContentTypes: []string{"text/javascript", "application/javascript"}},
|
||||
{Path: "/admin/assets/admin.css", ContentTypes: []string{"text/css"}},
|
||||
{Path: "/admin/assets/admin.js", ContentTypes: []string{"text/javascript", "application/javascript"}},
|
||||
} {
|
||||
req := httptest.NewRequest(http.MethodGet, item.Path, nil)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusOK {
|
||||
t.Fatalf("%s returned %d: %s", item.Path, res.Code, res.Body.String())
|
||||
}
|
||||
if got := res.Header().Get("Content-Type"); !containsAny(got, item.ContentTypes) {
|
||||
t.Fatalf("%s content type = %q, want one of %v", item.Path, got, item.ContentTypes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func containsAny(value string, needles []string) bool {
|
||||
for _, needle := range needles {
|
||||
if strings.Contains(value, needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestReleaseNoticesRoutes(t *testing.T) {
|
||||
handler, cleanup := testRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
for _, path := range []string{"/api/client/notices", "/api/client/notices/2.0.0"} {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusOK {
|
||||
t.Fatalf("%s returned %d: %s", path, res.Code, res.Body.String())
|
||||
}
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/client/releases", nil)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(res.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if payload["notices"] == nil {
|
||||
t.Fatalf("release manifest missing notices: %#v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminLegacyRequiresAuth(t *testing.T) {
|
||||
handler, cleanup := testRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/admin/legacy/media-types", bytes.NewBufferString(`{"raw":"{}"}`))
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected unauthorized, got %d", res.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func testRouter(t *testing.T) (http.Handler, func()) {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
public := filepath.Join(root, "public")
|
||||
noticeDir := filepath.Join(root, "update-notice")
|
||||
if err := os.MkdirAll(filepath.Join(public, "downloads"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
adminDist := filepath.Join(root, "admin")
|
||||
portalDist := filepath.Join(root, "portal")
|
||||
for _, dir := range []string{filepath.Join(adminDist, "assets"), filepath.Join(portalDist, "assets")} {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(portalDist, "index.html"), []byte(`<!doctype html><link rel="stylesheet" href="/assets/portal.css"><script type="module" src="/assets/portal.js"></script>`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(portalDist, "assets", "portal.css"), []byte(`body{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(portalDist, "assets", "portal.js"), []byte(`console.log("portal")`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(adminDist, "index.html"), []byte(`<!doctype html><link rel="stylesheet" href="/admin/assets/admin.css"><script type="module" src="/admin/assets/admin.js"></script>`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(adminDist, "assets", "admin.css"), []byte(`body{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(adminDist, "assets", "admin.js"), []byte(`console.log("admin")`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(noticeDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mustWriteJSON(t, filepath.Join(public, "update-info.json"), map[string]any{"app_version": "0.0.1"})
|
||||
mustWriteJSON(t, filepath.Join(public, "tool-status.json"), map[string]any{"ok": true})
|
||||
mustWriteJSON(t, filepath.Join(public, "modules.json"), map[string]any{"modules": []any{}})
|
||||
mustWriteJSON(t, filepath.Join(public, "media-types.json"), map[string]any{
|
||||
"categories": []map[string]any{{
|
||||
"id": "image", "name": "image",
|
||||
"subcategories": []map[string]any{{"id": "demo", "name": "demo", "api_url": "https://example.com/demo"}},
|
||||
}},
|
||||
})
|
||||
mustWriteJSON(t, filepath.Join(noticeDir, "total.json"), map[string]any{
|
||||
"schema_version": 1,
|
||||
"latest_version": "2.0.0",
|
||||
"latest_notice_file": "2.0.0.json",
|
||||
"latest": map[string]any{"version": "2.0.0", "title": "YMhut Box 2.0.0", "release_notes": "Initial release"},
|
||||
"versions": []map[string]any{{"version": "2.0.0", "notice_file": "2.0.0.json", "summary": "Initial release"}},
|
||||
})
|
||||
mustWriteJSON(t, filepath.Join(noticeDir, "2.0.0.json"), map[string]any{"app_version": "2.0.0", "title": "YMhut Box 2.0.0", "release_notes": "Initial release", "release_notes_md": "## Initial"})
|
||||
cfg := &config.Config{
|
||||
Listen: ":0",
|
||||
BaseURL: "https://update.ymhut.cn",
|
||||
StorageDir: filepath.Join(root, "storage"),
|
||||
UpdatePublicDir: public,
|
||||
UpdateNoticeDir: noticeDir,
|
||||
DownloadsDir: filepath.Join(public, "downloads"),
|
||||
AdminWebDir: adminDist,
|
||||
PortalWebDir: portalDist,
|
||||
SourceCheckSeconds: 3600,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: filepath.Join(root, "storage", "unified.sqlite"),
|
||||
FailoverEnabled: true,
|
||||
HotSyncEnabled: true,
|
||||
HealthIntervalSec: 3600,
|
||||
},
|
||||
}
|
||||
store, err := db.Open(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := store.EnsureDefaultAdmin(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sourceService := sources.NewService(cfg, store)
|
||||
if err := sourceService.ImportLegacyMediaTypesIfEmpty(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
noticeService := notices.NewService(cfg, store)
|
||||
if err := noticeService.Import(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
handler := NewRouter(
|
||||
cfg,
|
||||
store,
|
||||
auth.NewService(store),
|
||||
feedback.NewService(cfg, store),
|
||||
releases.NewService(cfg, store, noticeService),
|
||||
sourceService,
|
||||
legacy.NewService(cfg, store),
|
||||
noticeService,
|
||||
)
|
||||
return handler, func() { _ = store.Close() }
|
||||
}
|
||||
|
||||
func mustWriteJSON(t *testing.T, path string, payload any) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
"ymhut-box/server/unified-management/internal/db"
|
||||
)
|
||||
|
||||
type setupRouter struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type setupRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
SQLitePath string `json:"sqlitePath"`
|
||||
MySQLDSN string `json:"mysqlDsn"`
|
||||
MySQL setupMySQLConfig `json:"mysql"`
|
||||
}
|
||||
|
||||
type setupMySQLConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Database string `json:"database"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Charset string `json:"charset"`
|
||||
ParseTime bool `json:"parseTime"`
|
||||
TLS string `json:"tls"`
|
||||
}
|
||||
|
||||
func NewSetupRouter(cfg *config.Config) http.Handler {
|
||||
return withSecurity(&setupRouter{cfg: cfg})
|
||||
}
|
||||
|
||||
func (r *setupRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
path := cleanPath(req.URL.Path)
|
||||
switch {
|
||||
case path == "/" || path == "/setup":
|
||||
r.serveSetup(w, req)
|
||||
case strings.HasPrefix(path, "/setup/assets/"):
|
||||
serveStaticAsset(w, req, r.cfg.SetupWebDir, "setup/dist", strings.TrimPrefix(path, "/setup/"))
|
||||
case path == "/api/setup/status":
|
||||
writeJSON(w, http.StatusOK, r.status())
|
||||
case path == "/api/setup/database/test":
|
||||
r.handleDatabaseTest(w, req)
|
||||
case path == "/api/setup/complete":
|
||||
r.handleComplete(w, req)
|
||||
default:
|
||||
if strings.HasPrefix(path, "/api/") {
|
||||
writeError(w, http.StatusServiceUnavailable, "SETUP_REQUIRED", errors.New("system setup is required"))
|
||||
return
|
||||
}
|
||||
http.Redirect(w, req, "/setup", http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *setupRouter) status() map[string]any {
|
||||
return map[string]any{
|
||||
"ok": true,
|
||||
"initialized": r.cfg.Initialized,
|
||||
"baseDir": r.cfg.BaseDir,
|
||||
"configPath": r.cfg.ConfigPath,
|
||||
"defaults": map[string]any{
|
||||
"provider": firstNonEmpty(r.cfg.Database.Provider, "sqlite"),
|
||||
"sqlitePath": relativeToBase(r.cfg.BaseDir, r.cfg.Database.SQLitePath),
|
||||
"mysqlDsn": maskDSN(r.cfg.Database.MySQLDSN),
|
||||
"baseUrl": r.cfg.BaseURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *setupRouter) handleDatabaseTest(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", errors.New("POST required"))
|
||||
return
|
||||
}
|
||||
next, body, err := r.decodeSetupDatabase(req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
started := time.Now()
|
||||
if err := db.TestDatabase(next); err != nil {
|
||||
writeError(w, http.StatusBadGateway, "DATABASE_TEST_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"ok": true,
|
||||
"provider": next.Provider,
|
||||
"latencyMs": time.Since(started).Milliseconds(),
|
||||
"maskedDsn": maskedDatabaseTarget(r.cfg.BaseDir, next),
|
||||
"normalized": map[string]any{
|
||||
"provider": next.Provider,
|
||||
"baseUrl": firstNonEmpty(body.BaseURL, r.cfg.BaseURL),
|
||||
"sqlitePath": relativeToBase(r.cfg.BaseDir, next.SQLitePath),
|
||||
"mysqlDsn": maskDSN(next.MySQLDSN),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *setupRouter) handleComplete(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", errors.New("POST required"))
|
||||
return
|
||||
}
|
||||
nextDB, body, err := r.decodeSetupDatabase(req)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_PAYLOAD", err)
|
||||
return
|
||||
}
|
||||
if err := db.TestDatabase(nextDB); err != nil {
|
||||
writeError(w, http.StatusBadGateway, "DATABASE_TEST_FAILED", err)
|
||||
return
|
||||
}
|
||||
next := *r.cfg
|
||||
next.Initialized = true
|
||||
next.BaseURL = firstNonEmpty(strings.TrimSpace(body.BaseURL), next.BaseURL)
|
||||
next.Database = nextDB
|
||||
if strings.EqualFold(next.Database.Provider, "mysql") {
|
||||
next.Database.FailoverEnabled = true
|
||||
next.Database.HotSyncEnabled = true
|
||||
}
|
||||
store, err := db.Open(&next)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "DATABASE_OPEN_FAILED", err)
|
||||
return
|
||||
}
|
||||
if err := store.EnsureDefaultAdmin(req.Context()); err != nil {
|
||||
_ = store.Close()
|
||||
writeError(w, http.StatusInternalServerError, "ADMIN_INIT_FAILED", err)
|
||||
return
|
||||
}
|
||||
_ = store.Close()
|
||||
if err := config.Save(&next); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "CONFIG_SAVE_FAILED", err)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"ok": true, "initialized": true, "message": "Setup completed. Restart the service, then open /admin/login."})
|
||||
}
|
||||
|
||||
func (r *setupRouter) decodeSetupDatabase(req *http.Request) (config.DatabaseConfig, setupRequest, error) {
|
||||
var body setupRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
return config.DatabaseConfig{}, body, err
|
||||
}
|
||||
next := r.cfg.Database
|
||||
next.Provider = strings.ToLower(strings.TrimSpace(firstNonEmpty(body.Provider, next.Provider, "sqlite")))
|
||||
if body.SQLitePath != "" {
|
||||
next.SQLitePath = body.SQLitePath
|
||||
}
|
||||
if next.SQLitePath != "" && !filepath.IsAbs(next.SQLitePath) && !strings.HasPrefix(strings.ToLower(next.SQLitePath), "file:") {
|
||||
next.SQLitePath = filepath.Join(r.cfg.BaseDir, next.SQLitePath)
|
||||
}
|
||||
if next.Provider == "sqlite" {
|
||||
next.MySQLDSN = ""
|
||||
} else if body.MySQLDSN != "" {
|
||||
next.MySQLDSN = body.MySQLDSN
|
||||
} else if body.MySQL.Host != "" || body.MySQL.Database != "" || body.MySQL.Username != "" {
|
||||
dsn, err := buildMySQLDSN(body.MySQL)
|
||||
if err != nil {
|
||||
return config.DatabaseConfig{}, body, err
|
||||
}
|
||||
next.MySQLDSN = dsn
|
||||
}
|
||||
if next.Provider != "sqlite" && next.Provider != "mysql" {
|
||||
return config.DatabaseConfig{}, body, errors.New("provider must be sqlite or mysql")
|
||||
}
|
||||
if next.Provider == "mysql" && strings.TrimSpace(next.MySQLDSN) == "" {
|
||||
return config.DatabaseConfig{}, body, errors.New("mysql connection is required")
|
||||
}
|
||||
if next.MaxOpenConns <= 0 {
|
||||
next.MaxOpenConns = 10
|
||||
}
|
||||
if next.MaxIdleConns <= 0 {
|
||||
next.MaxIdleConns = 4
|
||||
}
|
||||
if next.ConnMaxLifetimeSeconds <= 0 {
|
||||
next.ConnMaxLifetimeSeconds = 300
|
||||
}
|
||||
if next.HealthIntervalSec <= 0 {
|
||||
next.HealthIntervalSec = 30
|
||||
}
|
||||
return next, body, nil
|
||||
}
|
||||
|
||||
func (r *setupRouter) serveSetup(w http.ResponseWriter, req *http.Request) {
|
||||
index := filepath.Join(r.cfg.SetupWebDir, "index.html")
|
||||
if tryServeDiskFile(w, req, r.cfg.SetupWebDir, "index.html") {
|
||||
return
|
||||
}
|
||||
if serveEmbeddedFile(w, req, "setup/dist/index.html") {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, _ = w.Write([]byte(`<!doctype html><html lang="zh-CN"><head><meta charset="utf-8"><title>YMhut Setup</title></head><body><main><h1>YMhut Setup</h1><p>Setup frontend is not built. Run npm install && npm run build in web/setup.</p><p>` + index + `</p></main></body></html>`))
|
||||
}
|
||||
|
||||
func buildMySQLDSN(input setupMySQLConfig) (string, error) {
|
||||
host := strings.TrimSpace(input.Host)
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
port := input.Port
|
||||
if port <= 0 {
|
||||
port = 3306
|
||||
}
|
||||
database := strings.TrimSpace(input.Database)
|
||||
username := strings.TrimSpace(input.Username)
|
||||
if database == "" {
|
||||
return "", errors.New("mysql database is required")
|
||||
}
|
||||
if username == "" {
|
||||
return "", errors.New("mysql username is required")
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("charset", firstNonEmpty(strings.TrimSpace(input.Charset), "utf8mb4"))
|
||||
params.Set("parseTime", strconv.FormatBool(input.ParseTime))
|
||||
if tls := strings.TrimSpace(input.TLS); tls != "" {
|
||||
params.Set("tls", tls)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", username, input.Password, host, port, database, params.Encode()), nil
|
||||
}
|
||||
|
||||
func maskDSN(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
at := strings.Index(value, "@")
|
||||
colon := strings.Index(value, ":")
|
||||
if at > -1 && colon > -1 && colon < at {
|
||||
return value[:colon+1] + "******" + value[at:]
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func maskedDatabaseTarget(base string, cfg config.DatabaseConfig) string {
|
||||
if strings.EqualFold(cfg.Provider, "mysql") {
|
||||
return maskDSN(cfg.MySQLDSN)
|
||||
}
|
||||
return relativeToBase(base, cfg.SQLitePath)
|
||||
}
|
||||
|
||||
func relativeToBase(base, value string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return ""
|
||||
}
|
||||
if base != "" {
|
||||
if rel, err := filepath.Rel(base, value); err == nil && !strings.HasPrefix(rel, "..") && rel != "." {
|
||||
return filepath.ToSlash(rel)
|
||||
}
|
||||
}
|
||||
return filepath.ToSlash(value)
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ymhut-box/server/unified-management/internal/config"
|
||||
)
|
||||
|
||||
func TestSetupRouterServesBuiltAssetsAndBlocksBusinessAPI(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
setupDist := filepath.Join(root, "setup")
|
||||
if err := os.MkdirAll(filepath.Join(setupDist, "assets"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(setupDist, "index.html"), []byte(`<!doctype html><script type="module" src="/setup/assets/setup.js"></script><link rel="stylesheet" href="/setup/assets/setup.css">`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(setupDist, "assets", "setup.css"), []byte(`body{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(setupDist, "assets", "setup.js"), []byte(`console.log("setup")`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
handler := NewSetupRouter(setupConfig(root, setupDist))
|
||||
|
||||
for _, item := range []struct {
|
||||
path string
|
||||
want int
|
||||
typ string
|
||||
}{
|
||||
{"/setup", http.StatusOK, "text/html"},
|
||||
{"/setup/assets/setup.css", http.StatusOK, "text/css"},
|
||||
{"/setup/assets/setup.js", http.StatusOK, "javascript"},
|
||||
{"/api/client/bootstrap", http.StatusServiceUnavailable, "application/json"},
|
||||
} {
|
||||
req := httptest.NewRequest(http.MethodGet, item.path, nil)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != item.want {
|
||||
t.Fatalf("%s returned %d: %s", item.path, res.Code, res.Body.String())
|
||||
}
|
||||
if got := res.Header().Get("Content-Type"); !strings.Contains(got, item.typ) {
|
||||
t.Fatalf("%s content-type = %q, want %q", item.path, got, item.typ)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupSQLiteCompleteCreatesConfigAndDefaultAdmin(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
handler := NewSetupRouter(setupConfig(root, filepath.Join(root, "missing-setup-dist")))
|
||||
|
||||
body := bytes.NewBufferString(`{"provider":"sqlite","baseUrl":"https://update.ymhut.cn","sqlitePath":"storage/unified.sqlite"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/setup/complete", body)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusOK {
|
||||
t.Fatalf("complete returned %d: %s", res.Code, res.Body.String())
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(root, "config.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg config.Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !cfg.Initialized || cfg.Database.Provider != "sqlite" {
|
||||
t.Fatalf("unexpected config: %#v", cfg)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(root, "storage", "unified.sqlite")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupSQLiteIgnoresStructuredMySQLDefaults(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
handler := NewSetupRouter(setupConfig(root, filepath.Join(root, "missing-setup-dist")))
|
||||
|
||||
body := bytes.NewBufferString(`{"provider":"sqlite","baseUrl":"https://update.ymhut.cn","sqlitePath":"storage/unified.sqlite","mysql":{"host":"127.0.0.1","port":3306,"database":"ymhut_unified","username":"","password":"","charset":"utf8mb4","parseTime":true,"tls":"false"}}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/setup/database/test", body)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusOK {
|
||||
t.Fatalf("sqlite test returned %d: %s", res.Code, res.Body.String())
|
||||
}
|
||||
if strings.Contains(res.Body.String(), "mysql username is required") {
|
||||
t.Fatalf("sqlite test should ignore structured mysql defaults: %s", res.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupStructuredMySQLValidationReturnsFailureWithoutSaving(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
handler := NewSetupRouter(setupConfig(root, filepath.Join(root, "missing-setup-dist")))
|
||||
|
||||
body := bytes.NewBufferString(`{"provider":"mysql","mysql":{"host":"127.0.0.1","port":1,"database":"ymhut","username":"root","password":"secret","charset":"utf8mb4","parseTime":true,"tls":"false"}}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/setup/database/test", body)
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if res.Code != http.StatusBadGateway {
|
||||
t.Fatalf("mysql test returned %d, want 502: %s", res.Code, res.Body.String())
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(root, "config.json")); !os.IsNotExist(err) {
|
||||
t.Fatalf("config should not be written on failed test: %v", err)
|
||||
}
|
||||
if strings.Contains(res.Body.String(), "secret") {
|
||||
t.Fatalf("response leaked password: %s", res.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func setupConfig(root, setupDist string) *config.Config {
|
||||
return &config.Config{
|
||||
BaseDir: root,
|
||||
ConfigPath: filepath.Join(root, "config.json"),
|
||||
BaseURL: "https://update.ymhut.cn",
|
||||
StorageDir: filepath.Join(root, "storage"),
|
||||
SetupWebDir: setupDist,
|
||||
Database: config.DatabaseConfig{
|
||||
Provider: "sqlite",
|
||||
SQLitePath: filepath.Join(root, "storage", "unified.sqlite"),
|
||||
HealthIntervalSec: 30,
|
||||
MaxOpenConns: 1,
|
||||
MaxIdleConns: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user