diff --git a/main.go b/main.go index d86c650..086a43f 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,11 @@ package main import ( + "bytes" "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -19,7 +23,9 @@ func main() { arch := GetSystemInfo() installer_dir, err := SetupTemporaryDirectory(arch) checkError(err) - installer_path, err := DownloadInstaller(installer_dir, arch) + release_info, err := GetReleaseInfo(arch) + checkError(err) + installer_path, err := DownloadInstaller(installer_dir, arch, release_info) checkError(err) err = RunInstaller(installer_path) checkError(err) @@ -44,10 +50,44 @@ func SetupTemporaryDirectory(arch string) (installer_dir string, err error) { return ioutil.TempDir("", "vscode-winsta11er") } -func DownloadInstaller(installer_dir, arch string) (installer_path string, err error) { - downloadUrl := fmt.Sprintf("https://update.code.visualstudio.com/latest/win32-%s-user/stable", arch) +type ReleaseInfo struct { + Url string `json:"url"` + Name string `json:"name"` + Sha256Hash string `json:"sha256hash"` +} - fmt.Printf("Downloading installer from %s.\n", downloadUrl) +func GetReleaseInfo(arch string) (info *ReleaseInfo, err error) { + apiUrl := fmt.Sprintf("https://update.code.visualstudio.com/api/update/win32-%s-user/stable/latest", arch) + fmt.Printf("Requesting hash from %s.\n", apiUrl) + client := http.Client{ + Timeout: 30 * time.Second, + } + req, err := http.NewRequest("GET", apiUrl, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", "cli/vscode-winsta11er") + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, errors.New("invalid response status code") + } + info = &ReleaseInfo{} + err = json.NewDecoder(resp.Body).Decode(info) + if err != nil { + return nil, err + } + if info.Url == "" || info.Name == "" || info.Sha256Hash == "" { + return nil, errors.New("missing required fields in API response") + } + return info, nil +} + +func DownloadInstaller(installer_dir, arch string, info *ReleaseInfo) (installer_path string, err error) { + fmt.Printf("Downloading installer from %s.\n", info.Url) file, err := os.CreateTemp(installer_dir, fmt.Sprintf("vscode-win32-%s-user*.exe", arch)) if err != nil { @@ -56,10 +96,9 @@ func DownloadInstaller(installer_dir, arch string) (installer_path string, err e defer file.Close() client := http.Client{ - CheckRedirect: func(r *http.Request, via []*http.Request) error { - r.URL.Opaque = r.URL.Path - return nil - }, + // Disable timeout here because file can take a while to download on slow connections + // Instead, we're using a function that reads from the stream and makes sure data is flowing constantly + Timeout: 0, Transport: &http.Transport{ Dial: (&net.Dialer{ Timeout: 30 * time.Second, @@ -69,8 +108,13 @@ func DownloadInstaller(installer_dir, arch string) (installer_path string, err e }, } - // Put content on file - resp, err := client.Get(downloadUrl) + // Request the file + req, err := http.NewRequest("GET", info.Url, nil) + if err != nil { + return "", err + } + req.Header.Set("User-Agent", "cli/vscode-winsta11er") + resp, err := client.Do(req) if err != nil { return "", err } @@ -79,11 +123,11 @@ func DownloadInstaller(installer_dir, arch string) (installer_path string, err e return "", errors.New("invalid response status code") } - written, err := copyWithTimeout(context.Background(), file, resp.Body) + // Copy the stream to the file and calculate the hash + _, err = copyWithTimeout(context.Background(), file, resp.Body, info.Sha256Hash) if err != nil { return "", err } - fmt.Println(written) fmt.Printf("Downloaded installer to file %s.\n", file.Name()) installer_path = file.Name() @@ -117,13 +161,19 @@ func checkError(e error) { } } -func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { +func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader, expectedHash string) (int64, error) { // Every 5 seconds, ensure at least 200 bytes (40 bytes/second average) are read interval := 5 minCopyBytes := int64(200) prevWritten := int64(0) written := int64(0) + expectedHashBytes, err := hex.DecodeString(expectedHash) + if err != nil { + return 0, fmt.Errorf("error decoding hash from hex: %v", err) + } + + h := sha256.New() done := make(chan error) mu := sync.Mutex{} t := time.NewTicker(time.Duration(interval) * time.Second) @@ -131,11 +181,23 @@ func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader) (int64, // Read the stream, 32KB at a time go func() { - buf := make([]byte, 32<<10) + var ( + writeErr, readErr, hashErr error + writeBytes, readBytes int + buf = make([]byte, 32<<10) + ) for { - readBytes, readErr := src.Read(buf) + readBytes, readErr = src.Read(buf) if readBytes > 0 { - writeBytes, writeErr := dst.Write(buf[0:readBytes]) + // Add to the hash + _, hashErr = h.Write(buf[0:readBytes]) + if hashErr != nil { + done <- hashErr + return + } + + // Write to disk and update the number of bytes written + writeBytes, writeErr = dst.Write(buf[0:readBytes]) mu.Lock() written += int64(writeBytes) mu.Unlock() @@ -145,11 +207,21 @@ func copyWithTimeout(ctx context.Context, dst io.Writer, src io.Reader) (int64, } } if readErr != nil { + // If error is EOF, means we read the entire file, so don't consider that as error if readErr != io.EOF { done <- readErr - } else { - done <- nil + return } + + // Compute and compare the checksum + hash := h.Sum(nil) + if !bytes.Equal(expectedHashBytes, hash[:]) { + done <- errors.New("downloaded file's hash doesn't match") + return + } + + // No error + done <- nil return } }