mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-05-30 04:20:53 +00:00
Update vendored deps, including AWS SDK, openpgp, ftp, ...
This commit is contained in:
+1916
-465
File diff suppressed because it is too large
Load Diff
+58
@@ -0,0 +1,58 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
)
|
||||
|
||||
func BenchmarkPresign_GetObject(b *testing.B) {
|
||||
sess := unit.Session
|
||||
svc := New(sess)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req, _ := svc.GetObjectRequest(&GetObjectInput{
|
||||
Bucket: aws.String("mock-bucket"),
|
||||
Key: aws.String("mock-key"),
|
||||
})
|
||||
|
||||
u, h, err := req.PresignRequest(15 * time.Minute)
|
||||
if err != nil {
|
||||
b.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if len(u) == 0 {
|
||||
b.Fatalf("expect url, got none")
|
||||
}
|
||||
if len(h) != 0 {
|
||||
b.Fatalf("no signed headers, got %v", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPresign_PutObject(b *testing.B) {
|
||||
sess := unit.Session
|
||||
svc := New(sess)
|
||||
|
||||
body := make([]byte, 1024*1024*20)
|
||||
for i := 0; i < b.N; i++ {
|
||||
req, _ := svc.PutObjectRequest(&PutObjectInput{
|
||||
Bucket: aws.String("mock-bucket"),
|
||||
Key: aws.String("mock-key"),
|
||||
Body: bytes.NewReader(body),
|
||||
})
|
||||
|
||||
u, h, err := req.PresignRequest(15 * time.Minute)
|
||||
if err != nil {
|
||||
b.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if len(u) == 0 {
|
||||
b.Fatalf("expect url, got none")
|
||||
}
|
||||
if len(h) == 0 {
|
||||
b.Fatalf("expect signed header, got none")
|
||||
}
|
||||
}
|
||||
}
|
||||
+249
@@ -0,0 +1,249 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
const (
|
||||
contentMD5Header = "Content-Md5"
|
||||
contentSha256Header = "X-Amz-Content-Sha256"
|
||||
amzTeHeader = "X-Amz-Te"
|
||||
amzTxEncodingHeader = "X-Amz-Transfer-Encoding"
|
||||
|
||||
appendMD5TxEncoding = "append-md5"
|
||||
)
|
||||
|
||||
// contentMD5 computes and sets the HTTP Content-MD5 header for requests that
|
||||
// require it.
|
||||
func contentMD5(r *request.Request) {
|
||||
h := md5.New()
|
||||
|
||||
if !aws.IsReaderSeekable(r.Body) {
|
||||
if r.Config.Logger != nil {
|
||||
r.Config.Logger.Log(fmt.Sprintf(
|
||||
"Unable to compute Content-MD5 for unseekable body, S3.%s",
|
||||
r.Operation.Name))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := copySeekableBody(h, r.Body); err != nil {
|
||||
r.Error = awserr.New("ContentMD5", "failed to compute body MD5", err)
|
||||
return
|
||||
}
|
||||
|
||||
// encode the md5 checksum in base64 and set the request header.
|
||||
v := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
r.HTTPRequest.Header.Set(contentMD5Header, v)
|
||||
}
|
||||
|
||||
// computeBodyHashes will add Content MD5 and Content Sha256 hashes to the
|
||||
// request. If the body is not seekable or S3DisableContentMD5Validation set
|
||||
// this handler will be ignored.
|
||||
func computeBodyHashes(r *request.Request) {
|
||||
if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
|
||||
return
|
||||
}
|
||||
if r.IsPresigned() {
|
||||
return
|
||||
}
|
||||
if r.Error != nil || !aws.IsReaderSeekable(r.Body) {
|
||||
return
|
||||
}
|
||||
|
||||
var md5Hash, sha256Hash hash.Hash
|
||||
hashers := make([]io.Writer, 0, 2)
|
||||
|
||||
// Determine upfront which hashes can be set without overriding user
|
||||
// provide header data.
|
||||
if v := r.HTTPRequest.Header.Get(contentMD5Header); len(v) == 0 {
|
||||
md5Hash = md5.New()
|
||||
hashers = append(hashers, md5Hash)
|
||||
}
|
||||
|
||||
if v := r.HTTPRequest.Header.Get(contentSha256Header); len(v) == 0 {
|
||||
sha256Hash = sha256.New()
|
||||
hashers = append(hashers, sha256Hash)
|
||||
}
|
||||
|
||||
// Create the destination writer based on the hashes that are not already
|
||||
// provided by the user.
|
||||
var dst io.Writer
|
||||
switch len(hashers) {
|
||||
case 0:
|
||||
return
|
||||
case 1:
|
||||
dst = hashers[0]
|
||||
default:
|
||||
dst = io.MultiWriter(hashers...)
|
||||
}
|
||||
|
||||
if _, err := copySeekableBody(dst, r.Body); err != nil {
|
||||
r.Error = awserr.New("BodyHashError", "failed to compute body hashes", err)
|
||||
return
|
||||
}
|
||||
|
||||
// For the hashes created, set the associated headers that the user did not
|
||||
// already provide.
|
||||
if md5Hash != nil {
|
||||
sum := make([]byte, md5.Size)
|
||||
encoded := make([]byte, md5Base64EncLen)
|
||||
|
||||
base64.StdEncoding.Encode(encoded, md5Hash.Sum(sum[0:0]))
|
||||
r.HTTPRequest.Header[contentMD5Header] = []string{string(encoded)}
|
||||
}
|
||||
|
||||
if sha256Hash != nil {
|
||||
encoded := make([]byte, sha256HexEncLen)
|
||||
sum := make([]byte, sha256.Size)
|
||||
|
||||
hex.Encode(encoded, sha256Hash.Sum(sum[0:0]))
|
||||
r.HTTPRequest.Header[contentSha256Header] = []string{string(encoded)}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
md5Base64EncLen = (md5.Size + 2) / 3 * 4 // base64.StdEncoding.EncodedLen
|
||||
sha256HexEncLen = sha256.Size * 2 // hex.EncodedLen
|
||||
)
|
||||
|
||||
func copySeekableBody(dst io.Writer, src io.ReadSeeker) (int64, error) {
|
||||
curPos, err := src.Seek(0, sdkio.SeekCurrent)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// hash the body. seek back to the first position after reading to reset
|
||||
// the body for transmission. copy errors may be assumed to be from the
|
||||
// body.
|
||||
n, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
_, err = src.Seek(curPos, sdkio.SeekStart)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Adds the x-amz-te: append_md5 header to the request. This requests the service
|
||||
// responds with a trailing MD5 checksum.
|
||||
//
|
||||
// Will not ask for append MD5 if disabled, the request is presigned or,
|
||||
// or the API operation does not support content MD5 validation.
|
||||
func askForTxEncodingAppendMD5(r *request.Request) {
|
||||
if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
|
||||
return
|
||||
}
|
||||
if r.IsPresigned() {
|
||||
return
|
||||
}
|
||||
r.HTTPRequest.Header.Set(amzTeHeader, appendMD5TxEncoding)
|
||||
}
|
||||
|
||||
func useMD5ValidationReader(r *request.Request) {
|
||||
if r.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if v := r.HTTPResponse.Header.Get(amzTxEncodingHeader); v != appendMD5TxEncoding {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyReader *io.ReadCloser
|
||||
var contentLen int64
|
||||
switch tv := r.Data.(type) {
|
||||
case *GetObjectOutput:
|
||||
bodyReader = &tv.Body
|
||||
contentLen = aws.Int64Value(tv.ContentLength)
|
||||
// Update ContentLength hiden the trailing MD5 checksum.
|
||||
tv.ContentLength = aws.Int64(contentLen - md5.Size)
|
||||
tv.ContentRange = aws.String(r.HTTPResponse.Header.Get("X-Amz-Content-Range"))
|
||||
default:
|
||||
r.Error = awserr.New("ChecksumValidationError",
|
||||
fmt.Sprintf("%s: %s header received on unsupported API, %s",
|
||||
amzTxEncodingHeader, appendMD5TxEncoding, r.Operation.Name,
|
||||
), nil)
|
||||
return
|
||||
}
|
||||
|
||||
if contentLen < md5.Size {
|
||||
r.Error = awserr.New("ChecksumValidationError",
|
||||
fmt.Sprintf("invalid Content-Length %d for %s %s",
|
||||
contentLen, appendMD5TxEncoding, amzTxEncodingHeader,
|
||||
), nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap and swap the response body reader with the validation reader.
|
||||
*bodyReader = newMD5ValidationReader(*bodyReader, contentLen-md5.Size)
|
||||
}
|
||||
|
||||
type md5ValidationReader struct {
|
||||
rawReader io.ReadCloser
|
||||
payload io.Reader
|
||||
hash hash.Hash
|
||||
|
||||
payloadLen int64
|
||||
read int64
|
||||
}
|
||||
|
||||
func newMD5ValidationReader(reader io.ReadCloser, payloadLen int64) *md5ValidationReader {
|
||||
h := md5.New()
|
||||
return &md5ValidationReader{
|
||||
rawReader: reader,
|
||||
payload: io.TeeReader(&io.LimitedReader{R: reader, N: payloadLen}, h),
|
||||
hash: h,
|
||||
payloadLen: payloadLen,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *md5ValidationReader) Read(p []byte) (n int, err error) {
|
||||
n, err = v.payload.Read(p)
|
||||
if err != nil && err != io.EOF {
|
||||
return n, err
|
||||
}
|
||||
|
||||
v.read += int64(n)
|
||||
|
||||
if err == io.EOF {
|
||||
if v.read != v.payloadLen {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
expectSum := make([]byte, md5.Size)
|
||||
actualSum := make([]byte, md5.Size)
|
||||
if _, sumReadErr := io.ReadFull(v.rawReader, expectSum); sumReadErr != nil {
|
||||
return n, sumReadErr
|
||||
}
|
||||
actualSum = v.hash.Sum(actualSum[0:0])
|
||||
if !bytes.Equal(expectSum, actualSum) {
|
||||
return n, awserr.New("InvalidChecksum",
|
||||
fmt.Sprintf("expected MD5 checksum %s, got %s",
|
||||
hex.EncodeToString(expectSum),
|
||||
hex.EncodeToString(actualSum),
|
||||
),
|
||||
nil)
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (v *md5ValidationReader) Close() error {
|
||||
return v.rawReader.Close()
|
||||
}
|
||||
+523
@@ -0,0 +1,523 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
type errorReader struct{}
|
||||
|
||||
func (errorReader) Read([]byte) (int, error) {
|
||||
return 0, fmt.Errorf("errorReader error")
|
||||
}
|
||||
func (errorReader) Seek(int64, int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestComputeBodyHases(t *testing.T) {
|
||||
bodyContent := []byte("bodyContent goes here")
|
||||
|
||||
cases := []struct {
|
||||
Req *request.Request
|
||||
ExpectMD5 string
|
||||
ExpectSHA256 string
|
||||
Error string
|
||||
DisableContentMD5 bool
|
||||
Presigned bool
|
||||
}{
|
||||
{
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
Body: bytes.NewReader(bodyContent),
|
||||
},
|
||||
ExpectMD5: "CqD6NNPvoNOBT/5pkjtzOw==",
|
||||
ExpectSHA256: "3ff09c8b42a58a905e27835919ede45b61722e7cd400f30101bd9ed1a69a1825",
|
||||
},
|
||||
{
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(contentMD5Header, "MD5AlreadySet")
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Body: bytes.NewReader(bodyContent),
|
||||
},
|
||||
ExpectMD5: "MD5AlreadySet",
|
||||
ExpectSHA256: "3ff09c8b42a58a905e27835919ede45b61722e7cd400f30101bd9ed1a69a1825",
|
||||
},
|
||||
{
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(contentSha256Header, "SHA256AlreadySet")
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Body: bytes.NewReader(bodyContent),
|
||||
},
|
||||
ExpectMD5: "CqD6NNPvoNOBT/5pkjtzOw==",
|
||||
ExpectSHA256: "SHA256AlreadySet",
|
||||
},
|
||||
{
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(contentMD5Header, "MD5AlreadySet")
|
||||
h.Set(contentSha256Header, "SHA256AlreadySet")
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Body: bytes.NewReader(bodyContent),
|
||||
},
|
||||
ExpectMD5: "MD5AlreadySet",
|
||||
ExpectSHA256: "SHA256AlreadySet",
|
||||
},
|
||||
{
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
// Non-seekable reader
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer(bodyContent)),
|
||||
},
|
||||
ExpectMD5: "",
|
||||
ExpectSHA256: "",
|
||||
},
|
||||
{
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
// Empty seekable body
|
||||
Body: aws.ReadSeekCloser(bytes.NewReader(nil)),
|
||||
},
|
||||
ExpectMD5: "1B2M2Y8AsgTpgAmY7PhCfg==",
|
||||
ExpectSHA256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
},
|
||||
{
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
// failure while reading reader
|
||||
Body: errorReader{},
|
||||
},
|
||||
ExpectMD5: "",
|
||||
ExpectSHA256: "",
|
||||
Error: "errorReader error",
|
||||
},
|
||||
{
|
||||
// Disabled ContentMD5 validation
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
Body: bytes.NewReader(bodyContent),
|
||||
},
|
||||
ExpectMD5: "",
|
||||
ExpectSHA256: "",
|
||||
DisableContentMD5: true,
|
||||
},
|
||||
{
|
||||
// Disabled ContentMD5 validation
|
||||
Req: &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
Body: bytes.NewReader(bodyContent),
|
||||
},
|
||||
ExpectMD5: "",
|
||||
ExpectSHA256: "",
|
||||
Presigned: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
c.Req.Config.S3DisableContentMD5Validation = aws.Bool(c.DisableContentMD5)
|
||||
|
||||
if c.Presigned {
|
||||
c.Req.ExpireTime = 10 * time.Minute
|
||||
}
|
||||
computeBodyHashes(c.Req)
|
||||
|
||||
if e, a := c.ExpectMD5, c.Req.HTTPRequest.Header.Get(contentMD5Header); e != a {
|
||||
t.Errorf("%d, expect %v md5, got %v", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := c.ExpectSHA256, c.Req.HTTPRequest.Header.Get(contentSha256Header); e != a {
|
||||
t.Errorf("%d, expect %v sha256, got %v", i, e, a)
|
||||
}
|
||||
|
||||
if len(c.Error) != 0 {
|
||||
if c.Req.Error == nil {
|
||||
t.Fatalf("%d, expect error, got none", i)
|
||||
}
|
||||
if e, a := c.Error, c.Req.Error.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%d, expect %v error to be in %v", i, e, a)
|
||||
}
|
||||
|
||||
} else if c.Req.Error != nil {
|
||||
t.Errorf("%d, expect no error, got %v", i, c.Req.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkComputeBodyHashes(b *testing.B) {
|
||||
body := bytes.NewReader(make([]byte, 2*1024))
|
||||
req := &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
Body: body,
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
computeBodyHashes(req)
|
||||
if req.Error != nil {
|
||||
b.Fatalf("expect no error, got %v", req.Error)
|
||||
}
|
||||
|
||||
req.HTTPRequest.Header = http.Header{}
|
||||
body.Seek(0, sdkio.SeekStart)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAskForTxEncodingAppendMD5(t *testing.T) {
|
||||
cases := []struct {
|
||||
DisableContentMD5 bool
|
||||
Presigned bool
|
||||
}{
|
||||
{DisableContentMD5: true},
|
||||
{DisableContentMD5: false},
|
||||
{Presigned: true},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
req := &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
Config: aws.Config{
|
||||
S3DisableContentMD5Validation: aws.Bool(c.DisableContentMD5),
|
||||
},
|
||||
}
|
||||
if c.Presigned {
|
||||
req.ExpireTime = 10 * time.Minute
|
||||
}
|
||||
|
||||
askForTxEncodingAppendMD5(req)
|
||||
|
||||
v := req.HTTPRequest.Header.Get(amzTeHeader)
|
||||
|
||||
expectHeader := !(c.DisableContentMD5 || c.Presigned)
|
||||
|
||||
if e, a := expectHeader, len(v) != 0; e != a {
|
||||
t.Errorf("%d, expect %t disable content MD5, got %t, %s", i, e, a, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseMD5ValidationReader(t *testing.T) {
|
||||
body := []byte("create a really cool md5 checksum of me")
|
||||
bodySum := md5.Sum(body)
|
||||
bodyWithSum := append(body, bodySum[:]...)
|
||||
|
||||
emptyBodySum := md5.Sum([]byte{})
|
||||
|
||||
cases := []struct {
|
||||
Req *request.Request
|
||||
Error string
|
||||
Validate func(outupt interface{}) error
|
||||
}{
|
||||
{
|
||||
// Positive: Use Validation reader
|
||||
Req: &request.Request{
|
||||
HTTPResponse: &http.Response{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Data: &GetObjectOutput{
|
||||
Body: ioutil.NopCloser(bytes.NewReader(bodyWithSum)),
|
||||
ContentLength: aws.Int64(int64(len(bodyWithSum))),
|
||||
},
|
||||
},
|
||||
Validate: func(output interface{}) error {
|
||||
getObjOut := output.(*GetObjectOutput)
|
||||
reader, ok := getObjOut.Body.(*md5ValidationReader)
|
||||
if !ok {
|
||||
return fmt.Errorf("expect %T updated body reader, got %T",
|
||||
(*md5ValidationReader)(nil), getObjOut.Body)
|
||||
}
|
||||
|
||||
if reader.rawReader == nil {
|
||||
return fmt.Errorf("expect rawReader not to be nil")
|
||||
}
|
||||
if reader.payload == nil {
|
||||
return fmt.Errorf("expect payload not to be nil")
|
||||
}
|
||||
if e, a := int64(len(bodyWithSum)-md5.Size), reader.payloadLen; e != a {
|
||||
return fmt.Errorf("expect %v payload len, got %v", e, a)
|
||||
}
|
||||
if reader.hash == nil {
|
||||
return fmt.Errorf("expect hash not to be nil")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
// Positive: Use Validation reader, empty object
|
||||
Req: &request.Request{
|
||||
HTTPResponse: &http.Response{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Data: &GetObjectOutput{
|
||||
Body: ioutil.NopCloser(bytes.NewReader(emptyBodySum[:])),
|
||||
ContentLength: aws.Int64(int64(len(emptyBodySum[:]))),
|
||||
},
|
||||
},
|
||||
Validate: func(output interface{}) error {
|
||||
getObjOut := output.(*GetObjectOutput)
|
||||
reader, ok := getObjOut.Body.(*md5ValidationReader)
|
||||
if !ok {
|
||||
return fmt.Errorf("expect %T updated body reader, got %T",
|
||||
(*md5ValidationReader)(nil), getObjOut.Body)
|
||||
}
|
||||
|
||||
if reader.rawReader == nil {
|
||||
return fmt.Errorf("expect rawReader not to be nil")
|
||||
}
|
||||
if reader.payload == nil {
|
||||
return fmt.Errorf("expect payload not to be nil")
|
||||
}
|
||||
if e, a := int64(len(emptyBodySum)-md5.Size), reader.payloadLen; e != a {
|
||||
return fmt.Errorf("expect %v payload len, got %v", e, a)
|
||||
}
|
||||
if reader.hash == nil {
|
||||
return fmt.Errorf("expect hash not to be nil")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
// Negative: amzTxEncoding header not set
|
||||
Req: &request.Request{
|
||||
HTTPResponse: &http.Response{
|
||||
Header: http.Header{},
|
||||
},
|
||||
Data: &GetObjectOutput{
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
ContentLength: aws.Int64(int64(len(body))),
|
||||
},
|
||||
},
|
||||
Validate: func(output interface{}) error {
|
||||
getObjOut := output.(*GetObjectOutput)
|
||||
reader, ok := getObjOut.Body.(*md5ValidationReader)
|
||||
if ok {
|
||||
return fmt.Errorf("expect body reader not to be %T",
|
||||
reader)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
// Negative: Not GetObjectOutput type.
|
||||
Req: &request.Request{
|
||||
Operation: &request.Operation{
|
||||
Name: "PutObject",
|
||||
},
|
||||
HTTPResponse: &http.Response{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Data: &PutObjectOutput{},
|
||||
},
|
||||
Error: "header received on unsupported API",
|
||||
Validate: func(output interface{}) error {
|
||||
_, ok := output.(*PutObjectOutput)
|
||||
if !ok {
|
||||
return fmt.Errorf("expect %T output not to change, got %T",
|
||||
(*PutObjectOutput)(nil), output)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
// Negative: invalid content length.
|
||||
Req: &request.Request{
|
||||
HTTPResponse: &http.Response{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Data: &GetObjectOutput{
|
||||
Body: ioutil.NopCloser(bytes.NewReader(bodyWithSum)),
|
||||
ContentLength: aws.Int64(-1),
|
||||
},
|
||||
},
|
||||
Error: "invalid Content-Length -1",
|
||||
Validate: func(output interface{}) error {
|
||||
getObjOut := output.(*GetObjectOutput)
|
||||
reader, ok := getObjOut.Body.(*md5ValidationReader)
|
||||
if ok {
|
||||
return fmt.Errorf("expect body reader not to be %T",
|
||||
reader)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
// Negative: invalid content length, < md5.Size.
|
||||
Req: &request.Request{
|
||||
HTTPResponse: &http.Response{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
|
||||
return h
|
||||
}(),
|
||||
},
|
||||
Data: &GetObjectOutput{
|
||||
Body: ioutil.NopCloser(bytes.NewReader(make([]byte, 5))),
|
||||
ContentLength: aws.Int64(5),
|
||||
},
|
||||
},
|
||||
Error: "invalid Content-Length 5",
|
||||
Validate: func(output interface{}) error {
|
||||
getObjOut := output.(*GetObjectOutput)
|
||||
reader, ok := getObjOut.Body.(*md5ValidationReader)
|
||||
if ok {
|
||||
return fmt.Errorf("expect body reader not to be %T",
|
||||
reader)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
useMD5ValidationReader(c.Req)
|
||||
if len(c.Error) != 0 {
|
||||
if c.Req.Error == nil {
|
||||
t.Fatalf("%d, expect error, got none", i)
|
||||
}
|
||||
if e, a := c.Error, c.Req.Error.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%d, expect %v error to be in %v", i, e, a)
|
||||
}
|
||||
} else if c.Req.Error != nil {
|
||||
t.Errorf("%d, expect no error, got %v", i, c.Req.Error)
|
||||
}
|
||||
|
||||
if c.Validate != nil {
|
||||
if err := c.Validate(c.Req.Data); err != nil {
|
||||
t.Errorf("%d, expect Data to validate, got %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReaderMD5Validation(t *testing.T) {
|
||||
body := []byte("create a really cool md5 checksum of me")
|
||||
bodySum := md5.Sum(body)
|
||||
bodyWithSum := append(body, bodySum[:]...)
|
||||
emptyBodySum := md5.Sum([]byte{})
|
||||
badBodySum := append(body, emptyBodySum[:]...)
|
||||
|
||||
cases := []struct {
|
||||
Content []byte
|
||||
ContentReader io.ReadCloser
|
||||
PayloadLen int64
|
||||
Error string
|
||||
}{
|
||||
{
|
||||
Content: bodyWithSum,
|
||||
PayloadLen: int64(len(body)),
|
||||
},
|
||||
{
|
||||
Content: emptyBodySum[:],
|
||||
PayloadLen: 0,
|
||||
},
|
||||
{
|
||||
Content: badBodySum,
|
||||
PayloadLen: int64(len(body)),
|
||||
Error: "expected MD5 checksum",
|
||||
},
|
||||
{
|
||||
Content: emptyBodySum[:len(emptyBodySum)-2],
|
||||
PayloadLen: 0,
|
||||
Error: "unexpected EOF",
|
||||
},
|
||||
{
|
||||
Content: body,
|
||||
PayloadLen: int64(len(body) * 2),
|
||||
Error: "unexpected EOF",
|
||||
},
|
||||
{
|
||||
ContentReader: ioutil.NopCloser(errorReader{}),
|
||||
PayloadLen: int64(len(body)),
|
||||
Error: "errorReader error",
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
reader := c.ContentReader
|
||||
if reader == nil {
|
||||
reader = ioutil.NopCloser(bytes.NewReader(c.Content))
|
||||
}
|
||||
v := newMD5ValidationReader(reader, c.PayloadLen)
|
||||
|
||||
var actual bytes.Buffer
|
||||
n, err := io.Copy(&actual, v)
|
||||
if len(c.Error) != 0 {
|
||||
if err == nil {
|
||||
t.Errorf("%d, expect error, got none", i)
|
||||
}
|
||||
if e, a := c.Error, err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%d, expect %v error to be in %v", i, e, a)
|
||||
}
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Errorf("%d, expect no error, got %v", i, err)
|
||||
continue
|
||||
}
|
||||
if e, a := c.PayloadLen, n; e != a {
|
||||
t.Errorf("%d, expect %v len, got %v", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := c.Content[:c.PayloadLen], actual.Bytes(); !bytes.Equal(e, a) {
|
||||
t.Errorf("%d, expect:\n%v\nactual:\n%v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
+4
-2
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var s3LocationTests = []struct {
|
||||
@@ -32,7 +31,10 @@ func TestGetBucketLocation(t *testing.T) {
|
||||
})
|
||||
|
||||
resp, err := s.GetBucketLocation(&s3.GetBucketLocationInput{Bucket: aws.String("bucket")})
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if test.loc == "" {
|
||||
if v := resp.LocationConstraint; v != nil {
|
||||
t.Errorf("expect location constraint to be nil, got %s", *v)
|
||||
|
||||
-36
@@ -1,36 +0,0 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// contentMD5 computes and sets the HTTP Content-MD5 header for requests that
|
||||
// require it.
|
||||
func contentMD5(r *request.Request) {
|
||||
h := md5.New()
|
||||
|
||||
// hash the body. seek back to the first position after reading to reset
|
||||
// the body for transmission. copy errors may be assumed to be from the
|
||||
// body.
|
||||
_, err := io.Copy(h, r.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("ContentMD5", "failed to read body", err)
|
||||
return
|
||||
}
|
||||
_, err = r.Body.Seek(0, 0)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("ContentMD5", "failed to seek body", err)
|
||||
return
|
||||
}
|
||||
|
||||
// encode the md5 checksum in base64 and set the request header.
|
||||
sum := h.Sum(nil)
|
||||
sum64 := make([]byte, base64.StdEncoding.EncodedLen(len(sum)))
|
||||
base64.StdEncoding.Encode(sum64, sum)
|
||||
r.HTTPRequest.Header.Set("Content-MD5", string(sum64))
|
||||
}
|
||||
+6
@@ -42,6 +42,12 @@ func defaultInitRequestFn(r *request.Request) {
|
||||
r.Handlers.Validate.PushFront(populateLocationConstraint)
|
||||
case opCopyObject, opUploadPartCopy, opCompleteMultipartUpload:
|
||||
r.Handlers.Unmarshal.PushFront(copyMultipartStatusOKUnmarhsalError)
|
||||
case opPutObject, opUploadPart:
|
||||
r.Handlers.Build.PushBack(computeBodyHashes)
|
||||
// Disabled until #1837 root issue is resolved.
|
||||
// case opGetObject:
|
||||
// r.Handlers.Build.PushBack(askForTxEncodingAppendMD5)
|
||||
// r.Handlers.Unmarshal.PushBack(useMD5ValidationReader)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+21
-8
@@ -13,17 +13,22 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func assertMD5(t *testing.T, req *request.Request) {
|
||||
err := req.Build()
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
|
||||
out := md5.Sum(b)
|
||||
assert.NotEmpty(t, b)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(out[:]), req.HTTPRequest.Header.Get("Content-MD5"))
|
||||
if len(b) == 0 {
|
||||
t.Error("expected non-empty value")
|
||||
}
|
||||
if e, a := base64.StdEncoding.EncodeToString(out[:]), req.HTTPRequest.Header.Get("Content-MD5"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMD5InPutBucketCors(t *testing.T) {
|
||||
@@ -115,7 +120,9 @@ const (
|
||||
|
||||
func TestPutObjectMetadataWithUnicode(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, utf8Value, r.Header.Get(metaKeyPrefix+utf8KeySuffix))
|
||||
if e, a := utf8Value, r.Header.Get(metaKeyPrefix+utf8KeySuffix); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}))
|
||||
svc := s3.New(unit.Session, &aws.Config{
|
||||
Endpoint: aws.String(server.URL),
|
||||
@@ -133,7 +140,9 @@ func TestPutObjectMetadataWithUnicode(t *testing.T) {
|
||||
}(),
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetObjectMetadataWithUnicode(t *testing.T) {
|
||||
@@ -150,9 +159,13 @@ func TestGetObjectMetadataWithUnicode(t *testing.T) {
|
||||
Key: aws.String("my_key"),
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
assert.Equal(t, utf8Value, *resp.Metadata[utf8KeySuffix])
|
||||
if e, a := utf8Value, *resp.Metadata[utf8KeySuffix]; e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@
|
||||
//
|
||||
// Using the Client
|
||||
//
|
||||
// To Amazon Simple Storage Service with the SDK use the New function to create
|
||||
// To contact Amazon Simple Storage Service with the SDK use the New function to create
|
||||
// a new service client. With that client you can make API requests to the service.
|
||||
// These clients are safe to use concurrently.
|
||||
//
|
||||
|
||||
+2
-2
@@ -35,7 +35,7 @@
|
||||
//
|
||||
// The s3manager package's Downloader provides concurrently downloading of Objects
|
||||
// from S3. The Downloader will write S3 Object content with an io.WriterAt.
|
||||
// Once the Downloader instance is created you can call Upload concurrently from
|
||||
// Once the Downloader instance is created you can call Download concurrently from
|
||||
// multiple goroutines safely.
|
||||
//
|
||||
// // The session the S3 Downloader will use
|
||||
@@ -56,7 +56,7 @@
|
||||
// Key: aws.String(myString),
|
||||
// })
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to upload file, %v", err)
|
||||
// return fmt.Errorf("failed to download file, %v", err)
|
||||
// }
|
||||
// fmt.Printf("file downloaded, %d bytes\n", n)
|
||||
//
|
||||
|
||||
+15
-1
@@ -1480,6 +1480,12 @@ func ExampleS3_PutBucketLifecycleConfiguration_shared00() {
|
||||
LifecycleConfiguration: &s3.BucketLifecycleConfiguration{
|
||||
Rules: []*s3.LifecycleRule{
|
||||
{
|
||||
Expiration: &s3.LifecycleExpiration{
|
||||
Days: aws.Int64(3650),
|
||||
},
|
||||
Filter: &s3.LifecycleRuleFilter{
|
||||
Prefix: aws.String("documents/"),
|
||||
},
|
||||
ID: aws.String("TestOnly"),
|
||||
Status: aws.String("Enabled"),
|
||||
Transitions: []*s3.Transition{
|
||||
@@ -1525,6 +1531,10 @@ func ExampleS3_PutBucketLogging_shared00() {
|
||||
TargetBucket: aws.String("targetbucket"),
|
||||
TargetGrants: []*s3.TargetGrant{
|
||||
{
|
||||
Grantee: &s3.Grantee{
|
||||
Type: aws.String("Group"),
|
||||
URI: aws.String("http://acs.amazonaws.com/groups/global/AllUsers"),
|
||||
},
|
||||
Permission: aws.String("READ"),
|
||||
},
|
||||
},
|
||||
@@ -1628,6 +1638,10 @@ func ExampleS3_PutBucketReplication_shared00() {
|
||||
Role: aws.String("arn:aws:iam::123456789012:role/examplerole"),
|
||||
Rules: []*s3.ReplicationRule{
|
||||
{
|
||||
Destination: &s3.Destination{
|
||||
Bucket: aws.String("arn:aws:s3:::destinationbucket"),
|
||||
StorageClass: aws.String("STANDARD"),
|
||||
},
|
||||
Prefix: aws.String(""),
|
||||
Status: aws.String("Enabled"),
|
||||
},
|
||||
@@ -2198,7 +2212,7 @@ func ExampleS3_UploadPartCopy_shared01() {
|
||||
svc := s3.New(session.New())
|
||||
input := &s3.UploadPartCopyInput{
|
||||
Bucket: aws.String("examplebucket"),
|
||||
CopySource: aws.String("bucketname/sourceobjectkey"),
|
||||
CopySource: aws.String("/bucketname/sourceobjectkey"),
|
||||
Key: aws.String("examplelargeobject"),
|
||||
PartNumber: aws.Int64(1),
|
||||
UploadId: aws.String("exampleuoh_10OhKhT7YukE9bjzTPRiuaCotmZM_pFngJFir9OZNrSr5cWa3cq3LZSUsfjI4FI7PkP91We7Nrw--"),
|
||||
|
||||
+24
-9
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAdd100Continue_Added(t *testing.T) {
|
||||
@@ -22,8 +21,12 @@ func TestAdd100Continue_Added(t *testing.T) {
|
||||
|
||||
err := r.Sign()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "100-Continue", r.HTTPRequest.Header.Get("Expect"))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := "100-Continue", r.HTTPRequest.Header.Get("Expect"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd100Continue_SkipDisabled(t *testing.T) {
|
||||
@@ -36,8 +39,12 @@ func TestAdd100Continue_SkipDisabled(t *testing.T) {
|
||||
|
||||
err := r.Sign()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, r.HTTPRequest.Header.Get("Expect"))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if r.HTTPRequest.Header.Get("Expect") != "" {
|
||||
t.Errorf("expected empty value, but received %s", r.HTTPRequest.Header.Get("Expect"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd100Continue_SkipNonPUT(t *testing.T) {
|
||||
@@ -49,8 +56,12 @@ func TestAdd100Continue_SkipNonPUT(t *testing.T) {
|
||||
|
||||
err := r.Sign()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, r.HTTPRequest.Header.Get("Expect"))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if r.HTTPRequest.Header.Get("Expect") != "" {
|
||||
t.Errorf("expected empty value, but received %s", r.HTTPRequest.Header.Get("Expect"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd100Continue_SkipTooSmall(t *testing.T) {
|
||||
@@ -63,6 +74,10 @@ func TestAdd100Continue_SkipTooSmall(t *testing.T) {
|
||||
|
||||
err := r.Sign()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, r.HTTPRequest.Header.Get("Expect"))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if r.HTTPRequest.Header.Get("Expect") != "" {
|
||||
t.Errorf("expected empty value, but received %s", r.HTTPRequest.Header.Get("Expect"))
|
||||
}
|
||||
}
|
||||
|
||||
Generated
Vendored
+11
-6
@@ -3,21 +3,26 @@ package s3crypto_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
|
||||
)
|
||||
|
||||
func TestAESGCMContentCipherBuilder(t *testing.T) {
|
||||
generator := mockGenerator{}
|
||||
builder := s3crypto.AESGCMContentCipherBuilder(generator)
|
||||
assert.NotNil(t, builder)
|
||||
if builder := s3crypto.AESGCMContentCipherBuilder(generator); builder == nil {
|
||||
t.Error("expected non-nil value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESGCMContentCipherNewEncryptor(t *testing.T) {
|
||||
generator := mockGenerator{}
|
||||
builder := s3crypto.AESGCMContentCipherBuilder(generator)
|
||||
cipher, err := builder.ContentCipher()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cipher)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if cipher == nil {
|
||||
t.Errorf("expected non-nil vaue")
|
||||
}
|
||||
}
|
||||
|
||||
+19
-11
@@ -5,8 +5,6 @@ import (
|
||||
"encoding/hex"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// AES GCM
|
||||
@@ -52,22 +50,32 @@ func aesgcmTest(t *testing.T, iv, key, plaintext, expected, tag []byte) {
|
||||
IV: iv,
|
||||
}
|
||||
gcm, err := newAESGCM(cd)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
cipherdata := gcm.Encrypt(bytes.NewReader(plaintext))
|
||||
|
||||
ciphertext, err := ioutil.ReadAll(cipherdata)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
// splitting tag and ciphertext
|
||||
etag := ciphertext[len(ciphertext)-16:]
|
||||
assert.Equal(t, etag, tag)
|
||||
assert.Equal(t, len(ciphertext), len(expected))
|
||||
assert.Equal(t, ciphertext, expected)
|
||||
if !bytes.Equal(etag, tag) {
|
||||
t.Errorf("expected tags to be equivalent")
|
||||
}
|
||||
if !bytes.Equal(ciphertext, expected) {
|
||||
t.Errorf("expected ciphertext to be equivalent")
|
||||
}
|
||||
|
||||
data := gcm.Decrypt(bytes.NewReader(ciphertext))
|
||||
assert.NoError(t, err)
|
||||
text, err := ioutil.ReadAll(data)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(text), len(plaintext))
|
||||
assert.Equal(t, text, plaintext)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, text) {
|
||||
t.Errorf("expected ciphertext to be equivalent")
|
||||
}
|
||||
}
|
||||
|
||||
+12
-6
@@ -5,8 +5,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
|
||||
)
|
||||
|
||||
@@ -16,8 +14,12 @@ func TestCryptoReadCloserRead(t *testing.T) {
|
||||
rc := &s3crypto.CryptoReadCloser{Body: ioutil.NopCloser(str), Decrypter: str}
|
||||
|
||||
b, err := ioutil.ReadAll(rc)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedStr, string(b))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if expectedStr != string(b) {
|
||||
t.Errorf("expected %s, but received %s", expectedStr, string(b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCryptoReadCloserClose(t *testing.T) {
|
||||
@@ -29,6 +31,10 @@ func TestCryptoReadCloserClose(t *testing.T) {
|
||||
rc.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(rc)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedStr, string(b))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if expectedStr != string(b) {
|
||||
t.Errorf("expected %s, but received %s", expectedStr, string(b))
|
||||
}
|
||||
}
|
||||
|
||||
+64
-22
@@ -8,8 +8,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
@@ -31,10 +29,17 @@ func TestWrapFactory(t *testing.T) {
|
||||
MatDesc: `{"kms_cmk_id":""}`,
|
||||
}
|
||||
wrap, err := c.wrapFromEnvelope(env)
|
||||
_, ok := wrap.(*kmsKeyHandler)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, wrap)
|
||||
assert.True(t, ok)
|
||||
w, ok := wrap.(*kmsKeyHandler)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if wrap == nil {
|
||||
t.Error("expected non-nil value")
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("expected kmsKeyHandler, but received %v", *w)
|
||||
}
|
||||
}
|
||||
func TestWrapFactoryErrorNoWrap(t *testing.T) {
|
||||
c := DecryptionClient{
|
||||
@@ -52,8 +57,13 @@ func TestWrapFactoryErrorNoWrap(t *testing.T) {
|
||||
MatDesc: `{"kms_cmk_id":""}`,
|
||||
}
|
||||
wrap, err := c.wrapFromEnvelope(env)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, wrap)
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error, but received none")
|
||||
}
|
||||
if wrap != nil {
|
||||
t.Errorf("expected nil wrap value, received %v", wrap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapFactoryCustomEntry(t *testing.T) {
|
||||
@@ -72,8 +82,13 @@ func TestWrapFactoryCustomEntry(t *testing.T) {
|
||||
MatDesc: `{"kms_cmk_id":""}`,
|
||||
}
|
||||
wrap, err := c.wrapFromEnvelope(env)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, wrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if wrap == nil {
|
||||
t.Errorf("expected nil wrap value, received %v", wrap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCEKFactory(t *testing.T) {
|
||||
@@ -106,11 +121,15 @@ func TestCEKFactory(t *testing.T) {
|
||||
},
|
||||
}
|
||||
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
ivB64 := base64.URLEncoding.EncodeToString(iv)
|
||||
|
||||
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
|
||||
|
||||
env := Envelope{
|
||||
@@ -122,8 +141,13 @@ func TestCEKFactory(t *testing.T) {
|
||||
}
|
||||
wrap, err := c.wrapFromEnvelope(env)
|
||||
cek, err := c.cekFromEnvelope(env, wrap)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cek)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if cek == nil {
|
||||
t.Errorf("expected non-nil cek")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCEKFactoryNoCEK(t *testing.T) {
|
||||
@@ -156,11 +180,15 @@ func TestCEKFactoryNoCEK(t *testing.T) {
|
||||
},
|
||||
}
|
||||
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
ivB64 := base64.URLEncoding.EncodeToString(iv)
|
||||
|
||||
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
|
||||
|
||||
env := Envelope{
|
||||
@@ -172,8 +200,13 @@ func TestCEKFactoryNoCEK(t *testing.T) {
|
||||
}
|
||||
wrap, err := c.wrapFromEnvelope(env)
|
||||
cek, err := c.cekFromEnvelope(env, wrap)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cek)
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error, but received none")
|
||||
}
|
||||
if cek != nil {
|
||||
t.Errorf("expected nil cek value, received %v", wrap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCEKFactoryCustomEntry(t *testing.T) {
|
||||
@@ -204,11 +237,15 @@ func TestCEKFactoryCustomEntry(t *testing.T) {
|
||||
PadderRegistry: map[string]Padder{},
|
||||
}
|
||||
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
ivB64 := base64.URLEncoding.EncodeToString(iv)
|
||||
|
||||
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
|
||||
|
||||
env := Envelope{
|
||||
@@ -220,6 +257,11 @@ func TestCEKFactoryCustomEntry(t *testing.T) {
|
||||
}
|
||||
wrap, err := c.wrapFromEnvelope(env)
|
||||
cek, err := c.cekFromEnvelope(env, wrap)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cek)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if cek == nil {
|
||||
t.Errorf("expected non-nil cek")
|
||||
}
|
||||
}
|
||||
|
||||
+61
-25
@@ -11,8 +11,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
@@ -40,7 +38,9 @@ func TestGetObjectGCM(t *testing.T) {
|
||||
})
|
||||
|
||||
c := s3crypto.NewDecryptionClient(sess)
|
||||
assert.NotNil(t, c)
|
||||
if c == nil {
|
||||
t.Error("expected non-nil value")
|
||||
}
|
||||
input := &s3.GetObjectInput{
|
||||
Key: aws.String("test"),
|
||||
Bucket: aws.String("test"),
|
||||
@@ -49,9 +49,14 @@ func TestGetObjectGCM(t *testing.T) {
|
||||
req.Handlers.Send.Clear()
|
||||
req.Handlers.Send.PushBack(func(r *request.Request) {
|
||||
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
b, err := hex.DecodeString("fa4362189661d163fcd6a56d8bf0405ad636ac1bbedd5cc3ee727dc2ab4a9489")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: 200,
|
||||
@@ -69,14 +74,21 @@ func TestGetObjectGCM(t *testing.T) {
|
||||
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
|
||||
})
|
||||
err := req.Send()
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(out.Body)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
expected, err := hex.DecodeString("2db5168e932556f8089a0622981d017d")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(expected), len(b))
|
||||
assert.Equal(t, expected, b)
|
||||
if !bytes.Equal(expected, b) {
|
||||
t.Error("expected bytes to be equivalent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetObjectCBC(t *testing.T) {
|
||||
@@ -97,7 +109,9 @@ func TestGetObjectCBC(t *testing.T) {
|
||||
})
|
||||
|
||||
c := s3crypto.NewDecryptionClient(sess)
|
||||
assert.NotNil(t, c)
|
||||
if c == nil {
|
||||
t.Error("expected non-nil value")
|
||||
}
|
||||
input := &s3.GetObjectInput{
|
||||
Key: aws.String("test"),
|
||||
Bucket: aws.String("test"),
|
||||
@@ -106,9 +120,13 @@ func TestGetObjectCBC(t *testing.T) {
|
||||
req.Handlers.Send.Clear()
|
||||
req.Handlers.Send.PushBack(func(r *request.Request) {
|
||||
iv, err := hex.DecodeString("9dea7621945988f96491083849b068df")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
b, err := hex.DecodeString("e232cd6ef50047801ee681ec30f61d53cfd6b0bca02fd03c1b234baa10ea82ac9dab8b960926433a19ce6dea08677e34")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: 200,
|
||||
@@ -125,14 +143,21 @@ func TestGetObjectCBC(t *testing.T) {
|
||||
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
|
||||
})
|
||||
err := req.Send()
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(out.Body)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
expected, err := hex.DecodeString("0397f4f6820b1f9386f14403be5ac16e50213bd473b4874b9bcbf5f318ee686b1d")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(expected), len(b))
|
||||
assert.Equal(t, expected, b)
|
||||
if !bytes.Equal(expected, b) {
|
||||
t.Error("expected bytes to be equivalent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetObjectCBC2(t *testing.T) {
|
||||
@@ -153,7 +178,9 @@ func TestGetObjectCBC2(t *testing.T) {
|
||||
})
|
||||
|
||||
c := s3crypto.NewDecryptionClient(sess)
|
||||
assert.NotNil(t, c)
|
||||
if c == nil {
|
||||
t.Error("expected non-nil value")
|
||||
}
|
||||
input := &s3.GetObjectInput{
|
||||
Key: aws.String("test"),
|
||||
Bucket: aws.String("test"),
|
||||
@@ -162,7 +189,9 @@ func TestGetObjectCBC2(t *testing.T) {
|
||||
req.Handlers.Send.Clear()
|
||||
req.Handlers.Send.PushBack(func(r *request.Request) {
|
||||
b, err := hex.DecodeString("fd0c71ecb7ed16a9bf42ea5f75501d416df608f190890c3b4d8897f24744cd7f9ea4a0b212e60634302450e1c5378f047ff753ccefe365d411c36339bf22e301fae4c3a6226719a4b93dc74c1af79d0296659b5d56c0892315f2c7cc30190220db1eaafae3920d6d9c65d0aa366499afc17af493454e141c6e0fbdeb6a990cb4")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: 200,
|
||||
@@ -180,14 +209,21 @@ func TestGetObjectCBC2(t *testing.T) {
|
||||
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
|
||||
})
|
||||
err := req.Send()
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(out.Body)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
expected, err := hex.DecodeString("a6ccd3482f5ce25c9ddeb69437cd0acbc0bdda2ef8696d90781de2b35704543529871b2032e68ef1c5baed1769aba8d420d1aca181341b49b8b3587a6580cdf1d809c68f06735f7735c16691f4b70c967d68fc08195b81ad71bcc4df452fd0a5799c1e1234f92f1cd929fc072167ccf9f2ac85b93170932b32")
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(expected), len(b))
|
||||
assert.Equal(t, expected, b)
|
||||
if !bytes.Equal(expected, b) {
|
||||
t.Error("expected bytes to be equivalent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetObjectWithContext(t *testing.T) {
|
||||
|
||||
+4
-4
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3iface"
|
||||
)
|
||||
@@ -64,19 +65,18 @@ func NewEncryptionClient(prov client.ConfigProvider, builder ContentCipherBuilde
|
||||
// req, out := svc.PutObjectRequest(&s3.PutObjectInput {
|
||||
// Key: aws.String("testKey"),
|
||||
// Bucket: aws.String("testBucket"),
|
||||
// Body: bytes.NewBuffer("test data"),
|
||||
// Body: strings.NewReader("test data"),
|
||||
// })
|
||||
// err := req.Send()
|
||||
func (c *EncryptionClient) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
|
||||
req, out := c.S3Client.PutObjectRequest(input)
|
||||
|
||||
// Get Size of file
|
||||
n, err := input.Body.Seek(0, 2)
|
||||
n, err := aws.SeekerLen(input.Body)
|
||||
if err != nil {
|
||||
req.Error = err
|
||||
return req, out
|
||||
}
|
||||
input.Body.Seek(0, 0)
|
||||
|
||||
dst, err := getWriterStore(req, c.TempFolderPath, n >= c.MinFileSize)
|
||||
if err != nil {
|
||||
@@ -115,7 +115,7 @@ func (c *EncryptionClient) PutObjectRequest(input *s3.PutObjectInput) (*request.
|
||||
shaHex := hex.EncodeToString(sha.GetValue())
|
||||
req.HTTPRequest.Header.Set("X-Amz-Content-Sha256", shaHex)
|
||||
|
||||
dst.Seek(0, 0)
|
||||
dst.Seek(0, sdkio.SeekStart)
|
||||
input.Body = dst
|
||||
|
||||
err = c.SaveStrategy.Save(env, r)
|
||||
|
||||
+21
-9
@@ -9,8 +9,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
@@ -32,9 +30,15 @@ func TestDefaultConfigValues(t *testing.T) {
|
||||
|
||||
c := s3crypto.NewEncryptionClient(sess, s3crypto.AESGCMContentCipherBuilder(handler))
|
||||
|
||||
assert.NotNil(t, c)
|
||||
assert.NotNil(t, c.ContentCipherBuilder)
|
||||
assert.NotNil(t, c.SaveStrategy)
|
||||
if c == nil {
|
||||
t.Error("expected non-vil client value")
|
||||
}
|
||||
if c.ContentCipherBuilder == nil {
|
||||
t.Error("expected non-vil content cipher builder value")
|
||||
}
|
||||
if c.SaveStrategy == nil {
|
||||
t.Error("expected non-vil save strategy value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutObject(t *testing.T) {
|
||||
@@ -49,7 +53,9 @@ func TestPutObject(t *testing.T) {
|
||||
Region: aws.String("us-west-2"),
|
||||
})
|
||||
c := s3crypto.NewEncryptionClient(sess, cb)
|
||||
assert.NotNil(t, c)
|
||||
if c == nil {
|
||||
t.Error("expected non-vil client value")
|
||||
}
|
||||
input := &s3.PutObjectInput{
|
||||
Key: aws.String("test"),
|
||||
Bucket: aws.String("test"),
|
||||
@@ -64,10 +70,16 @@ func TestPutObject(t *testing.T) {
|
||||
}
|
||||
})
|
||||
err := req.Send()
|
||||
assert.Equal(t, "stop", err.Error())
|
||||
if e, a := "stop", err.Error(); e != a {
|
||||
t.Errorf("expected %s error, but received %s", e, a)
|
||||
}
|
||||
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, b)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if !bytes.Equal(expected, b) {
|
||||
t.Error("expected bytes to be equivalent, but received otherwise")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutObjectWithContext(t *testing.T) {
|
||||
|
||||
+8
-4
@@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// From Go stdlib encoding/sha256 test cases
|
||||
@@ -13,7 +11,10 @@ func TestSHA256(t *testing.T) {
|
||||
sha := newSHA256Writer(nil)
|
||||
expected, _ := hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
|
||||
b := sha.GetValue()
|
||||
assert.Equal(t, expected, b)
|
||||
|
||||
if !bytes.Equal(expected, b) {
|
||||
t.Errorf("expected equivalent sha values, but received otherwise")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSHA256_Case2(t *testing.T) {
|
||||
@@ -21,5 +22,8 @@ func TestSHA256_Case2(t *testing.T) {
|
||||
sha.Write([]byte("hello"))
|
||||
expected, _ := hex.DecodeString("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824")
|
||||
b := sha.GetValue()
|
||||
assert.Equal(t, expected, b)
|
||||
|
||||
if !bytes.Equal(expected, b) {
|
||||
t.Errorf("expected equivalent sha values, but received otherwise")
|
||||
}
|
||||
}
|
||||
|
||||
+52
-15
@@ -1,9 +1,10 @@
|
||||
package s3crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
func TestBytesReadWriteSeeker_Read(t *testing.T) {
|
||||
@@ -12,9 +13,17 @@ func TestBytesReadWriteSeeker_Read(t *testing.T) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := b.Read(buf)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.Equal(t, expected, buf)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if e, a := 3, n; e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
|
||||
if !bytes.Equal(expected, buf) {
|
||||
t.Error("expected equivalent byte slices, but received otherwise")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytesReadWriteSeeker_Write(t *testing.T) {
|
||||
@@ -23,25 +32,53 @@ func TestBytesReadWriteSeeker_Write(t *testing.T) {
|
||||
buf := make([]byte, 3)
|
||||
n, err := b.Write([]byte{1, 2, 3})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, n)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if e, a := 3, n; e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
|
||||
n, err = b.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.Equal(t, expected, buf)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if e, a := 3, n; e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
|
||||
if !bytes.Equal(expected, buf) {
|
||||
t.Error("expected equivalent byte slices, but received otherwise")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytesReadWriteSeeker_Seek(t *testing.T) {
|
||||
b := &bytesReadWriteSeeker{[]byte{1, 2, 3}, 0}
|
||||
expected := []byte{2, 3}
|
||||
m, err := b.Seek(1, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, int(m))
|
||||
m, err := b.Seek(1, sdkio.SeekStart)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if e, a := 1, int(m); e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
|
||||
buf := make([]byte, 3)
|
||||
n, err := b.Read(buf)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, n)
|
||||
assert.Equal(t, expected, buf[:n])
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if e, a := 2, n; e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
|
||||
if !bytes.Equal(expected, buf[:n]) {
|
||||
t.Error("expected equivalent byte slices, but received otherwise")
|
||||
}
|
||||
}
|
||||
|
||||
+9
-5
@@ -2,15 +2,19 @@ package s3crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGenerateBytes(t *testing.T) {
|
||||
b := generateBytes(5)
|
||||
assert.Equal(t, 5, len(b))
|
||||
if e, a := 5, len(b); e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
b = generateBytes(0)
|
||||
assert.Equal(t, 0, len(b))
|
||||
if e, a := 0, len(b); e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
b = generateBytes(1024)
|
||||
assert.Equal(t, 1024, len(b))
|
||||
if e, a := 1024, len(b); e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+34
-14
@@ -1,15 +1,15 @@
|
||||
package s3crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
func TestBuildKMSEncryptHandler(t *testing.T) {
|
||||
svc := kms.New(unit.Session)
|
||||
handler := NewKMSKeyGenerator(svc, "testid")
|
||||
assert.NotNil(t, handler)
|
||||
if handler == nil {
|
||||
t.Error("expected non-nil handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKMSEncryptHandlerWithMatDesc(t *testing.T) {
|
||||
@@ -26,14 +28,19 @@ func TestBuildKMSEncryptHandlerWithMatDesc(t *testing.T) {
|
||||
handler := NewKMSKeyGeneratorWithMatDesc(svc, "testid", MaterialDescription{
|
||||
"Testing": aws.String("123"),
|
||||
})
|
||||
assert.NotNil(t, handler)
|
||||
if handler == nil {
|
||||
t.Error("expected non-nil handler")
|
||||
}
|
||||
|
||||
kmsHandler := handler.(*kmsKeyHandler)
|
||||
expected := MaterialDescription{
|
||||
"kms_cmk_id": aws.String("testid"),
|
||||
"Testing": aws.String("123"),
|
||||
}
|
||||
assert.Equal(t, expected, kmsHandler.CipherData.MaterialDescription)
|
||||
|
||||
if !reflect.DeepEqual(expected, kmsHandler.CipherData.MaterialDescription) {
|
||||
t.Errorf("expected %v, but received %v", expected, kmsHandler.CipherData.MaterialDescription)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMSGenerateCipherData(t *testing.T) {
|
||||
@@ -56,11 +63,15 @@ func TestKMSGenerateCipherData(t *testing.T) {
|
||||
ivSize := 16
|
||||
|
||||
cd, err := handler.GenerateCipherData(keySize, ivSize)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, keySize, len(cd.Key))
|
||||
assert.Equal(t, ivSize, len(cd.IV))
|
||||
assert.NotEmpty(t, cd.Key)
|
||||
assert.NotEmpty(t, cd.IV)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if keySize != len(cd.Key) {
|
||||
t.Errorf("expected %d, but received %d", keySize, len(cd.Key))
|
||||
}
|
||||
if ivSize != len(cd.IV) {
|
||||
t.Errorf("expected %d, but received %d", ivSize, len(cd.IV))
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMSDecrypt(t *testing.T) {
|
||||
@@ -78,11 +89,18 @@ func TestKMSDecrypt(t *testing.T) {
|
||||
Region: aws.String("us-west-2"),
|
||||
})
|
||||
handler, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{MatDesc: `{"kms_cmk_id":"test"}`})
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
plaintextKey, err := handler.DecryptKey([]byte{1, 2, 3, 4})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key, plaintextKey)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(key, plaintextKey) {
|
||||
t.Errorf("expected %v, but received %v", key, plaintextKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMSDecryptBadJSON(t *testing.T) {
|
||||
@@ -101,5 +119,7 @@ func TestKMSDecryptBadJSON(t *testing.T) {
|
||||
})
|
||||
|
||||
_, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{MatDesc: `{"kms_cmk_id":"test"`})
|
||||
assert.Error(t, err)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, but received none")
|
||||
}
|
||||
}
|
||||
|
||||
+13
-6
@@ -1,10 +1,9 @@
|
||||
package s3crypto
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
)
|
||||
|
||||
@@ -13,8 +12,12 @@ func TestEncodeMaterialDescription(t *testing.T) {
|
||||
md["foo"] = aws.String("bar")
|
||||
b, err := md.encodeDescription()
|
||||
expected := `{"foo":"bar"}`
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, string(b))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if expected != string(b) {
|
||||
t.Errorf("expected %s, but received %s", expected, string(b))
|
||||
}
|
||||
}
|
||||
func TestDecodeMaterialDescription(t *testing.T) {
|
||||
md := MaterialDescription{}
|
||||
@@ -23,6 +26,10 @@ func TestDecodeMaterialDescription(t *testing.T) {
|
||||
expected := MaterialDescription{
|
||||
"foo": aws.String("bar"),
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, md)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(expected, md) {
|
||||
t.Error("expected material description to be equivalent, but received otherwise")
|
||||
}
|
||||
}
|
||||
|
||||
+4
-1
@@ -63,9 +63,12 @@ func (strat HeaderV2SaveStrategy) Save(env Envelope, req *request.Request) error
|
||||
input.Metadata[http.CanonicalHeaderKey(matDescHeader)] = &env.MatDesc
|
||||
input.Metadata[http.CanonicalHeaderKey(wrapAlgorithmHeader)] = &env.WrapAlg
|
||||
input.Metadata[http.CanonicalHeaderKey(cekAlgorithmHeader)] = &env.CEKAlg
|
||||
input.Metadata[http.CanonicalHeaderKey(tagLengthHeader)] = &env.TagLen
|
||||
input.Metadata[http.CanonicalHeaderKey(unencryptedMD5Header)] = &env.UnencryptedMD5
|
||||
input.Metadata[http.CanonicalHeaderKey(unencryptedContentLengthHeader)] = &env.UnencryptedContentLen
|
||||
|
||||
if len(env.TagLen) > 0 {
|
||||
input.Metadata[http.CanonicalHeaderKey(tagLengthHeader)] = &env.TagLen
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+62
-31
@@ -1,10 +1,9 @@
|
||||
package s3crypto_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
@@ -12,35 +11,67 @@ import (
|
||||
)
|
||||
|
||||
func TestHeaderV2SaveStrategy(t *testing.T) {
|
||||
env := s3crypto.Envelope{
|
||||
CipherKey: "Foo",
|
||||
IV: "Bar",
|
||||
MatDesc: "{}",
|
||||
WrapAlg: s3crypto.KMSWrap,
|
||||
CEKAlg: s3crypto.AESGCMNoPadding,
|
||||
TagLen: "128",
|
||||
UnencryptedMD5: "hello",
|
||||
UnencryptedContentLen: "0",
|
||||
}
|
||||
params := &s3.PutObjectInput{}
|
||||
req := &request.Request{
|
||||
Params: params,
|
||||
}
|
||||
strat := s3crypto.HeaderV2SaveStrategy{}
|
||||
err := strat.Save(env, req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := map[string]*string{
|
||||
"X-Amz-Key-V2": aws.String("Foo"),
|
||||
"X-Amz-Iv": aws.String("Bar"),
|
||||
"X-Amz-Matdesc": aws.String("{}"),
|
||||
"X-Amz-Wrap-Alg": aws.String(s3crypto.KMSWrap),
|
||||
"X-Amz-Cek-Alg": aws.String(s3crypto.AESGCMNoPadding),
|
||||
"X-Amz-Tag-Len": aws.String("128"),
|
||||
"X-Amz-Unencrypted-Content-Md5": aws.String("hello"),
|
||||
"X-Amz-Unencrypted-Content-Length": aws.String("0"),
|
||||
cases := []struct {
|
||||
env s3crypto.Envelope
|
||||
expected map[string]*string
|
||||
}{
|
||||
{
|
||||
s3crypto.Envelope{
|
||||
CipherKey: "Foo",
|
||||
IV: "Bar",
|
||||
MatDesc: "{}",
|
||||
WrapAlg: s3crypto.KMSWrap,
|
||||
CEKAlg: s3crypto.AESGCMNoPadding,
|
||||
TagLen: "128",
|
||||
UnencryptedMD5: "hello",
|
||||
UnencryptedContentLen: "0",
|
||||
},
|
||||
map[string]*string{
|
||||
"X-Amz-Key-V2": aws.String("Foo"),
|
||||
"X-Amz-Iv": aws.String("Bar"),
|
||||
"X-Amz-Matdesc": aws.String("{}"),
|
||||
"X-Amz-Wrap-Alg": aws.String(s3crypto.KMSWrap),
|
||||
"X-Amz-Cek-Alg": aws.String(s3crypto.AESGCMNoPadding),
|
||||
"X-Amz-Tag-Len": aws.String("128"),
|
||||
"X-Amz-Unencrypted-Content-Md5": aws.String("hello"),
|
||||
"X-Amz-Unencrypted-Content-Length": aws.String("0"),
|
||||
},
|
||||
},
|
||||
{
|
||||
s3crypto.Envelope{
|
||||
CipherKey: "Foo",
|
||||
IV: "Bar",
|
||||
MatDesc: "{}",
|
||||
WrapAlg: s3crypto.KMSWrap,
|
||||
CEKAlg: s3crypto.AESGCMNoPadding,
|
||||
UnencryptedMD5: "hello",
|
||||
UnencryptedContentLen: "0",
|
||||
},
|
||||
map[string]*string{
|
||||
"X-Amz-Key-V2": aws.String("Foo"),
|
||||
"X-Amz-Iv": aws.String("Bar"),
|
||||
"X-Amz-Matdesc": aws.String("{}"),
|
||||
"X-Amz-Wrap-Alg": aws.String(s3crypto.KMSWrap),
|
||||
"X-Amz-Cek-Alg": aws.String(s3crypto.AESGCMNoPadding),
|
||||
"X-Amz-Unencrypted-Content-Md5": aws.String("hello"),
|
||||
"X-Amz-Unencrypted-Content-Length": aws.String("0"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, len(expected), len(params.Metadata))
|
||||
assert.Equal(t, expected, params.Metadata)
|
||||
for _, c := range cases {
|
||||
params := &s3.PutObjectInput{}
|
||||
req := &request.Request{
|
||||
Params: params,
|
||||
}
|
||||
strat := s3crypto.HeaderV2SaveStrategy{}
|
||||
err := strat.Save(c.env, req)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(c.expected, params.Metadata) {
|
||||
t.Errorf("expected %v, but received %v", c.expected, params.Metadata)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+12
@@ -92,6 +92,10 @@ type S3API interface {
|
||||
DeleteBucketCorsWithContext(aws.Context, *s3.DeleteBucketCorsInput, ...request.Option) (*s3.DeleteBucketCorsOutput, error)
|
||||
DeleteBucketCorsRequest(*s3.DeleteBucketCorsInput) (*request.Request, *s3.DeleteBucketCorsOutput)
|
||||
|
||||
DeleteBucketEncryption(*s3.DeleteBucketEncryptionInput) (*s3.DeleteBucketEncryptionOutput, error)
|
||||
DeleteBucketEncryptionWithContext(aws.Context, *s3.DeleteBucketEncryptionInput, ...request.Option) (*s3.DeleteBucketEncryptionOutput, error)
|
||||
DeleteBucketEncryptionRequest(*s3.DeleteBucketEncryptionInput) (*request.Request, *s3.DeleteBucketEncryptionOutput)
|
||||
|
||||
DeleteBucketInventoryConfiguration(*s3.DeleteBucketInventoryConfigurationInput) (*s3.DeleteBucketInventoryConfigurationOutput, error)
|
||||
DeleteBucketInventoryConfigurationWithContext(aws.Context, *s3.DeleteBucketInventoryConfigurationInput, ...request.Option) (*s3.DeleteBucketInventoryConfigurationOutput, error)
|
||||
DeleteBucketInventoryConfigurationRequest(*s3.DeleteBucketInventoryConfigurationInput) (*request.Request, *s3.DeleteBucketInventoryConfigurationOutput)
|
||||
@@ -148,6 +152,10 @@ type S3API interface {
|
||||
GetBucketCorsWithContext(aws.Context, *s3.GetBucketCorsInput, ...request.Option) (*s3.GetBucketCorsOutput, error)
|
||||
GetBucketCorsRequest(*s3.GetBucketCorsInput) (*request.Request, *s3.GetBucketCorsOutput)
|
||||
|
||||
GetBucketEncryption(*s3.GetBucketEncryptionInput) (*s3.GetBucketEncryptionOutput, error)
|
||||
GetBucketEncryptionWithContext(aws.Context, *s3.GetBucketEncryptionInput, ...request.Option) (*s3.GetBucketEncryptionOutput, error)
|
||||
GetBucketEncryptionRequest(*s3.GetBucketEncryptionInput) (*request.Request, *s3.GetBucketEncryptionOutput)
|
||||
|
||||
GetBucketInventoryConfiguration(*s3.GetBucketInventoryConfigurationInput) (*s3.GetBucketInventoryConfigurationOutput, error)
|
||||
GetBucketInventoryConfigurationWithContext(aws.Context, *s3.GetBucketInventoryConfigurationInput, ...request.Option) (*s3.GetBucketInventoryConfigurationOutput, error)
|
||||
GetBucketInventoryConfigurationRequest(*s3.GetBucketInventoryConfigurationInput) (*request.Request, *s3.GetBucketInventoryConfigurationOutput)
|
||||
@@ -295,6 +303,10 @@ type S3API interface {
|
||||
PutBucketCorsWithContext(aws.Context, *s3.PutBucketCorsInput, ...request.Option) (*s3.PutBucketCorsOutput, error)
|
||||
PutBucketCorsRequest(*s3.PutBucketCorsInput) (*request.Request, *s3.PutBucketCorsOutput)
|
||||
|
||||
PutBucketEncryption(*s3.PutBucketEncryptionInput) (*s3.PutBucketEncryptionOutput, error)
|
||||
PutBucketEncryptionWithContext(aws.Context, *s3.PutBucketEncryptionInput, ...request.Option) (*s3.PutBucketEncryptionOutput, error)
|
||||
PutBucketEncryptionRequest(*s3.PutBucketEncryptionInput) (*request.Request, *s3.PutBucketEncryptionOutput)
|
||||
|
||||
PutBucketInventoryConfiguration(*s3.PutBucketInventoryConfigurationInput) (*s3.PutBucketInventoryConfigurationOutput, error)
|
||||
PutBucketInventoryConfigurationWithContext(aws.Context, *s3.PutBucketInventoryConfigurationInput, ...request.Option) (*s3.PutBucketInventoryConfigurationOutput, error)
|
||||
PutBucketInventoryConfigurationRequest(*s3.PutBucketInventoryConfigurationInput) (*request.Request, *s3.PutBucketInventoryConfigurationOutput)
|
||||
|
||||
+32
-8
@@ -60,7 +60,15 @@ func newError(err error, bucket, key *string) Error {
|
||||
}
|
||||
|
||||
func (err *Error) Error() string {
|
||||
return fmt.Sprintf("failed to upload %q to %q:\n%s", err.Key, err.Bucket, err.OrigErr.Error())
|
||||
origErr := ""
|
||||
if err.OrigErr != nil {
|
||||
origErr = ":\n" + err.OrigErr.Error()
|
||||
}
|
||||
return fmt.Sprintf("failed to perform batch operation on %q to %q%s",
|
||||
aws.StringValue(err.Key),
|
||||
aws.StringValue(err.Bucket),
|
||||
origErr,
|
||||
)
|
||||
}
|
||||
|
||||
// NewBatchError will return a BatchError that satisfies the awserr.Error interface.
|
||||
@@ -206,7 +214,7 @@ type BatchDelete struct {
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// if err := batcher.Delete(&s3manager.DeleteObjectsIterator{
|
||||
// if err := batcher.Delete(aws.BackgroundContext(), &s3manager.DeleteObjectsIterator{
|
||||
// Objects: objects,
|
||||
// }); err != nil {
|
||||
// return err
|
||||
@@ -239,7 +247,7 @@ func NewBatchDeleteWithClient(client s3iface.S3API, options ...func(*BatchDelete
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// if err := batcher.Delete(&s3manager.DeleteObjectsIterator{
|
||||
// if err := batcher.Delete(aws.BackgroundContext(), &s3manager.DeleteObjectsIterator{
|
||||
// Objects: objects,
|
||||
// }); err != nil {
|
||||
// return err
|
||||
@@ -312,7 +320,7 @@ func (d *BatchDelete) Delete(ctx aws.Context, iter BatchDeleteIterator) error {
|
||||
}
|
||||
|
||||
if len(input.Delete.Objects) == d.BatchSize || !parity {
|
||||
if err := deleteBatch(d, input, objects); err != nil {
|
||||
if err := deleteBatch(ctx, d, input, objects); err != nil {
|
||||
errs = append(errs, err...)
|
||||
}
|
||||
|
||||
@@ -331,7 +339,7 @@ func (d *BatchDelete) Delete(ctx aws.Context, iter BatchDeleteIterator) error {
|
||||
}
|
||||
|
||||
if input != nil && len(input.Delete.Objects) > 0 {
|
||||
if err := deleteBatch(d, input, objects); err != nil {
|
||||
if err := deleteBatch(ctx, d, input, objects); err != nil {
|
||||
errs = append(errs, err...)
|
||||
}
|
||||
}
|
||||
@@ -351,17 +359,33 @@ func initDeleteObjectsInput(o *s3.DeleteObjectInput) *s3.DeleteObjectsInput {
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// ErrDeleteBatchFailCode represents an error code which will be returned
|
||||
// only when DeleteObjects.Errors has an error that does not contain a code.
|
||||
ErrDeleteBatchFailCode = "DeleteBatchError"
|
||||
errDefaultDeleteBatchMessage = "failed to delete"
|
||||
)
|
||||
|
||||
// deleteBatch will delete a batch of items in the objects parameters.
|
||||
func deleteBatch(d *BatchDelete, input *s3.DeleteObjectsInput, objects []BatchDeleteObject) []Error {
|
||||
func deleteBatch(ctx aws.Context, d *BatchDelete, input *s3.DeleteObjectsInput, objects []BatchDeleteObject) []Error {
|
||||
errs := []Error{}
|
||||
|
||||
if result, err := d.Client.DeleteObjects(input); err != nil {
|
||||
if result, err := d.Client.DeleteObjectsWithContext(ctx, input); err != nil {
|
||||
for i := 0; i < len(input.Delete.Objects); i++ {
|
||||
errs = append(errs, newError(err, input.Bucket, input.Delete.Objects[i].Key))
|
||||
}
|
||||
} else if len(result.Errors) > 0 {
|
||||
for i := 0; i < len(result.Errors); i++ {
|
||||
errs = append(errs, newError(err, input.Bucket, result.Errors[i].Key))
|
||||
code := ErrDeleteBatchFailCode
|
||||
msg := errDefaultDeleteBatchMessage
|
||||
if result.Errors[i].Message != nil {
|
||||
msg = *result.Errors[i].Message
|
||||
}
|
||||
if result.Errors[i].Code != nil {
|
||||
code = *result.Errors[i].Code
|
||||
}
|
||||
|
||||
errs = append(errs, newError(awserr.New(code, msg, err), input.Bucket, result.Errors[i].Key))
|
||||
}
|
||||
}
|
||||
for _, object := range objects {
|
||||
|
||||
+116
@@ -0,0 +1,116 @@
|
||||
// +build go1.7
|
||||
|
||||
package s3manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
// #1790 bug
|
||||
func TestBatchDeleteContext(t *testing.T) {
|
||||
cases := []struct {
|
||||
objects []BatchDeleteObject
|
||||
size int
|
||||
expected int
|
||||
ctx aws.Context
|
||||
closeAt int
|
||||
errCheck func(error) (string, bool)
|
||||
}{
|
||||
{
|
||||
[]BatchDeleteObject{
|
||||
{
|
||||
Object: &s3.DeleteObjectInput{
|
||||
Key: aws.String("1"),
|
||||
Bucket: aws.String("bucket1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Object: &s3.DeleteObjectInput{
|
||||
Key: aws.String("2"),
|
||||
Bucket: aws.String("bucket2"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Object: &s3.DeleteObjectInput{
|
||||
Key: aws.String("3"),
|
||||
Bucket: aws.String("bucket3"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Object: &s3.DeleteObjectInput{
|
||||
Key: aws.String("4"),
|
||||
Bucket: aws.String("bucket4"),
|
||||
},
|
||||
},
|
||||
},
|
||||
1,
|
||||
0,
|
||||
aws.BackgroundContext(),
|
||||
0,
|
||||
func(err error) (string, bool) {
|
||||
batchErr, ok := err.(*BatchError)
|
||||
if !ok {
|
||||
return "not BatchError type", false
|
||||
}
|
||||
|
||||
errs := batchErr.Errors
|
||||
if len(errs) != 4 {
|
||||
return fmt.Sprintf("expected 1, but received %d", len(errs)), false
|
||||
}
|
||||
|
||||
for _, tempErr := range errs {
|
||||
aerr, ok := tempErr.OrigErr.(awserr.Error)
|
||||
if !ok {
|
||||
return "not awserr.Error type", false
|
||||
}
|
||||
|
||||
if code := aerr.Code(); code != request.CanceledErrorCode {
|
||||
return fmt.Sprintf("expected %q, but received %q", request.CanceledErrorCode, code), false
|
||||
}
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
count := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
count++
|
||||
}))
|
||||
|
||||
svc := &mockS3Client{S3: buildS3SvcClient(server.URL)}
|
||||
for i, c := range cases {
|
||||
ctx, done := context.WithCancel(c.ctx)
|
||||
defer done()
|
||||
if i == c.closeAt {
|
||||
done()
|
||||
}
|
||||
|
||||
batcher := BatchDelete{
|
||||
Client: svc,
|
||||
BatchSize: c.size,
|
||||
}
|
||||
|
||||
err := batcher.Delete(ctx, &DeleteObjectsIterator{Objects: c.objects})
|
||||
|
||||
if msg, ok := c.errCheck(err); !ok {
|
||||
t.Error(msg)
|
||||
}
|
||||
|
||||
if count != c.expected {
|
||||
t.Errorf("Case %d: expected %d, but received %d", i, c.expected, count)
|
||||
}
|
||||
|
||||
count = 0
|
||||
}
|
||||
}
|
||||
+134
-2
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
@@ -309,10 +310,101 @@ func TestBatchDelete(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchDeleteError(t *testing.T) {
|
||||
cases := []struct {
|
||||
objects []BatchDeleteObject
|
||||
output s3.DeleteObjectsOutput
|
||||
size int
|
||||
expectedErrCode string
|
||||
expectedErrMessage string
|
||||
}{
|
||||
{
|
||||
[]BatchDeleteObject{
|
||||
{
|
||||
Object: &s3.DeleteObjectInput{
|
||||
Key: aws.String("1"),
|
||||
Bucket: aws.String("bucket1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
s3.DeleteObjectsOutput{
|
||||
Errors: []*s3.Error{
|
||||
{
|
||||
Code: aws.String("foo code"),
|
||||
Message: aws.String("foo error"),
|
||||
},
|
||||
},
|
||||
},
|
||||
1,
|
||||
"foo code",
|
||||
"foo error",
|
||||
},
|
||||
{
|
||||
[]BatchDeleteObject{
|
||||
{
|
||||
Object: &s3.DeleteObjectInput{
|
||||
Key: aws.String("1"),
|
||||
Bucket: aws.String("bucket1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
s3.DeleteObjectsOutput{
|
||||
Errors: []*s3.Error{
|
||||
{},
|
||||
},
|
||||
},
|
||||
1,
|
||||
ErrDeleteBatchFailCode,
|
||||
errDefaultDeleteBatchMessage,
|
||||
},
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
|
||||
index := 0
|
||||
svc := &mockS3Client{
|
||||
S3: buildS3SvcClient(server.URL),
|
||||
deleteObjects: func(input *s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error) {
|
||||
output := &cases[index].output
|
||||
index++
|
||||
return output, nil
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
batcher := BatchDelete{
|
||||
Client: svc,
|
||||
BatchSize: c.size,
|
||||
}
|
||||
|
||||
err := batcher.Delete(aws.BackgroundContext(), &DeleteObjectsIterator{Objects: c.objects})
|
||||
if err == nil {
|
||||
t.Errorf("expected error, but received none")
|
||||
}
|
||||
|
||||
berr := err.(*BatchError)
|
||||
|
||||
if len(berr.Errors) != 1 {
|
||||
t.Errorf("expected 1 error, but received %d", len(berr.Errors))
|
||||
}
|
||||
|
||||
aerr := berr.Errors[0].OrigErr.(awserr.Error)
|
||||
if e, a := c.expectedErrCode, aerr.Code(); e != a {
|
||||
t.Errorf("expected %q, but received %q", e, a)
|
||||
}
|
||||
|
||||
if e, a := c.expectedErrMessage, aerr.Message(); e != a {
|
||||
t.Errorf("expected %q, but received %q", e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockS3Client struct {
|
||||
*s3.S3
|
||||
index int
|
||||
objects []*s3.ListObjectsOutput
|
||||
index int
|
||||
objects []*s3.ListObjectsOutput
|
||||
deleteObjects func(*s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error)
|
||||
}
|
||||
|
||||
func (client *mockS3Client) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
|
||||
@@ -321,6 +413,46 @@ func (client *mockS3Client) ListObjects(input *s3.ListObjectsInput) (*s3.ListObj
|
||||
return object, nil
|
||||
}
|
||||
|
||||
func (client *mockS3Client) DeleteObjects(input *s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error) {
|
||||
if client.deleteObjects == nil {
|
||||
return client.S3.DeleteObjectsWithContext(aws.BackgroundContext(), input)
|
||||
}
|
||||
|
||||
return client.deleteObjects(input)
|
||||
}
|
||||
|
||||
func (client *mockS3Client) DeleteObjectsWithContext(ctx aws.Context, input *s3.DeleteObjectsInput, opt ...request.Option) (*s3.DeleteObjectsOutput, error) {
|
||||
if client.deleteObjects == nil {
|
||||
return client.S3.DeleteObjectsWithContext(ctx, input)
|
||||
}
|
||||
|
||||
return client.deleteObjects(input)
|
||||
}
|
||||
|
||||
func TestNilOrigError(t *testing.T) {
|
||||
err := Error{
|
||||
Bucket: aws.String("bucket"),
|
||||
Key: aws.String("key"),
|
||||
}
|
||||
errStr := err.Error()
|
||||
const expected1 = `failed to perform batch operation on "key" to "bucket"`
|
||||
if errStr != expected1 {
|
||||
t.Errorf("Expected %s, but received %s", expected1, errStr)
|
||||
}
|
||||
|
||||
err = Error{
|
||||
OrigErr: errors.New("foo"),
|
||||
Bucket: aws.String("bucket"),
|
||||
Key: aws.String("key"),
|
||||
}
|
||||
errStr = err.Error()
|
||||
const expected2 = "failed to perform batch operation on \"key\" to \"bucket\":\nfoo"
|
||||
if errStr != expected2 {
|
||||
t.Errorf("Expected %s, but received %s", expected2, errStr)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestBatchDeleteList(t *testing.T) {
|
||||
count := 0
|
||||
|
||||
|
||||
+10
-5
@@ -14,8 +14,11 @@ import (
|
||||
//
|
||||
// The request will not be signed, and will not use your AWS credentials.
|
||||
//
|
||||
// A "NotFound" error code will be returned if the bucket does not exist in
|
||||
// the AWS partition the regionHint belongs to.
|
||||
// A "NotFound" error code will be returned if the bucket does not exist in the
|
||||
// AWS partition the regionHint belongs to. If the regionHint parameter is an
|
||||
// empty string GetBucketRegion will fallback to the ConfigProvider's region
|
||||
// config. If the regionHint is empty, and the ConfigProvider does not have a
|
||||
// region value, an error will be returned..
|
||||
//
|
||||
// For example to get the region of a bucket which exists in "eu-central-1"
|
||||
// you could provide a region hint of "us-west-2".
|
||||
@@ -33,9 +36,11 @@ import (
|
||||
// fmt.Printf("Bucket %s is in %s region\n", bucket, region)
|
||||
//
|
||||
func GetBucketRegion(ctx aws.Context, c client.ConfigProvider, bucket, regionHint string, opts ...request.Option) (string, error) {
|
||||
svc := s3.New(c, &aws.Config{
|
||||
Region: aws.String(regionHint),
|
||||
})
|
||||
var cfg aws.Config
|
||||
if len(regionHint) != 0 {
|
||||
cfg.Region = aws.String(regionHint)
|
||||
}
|
||||
svc := s3.New(c, &cfg)
|
||||
return GetBucketRegionWithClient(ctx, svc, bucket, opts...)
|
||||
}
|
||||
|
||||
|
||||
+12
-8
@@ -21,12 +21,15 @@ func testSetupGetBucketRegionServer(region string, statusCode int, incHeader boo
|
||||
}
|
||||
|
||||
var testGetBucketRegionCases = []struct {
|
||||
RespRegion string
|
||||
StatusCode int
|
||||
RespRegion string
|
||||
StatusCode int
|
||||
HintRegion string
|
||||
ExpectReqRegion string
|
||||
}{
|
||||
{"bucket-region", 301},
|
||||
{"bucket-region", 403},
|
||||
{"bucket-region", 200},
|
||||
{"bucket-region", 301, "hint-region", ""},
|
||||
{"bucket-region", 403, "hint-region", ""},
|
||||
{"bucket-region", 200, "hint-region", ""},
|
||||
{"bucket-region", 200, "", "default-region"},
|
||||
}
|
||||
|
||||
func TestGetBucketRegion_Exists(t *testing.T) {
|
||||
@@ -34,11 +37,12 @@ func TestGetBucketRegion_Exists(t *testing.T) {
|
||||
server := testSetupGetBucketRegionServer(c.RespRegion, c.StatusCode, true)
|
||||
|
||||
sess := unit.Session.Copy()
|
||||
sess.Config.Region = aws.String("default-region")
|
||||
sess.Config.Endpoint = aws.String(server.URL)
|
||||
sess.Config.DisableSSL = aws.Bool(true)
|
||||
|
||||
ctx := aws.BackgroundContext()
|
||||
region, err := GetBucketRegion(ctx, sess, "bucket", "region")
|
||||
region, err := GetBucketRegion(ctx, sess, "bucket", c.HintRegion)
|
||||
if err != nil {
|
||||
t.Fatalf("%d, expect no error, got %v", i, err)
|
||||
}
|
||||
@@ -56,7 +60,7 @@ func TestGetBucketRegion_NotExists(t *testing.T) {
|
||||
sess.Config.DisableSSL = aws.Bool(true)
|
||||
|
||||
ctx := aws.BackgroundContext()
|
||||
region, err := GetBucketRegion(ctx, sess, "bucket", "region")
|
||||
region, err := GetBucketRegion(ctx, sess, "bucket", "hint-region")
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, but did not get one")
|
||||
}
|
||||
@@ -74,7 +78,7 @@ func TestGetBucketRegionWithClient(t *testing.T) {
|
||||
server := testSetupGetBucketRegionServer(c.RespRegion, c.StatusCode, true)
|
||||
|
||||
svc := s3.New(unit.Session, &aws.Config{
|
||||
Region: aws.String("region"),
|
||||
Region: aws.String("hint-region"),
|
||||
Endpoint: aws.String(server.URL),
|
||||
DisableSSL: aws.Bool(true),
|
||||
})
|
||||
|
||||
+26
-16
@@ -117,6 +117,9 @@ type UploadInput struct {
|
||||
// The language the content is in.
|
||||
ContentLanguage *string `location:"header" locationName:"Content-Language" type:"string"`
|
||||
|
||||
// The base64-encoded 128-bit MD5 digest of the part data.
|
||||
ContentMD5 *string `location:"header" locationName:"Content-MD5" type:"string"`
|
||||
|
||||
// A standard MIME type describing the format of the object data.
|
||||
ContentType *string `location:"header" locationName:"Content-Type" type:"string"`
|
||||
|
||||
@@ -440,6 +443,8 @@ type uploader struct {
|
||||
|
||||
readerPos int64 // current reader position
|
||||
totalSize int64 // set to -1 if the size is not known
|
||||
|
||||
bufferPool sync.Pool
|
||||
}
|
||||
|
||||
// internal logic for deciding whether to upload a single part or use a
|
||||
@@ -453,7 +458,7 @@ func (u *uploader) upload() (*UploadOutput, error) {
|
||||
}
|
||||
|
||||
// Do one read to determine if we have more than one part
|
||||
reader, _, err := u.nextReader()
|
||||
reader, _, part, err := u.nextReader()
|
||||
if err == io.EOF { // single part
|
||||
return u.singlePart(reader)
|
||||
} else if err != nil {
|
||||
@@ -461,7 +466,7 @@ func (u *uploader) upload() (*UploadOutput, error) {
|
||||
}
|
||||
|
||||
mu := multiuploader{uploader: u}
|
||||
return mu.upload(reader)
|
||||
return mu.upload(reader, part)
|
||||
}
|
||||
|
||||
// init will initialize all default options.
|
||||
@@ -473,6 +478,10 @@ func (u *uploader) init() {
|
||||
u.cfg.PartSize = DefaultUploadPartSize
|
||||
}
|
||||
|
||||
u.bufferPool = sync.Pool{
|
||||
New: func() interface{} { return make([]byte, u.cfg.PartSize) },
|
||||
}
|
||||
|
||||
// Try to get the total size for some optimizations
|
||||
u.initSize()
|
||||
}
|
||||
@@ -484,10 +493,7 @@ func (u *uploader) initSize() {
|
||||
|
||||
switch r := u.in.Body.(type) {
|
||||
case io.Seeker:
|
||||
pos, _ := r.Seek(0, 1)
|
||||
defer r.Seek(pos, 0)
|
||||
|
||||
n, err := r.Seek(0, 2)
|
||||
n, err := aws.SeekerLen(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -507,7 +513,7 @@ func (u *uploader) initSize() {
|
||||
// This operation increases the shared u.readerPos counter, but note that it
|
||||
// does not need to be wrapped in a mutex because nextReader is only called
|
||||
// from the main thread.
|
||||
func (u *uploader) nextReader() (io.ReadSeeker, int, error) {
|
||||
func (u *uploader) nextReader() (io.ReadSeeker, int, []byte, error) {
|
||||
type readerAtSeeker interface {
|
||||
io.ReaderAt
|
||||
io.ReadSeeker
|
||||
@@ -529,14 +535,14 @@ func (u *uploader) nextReader() (io.ReadSeeker, int, error) {
|
||||
reader := io.NewSectionReader(r, u.readerPos, n)
|
||||
u.readerPos += n
|
||||
|
||||
return reader, int(n), err
|
||||
return reader, int(n), nil, err
|
||||
|
||||
default:
|
||||
part := make([]byte, u.cfg.PartSize)
|
||||
part := u.bufferPool.Get().([]byte)
|
||||
n, err := readFillBuf(r, part)
|
||||
u.readerPos += int64(n)
|
||||
|
||||
return bytes.NewReader(part[0:n]), n, err
|
||||
return bytes.NewReader(part[0:n]), n, part, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -586,8 +592,9 @@ type multiuploader struct {
|
||||
|
||||
// keeps track of a single chunk of data being sent to S3.
|
||||
type chunk struct {
|
||||
buf io.ReadSeeker
|
||||
num int64
|
||||
buf io.ReadSeeker
|
||||
part []byte
|
||||
num int64
|
||||
}
|
||||
|
||||
// completedParts is a wrapper to make parts sortable by their part number,
|
||||
@@ -600,7 +607,7 @@ func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].Pa
|
||||
|
||||
// upload will perform a multipart upload using the firstBuf buffer containing
|
||||
// the first chunk of data.
|
||||
func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
|
||||
func (u *multiuploader) upload(firstBuf io.ReadSeeker, firstPart []byte) (*UploadOutput, error) {
|
||||
params := &s3.CreateMultipartUploadInput{}
|
||||
awsutil.Copy(params, u.in)
|
||||
|
||||
@@ -620,7 +627,7 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
|
||||
|
||||
// Send part 1 to the workers
|
||||
var num int64 = 1
|
||||
ch <- chunk{buf: firstBuf, num: num}
|
||||
ch <- chunk{buf: firstBuf, part: firstPart, num: num}
|
||||
|
||||
// Read and queue the rest of the parts
|
||||
for u.geterr() == nil && err == nil {
|
||||
@@ -641,7 +648,8 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
|
||||
|
||||
var reader io.ReadSeeker
|
||||
var nextChunkLen int
|
||||
reader, nextChunkLen, err = u.nextReader()
|
||||
var part []byte
|
||||
reader, nextChunkLen, part, err = u.nextReader()
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
u.seterr(awserr.New(
|
||||
@@ -658,7 +666,7 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
|
||||
break
|
||||
}
|
||||
|
||||
ch <- chunk{buf: reader, num: num}
|
||||
ch <- chunk{buf: reader, part: part, num: num}
|
||||
}
|
||||
|
||||
// Close the channel, wait for workers, and complete upload
|
||||
@@ -714,6 +722,8 @@ func (u *multiuploader) send(c chunk) error {
|
||||
PartNumber: &c.num,
|
||||
}
|
||||
resp, err := u.cfg.S3.UploadPartWithContext(u.ctx, params, u.cfg.RequestOptions...)
|
||||
// put the byte array back into the pool to conserve memory
|
||||
u.bufferPool.Put(c.part)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+49
-17
@@ -1,13 +1,13 @@
|
||||
package s3_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSSECustomerKeyOverHTTPError(t *testing.T) {
|
||||
@@ -20,9 +20,15 @@ func TestSSECustomerKeyOverHTTPError(t *testing.T) {
|
||||
})
|
||||
err := req.Build()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
|
||||
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
|
||||
if err == nil {
|
||||
t.Error("expected an error")
|
||||
}
|
||||
if e, a := "ConfigError", err.(awserr.Error).Code(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if !strings.Contains(err.(awserr.Error).Message(), "cannot send SSE keys over HTTP") {
|
||||
t.Errorf("expected error to contain 'cannot send SSE keys over HTTP', but received %s", err.(awserr.Error).Message())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopySourceSSECustomerKeyOverHTTPError(t *testing.T) {
|
||||
@@ -35,9 +41,15 @@ func TestCopySourceSSECustomerKeyOverHTTPError(t *testing.T) {
|
||||
})
|
||||
err := req.Build()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
|
||||
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
|
||||
if err == nil {
|
||||
t.Error("expected an error")
|
||||
}
|
||||
if e, a := "ConfigError", err.(awserr.Error).Code(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if !strings.Contains(err.(awserr.Error).Message(), "cannot send SSE keys over HTTP") {
|
||||
t.Errorf("expected error to contain 'cannot send SSE keys over HTTP', but received %s", err.(awserr.Error).Message())
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeSSEKeys(t *testing.T) {
|
||||
@@ -51,11 +63,21 @@ func TestComputeSSEKeys(t *testing.T) {
|
||||
})
|
||||
err := req.Build()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
|
||||
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
|
||||
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
|
||||
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeSSEKeysShortcircuit(t *testing.T) {
|
||||
@@ -71,9 +93,19 @@ func TestComputeSSEKeysShortcircuit(t *testing.T) {
|
||||
})
|
||||
err := req.Build()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
|
||||
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
|
||||
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
|
||||
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "MD5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "MD5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+2
-1
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
func copyMultipartStatusOKUnmarhsalError(r *request.Request) {
|
||||
@@ -17,7 +18,7 @@ func copyMultipartStatusOKUnmarhsalError(r *request.Request) {
|
||||
}
|
||||
body := bytes.NewReader(b)
|
||||
r.HTTPResponse.Body = ioutil.NopCloser(body)
|
||||
defer body.Seek(0, 0)
|
||||
defer body.Seek(0, sdkio.SeekStart)
|
||||
|
||||
if body.Len() == 0 {
|
||||
// If there is no body don't attempt to parse the body.
|
||||
|
||||
+60
-24
@@ -7,9 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
@@ -31,10 +28,15 @@ func TestCopyObjectNoError(t *testing.T) {
|
||||
Key: aws.String("destination.txt"),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyObjectResult.ETag)
|
||||
assert.Equal(t, lastModifiedTime, *res.CopyObjectResult.LastModified)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyObjectResult.ETag; e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := lastModifiedTime, *res.CopyObjectResult.LastModified; !e.Equal(a) {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyObjectError(t *testing.T) {
|
||||
@@ -44,11 +46,17 @@ func TestCopyObjectError(t *testing.T) {
|
||||
Key: aws.String("destination.txt"),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
if err == nil {
|
||||
t.Error("expected error, but received none")
|
||||
}
|
||||
e := err.(awserr.Error)
|
||||
|
||||
assert.Equal(t, "ErrorCode", e.Code())
|
||||
assert.Equal(t, "message body", e.Message())
|
||||
if e, a := "ErrorCode", e.Code(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "message body", e.Message(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadPartCopySuccess(t *testing.T) {
|
||||
@@ -64,10 +72,16 @@ func TestUploadPartCopySuccess(t *testing.T) {
|
||||
UploadId: aws.String("uploadID"),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyPartResult.ETag)
|
||||
assert.Equal(t, lastModifiedTime, *res.CopyPartResult.LastModified)
|
||||
if e, a := fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyPartResult.ETag; e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := lastModifiedTime, *res.CopyPartResult.LastModified; !e.Equal(a) {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadPartCopyError(t *testing.T) {
|
||||
@@ -79,11 +93,17 @@ func TestUploadPartCopyError(t *testing.T) {
|
||||
UploadId: aws.String("uploadID"),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
if err == nil {
|
||||
t.Error("expected an error")
|
||||
}
|
||||
e := err.(awserr.Error)
|
||||
|
||||
assert.Equal(t, "ErrorCode", e.Code())
|
||||
assert.Equal(t, "message body", e.Message())
|
||||
if e, a := "ErrorCode", e.Code(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "message body", e.Message(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteMultipartUploadSuccess(t *testing.T) {
|
||||
@@ -96,12 +116,22 @@ func TestCompleteMultipartUploadSuccess(t *testing.T) {
|
||||
UploadId: aws.String("uploadID"),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, `"etagVal"`, *res.ETag)
|
||||
assert.Equal(t, "bucketName", *res.Bucket)
|
||||
assert.Equal(t, "keyName", *res.Key)
|
||||
assert.Equal(t, "locationName", *res.Location)
|
||||
if e, a := `"etagVal"`, *res.ETag; e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "bucketName", *res.Bucket; e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "keyName", *res.Key; e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "locationName", *res.Location; e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteMultipartUploadError(t *testing.T) {
|
||||
@@ -111,11 +141,17 @@ func TestCompleteMultipartUploadError(t *testing.T) {
|
||||
UploadId: aws.String("uploadID"),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
if err == nil {
|
||||
t.Error("expected an error")
|
||||
}
|
||||
e := err.(awserr.Error)
|
||||
|
||||
assert.Equal(t, "ErrorCode", e.Code())
|
||||
assert.Equal(t, "message body", e.Message())
|
||||
if e, a := "ErrorCode", e.Code(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
if e, a := "message body", e.Message(); e != a {
|
||||
t.Errorf("expected %s, but received %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func newCopyTestSvc(errMsg string) *s3.S3 {
|
||||
|
||||
+11
-4
@@ -1,7 +1,6 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
@@ -27,7 +26,15 @@ func TestUnmarhsalErrorLeak(t *testing.T) {
|
||||
reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
|
||||
unmarshalError(req)
|
||||
|
||||
assert.NotNil(t, req.Error)
|
||||
assert.Equal(t, reader.Closed, true)
|
||||
assert.Equal(t, reader.Size, 0)
|
||||
if req.Error == nil {
|
||||
t.Error("expected an error, but received none")
|
||||
}
|
||||
|
||||
if !reader.Closed {
|
||||
t.Error("expected reader to be closed")
|
||||
}
|
||||
|
||||
if e, a := 0, reader.Size; e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user