diff --git a/README.MD b/README.MD index 5db6b3e5..cf731449 100644 --- a/README.MD +++ b/README.MD @@ -52,13 +52,16 @@ 例: ```javascript { - "debug": false, // 是否打印调试日志, 默认为否 - "show_serve_info": false, // 是否打印访问信息, 默认为否 (这个选项对于压缩日志文件十分有用) - "ignore_serve_error": true, // 是否忽略http serve error, 默认打开 (这个选项对于压缩日志文件十分有用) - "public_host": "example.com", // public host, 同 CLUSTER_IP - "public_port": 8080, // 实际开放的公网端口, 同 CLUSTER_PUBLIC_PORT - "port": 80, // 要监听的本地端口, 同 CLUSTER_PORT - "cluster_id": "${CLUSTER_ID}", // CLUSTER_ID - "cluster_secret": "${CLUSTER_SECRET}" // CLUSTER_SECRET + "debug": false, // 是否打印调试日志, 默认为否 + "show_serve_info": false, // 是否打印访问信息, 默认为否 (这个选项对于压缩日志文件十分有用) + "ignore_serve_error": true, // 是否忽略http serve error, 默认打开 (这个选项对于压缩日志文件十分有用) + "public_host": "example.com", // public host, 同 CLUSTER_IP + "public_port": 8080, // 实际开放的公网端口, 同 CLUSTER_PUBLIC_PORT + "port": 80, // 要监听的本地端口, 同 CLUSTER_PORT + "cluster_id": "${CLUSTER_ID}", // CLUSTER_ID + "cluster_secret": "${CLUSTER_SECRET}", // CLUSTER_SECRET + "hijack": false, // 是否启动 bmclapi 劫持代理 + "hijack_port": 8090, // 劫持代理监听的端口 + "anti_hijack_dns": "8.8.8.8:32" // 用于绕过DNS劫持的DNS服务器 } ``` diff --git a/config.json b/config.json index 97023bfb..1e425b06 100644 --- a/config.json +++ b/config.json @@ -1,11 +1,14 @@ { - "debug": false, - "show_serve_info": false, - "ignore_serve_error": true, - "nohttps": false, - "public_host": "example.com", - "public_port": 8080, - "port": 4000, - "cluster_id": "${CLUSTER_ID}", - "cluster_secret": "${CLUSTER_SECRET}" + "debug": false, + "show_serve_info": false, + "ignore_serve_error": true, + "nohttps": false, + "public_host": "example.com", + "public_port": 8080, + "port": 4000, + "cluster_id": "${CLUSTER_ID}", + "cluster_secret": "${CLUSTER_SECRET}", + "hijack": true, + "hijack_port": 8090, + "anti_hijack_dns": "8.8.8.8:53" } \ No newline at end of file diff --git a/src/cluster.go b/src/cluster.go index 5ef598f5..d7565763 100644 --- a/src/cluster.go +++ b/src/cluster.go @@ -3,11 +3,11 @@ package main import ( "context" "crypto" - "crypto/tls" "encoding/hex" "errors" "fmt" "io" + "net" "net/http" "os" "path/filepath" @@ -31,6 +31,7 @@ type Cluster struct { version string useragent string prefix string + byoc bool cachedir string maxConn int @@ -50,10 +51,16 @@ type Cluster struct { } func NewCluster( - ctx context.Context, + ctx context.Context, cacheDir string, host string, publicPort uint16, username string, password string, - version string, address string) (cr *Cluster) { + version string, address string, + byoc bool, dialer *net.Dialer, +) (cr *Cluster) { + transport := &http.Transport{} + if dialer != nil { + transport.DialContext = dialer.DialContext + } cr = &Cluster{ ctx: ctx, @@ -64,17 +71,14 @@ func NewCluster( version: version, useragent: "openbmclapi-cluster/" + version, prefix: "https://openbmclapi.bangbang93.com", + byoc: byoc, - cachedir: "cache", + cachedir: cacheDir, maxConn: 400, client: &http.Client{ - Timeout: time.Second * 60, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - // InsecureSkipVerify: true, // Skip verify because the author was lazy - }, - }, + Timeout: time.Second * 60, + Transport: transport, }, Server: &http.Server{ Addr: address, @@ -99,7 +103,7 @@ func (cr *Cluster) Connect() bool { header.Set("User-Agent", cr.useragent) connectCh := make(chan struct{}, 0) - connected := sync.OnceFunc(func(){ + connected := sync.OnceFunc(func() { close(connectCh) }) @@ -152,7 +156,7 @@ func (cr *Cluster) Enable() (err error) { "host": cr.host, "port": cr.publicPort, "version": cr.version, - "byoc": USE_HTTPS, + "byoc": cr.byoc, }) if err != nil { return @@ -397,7 +401,7 @@ RESYNC: } // sort the files in descending order of size - sort.Slice(files, func(i, j int) bool { return files[i].Size > files[j].Size }) + sort.Slice(files, func(i, j int) bool { return files[i].Size < files[j].Size }) var stats syncStats stats.slots = make(chan struct{}, cr.maxConn) @@ -505,7 +509,9 @@ func (cr *Cluster) dlhandle(ctx context.Context, f *FileInfo) (err error) { } for i := 0; i < 3; i++ { - cr.downloadFileBuf(ctx, f, hashMethod, buf) + if err = cr.downloadFileBuf(ctx, f, hashMethod, buf); err == nil { + return + } } return } @@ -597,7 +603,8 @@ func (cr *Cluster) downloadFileBuf(ctx context.Context, f *FileInfo, hashMethod hw := hashMethod.New() - if fd, err = os.Create(cr.getHashPath(f.Hash)); err != nil { + hspt := cr.getHashPath(f.Hash) + if fd, err = os.Create(hspt); err != nil { return } @@ -612,13 +619,28 @@ func (cr *Cluster) downloadFileBuf(ctx context.Context, f *FileInfo, hashMethod } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != f.Hash { err = fmt.Errorf("File hash not match, got %s, expect %s", hs, f.Hash) } - if DEBUG && err != nil { - f0, _ := os.Open(cr.getHashPath(f.Hash)) - b0, _ := io.ReadAll(f0) - if len(b0) < 16*1024 { - logDebug("File content:", (string)(b0), "//for", f.Path) + if err != nil { + if config.Debug { + f0, _ := os.Open(cr.getHashPath(f.Hash)) + b0, _ := io.ReadAll(f0) + if len(b0) < 16*1024 { + logDebug("File content:", (string)(b0), "//for", f.Path) + } } + return } + + if config.Hijack { + if !strings.HasPrefix(f.Path, "/openbmclapi/download/") { + target := filepath.Join(hijackPath, filepath.FromSlash(f.Path)) + dir := filepath.Dir(target) + os.MkdirAll(dir, 0755) + if rp, err := filepath.Rel(dir, hspt); err == nil { + os.Symlink(rp, target) + } + } + } + return } diff --git a/src/handler.go b/src/handler.go index f9f7560f..a9ae0f94 100644 --- a/src/handler.go +++ b/src/handler.go @@ -18,12 +18,12 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { method := req.Method url := req.URL rawpath := url.EscapedPath() - if SHOW_SERVE_INFO { - go logInfo("serve url:", url.String()) + if config.ShowServeInfo { + logInfo("serve url:", url.String()) } switch { case strings.HasPrefix(rawpath, "/download/"): - if method == "GET" { + if method == http.MethodGet { hash := rawpath[len("/download/"):] path := cr.getHashPath(hash) if ufile.IsNotExist(path) { @@ -68,7 +68,7 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } _, err = rw.Write(buf[:n]) if err != nil { - if !IGNORE_SERVE_ERROR { + if !config.IgnoreServeError { logError("Error when serving download:", err) } return @@ -80,7 +80,7 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } case strings.HasPrefix(rawpath, "/measure/"): - if method == "GET" { + if method == http.MethodGet { if req.Header.Get("x-openbmclapi-secret") != cr.password { rw.WriteHeader(http.StatusForbidden) return diff --git a/src/hijacker.go b/src/hijacker.go new file mode 100644 index 00000000..e19d45e2 --- /dev/null +++ b/src/hijacker.go @@ -0,0 +1,81 @@ +package main + +import ( + "context" + "io" + "net" + "net/http" + "os" + "path/filepath" + "time" +) + +func getDialerWithDNS(dnsaddr string) *net.Dialer { + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, dnsaddr) + }, + } + return &net.Dialer{ + Resolver: resolver, + } +} + +type HjProxy struct { + dialer *net.Dialer + path string + client *http.Client +} + +func NewHjProxy(dialer *net.Dialer, path string) (h *HjProxy) { + return &HjProxy{ + dialer: dialer, + path: path, + client: &http.Client{ + Timeout: time.Second * 15, + Transport: &http.Transport{ + DialContext: dialer.DialContext, + }, + }, + } +} + +const hijackingHost = "bmclapi2.bangbang93.com" + +func (h *HjProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if req.Method == http.MethodGet || req.Method == http.MethodHead { + target := filepath.Join(h.path, filepath.Clean(filepath.FromSlash(req.URL.Path))) + fd, err := os.Open(target) + if err == nil { + defer fd.Close() + if stat, err := fd.Stat(); err == nil { + if !stat.IsDir() { + modTime := stat.ModTime() + http.ServeContent(rw, req, filepath.Base(target), modTime, fd) + return + } + } + } + } + + u := *req.URL + u.Scheme = "https" + u.Host = hijackingHost + req2, err := http.NewRequestWithContext(req.Context(), req.Method, u.String(), req.Body) + if err != nil { + http.Error(rw, err.Error(), http.StatusBadGateway) + return + } + res, err := h.client.Do(req2) + if err != nil { + http.Error(rw, err.Error(), http.StatusBadGateway) + return + } + defer res.Body.Close() + for k, v := range res.Header { + rw.Header()[k] = v + } + rw.WriteHeader(res.StatusCode) + io.Copy(rw, res.Body) +} diff --git a/src/logger.go b/src/logger.go index 39e0b024..70167086 100644 --- a/src/logger.go +++ b/src/logger.go @@ -55,13 +55,13 @@ func logXf(x string, format string, args ...any) { } func logDebug(args ...any) { - if DEBUG { + if config.Debug { logX("DBUG", args...) } } func logDebugf(format string, args ...any) { - if DEBUG { + if config.Debug { logXf("DBUG", format, args...) } } diff --git a/src/main.go b/src/main.go index b7cefb6a..646d0c1f 100644 --- a/src/main.go +++ b/src/main.go @@ -2,98 +2,113 @@ package main import ( "context" + "encoding/json" "errors" "fmt" + "net" "net/http" "os" "os/signal" + "path/filepath" "strconv" "syscall" "time" - - json "github.com/KpnmServer/go-util/json" ) var ( - DEBUG bool = false - SHOW_SERVE_INFO bool = false - IGNORE_SERVE_ERROR bool = true - HOST string = "" - PORT uint16 = 0 - PUBLIC_PORT uint16 = 0 - CLUSTER_ID string = "username" - CLUSTER_SECRET string = "password" - USE_HTTPS bool = false - SyncFileInterval = time.Minute * 10 KeepAliveInterval = time.Second * 60 ) +type Config struct { + Debug bool `json:"debug"` + ShowServeInfo bool `json:"show_serve_info"` + IgnoreServeError bool `json:"ignore_serve_error"` + Nohttps bool `json:"nohttps"` + PublicHost string `json:"public_host"` + PublicPort uint16 `json:"public_port"` + Port uint16 `json:"port"` + ClusterId string `json:"cluster_id"` + ClusterSecret string `json:"cluster_secret"` + Hijack bool `json:"hijack"` + HijackPort uint16 `json:"hijack_port"` + AntiHijackDNS string `json:"anti_hijack_dns"` +} + +var config Config + func readConfig() { - // TODO: Use struct with json tag - { // read config file - var ( - fd *os.File - obj json.JsonObj = nil - err error - n int - ) - fd, err = os.Open("config.json") - if err != nil { - panic(err) - } - defer fd.Close() - obj = make(json.JsonObj) - err = json.ReadJson(fd, &obj) - if err != nil { - panic(err) - } - DEBUG = obj.Has("debug") && obj.GetBool("debug") - SHOW_SERVE_INFO = obj.Has("show_serve_info") && obj.GetBool("show_serve_info") - IGNORE_SERVE_ERROR = obj.Has("ignore_serve_error") && obj.GetBool("ignore_serve_error") - if os.Getenv("CLUSTER_IP") != "" { - HOST = os.Getenv("CLUSTER_IP") - } else { - HOST = obj.GetString("public_host") - } - if os.Getenv("CLUSTER_PORT") != "" { - n, err = strconv.Atoi(os.Getenv("CLUSTER_PORT")) - if err != nil { - panic(err) - } - PORT = (uint16)(n) - } else if obj.Has("port") { - PORT = obj.GetUInt16("port") - } - if os.Getenv("CLUSTER_PUBLIC_PORT") != "" { - n, err = strconv.Atoi(os.Getenv("CLUSTER_PUBLIC_PORT")) - if err != nil { - panic(err) - } - PUBLIC_PORT = (uint16)(n) - } else if obj.Has("public_port") { - PUBLIC_PORT = obj.GetUInt16("public_port") - } else { - PUBLIC_PORT = PORT - } - if os.Getenv("CLUSTER_ID") != "" { - CLUSTER_ID = os.Getenv("CLUSTER_ID") - } else { - CLUSTER_ID = obj.GetString("cluster_id") + const configPath = "config.json" + data, err := os.ReadFile(configPath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + logError("Cannot read config:", err) + os.Exit(1) } - if os.Getenv("CLUSTER_SECRET") != "" { - CLUSTER_SECRET = os.Getenv("CLUSTER_SECRET") + logError("Config file not exists, create one") + config = Config{ + Debug: false, + ShowServeInfo: false, + IgnoreServeError: true, + Nohttps: false, + PublicHost: "example.com", + PublicPort: 8080, + Port: 4000, + ClusterId: "${CLUSTER_ID}", + ClusterSecret: "${CLUSTER_SECRET}", + Hijack: false, + HijackPort: 8090, + AntiHijackDNS: "8.8.8.8:53", + } + } else if err = json.Unmarshal(data, &config); err != nil { + logError("Cannot parse config:", err) + os.Exit(1) + } + + if data, err = json.MarshalIndent(config, "", " "); err != nil { + logError("Cannot encode config:", err) + os.Exit(1) + } + if err = os.WriteFile(configPath, data, 0600); err != nil { + logError("Cannot write config:", err) + os.Exit(1) + } + + if os.Getenv("DEBUG") == "true" { + config.Debug = true + } + if v := os.Getenv("CLUSTER_IP"); v != "" { + config.PublicHost = v + } + if v := os.Getenv("CLUSTER_PORT"); v != "" { + if n, err := strconv.Atoi(v); err != nil { + logErrorf("Cannot parse CLUSTER_PORT %q: %v", v, err) } else { - CLUSTER_SECRET = obj.GetString("cluster_secret") + config.Port = (uint16)(n) } - if byoc := os.Getenv("CLUSTER_BYOC"); byoc != "" { - USE_HTTPS = byoc != "true" + } + if v := os.Getenv("CLUSTER_PUBLIC_PORT"); v != "" { + if n, err := strconv.Atoi(v); err != nil { + logErrorf("Cannot parse CLUSTER_PUBLIC_PORT %q: %v", v, err) } else { - USE_HTTPS = !(obj.Has("nohttps") && obj.GetBool("nohttps")) + config.PublicPort = (uint16)(n) } } + if v := os.Getenv("CLUSTER_ID"); v != "" { + config.ClusterId = v + } + if v := os.Getenv("CLUSTER_SECRET"); v != "" { + config.ClusterSecret = v + } + if byoc := os.Getenv("CLUSTER_BYOC"); byoc != "" { + config.Nohttps = byoc == "true" + } } +const cacheDir = "cache" + +var hijackPath = filepath.Join(cacheDir, "__hijack") + func main() { defer func() { if err := recover(); err != nil { @@ -115,26 +130,53 @@ START: ctx, cancel := context.WithCancel(bgctx) readConfig() - cluster := NewCluster(ctx, HOST, PUBLIC_PORT, CLUSTER_ID, CLUSTER_SECRET, VERSION, fmt.Sprintf("%s:%d", "0.0.0.0", PORT)) - logInfof("Starting Go-OpenBmclApi v%s", VERSION) - - { - logInfof("Fetching file list") - fl := cluster.GetFileList() - if fl == nil { - logError("Cluster filelist is nil, exit") - os.Exit(1) + var ( + dialer *net.Dialer + hjproxy *HjProxy + hjServer *http.Server + ) + if config.Hijack { + dialer = getDialerWithDNS(config.AntiHijackDNS) + hjproxy = NewHjProxy(dialer, hijackPath) + hjServer = &http.Server{ + Addr: fmt.Sprintf("0.0.0.0:%d", config.HijackPort), + Handler: hjproxy, } - cluster.SyncFiles(fl, ctx) + go func() { + logInfof("Hijack server start at %q", hjServer.Addr) + err := hjServer.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + logError("Error on server:", err) + os.Exit(1) + } + }() + time.Sleep(time.Second * 100000) } + cluster := NewCluster(ctx, cacheDir, + config.PublicHost, config.PublicPort, + config.ClusterId, config.ClusterSecret, VERSION, + fmt.Sprintf("%s:%d", "0.0.0.0", config.Port), + config.Nohttps, dialer) + + logInfof("Starting Go-OpenBmclApi v%s", VERSION) + + // { + // logInfof("Fetching file list") + // fl := cluster.GetFileList() + // if fl == nil { + // logError("Cluster filelist is nil, exit") + // os.Exit(1) + // } + // cluster.SyncFiles(fl, ctx) + // } if !cluster.Connect() { os.Exit(1) } var certFile, keyFile string - if USE_HTTPS { + if !config.Nohttps { pair, err := cluster.RequestCert() if err != nil { logError("Error when requesting cert key pair:", err) @@ -149,9 +191,9 @@ START: go func() { defer close(exitCh) - logInfof("Server start at \"%s\"", cluster.Server.Addr) + logInfof("Server start at %q", cluster.Server.Addr) var err error - if USE_HTTPS { + if !config.Nohttps { err = cluster.Server.ListenAndServeTLS(certFile, keyFile) } else { err = cluster.Server.ListenAndServe() @@ -186,6 +228,7 @@ START: logWarn("Closing server ...") cancel() shutExit := make(chan struct{}, 0) + go hjServer.Shutdown(shutCtx) go func() { defer close(shutExit) defer cancelShut()