mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-01-12 03:21:33 +00:00
309 lines
7.7 KiB
Go
309 lines
7.7 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/aptly-dev/aptly/aptly"
|
|
"github.com/aptly-dev/aptly/utils"
|
|
"github.com/mxk/go-flowrate/flowrate"
|
|
"github.com/pkg/errors"
|
|
"github.com/smira/go-ftp-protocol/protocol"
|
|
)
|
|
|
|
// Check interface
|
|
var (
|
|
_ aptly.Downloader = (*downloaderImpl)(nil)
|
|
)
|
|
|
|
// downloaderImpl is implementation of Downloader interface
|
|
type downloaderImpl struct {
|
|
progress aptly.Progress
|
|
aggWriter io.Writer
|
|
maxTries int
|
|
client *http.Client
|
|
}
|
|
|
|
// NewDownloader creates new instance of Downloader which specified number
|
|
// of threads and download limit in bytes/sec
|
|
func NewDownloader(downLimit int64, maxTries int, progress aptly.Progress) aptly.Downloader {
|
|
transport := http.Transport{}
|
|
transport.Proxy = http.DefaultTransport.(*http.Transport).Proxy
|
|
transport.ResponseHeaderTimeout = 30 * time.Second
|
|
transport.TLSHandshakeTimeout = http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout
|
|
transport.ExpectContinueTimeout = http.DefaultTransport.(*http.Transport).ExpectContinueTimeout
|
|
transport.DisableCompression = true
|
|
initTransport(&transport)
|
|
transport.RegisterProtocol("ftp", &protocol.FTPRoundTripper{})
|
|
|
|
downloader := &downloaderImpl{
|
|
progress: progress,
|
|
maxTries: maxTries,
|
|
aggWriter: io.Writer(progress),
|
|
client: &http.Client{
|
|
Transport: &transport,
|
|
},
|
|
}
|
|
|
|
if progress == nil {
|
|
downloader.aggWriter = io.Discard
|
|
}
|
|
if downLimit > 0 {
|
|
downloader.aggWriter = flowrate.NewWriter(downloader.aggWriter, downLimit)
|
|
}
|
|
|
|
downloader.client.CheckRedirect = downloader.checkRedirect
|
|
|
|
return downloader
|
|
}
|
|
|
|
func (downloader *downloaderImpl) checkRedirect(req *http.Request, _ []*http.Request) error {
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Following redirect to %s...\n", req.URL)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetProgress returns Progress object
|
|
func (downloader *downloaderImpl) GetProgress() aptly.Progress {
|
|
return downloader.progress
|
|
}
|
|
|
|
// GetLength of given url
|
|
func (downloader *downloaderImpl) GetLength(ctx context.Context, url string) (int64, error) {
|
|
req, err := downloader.newRequest(ctx, "HEAD", url)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
var resp *http.Response
|
|
|
|
maxTries := downloader.maxTries
|
|
for maxTries > 0 {
|
|
resp, err = downloader.client.Do(req)
|
|
if err != nil && retryableError(err) {
|
|
maxTries--
|
|
} else {
|
|
// stop retrying
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return -1, errors.Wrap(err, url)
|
|
}
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
return -1, &Error{Code: resp.StatusCode, URL: url}
|
|
}
|
|
|
|
if resp.ContentLength < 0 {
|
|
// an existing, but zero-length file can be reported with ContentLength -1
|
|
if resp.StatusCode == 200 && resp.ContentLength == -1 {
|
|
return 0, nil
|
|
}
|
|
return -1, fmt.Errorf("could not determine length of %s", url)
|
|
}
|
|
|
|
return resp.ContentLength, nil
|
|
}
|
|
|
|
// Download starts new download task
|
|
func (downloader *downloaderImpl) Download(ctx context.Context, url string, destination string) error {
|
|
return downloader.DownloadWithChecksum(ctx, url, destination, nil, false)
|
|
}
|
|
|
|
func retryableError(err error) bool {
|
|
// unwrap errors.Wrap
|
|
err = errors.Cause(err)
|
|
|
|
// unwrap *url.Error
|
|
if wrapped, ok := err.(*url.Error); ok {
|
|
err = wrapped.Err
|
|
}
|
|
|
|
switch err {
|
|
case context.Canceled:
|
|
return false
|
|
case io.EOF:
|
|
return true
|
|
case io.ErrUnexpectedEOF:
|
|
return true
|
|
}
|
|
|
|
switch err.(type) {
|
|
case *net.OpError:
|
|
return true
|
|
case syscall.Errno:
|
|
return true
|
|
case net.Error:
|
|
return true
|
|
}
|
|
// Note: make all errors retryable
|
|
return true
|
|
}
|
|
|
|
func (downloader *downloaderImpl) newRequest(ctx context.Context, method, url string) (*http.Request, error) {
|
|
req, err := http.NewRequest(method, url, nil)
|
|
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, url)
|
|
}
|
|
req.Close = true
|
|
req = req.WithContext(ctx)
|
|
|
|
proxyURL, _ := downloader.client.Transport.(*http.Transport).Proxy(req)
|
|
if proxyURL == nil && (req.URL.Scheme == "http" || req.URL.Scheme == "https") {
|
|
req.URL.Opaque = strings.Replace(req.URL.RequestURI(), "+", "%2b", -1)
|
|
req.URL.RawQuery = ""
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
// DownloadWithChecksum starts new download task with checksum verification
|
|
func (downloader *downloaderImpl) DownloadWithChecksum(ctx context.Context, url string, destination string,
|
|
expected *utils.ChecksumInfo, ignoreMismatch bool) error {
|
|
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Downloading: %s\n", url)
|
|
defer downloader.progress.Flush()
|
|
}
|
|
req, err := downloader.newRequest(ctx, "GET", url)
|
|
|
|
var temppath string
|
|
maxTries := downloader.maxTries
|
|
const delayMax = time.Duration(5 * time.Minute)
|
|
delay := time.Duration(1 * time.Second)
|
|
const delayMultiplier = 2
|
|
for maxTries > 0 {
|
|
temppath, err = downloader.download(req, url, destination, expected, ignoreMismatch)
|
|
|
|
if err != nil {
|
|
if retryableError(err) {
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Error (retrying): %s\n", err)
|
|
}
|
|
maxTries--
|
|
time.Sleep(delay)
|
|
// Sleep exponentially at the next retry, but no longer than delayMax
|
|
delay *= delayMultiplier
|
|
if delay > delayMax {
|
|
delay = delayMax
|
|
}
|
|
} else {
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Error: %s \n", err)
|
|
}
|
|
break
|
|
}
|
|
} else {
|
|
// get out of the loop
|
|
break
|
|
}
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Retrying %d %s...\n", maxTries, url)
|
|
}
|
|
}
|
|
|
|
// still an error after retrying, giving up
|
|
if err != nil {
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("Download Error: %s\n", url)
|
|
}
|
|
return err
|
|
}
|
|
|
|
err = os.Rename(temppath, destination)
|
|
if err != nil {
|
|
os.Remove(temppath)
|
|
return errors.Wrap(err, url)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (downloader *downloaderImpl) download(req *http.Request, url, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool) (string, error) {
|
|
resp, err := downloader.client.Do(req)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
if resp.Body != nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
return "", &Error{Code: resp.StatusCode, URL: url}
|
|
}
|
|
|
|
err = os.MkdirAll(filepath.Dir(destination), 0777)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
|
|
temppath := destination + ".down"
|
|
|
|
outfile, err := os.Create(temppath)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
defer outfile.Close()
|
|
|
|
checksummer := utils.NewChecksumWriter()
|
|
writers := []io.Writer{outfile, downloader.aggWriter}
|
|
|
|
if expected != nil {
|
|
writers = append(writers, checksummer)
|
|
}
|
|
|
|
w := io.MultiWriter(writers...)
|
|
|
|
_, err = io.Copy(w, resp.Body)
|
|
if err != nil {
|
|
os.Remove(temppath)
|
|
return "", errors.Wrap(err, url)
|
|
}
|
|
|
|
if expected != nil {
|
|
actual := checksummer.Sum()
|
|
|
|
if actual.Size != expected.Size {
|
|
err = fmt.Errorf("%s: size check mismatch %d != %d", url, actual.Size, expected.Size)
|
|
} else if expected.MD5 != "" && actual.MD5 != expected.MD5 {
|
|
err = fmt.Errorf("%s: md5 hash mismatch %#v != %#v", url, actual.MD5, expected.MD5)
|
|
} else if expected.SHA1 != "" && actual.SHA1 != expected.SHA1 {
|
|
err = fmt.Errorf("%s: sha1 hash mismatch %#v != %#v", url, actual.SHA1, expected.SHA1)
|
|
} else if expected.SHA256 != "" && actual.SHA256 != expected.SHA256 {
|
|
err = fmt.Errorf("%s: sha256 hash mismatch %#v != %#v", url, actual.SHA256, expected.SHA256)
|
|
} else if expected.SHA512 != "" && actual.SHA512 != expected.SHA512 {
|
|
err = fmt.Errorf("%s: sha512 hash mismatch %#v != %#v", url, actual.SHA512, expected.SHA512)
|
|
}
|
|
|
|
if err != nil {
|
|
if ignoreMismatch {
|
|
if downloader.progress != nil {
|
|
downloader.progress.Printf("WARNING: %s\n", err.Error())
|
|
}
|
|
} else {
|
|
os.Remove(temppath)
|
|
return "", err
|
|
}
|
|
} else {
|
|
// update checksums if they match, so that they contain exactly expected set
|
|
*expected = actual
|
|
}
|
|
}
|
|
|
|
return temppath, nil
|
|
}
|