mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-06-05 05:20:34 +00:00
26098f6c8d
Use CDN-backed Debian mirror to make tests run faster hopefully for everyone. Redirects might be important to know what exactly is going on when items are being downloaded.
271 lines
6.7 KiB
Go
271 lines
6.7 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/aptly-dev/aptly/aptly"
|
|
"github.com/aptly-dev/aptly/utils"
|
|
"github.com/mxk/go-flowrate/flowrate"
|
|
"github.com/pkg/errors"
|
|
"github.com/smira/go-ftp-protocol/protocol"
|
|
)
|
|
|
|
// Check interface
|
|
var (
|
|
_ aptly.Downloader = (*downloaderImpl)(nil)
|
|
)
|
|
|
|
// downloaderImpl is implementation of Downloader interface
|
|
type downloaderImpl struct {
|
|
progress aptly.Progress
|
|
aggWriter io.Writer
|
|
maxTries int
|
|
client *http.Client
|
|
}
|
|
|
|
// NewDownloader creates new instance of Downloader which specified number
|
|
// of threads and download limit in bytes/sec
|
|
func NewDownloader(downLimit int64, maxTries int, progress aptly.Progress) aptly.Downloader {
|
|
transport := http.Transport{}
|
|
transport.Proxy = http.DefaultTransport.(*http.Transport).Proxy
|
|
transport.ResponseHeaderTimeout = 30 * time.Second
|
|
transport.TLSHandshakeTimeout = http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout
|
|
transport.ExpectContinueTimeout = http.DefaultTransport.(*http.Transport).ExpectContinueTimeout
|
|
transport.DisableCompression = true
|
|
initTransport(&transport)
|
|
transport.RegisterProtocol("ftp", &protocol.FTPRoundTripper{})
|
|
|
|
downloader := &downloaderImpl{
|
|
progress: progress,
|
|
maxTries: maxTries,
|
|
client: &http.Client{
|
|
Transport: &transport,
|
|
},
|
|
}
|
|
|
|
downloader.client.CheckRedirect = downloader.checkRedirect
|
|
if downLimit > 0 {
|
|
downloader.aggWriter = flowrate.NewWriter(progress, downLimit)
|
|
} else {
|
|
downloader.aggWriter = progress
|
|
}
|
|
|
|
return downloader
|
|
}
|
|
|
|
func (downloader *downloaderImpl) checkRedirect(req *http.Request, via []*http.Request) error {
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Following redirect to %s...\n", req.URL)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetProgress returns Progress object
|
|
func (downloader *downloaderImpl) GetProgress() aptly.Progress {
|
|
return downloader.progress
|
|
}
|
|
|
|
// GetLength of given url
|
|
func (downloader *downloaderImpl) GetLength(ctx context.Context, url string) (int64, error) {
|
|
req, err := downloader.newRequest(ctx, "HEAD", url)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
var resp *http.Response
|
|
|
|
maxTries := downloader.maxTries
|
|
for maxTries > 0 {
|
|
resp, err = downloader.client.Do(req)
|
|
if err != nil && retryableError(err) {
|
|
maxTries--
|
|
} else {
|
|
// stop retrying
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return -1, errors.Wrap(err, url)
|
|
}
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
return -1, &Error{Code: resp.StatusCode, URL: url}
|
|
}
|
|
|
|
if resp.ContentLength < 0 {
|
|
return -1, fmt.Errorf("could not determine length of %s", url)
|
|
}
|
|
|
|
return resp.ContentLength, nil
|
|
}
|
|
|
|
// Download starts new download task
|
|
func (downloader *downloaderImpl) Download(ctx context.Context, url string, destination string) error {
|
|
return downloader.DownloadWithChecksum(ctx, url, destination, nil, false)
|
|
}
|
|
|
|
func retryableError(err error) bool {
|
|
// unwrap errors.Wrap
|
|
err = errors.Cause(err)
|
|
|
|
// unwrap *url.Error
|
|
if wrapped, ok := err.(*url.Error); ok {
|
|
err = wrapped.Err
|
|
}
|
|
|
|
switch err {
|
|
case io.EOF:
|
|
return true
|
|
case io.ErrUnexpectedEOF:
|
|
return true
|
|
}
|
|
|
|
switch err.(type) {
|
|
case *net.OpError:
|
|
return true
|
|
case syscall.Errno:
|
|
return true
|
|
case net.Error:
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (downloader *downloaderImpl) newRequest(ctx context.Context, method, url string) (*http.Request, error) {
|
|
req, err := http.NewRequest(method, url, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, url)
|
|
}
|
|
req.Close = true
|
|
req = req.WithContext(ctx)
|
|
|
|
proxyURL, _ := downloader.client.Transport.(*http.Transport).Proxy(req)
|
|
if proxyURL == nil && (req.URL.Scheme == "http" || req.URL.Scheme == "https") {
|
|
req.URL.Opaque = strings.Replace(req.URL.RequestURI(), "+", "%2b", -1)
|
|
req.URL.RawQuery = ""
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
// DownloadWithChecksum starts new download task with checksum verification
|
|
func (downloader *downloaderImpl) DownloadWithChecksum(ctx context.Context, url string, destination string,
|
|
expected *utils.ChecksumInfo, ignoreMismatch bool) error {
|
|
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Downloading %s...\n", url)
|
|
}
|
|
req, err := downloader.newRequest(ctx, "GET", url)
|
|
|
|
var temppath string
|
|
maxTries := downloader.maxTries
|
|
for maxTries > 0 {
|
|
temppath, err = downloader.download(req, url, destination, expected, ignoreMismatch)
|
|
|
|
if err != nil && retryableError(err) {
|
|
maxTries--
|
|
} else {
|
|
// get out of the loop
|
|
break
|
|
}
|
|
}
|
|
|
|
// still an error after retrying, giving up
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = os.Rename(temppath, destination)
|
|
if err != nil {
|
|
os.Remove(temppath)
|
|
return errors.Wrap(err, url)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (downloader *downloaderImpl) download(req *http.Request, url, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool) (string, error) {
|
|
resp, err := downloader.client.Do(req)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
if resp.Body != nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
return "", &Error{Code: resp.StatusCode, URL: url}
|
|
}
|
|
|
|
err = os.MkdirAll(filepath.Dir(destination), 0777)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
|
|
temppath := destination + ".down"
|
|
|
|
outfile, err := os.Create(temppath)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
defer outfile.Close()
|
|
|
|
checksummer := utils.NewChecksumWriter()
|
|
writers := []io.Writer{outfile, downloader.aggWriter}
|
|
|
|
if expected != nil {
|
|
writers = append(writers, checksummer)
|
|
}
|
|
|
|
w := io.MultiWriter(writers...)
|
|
|
|
_, err = io.Copy(w, resp.Body)
|
|
if err != nil {
|
|
os.Remove(temppath)
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
|
|
if expected != nil {
|
|
actual := checksummer.Sum()
|
|
|
|
if actual.Size != expected.Size {
|
|
err = fmt.Errorf("%s: size check mismatch %d != %d", url, actual.Size, expected.Size)
|
|
} else if expected.MD5 != "" && actual.MD5 != expected.MD5 {
|
|
err = fmt.Errorf("%s: md5 hash mismatch %#v != %#v", url, actual.MD5, expected.MD5)
|
|
} else if expected.SHA1 != "" && actual.SHA1 != expected.SHA1 {
|
|
err = fmt.Errorf("%s: sha1 hash mismatch %#v != %#v", url, actual.SHA1, expected.SHA1)
|
|
} else if expected.SHA256 != "" && actual.SHA256 != expected.SHA256 {
|
|
err = fmt.Errorf("%s: sha256 hash mismatch %#v != %#v", url, actual.SHA256, expected.SHA256)
|
|
} else if expected.SHA512 != "" && actual.SHA512 != expected.SHA512 {
|
|
err = fmt.Errorf("%s: sha512 hash mismatch %#v != %#v", url, actual.SHA512, expected.SHA512)
|
|
}
|
|
|
|
if err != nil {
|
|
if ignoreMismatch {
|
|
downloader.progress.Printf("WARNING: %s\n", err.Error())
|
|
} else {
|
|
os.Remove(temppath)
|
|
return "", err
|
|
}
|
|
} else {
|
|
// update checksums if they match, so that they contain exactly expected set
|
|
*expected = actual
|
|
}
|
|
}
|
|
|
|
return temppath, nil
|
|
}
|