From c9f2c6398bd7988984f9919f80bacb5b44352f57 Mon Sep 17 00:00:00 2001 From: Pierre-Emmanuel Jacquier <15922119+pierre-emmanuelJ@users.noreply.github.com> Date: Sun, 21 Mar 2021 23:11:02 +0100 Subject: [PATCH] Forward HTTP headers on request Signed-off-by: Pierre-Emmanuel Jacquier <15922119+pierre-emmanuelJ@users.noreply.github.com> --- pkg/server/handlers.go | 21 +++++++++++-- pkg/server/server.go | 2 +- pkg/server/xtreamHandles.go | 6 ++-- pkg/xtream-proxy/xtream-proxy.go | 52 ++++++++++++++++++++++---------- 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index e11e4dfd..4979e016 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -71,7 +71,7 @@ func (c *Config) stream(ctx *gin.Context, oriURL *url.URL) { return } - copyHttpHeader(req.Header, ctx.Request.Header) + mergeHttpHeader(req.Header, ctx.Request.Header) resp, err := client.Do(req) if err != nil { @@ -80,7 +80,7 @@ func (c *Config) stream(ctx *gin.Context, oriURL *url.URL) { } defer resp.Body.Close() - copyHttpHeader(ctx.Writer.Header(), resp.Header) + mergeHttpHeader(ctx.Writer.Header(), resp.Header) ctx.Status(resp.StatusCode) ctx.Stream(func(w io.Writer) bool { io.Copy(w, resp.Body) // nolint: errcheck @@ -98,9 +98,24 @@ func (c *Config) xtreamStream(ctx *gin.Context, oriURL *url.URL) { c.stream(ctx, oriURL) } -func copyHttpHeader(dst, src http.Header) { +type values []string + +func (vs values) contains(s string) bool { + for _, v := range vs { + if v == s { + return true + } + } + + return false +} + +func mergeHttpHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { + if values(dst.Values(k)).contains(v) { + continue + } dst.Add(k, v) } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 23d8087c..8c58523b 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -28,11 +28,11 @@ import ( "path/filepath" "strings" + "github.com/gin-contrib/cors" "github.com/jamesnetherton/m3u" "github.com/pierre-emmanuelJ/iptv-proxy/pkg/config" uuid "github.com/satori/go.uuid" - "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) diff --git a/pkg/server/xtreamHandles.go b/pkg/server/xtreamHandles.go index 2c9f1dae..85485313 100644 --- a/pkg/server/xtreamHandles.go +++ b/pkg/server/xtreamHandles.go @@ -334,7 +334,7 @@ func (c *Config) hlsXtreamStream(ctx *gin.Context, oriURL *url.URL) { return } - copyHttpHeader(req.Header, ctx.Request.Header) + mergeHttpHeader(req.Header, ctx.Request.Header) resp, err := client.Do(req) if err != nil { @@ -361,7 +361,7 @@ func (c *Config) hlsXtreamStream(ctx *gin.Context, oriURL *url.URL) { return } - copyHttpHeader(hlsReq.Header, ctx.Request.Header) + mergeHttpHeader(hlsReq.Header, ctx.Request.Header) hlsResp, err := client.Do(hlsReq) if err != nil { @@ -378,7 +378,7 @@ func (c *Config) hlsXtreamStream(ctx *gin.Context, oriURL *url.URL) { body := string(b) body = strings.ReplaceAll(body, "/"+c.XtreamUser.String()+"/"+c.XtreamPassword.String()+"/", "/"+c.User.String()+"/"+c.Password.String()+"/") - copyHttpHeader(ctx.Request.Header, hlsResp.Header) + mergeHttpHeader(ctx.Writer.Header(), hlsResp.Header) ctx.Data(http.StatusOK, hlsResp.Header.Get("Content-Type"), []byte(body)) return diff --git a/pkg/xtream-proxy/xtream-proxy.go b/pkg/xtream-proxy/xtream-proxy.go index 9f12a7b5..df6afab0 100644 --- a/pkg/xtream-proxy/xtream-proxy.go +++ b/pkg/xtream-proxy/xtream-proxy.go @@ -104,36 +104,46 @@ func (c *Client) Action(config *config.ProxyConfig, action string, q url.Values) case getLiveCategories: respBody, err = c.GetLiveCategories() case getLiveStreams: - respBody, err = c.GetLiveStreams("") + categoryID := "" + if len(q["category_id"]) > 0 { + categoryID = q["category_id"][0] + } + respBody, err = c.GetLiveStreams(categoryID) case getVodCategories: respBody, err = c.GetVideoOnDemandCategories() case getVodStreams: - respBody, err = c.GetVideoOnDemandStreams("") + categoryID := "" + if len(q["category_id"]) > 0 { + categoryID = q["category_id"][0] + } + respBody, err = c.GetVideoOnDemandStreams(categoryID) case getVodInfo: - if len(q["vod_id"]) < 1 { - err = fmt.Errorf(`bad body url query parameters: missing "vod_id"`) - httpcode = http.StatusBadRequest + httpcode, err = validateParams(q, "vod_id") + if err != nil { return } respBody, err = c.GetVideoOnDemandInfo(q["vod_id"][0]) case getSeriesCategories: respBody, err = c.GetSeriesCategories() case getSeries: - respBody, err = c.GetSeries("") + categoryID := "" + if len(q["category_id"]) > 0 { + categoryID = q["category_id"][0] + } + respBody, err = c.GetSeries(categoryID) case getSerieInfo: - if len(q["series_id"]) < 1 { - err = fmt.Errorf(`bad body url query parameters: missing "series_id"`) - httpcode = http.StatusBadRequest + httpcode, err = validateParams(q, "series_id") + if err != nil { return } respBody, err = c.GetSeriesInfo(q["series_id"][0]) case getShortEPG: - if len(q["stream_id"]) < 1 { - err = fmt.Errorf(`bad body url query parameters: missing "stream_id"`) - httpcode = http.StatusBadRequest + limit := 0 + + httpcode, err = validateParams(q, "stream_id") + if err != nil { return } - limit := 0 if len(q["limit"]) > 0 { limit, err = strconv.Atoi(q["limit"][0]) if err != nil { @@ -143,9 +153,8 @@ func (c *Client) Action(config *config.ProxyConfig, action string, q url.Values) } respBody, err = c.GetShortEPG(q["stream_id"][0], limit) case getSimpleDataTable: - if len(q["stream_id"]) < 1 { - err = fmt.Errorf(`bad body url query parameters: missing "stream_id"`) - httpcode = http.StatusBadRequest + httpcode, err = validateParams(q, "stream_id") + if err != nil { return } respBody, err = c.GetEPG(q["stream_id"][0]) @@ -155,3 +164,14 @@ func (c *Client) Action(config *config.ProxyConfig, action string, q url.Values) return } + +func validateParams(u url.Values, params ...string) (int, error) { + for _, p := range params { + if len(u[p]) < 1 { + return http.StatusBadRequest, fmt.Errorf("missing %q", p) + } + + } + + return 0, nil +}