mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-06-06 05:30:57 +00:00
f0a370db24
Apply retries as global, config-level option `downloadRetries` so that it can be applied to any aptly command which downloads objects. Unwrap `errors.Wrap` which is used in downloader. Unwrap `*url.Error` which should be the actual error returned from the HTTP client, catch more cases, be more specific around failures.
262 lines
6.4 KiB
Go
262 lines
6.4 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/aptly-dev/aptly/aptly"
|
|
"github.com/aptly-dev/aptly/utils"
|
|
"github.com/mxk/go-flowrate/flowrate"
|
|
"github.com/pkg/errors"
|
|
"github.com/smira/go-ftp-protocol/protocol"
|
|
)
|
|
|
|
// Check interface
|
|
var (
|
|
_ aptly.Downloader = (*downloaderImpl)(nil)
|
|
)
|
|
|
|
// downloaderImpl is implementation of Downloader interface
|
|
type downloaderImpl struct {
|
|
progress aptly.Progress
|
|
aggWriter io.Writer
|
|
maxTries int
|
|
client *http.Client
|
|
}
|
|
|
|
// NewDownloader creates new instance of Downloader which specified number
|
|
// of threads and download limit in bytes/sec
|
|
func NewDownloader(downLimit int64, maxTries int, progress aptly.Progress) aptly.Downloader {
|
|
transport := http.Transport{}
|
|
transport.Proxy = http.DefaultTransport.(*http.Transport).Proxy
|
|
transport.ResponseHeaderTimeout = 30 * time.Second
|
|
transport.TLSHandshakeTimeout = http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout
|
|
transport.ExpectContinueTimeout = http.DefaultTransport.(*http.Transport).ExpectContinueTimeout
|
|
transport.DisableCompression = true
|
|
initTransport(&transport)
|
|
transport.RegisterProtocol("ftp", &protocol.FTPRoundTripper{})
|
|
|
|
downloader := &downloaderImpl{
|
|
progress: progress,
|
|
maxTries: maxTries,
|
|
client: &http.Client{
|
|
Transport: &transport,
|
|
},
|
|
}
|
|
|
|
if downLimit > 0 {
|
|
downloader.aggWriter = flowrate.NewWriter(progress, downLimit)
|
|
} else {
|
|
downloader.aggWriter = progress
|
|
}
|
|
|
|
return downloader
|
|
}
|
|
|
|
// GetProgress returns Progress object
|
|
func (downloader *downloaderImpl) GetProgress() aptly.Progress {
|
|
return downloader.progress
|
|
}
|
|
|
|
// GetLength of given url
|
|
func (downloader *downloaderImpl) GetLength(ctx context.Context, url string) (int64, error) {
|
|
req, err := downloader.newRequest(ctx, "HEAD", url)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
var resp *http.Response
|
|
|
|
maxTries := downloader.maxTries
|
|
for maxTries > 0 {
|
|
resp, err = downloader.client.Do(req)
|
|
if err != nil && retryableError(err) {
|
|
maxTries--
|
|
} else {
|
|
// stop retrying
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return -1, errors.Wrap(err, url)
|
|
}
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
return -1, &Error{Code: resp.StatusCode, URL: url}
|
|
}
|
|
|
|
if resp.ContentLength < 0 {
|
|
return -1, fmt.Errorf("could not determine length of %s", url)
|
|
}
|
|
|
|
return resp.ContentLength, nil
|
|
}
|
|
|
|
// Download starts new download task
|
|
func (downloader *downloaderImpl) Download(ctx context.Context, url string, destination string) error {
|
|
return downloader.DownloadWithChecksum(ctx, url, destination, nil, false)
|
|
}
|
|
|
|
func retryableError(err error) bool {
|
|
// unwrap errors.Wrap
|
|
err = errors.Cause(err)
|
|
|
|
// unwrap *url.Error
|
|
if wrapped, ok := err.(*url.Error); ok {
|
|
err = wrapped.Err
|
|
}
|
|
|
|
switch err {
|
|
case io.EOF:
|
|
return true
|
|
case io.ErrUnexpectedEOF:
|
|
return true
|
|
}
|
|
|
|
switch err.(type) {
|
|
case *net.OpError:
|
|
return true
|
|
case syscall.Errno:
|
|
return true
|
|
case net.Error:
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (downloader *downloaderImpl) newRequest(ctx context.Context, method, url string) (*http.Request, error) {
|
|
req, err := http.NewRequest(method, url, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, url)
|
|
}
|
|
req.Close = true
|
|
req = req.WithContext(ctx)
|
|
|
|
proxyURL, _ := downloader.client.Transport.(*http.Transport).Proxy(req)
|
|
if proxyURL == nil && (req.URL.Scheme == "http" || req.URL.Scheme == "https") {
|
|
req.URL.Opaque = strings.Replace(req.URL.RequestURI(), "+", "%2b", -1)
|
|
req.URL.RawQuery = ""
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
// DownloadWithChecksum starts new download task with checksum verification
|
|
func (downloader *downloaderImpl) DownloadWithChecksum(ctx context.Context, url string, destination string,
|
|
expected *utils.ChecksumInfo, ignoreMismatch bool) error {
|
|
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Downloading %s...\n", url)
|
|
}
|
|
req, err := downloader.newRequest(ctx, "GET", url)
|
|
|
|
var temppath string
|
|
maxTries := downloader.maxTries
|
|
for maxTries > 0 {
|
|
temppath, err = downloader.download(req, url, destination, expected, ignoreMismatch)
|
|
|
|
if err != nil && retryableError(err) {
|
|
maxTries--
|
|
} else {
|
|
// get out of the loop
|
|
break
|
|
}
|
|
}
|
|
|
|
// still an error after retrying, giving up
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = os.Rename(temppath, destination)
|
|
if err != nil {
|
|
os.Remove(temppath)
|
|
return errors.Wrap(err, url)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (downloader *downloaderImpl) download(req *http.Request, url, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool) (string, error) {
|
|
resp, err := downloader.client.Do(req)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
if resp.Body != nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
return "", &Error{Code: resp.StatusCode, URL: url}
|
|
}
|
|
|
|
err = os.MkdirAll(filepath.Dir(destination), 0777)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
|
|
temppath := destination + ".down"
|
|
|
|
outfile, err := os.Create(temppath)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
defer outfile.Close()
|
|
|
|
checksummer := utils.NewChecksumWriter()
|
|
writers := []io.Writer{outfile, downloader.aggWriter}
|
|
|
|
if expected != nil {
|
|
writers = append(writers, checksummer)
|
|
}
|
|
|
|
w := io.MultiWriter(writers...)
|
|
|
|
_, err = io.Copy(w, resp.Body)
|
|
if err != nil {
|
|
os.Remove(temppath)
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
|
|
if expected != nil {
|
|
actual := checksummer.Sum()
|
|
|
|
if actual.Size != expected.Size {
|
|
err = fmt.Errorf("%s: size check mismatch %d != %d", url, actual.Size, expected.Size)
|
|
} else if expected.MD5 != "" && actual.MD5 != expected.MD5 {
|
|
err = fmt.Errorf("%s: md5 hash mismatch %#v != %#v", url, actual.MD5, expected.MD5)
|
|
} else if expected.SHA1 != "" && actual.SHA1 != expected.SHA1 {
|
|
err = fmt.Errorf("%s: sha1 hash mismatch %#v != %#v", url, actual.SHA1, expected.SHA1)
|
|
} else if expected.SHA256 != "" && actual.SHA256 != expected.SHA256 {
|
|
err = fmt.Errorf("%s: sha256 hash mismatch %#v != %#v", url, actual.SHA256, expected.SHA256)
|
|
} else if expected.SHA512 != "" && actual.SHA512 != expected.SHA512 {
|
|
err = fmt.Errorf("%s: sha512 hash mismatch %#v != %#v", url, actual.SHA512, expected.SHA512)
|
|
}
|
|
|
|
if err != nil {
|
|
if ignoreMismatch {
|
|
downloader.progress.Printf("WARNING: %s\n", err.Error())
|
|
} else {
|
|
os.Remove(temppath)
|
|
return "", err
|
|
}
|
|
} else {
|
|
// update checksums if they match, so that they contain exactly expected set
|
|
*expected = actual
|
|
}
|
|
}
|
|
|
|
return temppath, nil
|
|
}
|