From dc46a2e5b5d834eb865b0592b670bcfd497f2527 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Tue, 17 Aug 2021 16:53:09 +0200 Subject: [PATCH 1/2] Added timeout to network request --- main.go | 132 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 107 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index 247cf09..d86c650 100644 --- a/main.go +++ b/main.go @@ -1,22 +1,30 @@ package main import ( + "context" + "errors" "fmt" "io" "io/ioutil" + "net" "net/http" "os" "os/exec" "runtime" - "strings" + "sync" + "time" ) func main() { arch := GetSystemInfo() - installer_dir := SetupTemporaryDirectory(arch) - installer_path := DownloadInstaller(installer_dir, arch) - RunInstaller(installer_path) - Cleanup(installer_dir) + installer_dir, err := SetupTemporaryDirectory(arch) + checkError(err) + installer_path, err := DownloadInstaller(installer_dir, arch) + checkError(err) + err = RunInstaller(installer_path) + checkError(err) + err = Cleanup(installer_dir) + checkError(err) } func GetSystemInfo() (arch string) { @@ -32,57 +40,75 @@ func GetSystemInfo() (arch string) { return } -func SetupTemporaryDirectory(arch string) (installer_dir string) { - installer_dir, err := ioutil.TempDir("", "vscode-winsta11er") - checkError(err) - - return installer_dir +func SetupTemporaryDirectory(arch string) (installer_dir string, err error) { + return ioutil.TempDir("", "vscode-winsta11er") } -func DownloadInstaller(installer_dir, arch string) (installer_path string) { - var downloadUrl = strings.Replace("https://update.code.visualstudio.com/latest/win32-$arch-user/stable", "$arch", arch, 1) +func DownloadInstaller(installer_dir, arch string) (installer_path string, err error) { + downloadUrl := fmt.Sprintf("https://update.code.visualstudio.com/latest/win32-%s-user/stable", arch) fmt.Printf("Downloading installer from %s.\n", downloadUrl) - file, err := os.CreateTemp(installer_dir, strings.Replace("vscode-win32-$arch-user*.exe", "$arch", arch, 1)) - checkError(err) + file, err := os.CreateTemp(installer_dir, fmt.Sprintf("vscode-win32-%s-user*.exe", arch)) + if err != nil { + return "", err + } + defer file.Close() client := http.Client{ CheckRedirect: func(r *http.Request, via []*http.Request) error { r.URL.Opaque = r.URL.Path return nil }, + Transport: &http.Transport{ + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 15 * time.Second, + ResponseHeaderTimeout: 15 * time.Second, + }, } // Put content on file resp, err := client.Get(downloadUrl) - checkError(err) + if err != nil { + return "", err + } defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return "", errors.New("invalid response status code") + } - _, err = io.Copy(file, resp.Body) - checkError(err) - defer file.Close() + written, err := copyWithTimeout(context.Background(), file, resp.Body) + if err != nil { + return "", err + } + fmt.Println(written) fmt.Printf("Downloaded installer to file %s.\n", file.Name()) installer_path = file.Name() - return installer_path + return installer_path, nil } -func RunInstaller(installer_path string) { +func RunInstaller(installer_path string) error { path, err := exec.LookPath(installer_path) - checkError(err) + if err != nil { + return err + } cmd := exec.Command(path, "/verysilent", "/mergetasks=!runcode") stdout, err := cmd.Output() - checkError(err) + if err != nil { + return err + } fmt.Println(string(stdout)) + return nil } -func Cleanup(installer_dir string) { - err := os.RemoveAll(installer_dir) - checkError(err) +func Cleanup(installer_dir string) error { + return os.RemoveAll(installer_dir) } func checkError(e error) { @@ -90,3 +116,59 @@ func checkError(e error) { panic(e) } } + +func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { + // Every 5 seconds, ensure at least 200 bytes (40 bytes/second average) are read + interval := 5 + minCopyBytes := int64(200) + prevWritten := int64(0) + written := int64(0) + + done := make(chan error) + mu := sync.Mutex{} + t := time.NewTicker(time.Duration(interval) * time.Second) + defer t.Stop() + + // Read the stream, 32KB at a time + go func() { + buf := make([]byte, 32<<10) + for { + readBytes, readErr := src.Read(buf) + if readBytes > 0 { + writeBytes, writeErr := dst.Write(buf[0:readBytes]) + mu.Lock() + written += int64(writeBytes) + mu.Unlock() + if writeErr != nil { + done <- writeErr + return + } + } + if readErr != nil { + if readErr != io.EOF { + done <- readErr + } else { + done <- nil + } + return + } + } + }() + + for { + select { + case <-ctx.Done(): + return written, ctx.Err() + case <-t.C: + mu.Lock() + if written < prevWritten+minCopyBytes { + mu.Unlock() + return written, fmt.Errorf("stream stalled: received %d bytes over the last %d seconds", written, interval) + } + prevWritten = written + mu.Unlock() + case e := <-done: + return written, e + } + } +} From 28b8a21e968ace7e0e14d3eadc751703b78f8684 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Tue, 17 Aug 2021 17:46:30 +0200 Subject: [PATCH 2/2] Validate SHA-256 hash of downloaded file --- main.go | 108 ++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 90 insertions(+), 18 deletions(-) diff --git a/main.go b/main.go index d86c650..086a43f 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,11 @@ package main import ( + "bytes" "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -19,7 +23,9 @@ func main() { arch := GetSystemInfo() installer_dir, err := SetupTemporaryDirectory(arch) checkError(err) - installer_path, err := DownloadInstaller(installer_dir, arch) + release_info, err := GetReleaseInfo(arch) + checkError(err) + installer_path, err := DownloadInstaller(installer_dir, arch, release_info) checkError(err) err = RunInstaller(installer_path) checkError(err) @@ -44,10 +50,44 @@ func SetupTemporaryDirectory(arch string) (installer_dir string, err error) { return ioutil.TempDir("", "vscode-winsta11er") } -func DownloadInstaller(installer_dir, arch string) (installer_path string, err error) { - downloadUrl := fmt.Sprintf("https://update.code.visualstudio.com/latest/win32-%s-user/stable", arch) +type ReleaseInfo struct { + Url string `json:"url"` + Name string `json:"name"` + Sha256Hash string `json:"sha256hash"` +} - fmt.Printf("Downloading installer from %s.\n", downloadUrl) +func GetReleaseInfo(arch string) (info *ReleaseInfo, err error) { + apiUrl := fmt.Sprintf("https://update.code.visualstudio.com/api/update/win32-%s-user/stable/latest", arch) + fmt.Printf("Requesting hash from %s.\n", apiUrl) + client := http.Client{ + Timeout: 30 * time.Second, + } + req, err := http.NewRequest("GET", apiUrl, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", "cli/vscode-winsta11er") + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, errors.New("invalid response status code") + } + info = &ReleaseInfo{} + err = json.NewDecoder(resp.Body).Decode(info) + if err != nil { + return nil, err + } + if info.Url == "" || info.Name == "" || info.Sha256Hash == "" { + return nil, errors.New("missing required fields in API response") + } + return info, nil +} + +func DownloadInstaller(installer_dir, arch string, info *ReleaseInfo) (installer_path string, err error) { + fmt.Printf("Downloading installer from %s.\n", info.Url) file, err := os.CreateTemp(installer_dir, fmt.Sprintf("vscode-win32-%s-user*.exe", arch)) if err != nil { @@ -56,10 +96,9 @@ func DownloadInstaller(installer_dir, arch string) (installer_path string, err e defer file.Close() client := http.Client{ - CheckRedirect: func(r *http.Request, via []*http.Request) error { - r.URL.Opaque = r.URL.Path - return nil - }, + // Disable timeout here because file can take a while to download on slow connections + // Instead, we're using a function that reads from the stream and makes sure data is flowing constantly + Timeout: 0, Transport: &http.Transport{ Dial: (&net.Dialer{ Timeout: 30 * time.Second, @@ -69,8 +108,13 @@ func DownloadInstaller(installer_dir, arch string) (installer_path string, err e }, } - // Put content on file - resp, err := client.Get(downloadUrl) + // Request the file + req, err := http.NewRequest("GET", info.Url, nil) + if err != nil { + return "", err + } + req.Header.Set("User-Agent", "cli/vscode-winsta11er") + resp, err := client.Do(req) if err != nil { return "", err } @@ -79,11 +123,11 @@ func DownloadInstaller(installer_dir, arch string) (installer_path string, err e return "", errors.New("invalid response status code") } - written, err := copyWithTimeout(context.Background(), file, resp.Body) + // Copy the stream to the file and calculate the hash + _, err = copyWithTimeout(context.Background(), file, resp.Body, info.Sha256Hash) if err != nil { return "", err } - fmt.Println(written) fmt.Printf("Downloaded installer to file %s.\n", file.Name()) installer_path = file.Name() @@ -117,13 +161,19 @@ func checkError(e error) { } } -func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { +func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader, expectedHash string) (int64, error) { // Every 5 seconds, ensure at least 200 bytes (40 bytes/second average) are read interval := 5 minCopyBytes := int64(200) prevWritten := int64(0) written := int64(0) + expectedHashBytes, err := hex.DecodeString(expectedHash) + if err != nil { + return 0, fmt.Errorf("error decoding hash from hex: %v", err) + } + + h := sha256.New() done := make(chan error) mu := sync.Mutex{} t := time.NewTicker(time.Duration(interval) * time.Second) @@ -131,11 +181,23 @@ func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader) (int64, // Read the stream, 32KB at a time go func() { - buf := make([]byte, 32<<10) + var ( + writeErr, readErr, hashErr error + writeBytes, readBytes int + buf = make([]byte, 32<<10) + ) for { - readBytes, readErr := src.Read(buf) + readBytes, readErr = src.Read(buf) if readBytes > 0 { - writeBytes, writeErr := dst.Write(buf[0:readBytes]) + // Add to the hash + _, hashErr = h.Write(buf[0:readBytes]) + if hashErr != nil { + done <- hashErr + return + } + + // Write to disk and update the number of bytes written + writeBytes, writeErr = dst.Write(buf[0:readBytes]) mu.Lock() written += int64(writeBytes) mu.Unlock() @@ -145,11 +207,21 @@ func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader) (int64, } } if readErr != nil { + // If error is EOF, means we read the entire file, so don't consider that as error if readErr != io.EOF { done <- readErr - } else { - done <- nil + return } + + // Compute and compare the checksum + hash := h.Sum(nil) + if !bytes.Equal(expectedHashBytes, hash[:]) { + done <- errors.New("downloaded file's hash doesn't match") + return + } + + // No error + done <- nil return } }