Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cosmosutils/binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func getMinitiadBinaryURL(vm, version string) (string, error) {
}

// FindBinaryDir walks versionDir to find the directory that contains the named
// executable. This avoids hardcoding assumptions about how a release tarball is
// binary. This avoids hardcoding assumptions about how a release tarball is
// structured, so the code stays correct even if a future tarball places the
// binary inside a subdirectory.
func FindBinaryDir(versionDir, binaryName string) (string, error) {
Expand All @@ -269,7 +269,7 @@ func FindBinaryDir(versionDir, binaryName string) (string, error) {
if err != nil {
return err
}
if !info.IsDir() && info.Name() == binaryName && info.Mode()&0o111 != 0 {
if !info.IsDir() && info.Name() == binaryName {
result = filepath.Dir(path)
return filepath.SkipAll
}
Expand Down
12 changes: 8 additions & 4 deletions cosmosutils/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,17 @@ func TestFindBinaryDir(t *testing.T) {
wantRel: filepath.Join("a", "b", "c"),
},
{
name: "non-executable file is ignored",
name: "finds binary before executable permissions are restored",
layout: func(root string) {
os.MkdirAll(root, 0o755)
os.WriteFile(filepath.Join(root, "minitiad"), []byte("data"), 0o644)
if err := os.MkdirAll(root, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(root, "minitiad"), []byte("data"), 0o644); err != nil {
t.Fatal(err)
}
},
binaryName: "minitiad",
wantRel: "",
wantRel: ".",
},
{
name: "wrong name is ignored",
Expand Down
49 changes: 43 additions & 6 deletions io/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ func DownloadAndExtractTarGz(url, tarballPath, extractedPath string) error {
}

func ExtractTarGz(src string, dest string) error {
destRoot, err := filepath.Abs(dest)
if err != nil {
return err
}

file, err := os.Open(src)
if err != nil {
return err
Expand All @@ -60,23 +65,27 @@ func ExtractTarGz(src string, dest string) error {
return err
}

target := filepath.Join(dest, header.Name)
target, err := safeArchivePath(destRoot, header.Name)
if err != nil {
return err
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, os.ModePerm); err != nil {
return err
}
case tar.TypeReg:
file, err := os.Create(target)
if err != nil {
if err := os.MkdirAll(filepath.Dir(target), os.ModePerm); err != nil {
return err
}
_, err = io.Copy(file, tarReader)
file, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(header.Mode))
if err != nil {
return err
}
err = file.Close()
if err != nil {
if err := writeTarFile(file, tarReader); err != nil {
return err
}
if err := os.Chmod(target, os.FileMode(header.Mode)); err != nil {
return err
}
default:
Expand All @@ -86,6 +95,34 @@ func ExtractTarGz(src string, dest string) error {
return nil
}

func writeTarFile(file *os.File, src io.Reader) (err error) {
defer func() {
if closeErr := file.Close(); err == nil && closeErr != nil {
err = closeErr
}
}()

_, err = io.Copy(file, src)
return err
}

func safeArchivePath(destRoot, entryName string) (string, error) {
cleanName := filepath.Clean(entryName)
if cleanName == "." {
return destRoot, nil
}

target := filepath.Join(destRoot, cleanName)
rel, err := filepath.Rel(destRoot, target)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return "", fmt.Errorf("unsafe archive entry path: %s", entryName)
}
return target, nil
}

func SetLibraryPaths(binaryDir string) error {
envKey, envValue, err := LibraryPathEnv(binaryDir)
if err != nil {
Expand Down
66 changes: 66 additions & 0 deletions io/filesystem_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package io

import (
"archive/tar"
"compress/gzip"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -59,6 +62,69 @@ func TestExtractTarGz(t *testing.T) {
err := ExtractTarGz("./invalid.tar.gz", "./invalid")
assert.Error(t, err)
})

t.Run("PreservesExtractedFileMode", func(t *testing.T) {
tmpDir := t.TempDir()
tarballPath := filepath.Join(tmpDir, "test.tar.gz")
extractDir := filepath.Join(tmpDir, "extract")

file, err := os.Create(tarballPath)
assert.NoError(t, err)

gzw := gzip.NewWriter(file)
tw := tar.NewWriter(gzw)

content := []byte("#!/bin/sh\necho ok\n")
header := &tar.Header{
Name: "minitiad",
Mode: 0o755,
Size: int64(len(content)),
Typeflag: tar.TypeReg,
}
assert.NoError(t, tw.WriteHeader(header))
_, err = tw.Write(content)
assert.NoError(t, err)
assert.NoError(t, tw.Close())
assert.NoError(t, gzw.Close())
assert.NoError(t, file.Close())

err = ExtractTarGz(tarballPath, extractDir)
assert.NoError(t, err)

info, err := os.Stat(filepath.Join(extractDir, "minitiad"))
assert.NoError(t, err)
assert.Equal(t, os.FileMode(0o755), info.Mode().Perm())
})

t.Run("RejectsPathTraversalEntries", func(t *testing.T) {
tmpDir := t.TempDir()
tarballPath := filepath.Join(tmpDir, "test.tar.gz")
extractDir := filepath.Join(tmpDir, "extract")

file, err := os.Create(tarballPath)
assert.NoError(t, err)

gzw := gzip.NewWriter(file)
tw := tar.NewWriter(gzw)

content := []byte("bad\n")
header := &tar.Header{
Name: "../escape",
Mode: 0o644,
Size: int64(len(content)),
Typeflag: tar.TypeReg,
}
assert.NoError(t, tw.WriteHeader(header))
_, err = tw.Write(content)
assert.NoError(t, err)
assert.NoError(t, tw.Close())
assert.NoError(t, gzw.Close())
assert.NoError(t, file.Close())

err = ExtractTarGz(tarballPath, extractDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsafe archive entry path")
})
}

func TestSetLibraryPaths(t *testing.T) {
Expand Down
Loading