diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d032a41 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,19 @@ +name: ci + +on: + push: + branches: + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + + - name: Run tests + run: go test -v -race ./... diff --git a/README.md b/README.md index b0c04e9..d9d039e 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,14 @@ dl-pipe https://example.invalid/my-file.tar | tar x ``` +You may also provide the parts of a multipart tar file and it will be reassembled. + +``` +dl-pipe https://example.invalid/my-file.tar.part1 https://example.invalid/my-file.tar.part2 https://example.invalid/my-file.tar.part3 | tar x +``` + +We use this to workaround the 5TB size limit of most object storage providers. + We also provide an expected hash via the `-hash` option to ensure that the download content is correct. Make sure you set `set -eo pipefail` to ensure your script stops on errors. Install with `go install github.com/zeta-chain/dl-pipe@latest`. \ No newline at end of file diff --git a/cmd/dl-pipe/main.go b/cmd/dl-pipe/main.go index d2eda62..b36f60b 100644 --- a/cmd/dl-pipe/main.go +++ b/cmd/dl-pipe/main.go @@ -71,7 +71,7 @@ const progressFuncInterval = time.Second * 10 func getProgressFunc() dlpipe.ProgressFunc { prevLength := uint64(0) - return func(currentLength uint64, totalLength uint64) { + return func(currentLength uint64, totalLength uint64, currentPart int, totalParts int) { currentLengthStr := humanize.Bytes(currentLength) totalLengthStr := humanize.Bytes(totalLength) @@ -81,7 +81,12 @@ func getProgressFunc() dlpipe.ProgressFunc { percent := float64(currentLength) / float64(totalLength) * 100 - fmt.Fprintf(os.Stderr, "Downloaded %s of %s (%.1f%%) at %s/s\n", currentLengthStr, totalLengthStr, percent, rateStr) + partStr := "" + if totalParts > 1 { + partStr = fmt.Sprintf(" (part %d of %d)", currentPart+1, totalParts) + } + + fmt.Fprintf(os.Stderr, "Downloaded %s of %s (%.1f%%) at %s/s%s\n", currentLengthStr, totalLengthStr, percent, rateStr, partStr) } } @@ -101,9 +106,9 @@ func main() { flag.BoolVar(&progress, "progress", false, "Show download progress") flag.Parse() - url := flag.Arg(0) - if url == "" { - fmt.Fprintf(os.Stderr, ("URL is required")) + urls := flag.Args() + if len(urls) == 0 { + fmt.Fprintf(os.Stderr, ("URL(s) are required")) os.Exit(1) } @@ -119,9 +124,9 @@ func main() { headerMap[parts[0]] = parts[1] } - err := dlpipe.DownloadURL( + err := dlpipe.DownloadURLMultipart( ctx, - url, + urls, os.Stdout, dlpipe.WithHeaders(headerMap), getHashOpt(hash), diff --git a/download.go b/download.go index 05cfde8..ac2e64c 100644 --- a/download.go +++ b/download.go @@ -3,11 +3,13 @@ package dlpipe import ( "bytes" "context" + "errors" "fmt" "hash" "io" "net/http" "strings" + "sync" "time" "github.com/miolini/datacounter" @@ -57,7 +59,7 @@ func WithHeaders(headers map[string]string) DownloadOpt { } } -type ProgressFunc func(currentLength uint64, totalLength uint64) +type ProgressFunc func(currentLength, totalLength uint64, currentPart, totalParts int) func WithProgressFunc(progressFunc ProgressFunc, interval time.Duration) DownloadOpt { return func(d *downloader) { @@ -115,7 +117,7 @@ func DefaultRetryParameters() RetryParameters { type downloader struct { // these fields are set once - url string + urls []string writer *datacounter.WriterCounter httpClient *http.Client retryParameters RetryParameters @@ -129,6 +131,9 @@ type downloader struct { // these fields are updated at runtime contentLength int64 + urlsPosition int + + sync.RWMutex } func (d *downloader) progressReportLoop(ctx context.Context) { @@ -137,7 +142,9 @@ func (d *downloader) progressReportLoop(ctx context.Context) { for { select { case <-t.C: - d.progressFunc(d.writer.Count(), uint64(d.contentLength)) + d.RLock() + d.progressFunc(d.writer.Count(), uint64(d.contentLength), d.urlsPosition, d.totalPartCount()) + d.RUnlock() case <-ctx.Done(): return } @@ -145,7 +152,9 @@ func (d *downloader) progressReportLoop(ctx context.Context) { } func (d *downloader) runInner(ctx context.Context) (io.ReadCloser, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) + d.RLock() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.urls[d.urlsPosition], nil) + d.RUnlock() if err != nil { return nil, NonRetryableWrapf("create request: %w", err) } @@ -176,7 +185,9 @@ func (d *downloader) runInner(ctx context.Context) (io.ReadCloser, error) { } if resp.StatusCode != http.StatusPartialContent { - return nil, NonRetryableWrapf("unexpected status code on subsequent read: %d", resp.StatusCode) + // this error should be retried since cloudflare r2 sometimes ignores the range request and + // returns 200 + return nil, fmt.Errorf("unexpected status code on subsequent read: %d", resp.StatusCode) } // Validate we are receiving the right portion of partial content @@ -212,15 +223,23 @@ func (d *downloader) run(ctx context.Context) error { if d.progressFunc != nil { go d.progressReportLoop(ctx) } - for { + d.resetWriterPosition() + + for d.urlsPosition < d.totalPartCount() { body, err := d.runInner(ctx) - if err != nil { - return err - } - defer body.Close() - _, err = io.Copy(d.writer, body) if err == nil { - break + defer body.Close() + _, err = io.Copy(d.writer, body) + if err == nil { + d.Lock() + d.urlsPosition++ + d.resetWriterPosition() + d.Unlock() + continue + } + } + if errors.Is(err, ErrNonRetryable{}) { + return err } err = d.retryParameters.Wait(ctx, d.writer.Count()) if err != nil { @@ -236,9 +255,22 @@ func (d *downloader) run(ctx context.Context) error { return nil } +func (d *downloader) resetWriterPosition() { + d.writer = datacounter.NewWriterCounter(d.tmpWriter) + d.contentLength = 0 +} + +func (d *downloader) totalPartCount() int { + return len(d.urls) +} + func DownloadURL(ctx context.Context, url string, writer io.Writer, opts ...DownloadOpt) error { + return DownloadURLMultipart(ctx, []string{url}, writer, opts...) +} + +func DownloadURLMultipart(ctx context.Context, urls []string, writer io.Writer, opts ...DownloadOpt) error { d := &downloader{ - url: url, + urls: urls, tmpWriter: writer, httpClient: &http.Client{ Transport: &http.Transport{ @@ -254,6 +286,5 @@ func DownloadURL(ctx context.Context, url string, writer io.Writer, opts ...Down } opt(d) } - d.writer = datacounter.NewWriterCounter(d.tmpWriter) return d.run(ctx) } diff --git a/download_test.go b/download_test.go index e6bb23c..212dca6 100644 --- a/download_test.go +++ b/download_test.go @@ -4,12 +4,14 @@ import ( "context" "crypto/rand" "crypto/sha256" + "errors" "fmt" "io" "log" "net/http" "net/http/httptest" "os" + "path/filepath" "sync" "testing" @@ -26,12 +28,12 @@ func TestUninterruptedDownload(t *testing.T) { r := require.New(t) ctx := context.Background() - serverURL, expectedHash, cleanup := serveInterruptedTestFile(t, fileSize, 0) + serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 1) defer cleanup() hasher := sha256.New() - err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, expectedHash)) + err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash)) r.NoError(err) givenHash := hasher.Sum(nil) @@ -42,12 +44,12 @@ func TestUninterruptedMismatch(t *testing.T) { r := require.New(t) ctx := context.Background() - serverURL, _, cleanup := serveInterruptedTestFile(t, fileSize, 0) + serverURLs, _, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 1) defer cleanup() hasher := sha256.New() - err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, []byte{})) + err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, []byte{})) r.Error(err) } @@ -55,41 +57,83 @@ func TestInterruptedDownload(t *testing.T) { r := require.New(t) ctx := context.Background() - serverURL, expectedHash, cleanup := serveInterruptedTestFile(t, fileSize, interruptAt) + serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, interruptAt, 1) defer cleanup() hasher := sha256.New() - err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, expectedHash)) + err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash)) r.NoError(err) } -// derrived from https://github.com/vansante/go-dl-stream/blob/e29aef86498f37d3506126bc258193f1c913ea55/download_test.go#L166 -func serveInterruptedTestFile(t *testing.T, fileSize, interruptAt int64) (serverURL string, sha256Hash []byte, cleanup func()) { - rndFile, err := os.CreateTemp(os.TempDir(), "random_file_*.rnd") - assert.NoError(t, err) - filePath := rndFile.Name() +func TestDownloadMultipart(t *testing.T) { + r := require.New(t) + ctx := context.Background() + + serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 10) + defer cleanup() hasher := sha256.New() - _, err = io.Copy(io.MultiWriter(hasher, rndFile), io.LimitReader(rand.Reader, fileSize)) - assert.NoError(t, err) - assert.NoError(t, rndFile.Close()) - mux := http.NewServeMux() - mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - log.Printf("Serving random interrupted file (size: %d, interuptAt: %d), Range: %s", fileSize, interruptAt, request.Header.Get(rangeHeader)) + err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash)) + r.NoError(err) +} + +func TestDownloadMultipartInterrupted(t *testing.T) { + r := require.New(t) + ctx := context.Background() + + serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, interruptAt, 10) + defer cleanup() + + hasher := sha256.New() - http.ServeFile(&interruptibleHTTPWriter{ - ResponseWriter: writer, - writer: writer, - interruptAt: interruptAt, - }, request, filePath) + err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash)) + r.NoError(err) +} + +func TestErrNonRetryable(t *testing.T) { + err := NonRetryableWrapf("test") + require.True(t, errors.Is(err, ErrNonRetryable{})) +} - }) +// derrived from https://github.com/vansante/go-dl-stream/blob/e29aef86498f37d3506126bc258193f1c913ea55/download_test.go#L166 +func serveInterruptedTestFiles(t *testing.T, fileSize, interruptAt int64, parts int) ([]string, []byte, func()) { + mux := http.NewServeMux() server := httptest.NewServer(mux) + hasher := sha256.New() + filePaths := []string{} + urls := []string{} + + for i := 0; i < parts; i++ { + rndFile, err := os.CreateTemp(os.TempDir(), "random_file_*.rnd") + assert.NoError(t, err) + filePath := rndFile.Name() + filePaths = append(filePaths, filePath) + filePathBase := filepath.Base(filePath) - return server.URL, hasher.Sum(nil), func() { - _ = os.Remove(filePath) + _, err = io.Copy(io.MultiWriter(hasher, rndFile), io.LimitReader(rand.Reader, fileSize)) + assert.NoError(t, err) + assert.NoError(t, rndFile.Close()) + + mux.HandleFunc(filePath, func(writer http.ResponseWriter, request *http.Request) { + log.Printf("Serving random interrupted file %s (size: %d, interuptAt: %d), Range: %s", filePathBase, fileSize, interruptAt, request.Header.Get(rangeHeader)) + + http.ServeFile(&interruptibleHTTPWriter{ + ResponseWriter: writer, + writer: writer, + interruptAt: interruptAt, + }, request, filePath) + + }) + urls = append(urls, server.URL+filePath) + + } + + return urls, hasher.Sum(nil), func() { + for _, path := range filePaths { + _ = os.Remove(path) + } } } diff --git a/errors.go b/errors.go index 24f6d32..5933867 100644 --- a/errors.go +++ b/errors.go @@ -28,6 +28,11 @@ func (e ErrNonRetryable) Unwrap() error { return e.inner } +func (e ErrNonRetryable) Is(target error) bool { + _, ok := target.(ErrNonRetryable) + return ok +} + func NonRetryableWrap(err error) error { return ErrNonRetryable{inner: err} }