diff --git a/pkg/consumer/tar_extractor_test.go b/pkg/consumer/tar_extractor_test.go index 5f7f93d..072ecc7 100644 --- a/pkg/consumer/tar_extractor_test.go +++ b/pkg/consumer/tar_extractor_test.go @@ -3,6 +3,7 @@ package consumer_test import ( "archive/tar" "bytes" + "io" "os" "path" "testing" @@ -110,7 +111,7 @@ func TestTarExtractor_Consume(t *testing.T) { r.NoError(err) // Create a reader from the tar file bytes - reader := bytes.NewReader(tarFileBytes) + reader := io.MultiReader(bytes.NewReader(tarFileBytes), bytes.NewReader(make([]byte, 1024))) // Create a temporary directory to extract the tar file tmpDir, err := os.MkdirTemp("", "tarExtractorTest-") @@ -120,15 +121,15 @@ func TestTarExtractor_Consume(t *testing.T) { tarConsumer := consumer.TarExtractor{} targetDir := path.Join(tmpDir, "extract") - r.NoError(tarConsumer.Consume(reader, targetDir, int64(len(tarFileBytes)))) + r.NoError(tarConsumer.Consume(reader, targetDir, int64(len(tarFileBytes)+1024))) // Check if the extraction was successful checkTarExtraction(t, targetDir) // Test with incorrect expectedBytes - _, _ = reader.Seek(0, 0) + reader = io.MultiReader(bytes.NewReader(tarFileBytes), bytes.NewReader(make([]byte, 1024))) targetDir = path.Join(tmpDir, "extract-fail") - r.Error(tarConsumer.Consume(reader, targetDir, int64(len(tarFileBytes)-1))) + r.Error(tarConsumer.Consume(reader, targetDir, int64(len(tarFileBytes)+1024-1))) } func checkTarExtraction(t *testing.T, targetDir string) { diff --git a/pkg/extract/tar.go b/pkg/extract/tar.go index ce37807..1355c0c 100644 --- a/pkg/extract/tar.go +++ b/pkg/extract/tar.go @@ -119,6 +119,18 @@ func TarFile(r *bufio.Reader, destDir string, overwrite bool) error { return fmt.Errorf("error creating links: %w", err) } + // Read the rest of the bytes from the archive and verify they are all null bytes + // This is for validation that the byte count is correct + padding, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("error reading padding bytes: %w", err) + } + for _, b := range padding { + if b != 0x00 { + return fmt.Errorf("unexpected non-null byte in padding: %x", b) + } + } + elapsed := time.Since(startTime).Seconds() logger.Debug(). Str("extractor", "tar").