Grab downloader

This commit is contained in:
Lorenzo Bolla
2021-10-08 10:43:52 +02:00
parent f93bc6ef0f
commit 894192851e
38 changed files with 4240 additions and 1 deletions
+104
View File
@@ -0,0 +1,104 @@
package grabtest
import (
"bytes"
"crypto/sha256"
"fmt"
"io"
"io/ioutil"
"net/http"
"testing"
)
func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int) (ok bool) {
if resp.StatusCode != expect {
t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode)
return
}
ok = true
return true
}
func AssertHTTPResponseHeader(t *testing.T, resp *http.Response, key, format string, a ...interface{}) (ok bool) {
expect := fmt.Sprintf(format, a...)
actual := resp.Header.Get(key)
if actual != expect {
t.Errorf("expected header %s: %s, got: %s", key, expect, actual)
return
}
ok = true
return
}
func AssertHTTPResponseContentLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
ok = true
if resp.ContentLength != n {
ok = false
t.Errorf("expected header Content-Length: %d, got: %d", n, resp.ContentLength)
}
if !AssertHTTPResponseBodyLength(t, resp, n) {
ok = false
}
return
}
func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
defer func() {
if err := resp.Body.Close(); err != nil {
panic(err)
}
}()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
panic(err)
}
if int64(len(b)) != n {
ok = false
t.Errorf("expected body length: %d, got: %d", n, len(b))
}
return
}
func MustHTTPNewRequest(method, url string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, url, body)
if err != nil {
panic(err)
}
return req
}
func MustHTTPDo(req *http.Request) *http.Response {
resp, err := http.DefaultClient.Do(req)
if err != nil {
panic(err)
}
return resp
}
func MustHTTPDoWithClose(req *http.Request) *http.Response {
resp := MustHTTPDo(req)
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
panic(err)
}
if err := resp.Body.Close(); err != nil {
panic(err)
}
return resp
}
func AssertSHA256Sum(t *testing.T, sum []byte, r io.Reader) (ok bool) {
h := sha256.New()
if _, err := io.Copy(h, r); err != nil {
panic(err)
}
computed := h.Sum(nil)
ok = bytes.Equal(sum, computed)
if !ok {
t.Errorf(
"expected checksum: %s, got: %s",
MustHexEncodeString(sum),
MustHexEncodeString(computed),
)
}
return
}
+160
View File
@@ -0,0 +1,160 @@
package grabtest
import (
"bufio"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
var (
DefaultHandlerContentLength = 1 << 20
DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372"
DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum)
DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83"
DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum)
)
type StatusCodeFunc func(req *http.Request) int
type handler struct {
statusCodeFunc StatusCodeFunc
methodWhitelist []string
headerBlacklist []string
contentLength int
acceptRanges bool
attachmentFilename string
lastModified time.Time
ttfb time.Duration
rateLimiter *time.Ticker
}
func NewHandler(options ...HandlerOption) (http.Handler, error) {
h := &handler{
statusCodeFunc: func(req *http.Request) int { return http.StatusOK },
methodWhitelist: []string{"GET", "HEAD"},
contentLength: DefaultHandlerContentLength,
acceptRanges: true,
}
for _, option := range options {
if err := option(h); err != nil {
return nil, err
}
}
return h, nil
}
func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) {
h, err := NewHandler(options...)
if err != nil {
t.Fatalf("unable to create test server handler: %v", err)
return
}
s := httptest.NewServer(h)
defer func() {
h.(*handler).close()
s.Close()
}()
f(s.URL)
}
func (h *handler) close() {
if h.rateLimiter != nil {
h.rateLimiter.Stop()
}
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// delay response
if h.ttfb > 0 {
time.Sleep(h.ttfb)
}
// validate request method
allowed := false
for _, m := range h.methodWhitelist {
if r.Method == m {
allowed = true
break
}
}
if !allowed {
httpError(w, http.StatusMethodNotAllowed)
return
}
// set server options
if h.acceptRanges {
w.Header().Set("Accept-Ranges", "bytes")
}
// set attachment filename
if h.attachmentFilename != "" {
w.Header().Set(
"Content-Disposition",
fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename),
)
}
// set last modified timestamp
lastMod := time.Now()
if !h.lastModified.IsZero() {
lastMod = h.lastModified
}
w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat))
// set content-length
offset := 0
if h.acceptRanges {
if reqRange := r.Header.Get("Range"); reqRange != "" {
if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil {
httpError(w, http.StatusBadRequest)
return
}
if offset >= h.contentLength {
httpError(w, http.StatusRequestedRangeNotSatisfiable)
return
}
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset))
// apply header blacklist
for _, key := range h.headerBlacklist {
w.Header().Del(key)
}
// send header and status code
w.WriteHeader(h.statusCodeFunc(r))
// send body
if r.Method == "GET" {
// use buffered io to reduce overhead on the reader
bw := bufio.NewWriterSize(w, 4096)
for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ {
bw.Write([]byte{byte(i)})
if h.rateLimiter != nil {
bw.Flush()
w.(http.Flusher).Flush() // force the server to send the data to the client
select {
case <-h.rateLimiter.C:
case <-r.Context().Done():
}
}
}
if !isRequestClosed(r) {
bw.Flush()
}
}
}
// isRequestClosed returns true if the client request has been canceled.
func isRequestClosed(r *http.Request) bool {
return r.Context().Err() != nil
}
func httpError(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}
+92
View File
@@ -0,0 +1,92 @@
package grabtest
import (
"errors"
"net/http"
"time"
)
type HandlerOption func(*handler) error
func StatusCodeStatic(code int) HandlerOption {
return func(h *handler) error {
return StatusCode(func(req *http.Request) int {
return code
})(h)
}
}
func StatusCode(f StatusCodeFunc) HandlerOption {
return func(h *handler) error {
if f == nil {
return errors.New("status code function cannot be nil")
}
h.statusCodeFunc = f
return nil
}
}
func MethodWhitelist(methods ...string) HandlerOption {
return func(h *handler) error {
h.methodWhitelist = methods
return nil
}
}
func HeaderBlacklist(headers ...string) HandlerOption {
return func(h *handler) error {
h.headerBlacklist = headers
return nil
}
}
func ContentLength(n int) HandlerOption {
return func(h *handler) error {
if n < 0 {
return errors.New("content length must be zero or greater")
}
h.contentLength = n
return nil
}
}
func AcceptRanges(enabled bool) HandlerOption {
return func(h *handler) error {
h.acceptRanges = enabled
return nil
}
}
func LastModified(t time.Time) HandlerOption {
return func(h *handler) error {
h.lastModified = t.UTC()
return nil
}
}
func TimeToFirstByte(d time.Duration) HandlerOption {
return func(h *handler) error {
if d < 1 {
return errors.New("time to first byte must be greater than zero")
}
h.ttfb = d
return nil
}
}
func RateLimiter(bps int) HandlerOption {
return func(h *handler) error {
if bps < 1 {
return errors.New("bytes per second must be greater than zero")
}
h.rateLimiter = time.NewTicker(time.Second / time.Duration(bps))
return nil
}
}
func AttachmentFilename(filename string) HandlerOption {
return func(h *handler) error {
h.attachmentFilename = filename
return nil
}
}
+150
View File
@@ -0,0 +1,150 @@
package grabtest
import (
"fmt"
"io/ioutil"
"net/http"
"testing"
"time"
)
func TestHandlerDefaults(t *testing.T) {
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseStatusCode(t, resp, http.StatusOK)
AssertHTTPResponseContentLength(t, resp, 1048576)
AssertHTTPResponseHeader(t, resp, "Accept-Ranges", "bytes")
})
}
func TestHandlerMethodWhitelist(t *testing.T) {
tests := []struct {
Whitelist []string
Method string
ExpectStatusCode int
}{
{[]string{"GET", "HEAD"}, "GET", http.StatusOK},
{[]string{"GET", "HEAD"}, "HEAD", http.StatusOK},
{[]string{"GET"}, "HEAD", http.StatusMethodNotAllowed},
{[]string{"HEAD"}, "GET", http.StatusMethodNotAllowed},
}
for _, test := range tests {
WithTestServer(t, func(url string) {
resp := MustHTTPDoWithClose(MustHTTPNewRequest(test.Method, url, nil))
AssertHTTPResponseStatusCode(t, resp, test.ExpectStatusCode)
}, MethodWhitelist(test.Whitelist...))
}
}
func TestHandlerHeaderBlacklist(t *testing.T) {
contentLength := 4096
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
defer resp.Body.Close()
if resp.ContentLength != -1 {
t.Errorf("expected Response.ContentLength: -1, got: %d", resp.ContentLength)
}
AssertHTTPResponseHeader(t, resp, "Content-Length", "")
AssertHTTPResponseBodyLength(t, resp, int64(contentLength))
},
ContentLength(contentLength),
HeaderBlacklist("Content-Length"),
)
}
func TestHandlerStatusCodeFuncs(t *testing.T) {
expect := 418 // I'm a teapot
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseStatusCode(t, resp, expect)
},
StatusCode(func(req *http.Request) int { return expect }),
)
}
func TestHandlerContentLength(t *testing.T) {
tests := []struct {
Method string
ContentLength int
ExpectHeaderLen int64
ExpectBodyLen int
}{
{"GET", 321, 321, 321},
{"HEAD", 321, 321, 0},
{"GET", 0, 0, 0},
{"HEAD", 0, 0, 0},
}
for _, test := range tests {
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest(test.Method, url, nil))
defer resp.Body.Close()
AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen)
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
panic(err)
}
if len(b) != test.ExpectBodyLen {
t.Errorf(
"expected body length: %v, got: %v, in: %v",
test.ExpectBodyLen,
len(b),
test,
)
}
},
ContentLength(test.ContentLength),
)
}
}
func TestHandlerAcceptRanges(t *testing.T) {
header := "Accept-Ranges"
n := 128
t.Run("Enabled", func(t *testing.T) {
WithTestServer(t, func(url string) {
req := MustHTTPNewRequest("GET", url, nil)
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
resp := MustHTTPDo(req)
AssertHTTPResponseHeader(t, resp, header, "bytes")
AssertHTTPResponseContentLength(t, resp, int64(n/2))
},
ContentLength(n),
)
})
t.Run("Disabled", func(t *testing.T) {
WithTestServer(t, func(url string) {
req := MustHTTPNewRequest("GET", url, nil)
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
resp := MustHTTPDo(req)
AssertHTTPResponseHeader(t, resp, header, "")
AssertHTTPResponseContentLength(t, resp, int64(n))
},
AcceptRanges(false),
ContentLength(n),
)
})
}
func TestHandlerAttachmentFilename(t *testing.T) {
filename := "foo.pdf"
WithTestServer(t, func(url string) {
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseHeader(t, resp, "Content-Disposition", `attachment;filename="%s"`, filename)
},
AttachmentFilename(filename),
)
}
func TestHandlerLastModified(t *testing.T) {
WithTestServer(t, func(url string) {
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseHeader(t, resp, "Last-Modified", "Thu, 29 Nov 1973 21:33:09 GMT")
},
LastModified(time.Unix(123456789, 0)),
)
}
+16
View File
@@ -0,0 +1,16 @@
package grabtest
import "encoding/hex"
func MustHexDecodeString(s string) (b []byte) {
var err error
b, err = hex.DecodeString(s)
if err != nil {
panic(err)
}
return
}
func MustHexEncodeString(b []byte) (s string) {
return hex.EncodeToString(b)
}