package http import ( "context" "fmt" "io" "net" "net/http" "net/url" "os" "path/filepath" "strings" "syscall" "time" "github.com/aptly-dev/aptly/aptly" "github.com/aptly-dev/aptly/utils" "github.com/mxk/go-flowrate/flowrate" "github.com/pkg/errors" "github.com/smira/go-ftp-protocol/protocol" ) // Check interface var ( _ aptly.Downloader = (*downloaderImpl)(nil) ) // downloaderImpl is implementation of Downloader interface type downloaderImpl struct { progress aptly.Progress aggWriter io.Writer maxTries int client *http.Client } // NewDownloader creates new instance of Downloader which specified number // of threads and download limit in bytes/sec func NewDownloader(downLimit int64, maxTries int, progress aptly.Progress) aptly.Downloader { transport := http.Transport{} transport.Proxy = http.DefaultTransport.(*http.Transport).Proxy transport.ResponseHeaderTimeout = 30 * time.Second transport.TLSHandshakeTimeout = http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout transport.ExpectContinueTimeout = http.DefaultTransport.(*http.Transport).ExpectContinueTimeout transport.DisableCompression = true initTransport(&transport) transport.RegisterProtocol("ftp", &protocol.FTPRoundTripper{}) transport.RegisterProtocol("ar+https", NewGCPRoundTripper(&transport)) downloader := &downloaderImpl{ progress: progress, maxTries: maxTries, aggWriter: io.Writer(progress), client: &http.Client{ Transport: &transport, }, } if progress == nil { downloader.aggWriter = io.Discard } if downLimit > 0 { downloader.aggWriter = flowrate.NewWriter(downloader.aggWriter, downLimit) } downloader.client.CheckRedirect = downloader.checkRedirect return downloader } func (downloader *downloaderImpl) checkRedirect(req *http.Request, _ []*http.Request) error { if downloader.progress != nil { downloader.progress.Printf("Following redirect to %s...\n", req.URL) } return nil } // GetProgress returns Progress object func (downloader *downloaderImpl) GetProgress() aptly.Progress { return downloader.progress } // GetLength of given url func (downloader *downloaderImpl) GetLength(ctx context.Context, url string) (int64, error) { req, err := downloader.newRequest(ctx, "HEAD", url) if err != nil { return -1, err } var resp *http.Response maxTries := downloader.maxTries for maxTries > 0 { resp, err = downloader.client.Do(req) if err != nil && retryableError(err) { maxTries-- } else { // stop retrying break } } if err != nil { return -1, errors.Wrap(err, url) } if resp.StatusCode < 200 || resp.StatusCode > 299 { return -1, &Error{Code: resp.StatusCode, URL: url} } if resp.ContentLength < 0 { // an existing, but zero-length file can be reported with ContentLength -1 if resp.StatusCode == 200 && resp.ContentLength == -1 { return 0, nil } return -1, fmt.Errorf("could not determine length of %s", url) } return resp.ContentLength, nil } // Download starts new download task func (downloader *downloaderImpl) Download(ctx context.Context, url string, destination string) error { return downloader.DownloadWithChecksum(ctx, url, destination, nil, false) } func retryableError(err error) bool { // unwrap errors.Wrap err = errors.Cause(err) // unwrap *url.Error if wrapped, ok := err.(*url.Error); ok { err = wrapped.Err } switch err { case context.Canceled: return false case io.EOF: return true case io.ErrUnexpectedEOF: return true } switch err.(type) { case *net.OpError: return true case syscall.Errno: return true case net.Error: return true } // Note: make all errors retryable return true } func (downloader *downloaderImpl) newRequest(ctx context.Context, method, url string) (*http.Request, error) { req, err := http.NewRequest(method, url, nil) if err != nil { return nil, errors.Wrap(err, url) } req.Close = true req = req.WithContext(ctx) proxyURL, _ := downloader.client.Transport.(*http.Transport).Proxy(req) if proxyURL == nil && (req.URL.Scheme == "http" || req.URL.Scheme == "https") { req.URL.Opaque = strings.Replace(req.URL.RequestURI(), "+", "%2b", -1) req.URL.RawQuery = "" } return req, nil } // DownloadWithChecksum starts new download task with checksum verification func (downloader *downloaderImpl) DownloadWithChecksum(ctx context.Context, url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool) error { if downloader.progress != nil { downloader.progress.Printf("Downloading: %s\n", url) defer downloader.progress.Flush() } req, err := downloader.newRequest(ctx, "GET", url) var temppath string maxTries := downloader.maxTries const delayMax = time.Duration(5 * time.Minute) delay := time.Duration(1 * time.Second) const delayMultiplier = 2 for maxTries > 0 { temppath, err = downloader.download(req, url, destination, expected, ignoreMismatch) if err != nil { if retryableError(err) { if downloader.progress != nil { downloader.progress.Printf("Error (retrying): %s\n", err) } maxTries-- time.Sleep(delay) // Sleep exponentially at the next retry, but no longer than delayMax delay *= delayMultiplier if delay > delayMax { delay = delayMax } } else { if downloader.progress != nil { downloader.progress.Printf("Error: %s \n", err) } break } } else { // get out of the loop break } if downloader.progress != nil { downloader.progress.Printf("Retrying %d %s...\n", maxTries, url) } } // still an error after retrying, giving up if err != nil { if downloader.progress != nil { downloader.progress.Printf("Download Error: %s\n", url) } return err } err = os.Rename(temppath, destination) if err != nil { _ = os.Remove(temppath) return errors.Wrap(err, url) } return nil } func (downloader *downloaderImpl) download(req *http.Request, url, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool) (string, error) { resp, err := downloader.client.Do(req) if err != nil { return "", errors.Wrap(err, url) } if resp.Body != nil { defer func() { _ = resp.Body.Close() }() } if resp.StatusCode < 200 || resp.StatusCode > 299 { return "", &Error{Code: resp.StatusCode, URL: url} } err = os.MkdirAll(filepath.Dir(destination), 0777) if err != nil { return "", errors.Wrap(err, url) } temppath := destination + ".down" outfile, err := os.Create(temppath) if err != nil { return "", errors.Wrap(err, url) } defer func() { _ = outfile.Close() }() checksummer := utils.NewChecksumWriter() writers := []io.Writer{outfile, downloader.aggWriter} if expected != nil { writers = append(writers, checksummer) } w := io.MultiWriter(writers...) _, err = io.Copy(w, resp.Body) if err != nil { _ = os.Remove(temppath) return "", errors.Wrap(err, url) } if expected != nil { actual := checksummer.Sum() if actual.Size != expected.Size { err = fmt.Errorf("%s: size check mismatch %d != %d", url, actual.Size, expected.Size) } else if expected.MD5 != "" && actual.MD5 != expected.MD5 { err = fmt.Errorf("%s: md5 hash mismatch %#v != %#v", url, actual.MD5, expected.MD5) } else if expected.SHA1 != "" && actual.SHA1 != expected.SHA1 { err = fmt.Errorf("%s: sha1 hash mismatch %#v != %#v", url, actual.SHA1, expected.SHA1) } else if expected.SHA256 != "" && actual.SHA256 != expected.SHA256 { err = fmt.Errorf("%s: sha256 hash mismatch %#v != %#v", url, actual.SHA256, expected.SHA256) } else if expected.SHA512 != "" && actual.SHA512 != expected.SHA512 { err = fmt.Errorf("%s: sha512 hash mismatch %#v != %#v", url, actual.SHA512, expected.SHA512) } if err != nil { if ignoreMismatch { if downloader.progress != nil { downloader.progress.Printf("WARNING: %s\n", err.Error()) } } else { _ = os.Remove(temppath) return "", err } } else { // update checksums if they match, so that they contain exactly expected set *expected = actual } } return temppath, nil }