Files
aptly/http/download.go
Rohan Garg 140a11c04a Continue even when a server replies with 403
Amazon S3 replies with a 403 when accessing files that don't exist,
while this should be fixed at Amazon, we can workaround it in aptly
for the moment.
2014-10-25 21:41:57 +02:00

337 lines
8.7 KiB
Go

package http
import (
"code.google.com/p/mxk/go1/flowcontrol"
"compress/bzip2"
"compress/gzip"
"fmt"
"github.com/smira/aptly/aptly"
"github.com/smira/aptly/utils"
"github.com/smira/go-ftp-protocol/protocol"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
)
// HTTPError is download error connected to HTTP code
type HTTPError struct {
Code int
URL string
}
// Error
func (e *HTTPError) Error() string {
return fmt.Sprintf("HTTP code %d while fetching %s", e.Code, e.URL)
}
// Check interface
var (
_ aptly.Downloader = (*downloaderImpl)(nil)
)
// downloaderImpl is implementation of Downloader interface
type downloaderImpl struct {
queue chan *downloadTask
stop chan struct{}
stopped chan struct{}
pause chan struct{}
unpause chan struct{}
progress aptly.Progress
aggWriter io.Writer
threads int
client *http.Client
}
// downloadTask represents single item in queue
type downloadTask struct {
url string
destination string
result chan<- error
expected utils.ChecksumInfo
ignoreMismatch bool
}
// NewDownloader creates new instance of Downloader which specified number
// of threads and download limit in bytes/sec
func NewDownloader(threads int, downLimit int64, progress aptly.Progress) aptly.Downloader {
transport := *http.DefaultTransport.(*http.Transport)
transport.DisableCompression = true
transport.RegisterProtocol("ftp", &protocol.FTPRoundTripper{})
downloader := &downloaderImpl{
queue: make(chan *downloadTask, 1000),
stop: make(chan struct{}, threads),
stopped: make(chan struct{}, threads),
pause: make(chan struct{}),
unpause: make(chan struct{}),
threads: threads,
progress: progress,
client: &http.Client{
Transport: &transport,
},
}
if downLimit > 0 {
downloader.aggWriter = flowcontrol.NewWriter(progress, downLimit)
} else {
downloader.aggWriter = progress
}
for i := 0; i < downloader.threads; i++ {
go downloader.process()
}
return downloader
}
// Shutdown stops downloader after current tasks are finished,
// but doesn't process rest of queue
func (downloader *downloaderImpl) Shutdown() {
for i := 0; i < downloader.threads; i++ {
downloader.stop <- struct{}{}
}
for i := 0; i < downloader.threads; i++ {
<-downloader.stopped
}
}
// Abort stops downloader but doesn't wait for downloader to stop
func (downloader *downloaderImpl) Abort() {
for i := 0; i < downloader.threads; i++ {
downloader.stop <- struct{}{}
}
}
// Pause pauses task processing
func (downloader *downloaderImpl) Pause() {
for i := 0; i < downloader.threads; i++ {
downloader.pause <- struct{}{}
}
}
// Resume resumes task processing
func (downloader *downloaderImpl) Resume() {
for i := 0; i < downloader.threads; i++ {
downloader.unpause <- struct{}{}
}
}
// GetProgress returns Progress object
func (downloader *downloaderImpl) GetProgress() aptly.Progress {
return downloader.progress
}
// Download starts new download task
func (downloader *downloaderImpl) Download(url string, destination string, result chan<- error) {
downloader.DownloadWithChecksum(url, destination, result, utils.ChecksumInfo{Size: -1}, false)
}
// DownloadWithChecksum starts new download task with checksum verification
func (downloader *downloaderImpl) DownloadWithChecksum(url string, destination string, result chan<- error,
expected utils.ChecksumInfo, ignoreMismatch bool) {
downloader.queue <- &downloadTask{url: url, destination: destination, result: result, expected: expected, ignoreMismatch: ignoreMismatch}
}
// handleTask processes single download task
func (downloader *downloaderImpl) handleTask(task *downloadTask) {
downloader.progress.Printf("Downloading %s...\n", task.url)
resp, err := downloader.client.Get(task.url)
if err != nil {
task.result <- fmt.Errorf("%s: %s", task.url, err)
return
}
if resp.Body != nil {
defer resp.Body.Close()
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
task.result <- &HTTPError{Code: resp.StatusCode, URL: task.url}
return
}
err = os.MkdirAll(filepath.Dir(task.destination), 0755)
if err != nil {
task.result <- fmt.Errorf("%s: %s", task.url, err)
return
}
temppath := task.destination + ".down"
outfile, err := os.Create(temppath)
if err != nil {
task.result <- fmt.Errorf("%s: %s", task.url, err)
return
}
defer outfile.Close()
checksummer := utils.NewChecksumWriter()
writers := []io.Writer{outfile, downloader.aggWriter}
if task.expected.Size != -1 {
writers = append(writers, checksummer)
}
w := io.MultiWriter(writers...)
_, err = io.Copy(w, resp.Body)
if err != nil {
os.Remove(temppath)
task.result <- fmt.Errorf("%s: %s", task.url, err)
return
}
if task.expected.Size != -1 {
actual := checksummer.Sum()
if actual.Size != task.expected.Size {
err = fmt.Errorf("%s: size check mismatch %d != %d", task.url, actual.Size, task.expected.Size)
} else if task.expected.MD5 != "" && actual.MD5 != task.expected.MD5 {
err = fmt.Errorf("%s: md5 hash mismatch %#v != %#v", task.url, actual.MD5, task.expected.MD5)
} else if task.expected.SHA1 != "" && actual.SHA1 != task.expected.SHA1 {
err = fmt.Errorf("%s: sha1 hash mismatch %#v != %#v", task.url, actual.SHA1, task.expected.SHA1)
} else if task.expected.SHA256 != "" && actual.SHA256 != task.expected.SHA256 {
err = fmt.Errorf("%s: sha256 hash mismatch %#v != %#v", task.url, actual.SHA256, task.expected.SHA256)
}
if err != nil {
if task.ignoreMismatch {
downloader.progress.Printf("WARNING: %s\n", err.Error())
} else {
os.Remove(temppath)
task.result <- err
return
}
}
}
err = os.Rename(temppath, task.destination)
if err != nil {
os.Remove(temppath)
task.result <- fmt.Errorf("%s: %s", task.url, err)
return
}
task.result <- nil
}
// process implements download thread in goroutine
func (downloader *downloaderImpl) process() {
for {
select {
case <-downloader.stop:
downloader.stopped <- struct{}{}
return
case <-downloader.pause:
<-downloader.unpause
case task := <-downloader.queue:
downloader.handleTask(task)
}
}
}
// DownloadTemp starts new download to temporary file and returns File
//
// Temporary file would be already removed, so no need to cleanup
func DownloadTemp(downloader aptly.Downloader, url string) (*os.File, error) {
return DownloadTempWithChecksum(downloader, url, utils.ChecksumInfo{Size: -1}, false)
}
// DownloadTempWithChecksum is a DownloadTemp with checksum verification
//
// Temporary file would be already removed, so no need to cleanup
func DownloadTempWithChecksum(downloader aptly.Downloader, url string, expected utils.ChecksumInfo, ignoreMismatch bool) (*os.File, error) {
tempdir, err := ioutil.TempDir(os.TempDir(), "aptly")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempdir)
tempfile := filepath.Join(tempdir, "buffer")
if expected.Size != -1 && downloader.GetProgress() != nil {
downloader.GetProgress().InitBar(expected.Size, true)
defer downloader.GetProgress().ShutdownBar()
}
ch := make(chan error, 1)
downloader.DownloadWithChecksum(url, tempfile, ch, expected, ignoreMismatch)
err = <-ch
if err != nil {
return nil, err
}
file, err := os.Open(tempfile)
if err != nil {
return nil, err
}
return file, nil
}
// List of extensions + corresponding uncompression support
var compressionMethods = []struct {
extenstion string
transformation func(io.Reader) (io.Reader, error)
}{
{
extenstion: ".bz2",
transformation: func(r io.Reader) (io.Reader, error) { return bzip2.NewReader(r), nil },
},
{
extenstion: ".gz",
transformation: func(r io.Reader) (io.Reader, error) { return gzip.NewReader(r) },
},
{
extenstion: "",
transformation: func(r io.Reader) (io.Reader, error) { return r, nil },
},
}
// DownloadTryCompression tries to download from URL .bz2, .gz and raw extension until
// it finds existing file.
func DownloadTryCompression(downloader aptly.Downloader, url string, expectedChecksums map[string]utils.ChecksumInfo, ignoreMismatch bool) (io.Reader, *os.File, error) {
var err error
for _, method := range compressionMethods {
var file *os.File
tryURL := url + method.extenstion
foundChecksum := false
for suffix, expected := range expectedChecksums {
if strings.HasSuffix(tryURL, suffix) {
file, err = DownloadTempWithChecksum(downloader, tryURL, expected, ignoreMismatch)
foundChecksum = true
break
}
}
if !foundChecksum {
file, err = DownloadTemp(downloader, tryURL)
}
if err != nil {
if err1, ok := err.(*HTTPError); ok && err1.Code == 404 || err1.Code == 403 {
continue
}
return nil, nil, err
}
var uncompressed io.Reader
uncompressed, err = method.transformation(file)
if err != nil {
return nil, nil, err
}
return uncompressed, file, err
}
return nil, nil, err
}