Conver to regular Go vendor + dep tool

This commit is contained in:
Andrey Smirnov
2017-03-22 17:38:32 +03:00
parent 070347295e
commit c6c1012330
3260 changed files with 1742550 additions and 72 deletions
+19228
View File
File diff suppressed because it is too large Load Diff
+43
View File
@@ -0,0 +1,43 @@
package s3
import (
"io/ioutil"
"regexp"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
var reBucketLocation = regexp.MustCompile(`>([^<>]+)<\/Location`)
func buildGetBucketLocation(r *request.Request) {
if r.DataFilled() {
out := r.Data.(*GetBucketLocationOutput)
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed reading response body", err)
return
}
match := reBucketLocation.FindSubmatch(b)
if len(match) > 1 {
loc := string(match[1])
out.LocationConstraint = &loc
}
}
}
func populateLocationConstraint(r *request.Request) {
if r.ParamsFilled() && aws.StringValue(r.Config.Region) != "us-east-1" {
in := r.Params.(*CreateBucketInput)
if in.CreateBucketConfiguration == nil {
r.Params = awsutil.CopyOf(r.Params)
in = r.Params.(*CreateBucketInput)
in.CreateBucketConfiguration = &CreateBucketConfiguration{
LocationConstraint: r.Config.Region,
}
}
}
}
+78
View File
@@ -0,0 +1,78 @@
package s3_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var s3LocationTests = []struct {
body string
loc string
}{
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/"/>`, ``},
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">EU</LocationConstraint>`, `EU`},
}
func TestGetBucketLocation(t *testing.T) {
for _, test := range s3LocationTests {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
reader := ioutil.NopCloser(bytes.NewReader([]byte(test.body)))
r.HTTPResponse = &http.Response{StatusCode: 200, Body: reader}
})
resp, err := s.GetBucketLocation(&s3.GetBucketLocationInput{Bucket: aws.String("bucket")})
assert.NoError(t, err)
if test.loc == "" {
assert.Nil(t, resp.LocationConstraint)
} else {
assert.Equal(t, test.loc, *resp.LocationConstraint)
}
}
}
func TestPopulateLocationConstraint(t *testing.T) {
s := s3.New(unit.Session)
in := &s3.CreateBucketInput{
Bucket: aws.String("bucket"),
}
req, _ := s.CreateBucketRequest(in)
err := req.Build()
assert.NoError(t, err)
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
assert.Equal(t, "mock-region", *(v[0].(*string)))
assert.Nil(t, in.CreateBucketConfiguration) // don't modify original params
}
func TestNoPopulateLocationConstraintIfProvided(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
CreateBucketConfiguration: &s3.CreateBucketConfiguration{},
})
err := req.Build()
assert.NoError(t, err)
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
assert.Equal(t, 0, len(v))
}
func TestNoPopulateLocationConstraintIfClassic(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{Region: aws.String("us-east-1")})
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
})
err := req.Build()
assert.NoError(t, err)
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
assert.Equal(t, 0, len(v))
}
+36
View File
@@ -0,0 +1,36 @@
package s3
import (
"crypto/md5"
"encoding/base64"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// contentMD5 computes and sets the HTTP Content-MD5 header for requests that
// require it.
func contentMD5(r *request.Request) {
h := md5.New()
// hash the body. seek back to the first position after reading to reset
// the body for transmission. copy errors may be assumed to be from the
// body.
_, err := io.Copy(h, r.Body)
if err != nil {
r.Error = awserr.New("ContentMD5", "failed to read body", err)
return
}
_, err = r.Body.Seek(0, 0)
if err != nil {
r.Error = awserr.New("ContentMD5", "failed to seek body", err)
return
}
// encode the md5 checksum in base64 and set the request header.
sum := h.Sum(nil)
sum64 := make([]byte, base64.StdEncoding.EncodedLen(len(sum)))
base64.StdEncoding.Encode(sum64, sum)
r.HTTPRequest.Header.Set("Content-MD5", string(sum64))
}
+46
View File
@@ -0,0 +1,46 @@
package s3
import (
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
)
func init() {
initClient = defaultInitClientFn
initRequest = defaultInitRequestFn
}
func defaultInitClientFn(c *client.Client) {
// Support building custom endpoints based on config
c.Handlers.Build.PushFront(updateEndpointForS3Config)
// Require SSL when using SSE keys
c.Handlers.Validate.PushBack(validateSSERequiresSSL)
c.Handlers.Build.PushBack(computeSSEKeys)
// S3 uses custom error unmarshaling logic
c.Handlers.UnmarshalError.Clear()
c.Handlers.UnmarshalError.PushBack(unmarshalError)
}
func defaultInitRequestFn(r *request.Request) {
// Add reuest handlers for specific platforms.
// e.g. 100-continue support for PUT requests using Go 1.6
platformRequestHandlers(r)
switch r.Operation.Name {
case opPutBucketCors, opPutBucketLifecycle, opPutBucketPolicy,
opPutBucketTagging, opDeleteObjects, opPutBucketLifecycleConfiguration,
opPutBucketReplication:
// These S3 operations require Content-MD5 to be set
r.Handlers.Build.PushBack(contentMD5)
case opGetBucketLocation:
// GetBucketLocation has custom parsing logic
r.Handlers.Unmarshal.PushFront(buildGetBucketLocation)
case opCreateBucket:
// Auto-populate LocationConstraint with current region
r.Handlers.Validate.PushFront(populateLocationConstraint)
case opCopyObject, opUploadPartCopy, opCompleteMultipartUpload:
r.Handlers.Unmarshal.PushFront(copyMultipartStatusOKUnmarhsalError)
}
}
+158
View File
@@ -0,0 +1,158 @@
package s3_test
import (
"crypto/md5"
"encoding/base64"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
func assertMD5(t *testing.T, req *request.Request) {
err := req.Build()
assert.NoError(t, err)
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
out := md5.Sum(b)
assert.NotEmpty(t, b)
assert.Equal(t, base64.StdEncoding.EncodeToString(out[:]), req.HTTPRequest.Header.Get("Content-MD5"))
}
func TestMD5InPutBucketCors(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketCorsRequest(&s3.PutBucketCorsInput{
Bucket: aws.String("bucketname"),
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{
{
AllowedMethods: []*string{aws.String("GET")},
AllowedOrigins: []*string{aws.String("*")},
},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketLifecycle(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketLifecycleRequest(&s3.PutBucketLifecycleInput{
Bucket: aws.String("bucketname"),
LifecycleConfiguration: &s3.LifecycleConfiguration{
Rules: []*s3.Rule{
{
ID: aws.String("ID"),
Prefix: aws.String("Prefix"),
Status: aws.String("Enabled"),
},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketPolicy(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketPolicyRequest(&s3.PutBucketPolicyInput{
Bucket: aws.String("bucketname"),
Policy: aws.String("{}"),
})
assertMD5(t, req)
}
func TestMD5InPutBucketTagging(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketTaggingRequest(&s3.PutBucketTaggingInput{
Bucket: aws.String("bucketname"),
Tagging: &s3.Tagging{
TagSet: []*s3.Tag{
{Key: aws.String("KEY"), Value: aws.String("VALUE")},
},
},
})
assertMD5(t, req)
}
func TestMD5InDeleteObjects(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.DeleteObjectsRequest(&s3.DeleteObjectsInput{
Bucket: aws.String("bucketname"),
Delete: &s3.Delete{
Objects: []*s3.ObjectIdentifier{
{Key: aws.String("key")},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketLifecycleConfiguration(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketLifecycleConfigurationRequest(&s3.PutBucketLifecycleConfigurationInput{
Bucket: aws.String("bucketname"),
LifecycleConfiguration: &s3.BucketLifecycleConfiguration{
Rules: []*s3.LifecycleRule{
{Prefix: aws.String("prefix"), Status: aws.String(s3.ExpirationStatusEnabled)},
},
},
})
assertMD5(t, req)
}
const (
metaKeyPrefix = `X-Amz-Meta-`
utf8KeySuffix = `My-Info`
utf8Value = "hello-世界\u0444"
)
func TestPutObjectMetadataWithUnicode(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, utf8Value, r.Header.Get(metaKeyPrefix+utf8KeySuffix))
}))
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
})
_, err := svc.PutObject(&s3.PutObjectInput{
Bucket: aws.String("my_bucket"),
Key: aws.String("my_key"),
Body: strings.NewReader(""),
Metadata: func() map[string]*string {
v := map[string]*string{}
v[utf8KeySuffix] = aws.String(utf8Value)
return v
}(),
})
assert.NoError(t, err)
}
func TestGetObjectMetadataWithUnicode(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(metaKeyPrefix+utf8KeySuffix, utf8Value)
}))
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
})
resp, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("my_bucket"),
Key: aws.String("my_key"),
})
assert.NoError(t, err)
resp.Body.Close()
assert.Equal(t, utf8Value, *resp.Metadata[utf8KeySuffix])
}
+48
View File
@@ -0,0 +1,48 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package s3
const (
// ErrCodeBucketAlreadyExists for service response error code
// "BucketAlreadyExists".
//
// The requested bucket name is not available. The bucket namespace is shared
// by all users of the system. Please select a different name and try again.
ErrCodeBucketAlreadyExists = "BucketAlreadyExists"
// ErrCodeBucketAlreadyOwnedByYou for service response error code
// "BucketAlreadyOwnedByYou".
ErrCodeBucketAlreadyOwnedByYou = "BucketAlreadyOwnedByYou"
// ErrCodeNoSuchBucket for service response error code
// "NoSuchBucket".
//
// The specified bucket does not exist.
ErrCodeNoSuchBucket = "NoSuchBucket"
// ErrCodeNoSuchKey for service response error code
// "NoSuchKey".
//
// The specified key does not exist.
ErrCodeNoSuchKey = "NoSuchKey"
// ErrCodeNoSuchUpload for service response error code
// "NoSuchUpload".
//
// The specified multipart upload does not exist.
ErrCodeNoSuchUpload = "NoSuchUpload"
// ErrCodeObjectAlreadyInActiveTierError for service response error code
// "ObjectAlreadyInActiveTierError".
//
// This operation is not allowed against this storage tier
ErrCodeObjectAlreadyInActiveTierError = "ObjectAlreadyInActiveTierError"
// ErrCodeObjectNotInActiveTierError for service response error code
// "ObjectNotInActiveTierError".
//
// The source object of the COPY operation is not in the active tier and is
// only stored in Amazon Glacier.
ErrCodeObjectNotInActiveTierError = "ObjectNotInActiveTierError"
)
File diff suppressed because it is too large Load Diff
+162
View File
@@ -0,0 +1,162 @@
package s3
import (
"fmt"
"net/url"
"regexp"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
// an operationBlacklist is a list of operation names that should a
// request handler should not be executed with.
type operationBlacklist []string
// Continue will return true of the Request's operation name is not
// in the blacklist. False otherwise.
func (b operationBlacklist) Continue(r *request.Request) bool {
for i := 0; i < len(b); i++ {
if b[i] == r.Operation.Name {
return false
}
}
return true
}
var accelerateOpBlacklist = operationBlacklist{
opListBuckets, opCreateBucket, opDeleteBucket,
}
// Request handler to automatically add the bucket name to the endpoint domain
// if possible. This style of bucket is valid for all bucket names which are
// DNS compatible and do not contain "."
func updateEndpointForS3Config(r *request.Request) {
forceHostStyle := aws.BoolValue(r.Config.S3ForcePathStyle)
accelerate := aws.BoolValue(r.Config.S3UseAccelerate)
if accelerate && accelerateOpBlacklist.Continue(r) {
if forceHostStyle {
if r.Config.Logger != nil {
r.Config.Logger.Log("ERROR: aws.Config.S3UseAccelerate is not compatible with aws.Config.S3ForcePathStyle, ignoring S3ForcePathStyle.")
}
}
updateEndpointForAccelerate(r)
} else if !forceHostStyle && r.Operation.Name != opGetBucketLocation {
updateEndpointForHostStyle(r)
}
}
func updateEndpointForHostStyle(r *request.Request) {
bucket, ok := bucketNameFromReqParams(r.Params)
if !ok {
// Ignore operation requests if the bucketname was not provided
// if this is an input validation error the validation handler
// will report it.
return
}
if !hostCompatibleBucketName(r.HTTPRequest.URL, bucket) {
// bucket name must be valid to put into the host
return
}
moveBucketToHost(r.HTTPRequest.URL, bucket)
}
var (
accelElem = []byte("s3-accelerate.dualstack.")
)
func updateEndpointForAccelerate(r *request.Request) {
bucket, ok := bucketNameFromReqParams(r.Params)
if !ok {
// Ignore operation requests if the bucketname was not provided
// if this is an input validation error the validation handler
// will report it.
return
}
if !hostCompatibleBucketName(r.HTTPRequest.URL, bucket) {
r.Error = awserr.New("InvalidParameterException",
fmt.Sprintf("bucket name %s is not compatible with S3 Accelerate", bucket),
nil)
return
}
parts := strings.Split(r.HTTPRequest.URL.Host, ".")
if len(parts) < 3 {
r.Error = awserr.New("InvalidParameterExecption",
fmt.Sprintf("unable to update endpoint host for S3 accelerate, hostname invalid, %s",
r.HTTPRequest.URL.Host), nil)
return
}
if parts[0] == "s3" || strings.HasPrefix(parts[0], "s3-") {
parts[0] = "s3-accelerate"
}
for i := 1; i+1 < len(parts); i++ {
if parts[i] == aws.StringValue(r.Config.Region) {
parts = append(parts[:i], parts[i+1:]...)
break
}
}
r.HTTPRequest.URL.Host = strings.Join(parts, ".")
moveBucketToHost(r.HTTPRequest.URL, bucket)
}
// Attempts to retrieve the bucket name from the request input parameters.
// If no bucket is found, or the field is empty "", false will be returned.
func bucketNameFromReqParams(params interface{}) (string, bool) {
b, _ := awsutil.ValuesAtPath(params, "Bucket")
if len(b) == 0 {
return "", false
}
if bucket, ok := b[0].(*string); ok {
if bucketStr := aws.StringValue(bucket); bucketStr != "" {
return bucketStr, true
}
}
return "", false
}
// hostCompatibleBucketName returns true if the request should
// put the bucket in the host. This is false if S3ForcePathStyle is
// explicitly set or if the bucket is not DNS compatible.
func hostCompatibleBucketName(u *url.URL, bucket string) bool {
// Bucket might be DNS compatible but dots in the hostname will fail
// certificate validation, so do not use host-style.
if u.Scheme == "https" && strings.Contains(bucket, ".") {
return false
}
// if the bucket is DNS compatible
return dnsCompatibleBucketName(bucket)
}
var reDomain = regexp.MustCompile(`^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$`)
var reIPAddress = regexp.MustCompile(`^(\d+\.){3}\d+$`)
// dnsCompatibleBucketName returns true if the bucket name is DNS compatible.
// Buckets created outside of the classic region MUST be DNS compatible.
func dnsCompatibleBucketName(bucket string) bool {
return reDomain.MatchString(bucket) &&
!reIPAddress.MatchString(bucket) &&
!strings.Contains(bucket, "..")
}
// moveBucketToHost moves the bucket name from the URI path to URL host.
func moveBucketToHost(u *url.URL, bucket string) {
u.Host = bucket + "." + u.Host
u.Path = strings.Replace(u.Path, "/{Bucket}", "", -1)
if u.Path == "" {
u.Path = "/"
}
}
+127
View File
@@ -0,0 +1,127 @@
package s3_test
import (
"net/url"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
type s3BucketTest struct {
bucket string
url string
errCode string
}
var (
sslTests = []s3BucketTest{
{"abc", "https://abc.s3.mock-region.amazonaws.com/", ""},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c", ""},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c", ""},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc", ""},
}
nosslTests = []s3BucketTest{
{"a.b.c", "http://a.b.c.s3.mock-region.amazonaws.com/", ""},
{"a..bc", "http://s3.mock-region.amazonaws.com/a..bc", ""},
}
forcepathTests = []s3BucketTest{
{"abc", "https://s3.mock-region.amazonaws.com/abc", ""},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c", ""},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c", ""},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc", ""},
}
accelerateTests = []s3BucketTest{
{"abc", "https://abc.s3-accelerate.amazonaws.com/", ""},
{"a.b.c", "https://s3.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
{"a$b$c", "https://s3.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
}
accelerateNoSSLTests = []s3BucketTest{
{"abc", "http://abc.s3-accelerate.amazonaws.com/", ""},
{"a.b.c", "http://a.b.c.s3-accelerate.amazonaws.com/", ""},
{"a$b$c", "http://s3.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
}
accelerateDualstack = []s3BucketTest{
{"abc", "https://abc.s3-accelerate.dualstack.amazonaws.com/", ""},
{"a.b.c", "https://s3.dualstack.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
{"a$b$c", "https://s3.dualstack.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
}
)
func runTests(t *testing.T, svc *s3.S3, tests []s3BucketTest) {
for i, test := range tests {
req, _ := svc.ListObjectsRequest(&s3.ListObjectsInput{Bucket: &test.bucket})
req.Build()
if e, a := test.url, req.HTTPRequest.URL.String(); e != a {
t.Errorf("%d, expect url %s, got %s", i, e, a)
}
if test.errCode != "" {
if err := req.Error; err == nil {
t.Fatalf("%d, expect no error", i)
}
if a, e := req.Error.(awserr.Error).Code(), test.errCode; !strings.Contains(a, e) {
t.Errorf("%d, expect error code to contain %q, got %q", i, e, a)
}
}
}
}
func TestAccelerateBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{S3UseAccelerate: aws.Bool(true)})
runTests(t, s, accelerateTests)
}
func TestAccelerateNoSSLBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{S3UseAccelerate: aws.Bool(true), DisableSSL: aws.Bool(true)})
runTests(t, s, accelerateNoSSLTests)
}
func TestAccelerateDualstackBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{
S3UseAccelerate: aws.Bool(true),
UseDualStack: aws.Bool(true),
})
runTests(t, s, accelerateDualstack)
}
func TestHostStyleBucketBuild(t *testing.T) {
s := s3.New(unit.Session)
runTests(t, s, sslTests)
}
func TestHostStyleBucketBuildNoSSL(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{DisableSSL: aws.Bool(true)})
runTests(t, s, nosslTests)
}
func TestPathStyleBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{S3ForcePathStyle: aws.Bool(true)})
runTests(t, s, forcepathTests)
}
func TestHostStyleBucketGetBucketLocation(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.GetBucketLocationRequest(&s3.GetBucketLocationInput{
Bucket: aws.String("bucket"),
})
req.Build()
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
u, _ := url.Parse(req.HTTPRequest.URL.String())
if e, a := "bucket", u.Host; strings.Contains(a, e) {
t.Errorf("expect %s to not be in %s", e, a)
}
if e, a := "bucket", u.Path; !strings.Contains(a, e) {
t.Errorf("expect %s to be in %s", e, a)
}
}
+8
View File
@@ -0,0 +1,8 @@
// +build !go1.6
package s3
import "github.com/aws/aws-sdk-go/aws/request"
func platformRequestHandlers(r *request.Request) {
}
+28
View File
@@ -0,0 +1,28 @@
// +build go1.6
package s3
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
func platformRequestHandlers(r *request.Request) {
if r.Operation.HTTPMethod == "PUT" {
// 100-Continue should only be used on put requests.
r.Handlers.Sign.PushBack(add100Continue)
}
}
func add100Continue(r *request.Request) {
if aws.BoolValue(r.Config.S3Disable100Continue) {
return
}
if r.HTTPRequest.ContentLength < 1024*1024*2 {
// Ignore requests smaller than 2MB. This helps prevent delaying
// requests unnecessarily.
return
}
r.HTTPRequest.Header.Set("Expect", "100-Continue")
}
@@ -0,0 +1,68 @@
// +build go1.6
package s3_test
import (
"bytes"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
func TestAdd100Continue_Added(t *testing.T) {
svc := s3.New(unit.Session)
r, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
Body: bytes.NewReader(make([]byte, 1024*1024*5)),
})
err := r.Sign()
assert.NoError(t, err)
assert.Equal(t, "100-Continue", r.HTTPRequest.Header.Get("Expect"))
}
func TestAdd100Continue_SkipDisabled(t *testing.T) {
svc := s3.New(unit.Session, aws.NewConfig().WithS3Disable100Continue(true))
r, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
Body: bytes.NewReader(make([]byte, 1024*1024*5)),
})
err := r.Sign()
assert.NoError(t, err)
assert.Empty(t, r.HTTPRequest.Header.Get("Expect"))
}
func TestAdd100Continue_SkipNonPUT(t *testing.T) {
svc := s3.New(unit.Session)
r, _ := svc.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
})
err := r.Sign()
assert.NoError(t, err)
assert.Empty(t, r.HTTPRequest.Header.Get("Expect"))
}
func TestAdd100Continue_SkipTooSmall(t *testing.T) {
svc := s3.New(unit.Session)
r, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
Body: bytes.NewReader(make([]byte, 1024*1024*1)),
})
err := r.Sign()
assert.NoError(t, err)
assert.Empty(t, r.HTTPRequest.Header.Get("Expect"))
}
+190
View File
@@ -0,0 +1,190 @@
package s3crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"io"
)
// AESCBC is a symmetric crypto algorithm. This algorithm
// requires a padder due to CBC needing to be of the same block
// size. AES CBC is vulnerable to Padding Oracle attacks and
// so should be avoided when possible.
type aesCBC struct {
encrypter cipher.BlockMode
decrypter cipher.BlockMode
padder Padder
}
// newAESCBC creates a new AES CBC cipher. Expects keys to be of
// the correct size.
func newAESCBC(cd CipherData, padder Padder) (Cipher, error) {
block, err := aes.NewCipher(cd.Key)
if err != nil {
return nil, err
}
encrypter := cipher.NewCBCEncrypter(block, cd.IV)
decrypter := cipher.NewCBCDecrypter(block, cd.IV)
return &aesCBC{encrypter, decrypter, padder}, nil
}
// Encrypt will encrypt the data using AES CBC by returning
// an io.Reader. The io.Reader will encrypt the data as Read
// is called.
func (c *aesCBC) Encrypt(src io.Reader) io.Reader {
reader := &cbcEncryptReader{
encrypter: c.encrypter,
src: src,
padder: c.padder,
}
return reader
}
type cbcEncryptReader struct {
encrypter cipher.BlockMode
src io.Reader
padder Padder
size int
buf bytes.Buffer
}
// Read will read from our io.Reader and encrypt the data as necessary.
// Due to padding, we have to do some logic that when we encounter an
// end of file to pad properly.
func (reader *cbcEncryptReader) Read(data []byte) (int, error) {
n, err := reader.src.Read(data)
reader.size += n
blockSize := reader.encrypter.BlockSize()
reader.buf.Write(data[:n])
if err == io.EOF {
b := make([]byte, getSliceSize(blockSize, reader.buf.Len(), len(data)))
n, err = reader.buf.Read(b)
if err != nil && err != io.EOF {
return n, err
}
// The buffer is now empty, we can now pad the data
if reader.buf.Len() == 0 {
b, err = reader.padder.Pad(b[:n], reader.size)
if err != nil {
return n, err
}
n = len(b)
err = io.EOF
}
// We only want to encrypt if we have read anything
if n > 0 {
reader.encrypter.CryptBlocks(data, b)
}
return n, err
}
if err != nil {
return n, err
}
if size := reader.buf.Len(); size >= blockSize {
nBlocks := size / blockSize
if size > len(data) {
nBlocks = len(data) / blockSize
}
if nBlocks > 0 {
b := make([]byte, nBlocks*blockSize)
n, _ = reader.buf.Read(b)
reader.encrypter.CryptBlocks(data, b[:n])
}
} else {
n = 0
}
return n, nil
}
// Decrypt will decrypt the data using AES CBC
func (c *aesCBC) Decrypt(src io.Reader) io.Reader {
return &cbcDecryptReader{
decrypter: c.decrypter,
src: src,
padder: c.padder,
}
}
type cbcDecryptReader struct {
decrypter cipher.BlockMode
src io.Reader
padder Padder
buf bytes.Buffer
}
// Read will read from our io.Reader and decrypt the data as necessary.
// Due to padding, we have to do some logic that when we encounter an
// end of file to pad properly.
func (reader *cbcDecryptReader) Read(data []byte) (int, error) {
n, err := reader.src.Read(data)
blockSize := reader.decrypter.BlockSize()
reader.buf.Write(data[:n])
if err == io.EOF {
b := make([]byte, getSliceSize(blockSize, reader.buf.Len(), len(data)))
n, err = reader.buf.Read(b)
if err != nil && err != io.EOF {
return n, err
}
// We only want to decrypt if we have read anything
if n > 0 {
reader.decrypter.CryptBlocks(data, b)
}
if reader.buf.Len() == 0 {
b, err = reader.padder.Unpad(data[:n])
n = len(b)
if err != nil {
return n, err
}
err = io.EOF
}
return n, err
}
if err != nil {
return n, err
}
if size := reader.buf.Len(); size >= blockSize {
nBlocks := size / blockSize
if size > len(data) {
nBlocks = len(data) / blockSize
}
// The last block is always padded. This will allow us to unpad
// when we receive an io.EOF error
nBlocks -= blockSize
if nBlocks > 0 {
b := make([]byte, nBlocks*blockSize)
n, _ = reader.buf.Read(b)
reader.decrypter.CryptBlocks(data, b[:n])
} else {
n = 0
}
}
return n, nil
}
// getSliceSize will return the correct amount of bytes we need to
// read with regards to padding.
func getSliceSize(blockSize, bufSize, dataSize int) int {
size := bufSize
if bufSize > dataSize {
size = dataSize
}
size = size - (size % blockSize) - blockSize
if size <= 0 {
size = blockSize
}
return size
}
@@ -0,0 +1,73 @@
package s3crypto
import (
"io"
"strings"
)
const (
cbcKeySize = 32
cbcNonceSize = 16
)
type cbcContentCipherBuilder struct {
generator CipherDataGenerator
padder Padder
}
// AESCBCContentCipherBuilder returns a new encryption only mode structure with a specific cipher
// for the master key
func AESCBCContentCipherBuilder(generator CipherDataGenerator, padder Padder) ContentCipherBuilder {
return cbcContentCipherBuilder{generator: generator, padder: padder}
}
func (builder cbcContentCipherBuilder) ContentCipher() (ContentCipher, error) {
cd, err := builder.generator.GenerateCipherData(cbcKeySize, cbcNonceSize)
if err != nil {
return nil, err
}
cd.Padder = builder.padder
return newAESCBCContentCipher(cd)
}
// newAESCBCContentCipher will create a new aes cbc content cipher. If the cipher data's
// will set the CEK algorithm if it hasn't been set.
func newAESCBCContentCipher(cd CipherData) (ContentCipher, error) {
if len(cd.CEKAlgorithm) == 0 {
cd.CEKAlgorithm = strings.Join([]string{AESCBC, cd.Padder.Name()}, "/")
}
cipher, err := newAESCBC(cd, cd.Padder)
if err != nil {
return nil, err
}
return &aesCBCContentCipher{
CipherData: cd,
Cipher: cipher,
}, nil
}
// aesCBCContentCipher will use AES CBC for the main cipher.
type aesCBCContentCipher struct {
CipherData CipherData
Cipher Cipher
}
// EncryptContents will generate a random key and iv and encrypt the data using cbc
func (cc *aesCBCContentCipher) EncryptContents(src io.Reader) (io.Reader, error) {
return cc.Cipher.Encrypt(src), nil
}
// DecryptContents will use the symmetric key provider to instantiate a new CBC cipher.
// We grab a decrypt reader from CBC and wrap it in a CryptoReadCloser. The only error
// expected here is when the key or iv is of invalid length.
func (cc *aesCBCContentCipher) DecryptContents(src io.ReadCloser) (io.ReadCloser, error) {
reader := cc.Cipher.Decrypt(src)
return &CryptoReadCloser{Body: src, Decrypter: reader}, nil
}
// GetCipherData returns cipher data
func (cc aesCBCContentCipher) GetCipherData() CipherData {
return cc.CipherData
}
@@ -0,0 +1,20 @@
package s3crypto_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestAESCBCBuilder(t *testing.T) {
generator := mockGenerator{}
builder := s3crypto.AESCBCContentCipherBuilder(generator, s3crypto.NoPadder)
if builder == nil {
t.Fatal(builder)
}
_, err := builder.ContentCipher()
if err != nil {
t.Fatal(err)
}
}
+29
View File
@@ -0,0 +1,29 @@
package s3crypto
const (
pkcs5BlockSize = 16
)
var aescbcPadding = aescbcPadder{pkcs7Padder{16}}
// AESCBCPadder is used to pad AES encrypted and decrypted data.
// Altough it uses the pkcs5Padder, it isn't following the RFC
// for PKCS5. The only reason why it is called pkcs5Padder is
// due to the Name returning PKCS5Padding.
var AESCBCPadder = Padder(aescbcPadding)
type aescbcPadder struct {
padder pkcs7Padder
}
func (padder aescbcPadder) Pad(b []byte, n int) ([]byte, error) {
return padder.padder.Pad(b, n)
}
func (padder aescbcPadder) Unpad(b []byte) ([]byte, error) {
return padder.padder.Unpad(b)
}
func (padder aescbcPadder) Name() string {
return "PKCS5Padding"
}
@@ -0,0 +1,41 @@
package s3crypto
import (
"bytes"
"fmt"
"testing"
)
func TestAESCBCPadding(t *testing.T) {
for i := 0; i < 16; i++ {
input := make([]byte, i)
expected := append(input, bytes.Repeat([]byte{byte(16 - i)}, 16-i)...)
b, err := AESCBCPadder.Pad(input, len(input))
if err != nil {
t.Fatal("Expected error to be nil but received " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}
func TestAESCBCUnpadding(t *testing.T) {
for i := 0; i < 16; i++ {
expected := make([]byte, i)
input := append(expected, bytes.Repeat([]byte{byte(16 - i)}, 16-i)...)
b, err := AESCBCPadder.Unpad(input)
if err != nil {
t.Fatal("Error received, was expecting nil: " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}
+501
View File
@@ -0,0 +1,501 @@
package s3crypto
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"testing"
)
func TestAESCBCEncryptDecrypt(t *testing.T) {
var testCases = []struct {
key string
iv string
plaintext string
ciphertext string
decodeHex bool
padder Padder
}{
// Test vectors from RFC 3602: https://tools.ietf.org/html/rfc3602
{
"06a9214036b8a15b512e03d534120006",
"3dafba429d9eb430b422da802c9fac41",
"Single block msg",
"e353779c1079aeb82708942dbe77181a",
false,
NoPadder,
},
{
"c286696d887c9aa0611bbb3e2025a45a",
"562e17996d093d28ddb3ba695a2e6f58",
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f",
"d296cd94c2cccf8a3a863028b5e1dc0a7586602d253cfff91b8266bea6d61ab1",
true,
NoPadder,
},
{
"6c3ea0477630ce21a2ce334aa746c2cd",
"c782dc4c098c66cbd9cd27d825682c81",
"This is a 48-byte message (exactly 3 AES blocks)",
"d0a02b3836451753d493665d33f0e8862dea54cdb293abc7506939276772f8d5021c19216bad525c8579695d83ba2684",
false,
NoPadder,
},
{
"56e47a38c5598974bc46903dba290349",
"8ce82eefbea0da3c44699ed7db51b7d9",
"a0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedf",
"c30e32ffedc0774e6aff6af0869f71aa0f3af07a9a31a9c684db207eb0ef8e4e35907aa632c3ffdf868bb7b29d3d46ad83ce9f9a102ee99d49a53e87f4c3da55",
true,
NoPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"",
"B012949BA07D1A6DCE9DEE67274D41AB",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41",
"8A11ABA68A566132FFE04DB336621D41",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141",
"97D0896E41DFDB5CEA4A9EB70A938CFD",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141",
"8464EAD45FA2D8790E8741E32C28083F",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141",
"1E656D6E2745BA9F154FAF136B2BC73D",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141",
"0B6031C4B230DAC6BD6D3F195645B287",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141",
"5D09FEB6462BB489489A7E18FD341D9D",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141",
"85745E398F2FD1050C2CE8F8614DA369",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141",
"7BE52933970BA7B0FC6FB3FC37648205",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141",
"ED3A1E134EF36CCFE60C8123B4272F89",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141",
"C3B7C9E177E1052FC736F65FC1E74209",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141",
"C3A8B53F7F57F0B9D346FA99810A3C28",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141",
"D16B1ECE5BF00AF919E139E99775FF06",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141",
"B258F4DF57FFCA1EFCF8D76140F05139",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141",
"3CD2282DE24A2CF9E23326CC3DC9077A",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141",
"3010232E7C752A3B4C9EE428B4C4FE88",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A22BC4E6D03BFD2418DD412D1ED1B31AF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A5427BBD4A4D50776989441370E3B5B16",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A7FF985F55567D1B25EA40E23BB4CB1FE",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A0835E548C7370D8F8D9925C0E6B54727",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ADC0CF1436399E67BC1122B31CB596649",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A3D096F0DEAFF91938B82E5D404B0B065",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AD56ABA897A355CF307CCB74226243192",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A151284F950B1B1DBCAD6D9E7900DF4E6",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AEF85A612514121C299A1D87116C4A182",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A67F157569BFB4013EA3AD16DB8C69AD6",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AF8520D191F6ACBD88B2140588B91C697",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ADD8BBAA71745669B96F2683E2F5AEC35",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AFB2D4282817D7EC6B33EFAD7AA14A3C5",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A459B89E7E0DAF3DA654576B60B2DA7CE",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A65759F23F9789D05B23D5DBAA9E32036",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A03C78FBD5E2CB08B3B6D181E23FBDE79",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA013D941FBBDE56C106C482CD022F290F",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA0645D313AC3C29B79DB1AA2E00A5B393",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA2ED0FD8048053BF22EBE501D82C4B3F1",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAC57D706C7866A01D6E913F98AE57EE54",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAB7FC1241FAFDFE45C4FF982D5DC1DAEF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA7063EA296922DE8BDFD3B29D786C5F91",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA3A4603475F4AFDBFADC6E7FA908188B1",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA3365C63C2AF2A6C8FB4D0E9ED3C6FDA3",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA78BCC1874C0B7EB52645FC8F03B9C9CF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA9B7A31397718EECB89B9E9CCCD729326",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAB15EA8A67E9E9FADB4249710277F3D4F",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA94641D6A076193C660632CEA3F9CB02C",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAB2170A08417BE77F0EAA9110F4790E12",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA4E30F1CD7B2256ABD57DC3DAB05376C9",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA9909B7B93D01BDAAC22D15AF34DF1EEF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAD97F5D1206F00E5C7225CAD81CCD4027",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA570CBB001A0C87558906B60C884AB5F41DA97CEF2A9401BC6DD0D22A54DBAD6D",
true,
AESCBCPadder,
},
}
for i, testCase := range testCases {
key, _ := hex.DecodeString(testCase.key)
iv, _ := hex.DecodeString(testCase.iv)
cd := CipherData{
Key: key,
IV: iv,
}
cbc, err := newAESCBC(cd, testCase.padder)
if err != nil {
t.Fatal(fmt.Sprintf("Case %d: Expected no error for cipher creation, but received: %v", i, err.Error()))
}
plaintext := []byte(testCase.plaintext)
if testCase.decodeHex {
plaintext, _ = hex.DecodeString(testCase.plaintext)
}
cipherdata := cbc.Encrypt(bytes.NewReader(plaintext))
ciphertext := []byte{}
b := make([]byte, 19)
err = nil
n := 0
for err != io.EOF {
n, err = cipherdata.Read(b)
ciphertext = append(ciphertext, b[:n]...)
}
if err != io.EOF {
t.Fatal(fmt.Sprintf("Case %d: Expected no error during io reading, but received: %v", i, err.Error()))
}
expectedData, _ := hex.DecodeString(testCase.ciphertext)
if bytes.Compare(expectedData, ciphertext) != 0 {
t.Log("\n", ciphertext, "\n", expectedData)
t.Fatal(fmt.Sprintf("Case %d: AES CBC encryption fails. Data is not the same", i))
}
plaindata := cbc.Decrypt(bytes.NewReader(ciphertext))
plaintextDecrypted := []byte{}
err = nil
for err != io.EOF {
n, err = plaindata.Read(b)
plaintextDecrypted = append(plaintextDecrypted, b[:n]...)
}
if err != io.EOF {
t.Fatal(fmt.Sprintf("Case %d: Expected no error during io reading, but received: %v", i, err.Error()))
}
if bytes.Compare(plaintext, plaintextDecrypted) != 0 {
t.Log("\n", plaintext, "\n", plaintextDecrypted)
t.Fatal(fmt.Sprintf("Case %d: AES CBC decryption fails. Data is not the same", i))
}
}
}
+105
View File
@@ -0,0 +1,105 @@
package s3crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"io"
"io/ioutil"
)
// AESGCM Symmetric encryption algorithm. Since Golang designed this
// with only TLS in mind. We have to load it all into memory meaning
// this isn't streamed.
type aesGCM struct {
aead cipher.AEAD
nonce []byte
}
// newAESGCM creates a new AES GCM cipher. Expects keys to be of
// the correct size.
//
// Example:
//
// cd := &s3crypto.CipherData{
// Key: key,
// "IV": iv,
// }
// cipher, err := s3crypto.newAESGCM(cd)
func newAESGCM(cd CipherData) (Cipher, error) {
block, err := aes.NewCipher(cd.Key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return &aesGCM{aesgcm, cd.IV}, nil
}
// Encrypt will encrypt the data using AES GCM
// Tag will be included as the last 16 bytes of the slice
func (c *aesGCM) Encrypt(src io.Reader) io.Reader {
reader := &gcmEncryptReader{
encrypter: c.aead,
nonce: c.nonce,
src: src,
}
return reader
}
type gcmEncryptReader struct {
encrypter cipher.AEAD
nonce []byte
src io.Reader
buf *bytes.Buffer
}
func (reader *gcmEncryptReader) Read(data []byte) (int, error) {
if reader.buf == nil {
b, err := ioutil.ReadAll(reader.src)
if err != nil {
return len(b), err
}
b = reader.encrypter.Seal(b[:0], reader.nonce, b, nil)
reader.buf = bytes.NewBuffer(b)
}
return reader.buf.Read(data)
}
// Decrypt will decrypt the data using AES GCM
func (c *aesGCM) Decrypt(src io.Reader) io.Reader {
return &gcmDecryptReader{
decrypter: c.aead,
nonce: c.nonce,
src: src,
}
}
type gcmDecryptReader struct {
decrypter cipher.AEAD
nonce []byte
src io.Reader
buf *bytes.Buffer
}
func (reader *gcmDecryptReader) Read(data []byte) (int, error) {
if reader.buf == nil {
b, err := ioutil.ReadAll(reader.src)
if err != nil {
return len(b), err
}
b, err = reader.decrypter.Open(b[:0], reader.nonce, b, nil)
if err != nil {
return len(b), err
}
reader.buf = bytes.NewBuffer(b)
}
return reader.buf.Read(data)
}
@@ -0,0 +1,68 @@
package s3crypto
import (
"io"
)
const (
gcmKeySize = 32
gcmNonceSize = 12
)
type gcmContentCipherBuilder struct {
generator CipherDataGenerator
}
// AESGCMContentCipherBuilder returns a new encryption only mode structure with a specific cipher
// for the master key
func AESGCMContentCipherBuilder(generator CipherDataGenerator) ContentCipherBuilder {
return gcmContentCipherBuilder{generator}
}
func (builder gcmContentCipherBuilder) ContentCipher() (ContentCipher, error) {
cd, err := builder.generator.GenerateCipherData(gcmKeySize, gcmNonceSize)
if err != nil {
return nil, err
}
return newAESGCMContentCipher(cd)
}
func newAESGCMContentCipher(cd CipherData) (ContentCipher, error) {
cd.CEKAlgorithm = AESGCMNoPadding
cd.TagLength = "128"
cipher, err := newAESGCM(cd)
if err != nil {
return nil, err
}
return &aesGCMContentCipher{
CipherData: cd,
Cipher: cipher,
}, nil
}
// AESGCMContentCipher will use AES GCM for the main cipher.
type aesGCMContentCipher struct {
CipherData CipherData
Cipher Cipher
}
// EncryptContents will generate a random key and iv and encrypt the data using cbc
func (cc *aesGCMContentCipher) EncryptContents(src io.Reader) (io.Reader, error) {
return cc.Cipher.Encrypt(src), nil
}
// DecryptContents will use the symmetric key provider to instantiate a new GCM cipher.
// We grab a decrypt reader from gcm and wrap it in a CryptoReadCloser. The only error
// expected here is when the key or iv is of invalid length.
func (cc *aesGCMContentCipher) DecryptContents(src io.ReadCloser) (io.ReadCloser, error) {
reader := cc.Cipher.Decrypt(src)
return &CryptoReadCloser{Body: src, Decrypter: reader}, nil
}
// GetCipherData returns cipher data
func (cc aesGCMContentCipher) GetCipherData() CipherData {
return cc.CipherData
}
@@ -0,0 +1,23 @@
package s3crypto_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestAESGCMContentCipherBuilder(t *testing.T) {
generator := mockGenerator{}
builder := s3crypto.AESGCMContentCipherBuilder(generator)
assert.NotNil(t, builder)
}
func TestAESGCMContentCipherNewEncryptor(t *testing.T) {
generator := mockGenerator{}
builder := s3crypto.AESGCMContentCipherBuilder(generator)
cipher, err := builder.ContentCipher()
assert.NoError(t, err)
assert.NotNil(t, cipher)
}
+73
View File
@@ -0,0 +1,73 @@
package s3crypto
import (
"bytes"
"encoding/hex"
"io/ioutil"
"testing"
"github.com/stretchr/testify/assert"
)
// AES GCM
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_128_Test_0(t *testing.T) {
iv, _ := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
plaintext, _ := hex.DecodeString("2db5168e932556f8089a0622981d017d")
expected, _ := hex.DecodeString("fa4362189661d163fcd6a56d8bf0405ad636ac1bbedd5cc3ee727dc2ab4a9489")
tag, _ := hex.DecodeString("d636ac1bbedd5cc3ee727dc2ab4a9489")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_104_Test_3(t *testing.T) {
iv, _ := hex.DecodeString("4742357c335913153ff0eb0f")
key, _ := hex.DecodeString("e5a0eb92cc2b064e1bc80891faf1fab5e9a17a9c3a984e25416720e30e6c2b21")
plaintext, _ := hex.DecodeString("8499893e16b0ba8b007d54665a")
expected, _ := hex.DecodeString("eb8e6175f1fe38eb1acf95fd5188a8b74bb74fda553e91020a23deed45")
tag, _ := hex.DecodeString("88a8b74bb74fda553e91020a23deed45")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_256_Test_6(t *testing.T) {
iv, _ := hex.DecodeString("a291484c3de8bec6b47f525f")
key, _ := hex.DecodeString("37f39137416bafde6f75022a7a527cc593b6000a83ff51ec04871a0ff5360e4e")
plaintext, _ := hex.DecodeString("fafd94cede8b5a0730394bec68a8e77dba288d6ccaa8e1563a81d6e7ccc7fc97")
expected, _ := hex.DecodeString("44dc868006b21d49284016565ffb3979cc4271d967628bf7cdaf86db888e92e501a2b578aa2f41ec6379a44a31cc019c")
tag, _ := hex.DecodeString("01a2b578aa2f41ec6379a44a31cc019c")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_408_Test_8(t *testing.T) {
iv, _ := hex.DecodeString("92f258071d79af3e63672285")
key, _ := hex.DecodeString("595f259c55abe00ae07535ca5d9b09d6efb9f7e9abb64605c337acbd6b14fc7e")
plaintext, _ := hex.DecodeString("a6fee33eb110a2d769bbc52b0f36969c287874f665681477a25fc4c48015c541fbe2394133ba490a34ee2dd67b898177849a91")
expected, _ := hex.DecodeString("bbca4a9e09ae9690c0f6f8d405e53dccd666aa9c5fa13c8758bc30abe1ddd1bcce0d36a1eaaaaffef20cd3c5970b9673f8a65c26ccecb9976fd6ac9c2c0f372c52c821")
tag, _ := hex.DecodeString("26ccecb9976fd6ac9c2c0f372c52c821")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func aesgcmTest(t *testing.T, iv, key, plaintext, expected, tag []byte) {
cd := CipherData{
Key: key,
IV: iv,
}
gcm, err := newAESGCM(cd)
assert.NoError(t, err)
cipherdata := gcm.Encrypt(bytes.NewReader(plaintext))
ciphertext, err := ioutil.ReadAll(cipherdata)
assert.NoError(t, err)
// splitting tag and ciphertext
etag := ciphertext[len(ciphertext)-16:]
assert.Equal(t, etag, tag)
assert.Equal(t, len(ciphertext), len(expected))
assert.Equal(t, ciphertext, expected)
data := gcm.Decrypt(bytes.NewReader(ciphertext))
assert.NoError(t, err)
text, err := ioutil.ReadAll(data)
assert.NoError(t, err)
assert.Equal(t, len(text), len(plaintext))
assert.Equal(t, text, plaintext)
}
+43
View File
@@ -0,0 +1,43 @@
package s3crypto
import (
"io"
)
// Cipher interface allows for either encryption and decryption of an object
type Cipher interface {
Encrypter
Decrypter
}
// Encrypter interface with only the encrypt method
type Encrypter interface {
Encrypt(io.Reader) io.Reader
}
// Decrypter interface with only the decrypt method
type Decrypter interface {
Decrypt(io.Reader) io.Reader
}
// CryptoReadCloser handles closing of the body and allowing reads from the decrypted
// content.
type CryptoReadCloser struct {
Body io.ReadCloser
Decrypter io.Reader
isClosed bool
}
// Close lets the CryptoReadCloser satisfy io.ReadCloser interface
func (rc *CryptoReadCloser) Close() error {
rc.isClosed = true
return rc.Body.Close()
}
// Read lets the CryptoReadCloser satisfy io.ReadCloser interface
func (rc *CryptoReadCloser) Read(b []byte) (int, error) {
if rc.isClosed {
return 0, io.EOF
}
return rc.Decrypter.Read(b)
}
+31
View File
@@ -0,0 +1,31 @@
package s3crypto
import "io"
// ContentCipherBuilder is a builder interface that builds
// ciphers for each request.
type ContentCipherBuilder interface {
ContentCipher() (ContentCipher, error)
}
// ContentCipher deals with encrypting and decrypting content
type ContentCipher interface {
EncryptContents(io.Reader) (io.Reader, error)
DecryptContents(io.ReadCloser) (io.ReadCloser, error)
GetCipherData() CipherData
}
// CipherData is used for content encryption. It used for storing the
// metadata of the encrypted content.
type CipherData struct {
Key []byte
IV []byte
WrapAlgorithm string
CEKAlgorithm string
TagLength string
MaterialDescription MaterialDescription
// EncryptedKey should be populated when calling GenerateCipherData
EncryptedKey []byte
Padder Padder
}
+34
View File
@@ -0,0 +1,34 @@
package s3crypto_test
import (
"io/ioutil"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestCryptoReadCloserRead(t *testing.T) {
expectedStr := "HELLO WORLD "
str := strings.NewReader(expectedStr)
rc := &s3crypto.CryptoReadCloser{Body: ioutil.NopCloser(str), Decrypter: str}
b, err := ioutil.ReadAll(rc)
assert.NoError(t, err)
assert.Equal(t, expectedStr, string(b))
}
func TestCryptoReadCloserClose(t *testing.T) {
data := "HELLO WORLD "
expectedStr := ""
str := strings.NewReader(data)
rc := &s3crypto.CryptoReadCloser{Body: ioutil.NopCloser(str), Decrypter: str}
rc.Close()
b, err := ioutil.ReadAll(rc)
assert.NoError(t, err)
assert.Equal(t, expectedStr, string(b))
}
+111
View File
@@ -0,0 +1,111 @@
package s3crypto
import (
"encoding/base64"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
func (client *DecryptionClient) contentCipherFromEnvelope(env Envelope) (ContentCipher, error) {
wrap, err := client.wrapFromEnvelope(env)
if err != nil {
return nil, err
}
return client.cekFromEnvelope(env, wrap)
}
func (client *DecryptionClient) wrapFromEnvelope(env Envelope) (CipherDataDecrypter, error) {
f, ok := client.WrapRegistry[env.WrapAlg]
if !ok || f == nil {
return nil, awserr.New(
"InvalidWrapAlgorithmError",
"wrap algorithm isn't supported, "+env.WrapAlg,
nil,
)
}
return f(env)
}
// AESGCMNoPadding is the constant value that is used to specify
// the CEK algorithm consiting of AES GCM with no padding.
const AESGCMNoPadding = "AES/GCM/NoPadding"
// AESCBC is the string constant that signifies the AES CBC algorithm cipher.
const AESCBC = "AES/CBC"
func (client *DecryptionClient) cekFromEnvelope(env Envelope, decrypter CipherDataDecrypter) (ContentCipher, error) {
f, ok := client.CEKRegistry[env.CEKAlg]
if !ok || f == nil {
return nil, awserr.New(
"InvalidCEKAlgorithmError",
"cek algorithm isn't supported, "+env.CEKAlg,
nil,
)
}
key, err := base64.StdEncoding.DecodeString(env.CipherKey)
if err != nil {
return nil, err
}
iv, err := base64.StdEncoding.DecodeString(env.IV)
if err != nil {
return nil, err
}
key, err = decrypter.DecryptKey(key)
if err != nil {
return nil, err
}
cd := CipherData{
Key: key,
IV: iv,
CEKAlgorithm: env.CEKAlg,
Padder: client.getPadder(env.CEKAlg),
}
return f(cd)
}
// getPadder will return an unpadder with checking the cek algorithm specific padder.
// If there wasn't a cek algorithm specific padder, we check the padder itself.
// We return a no unpadder, if no unpadder was found. This means any customization
// either contained padding within the cipher implementation, and to maintain
// backwards compatility we will simply not unpad anything.
func (client *DecryptionClient) getPadder(cekAlg string) Padder {
padder, ok := client.PadderRegistry[cekAlg]
if !ok {
padder, ok = client.PadderRegistry[cekAlg[strings.LastIndex(cekAlg, "/")+1:]]
if !ok {
return NoPadder
}
}
return padder
}
func encodeMeta(reader hashReader, cd CipherData) (Envelope, error) {
iv := base64.StdEncoding.EncodeToString(cd.IV)
key := base64.StdEncoding.EncodeToString(cd.EncryptedKey)
md5 := reader.GetValue()
contentLength := reader.GetContentLength()
md5Str := base64.StdEncoding.EncodeToString(md5)
matdesc, err := cd.MaterialDescription.encodeDescription()
if err != nil {
return Envelope{}, err
}
return Envelope{
CipherKey: key,
IV: iv,
MatDesc: string(matdesc),
WrapAlg: cd.WrapAlgorithm,
CEKAlg: cd.CEKAlgorithm,
TagLen: cd.TagLength,
UnencryptedMD5: md5Str,
UnencryptedContentLen: strconv.FormatInt(contentLength, 10),
}, nil
}
@@ -0,0 +1,225 @@
package s3crypto
import (
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kms"
)
func TestWrapFactory(t *testing.T) {
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(unit.Session),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
}
env := Envelope{
WrapAlg: KMSWrap,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
_, ok := wrap.(*kmsKeyHandler)
assert.NoError(t, err)
assert.NotNil(t, wrap)
assert.True(t, ok)
}
func TestWrapFactoryErrorNoWrap(t *testing.T) {
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(unit.Session),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
}
env := Envelope{
WrapAlg: "none",
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
assert.Error(t, err)
assert.Nil(t, wrap)
}
func TestWrapFactoryCustomEntry(t *testing.T) {
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
"custom": (kmsKeyHandler{
kms: kms.New(unit.Session),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
}
env := Envelope{
WrapAlg: "custom",
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
assert.NoError(t, err)
assert.NotNil(t, wrap)
}
func TestCEKFactory(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(sess),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
PadderRegistry: map[string]Padder{
NoPadder.Name(): NoPadder,
},
}
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
assert.NoError(t, err)
ivB64 := base64.URLEncoding.EncodeToString(iv)
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
assert.NoError(t, err)
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
env := Envelope{
WrapAlg: KMSWrap,
CEKAlg: AESGCMNoPadding,
CipherKey: cipherKeyB64,
IV: ivB64,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
cek, err := c.cekFromEnvelope(env, wrap)
assert.NoError(t, err)
assert.NotNil(t, cek)
}
func TestCEKFactoryNoCEK(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(sess),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
PadderRegistry: map[string]Padder{
NoPadder.Name(): NoPadder,
},
}
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
assert.NoError(t, err)
ivB64 := base64.URLEncoding.EncodeToString(iv)
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
assert.NoError(t, err)
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
env := Envelope{
WrapAlg: KMSWrap,
CEKAlg: "none",
CipherKey: cipherKeyB64,
IV: ivB64,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
cek, err := c.cekFromEnvelope(env, wrap)
assert.Error(t, err)
assert.Nil(t, cek)
}
func TestCEKFactoryCustomEntry(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(sess),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
"custom": newAESGCMContentCipher,
},
PadderRegistry: map[string]Padder{},
}
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
assert.NoError(t, err)
ivB64 := base64.URLEncoding.EncodeToString(iv)
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
assert.NoError(t, err)
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
env := Envelope{
WrapAlg: KMSWrap,
CEKAlg: "custom",
CipherKey: cipherKeyB64,
IV: ivB64,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
cek, err := c.cekFromEnvelope(env, wrap)
assert.NoError(t, err)
assert.NotNil(t, cek)
}
@@ -0,0 +1,135 @@
package s3crypto
import (
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// WrapEntry is builder that return a proper key decrypter and error
type WrapEntry func(Envelope) (CipherDataDecrypter, error)
// CEKEntry is a builder thatn returns a proper content decrypter and error
type CEKEntry func(CipherData) (ContentCipher, error)
// DecryptionClient is an S3 crypto client. The decryption client
// will handle all get object requests from Amazon S3.
// Supported key wrapping algorithms:
// *AWS KMS
//
// Supported content ciphers:
// * AES/GCM
// * AES/CBC
type DecryptionClient struct {
S3Client s3iface.S3API
// LoadStrategy is used to load the metadata either from the metadata of the object
// or from a separate file in s3.
//
// Defaults to our default load strategy.
LoadStrategy LoadStrategy
WrapRegistry map[string]WrapEntry
CEKRegistry map[string]CEKEntry
PadderRegistry map[string]Padder
}
// NewDecryptionClient instantiates a new S3 crypto client
//
// Example:
// sess := session.New()
// svc := s3crypto.NewDecryptionClient(sess, func(svc *s3crypto.DecryptionClient{
// // Custom client options here
// }))
func NewDecryptionClient(prov client.ConfigProvider, options ...func(*DecryptionClient)) *DecryptionClient {
s3client := s3.New(prov)
client := &DecryptionClient{
S3Client: s3client,
LoadStrategy: defaultV2LoadStrategy{
client: s3client,
},
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(prov),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
strings.Join([]string{AESCBC, AESCBCPadder.Name()}, "/"): newAESCBCContentCipher,
},
PadderRegistry: map[string]Padder{
strings.Join([]string{AESCBC, AESCBCPadder.Name()}, "/"): AESCBCPadder,
"NoPadding": NoPadder,
},
}
for _, option := range options {
option(client)
}
return client
}
// GetObjectRequest will make a request to s3 and retrieve the object. In this process
// decryption will be done. The SDK only supports V2 reads of KMS and GCM.
//
// Example:
// sess := session.New()
// svc := s3crypto.NewDecryptionClient(sess)
// req, out := svc.GetObjectRequest(&s3.GetObjectInput {
// Key: aws.String("testKey"),
// Bucket: aws.String("testBucket"),
// })
// err := req.Send()
func (c *DecryptionClient) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) {
req, out := c.S3Client.GetObjectRequest(input)
req.Handlers.Unmarshal.PushBack(func(r *request.Request) {
env, err := c.LoadStrategy.Load(r)
if err != nil {
r.Error = err
out.Body.Close()
return
}
// If KMS should return the correct CEK algorithm with the proper
// KMS key provider
cipher, err := c.contentCipherFromEnvelope(env)
if err != nil {
r.Error = err
out.Body.Close()
return
}
reader, err := cipher.DecryptContents(out.Body)
if err != nil {
r.Error = err
out.Body.Close()
return
}
out.Body = reader
})
return req, out
}
// GetObject is a wrapper for GetObjectRequest
func (c *DecryptionClient) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) {
req, out := c.GetObjectRequest(input)
return out, req.Send()
}
// GetObjectWithContext is a wrapper for GetObjectRequest with the additional
// context, and request options support.
//
// GetObjectWithContext is the same as GetObject with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the Context to add deadlining, timeouts, ect. In the future
// this may create sub-contexts for individual underlying requests.
func (c *DecryptionClient) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) {
req, out := c.GetObjectRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return out, req.Send()
}
@@ -0,0 +1,215 @@
package s3crypto_test
import (
"bytes"
"encoding/base64"
"encoding/hex"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestGetObjectGCM(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.StdEncoding.EncodeToString(key)
// This is our KMS response
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewDecryptionClient(sess)
assert.NotNil(t, c)
input := &s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
req, out := c.GetObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
assert.NoError(t, err)
b, err := hex.DecodeString("fa4362189661d163fcd6a56d8bf0405ad636ac1bbedd5cc3ee727dc2ab4a9489")
assert.NoError(t, err)
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
http.CanonicalHeaderKey("x-amz-meta-x-amz-key-v2"): []string{"SpFRES0JyU8BLZSKo51SrwILK4lhtZsWiMNjgO4WmoK+joMwZPG7Hw=="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-iv"): []string{base64.URLEncoding.EncodeToString(iv)},
http.CanonicalHeaderKey("x-amz-meta-x-amz-matdesc"): []string{`{"kms_cmk_id":"arn:aws:kms:us-east-1:172259396726:key/a22a4b30-79f4-4b3d-bab4-a26d327a231b"}`},
http.CanonicalHeaderKey("x-amz-meta-x-amz-wrap-alg"): []string{s3crypto.KMSWrap},
http.CanonicalHeaderKey("x-amz-meta-x-amz-cek-alg"): []string{s3crypto.AESGCMNoPadding},
http.CanonicalHeaderKey("x-amz-meta-x-amz-tag-len"): []string{"128"},
},
Body: ioutil.NopCloser(bytes.NewBuffer(b)),
}
out.Metadata = make(map[string]*string)
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
})
err := req.Send()
assert.NoError(t, err)
b, err := ioutil.ReadAll(out.Body)
assert.NoError(t, err)
expected, err := hex.DecodeString("2db5168e932556f8089a0622981d017d")
assert.NoError(t, err)
assert.Equal(t, len(expected), len(b))
assert.Equal(t, expected, b)
}
func TestGetObjectCBC(t *testing.T) {
key, _ := hex.DecodeString("898be9cc5004ed0fa6e117c9a3099d31")
keyB64 := base64.StdEncoding.EncodeToString(key)
// This is our KMS response
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewDecryptionClient(sess)
assert.NotNil(t, c)
input := &s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
req, out := c.GetObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
iv, err := hex.DecodeString("9dea7621945988f96491083849b068df")
assert.NoError(t, err)
b, err := hex.DecodeString("e232cd6ef50047801ee681ec30f61d53cfd6b0bca02fd03c1b234baa10ea82ac9dab8b960926433a19ce6dea08677e34")
assert.NoError(t, err)
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
http.CanonicalHeaderKey("x-amz-meta-x-amz-key-v2"): []string{"SpFRES0JyU8BLZSKo51SrwILK4lhtZsWiMNjgO4WmoK+joMwZPG7Hw=="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-iv"): []string{base64.URLEncoding.EncodeToString(iv)},
http.CanonicalHeaderKey("x-amz-meta-x-amz-matdesc"): []string{`{"kms_cmk_id":"arn:aws:kms:us-east-1:172259396726:key/a22a4b30-79f4-4b3d-bab4-a26d327a231b"}`},
http.CanonicalHeaderKey("x-amz-meta-x-amz-wrap-alg"): []string{s3crypto.KMSWrap},
http.CanonicalHeaderKey("x-amz-meta-x-amz-cek-alg"): []string{strings.Join([]string{s3crypto.AESCBC, s3crypto.AESCBCPadder.Name()}, "/")},
},
Body: ioutil.NopCloser(bytes.NewBuffer(b)),
}
out.Metadata = make(map[string]*string)
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
})
err := req.Send()
assert.NoError(t, err)
b, err := ioutil.ReadAll(out.Body)
assert.NoError(t, err)
expected, err := hex.DecodeString("0397f4f6820b1f9386f14403be5ac16e50213bd473b4874b9bcbf5f318ee686b1d")
assert.NoError(t, err)
assert.Equal(t, len(expected), len(b))
assert.Equal(t, expected, b)
}
func TestGetObjectCBC2(t *testing.T) {
key, _ := hex.DecodeString("8d70e92489c4e6cfb12261b4d17f4b85826da687fc8742fcf9f87fadb5b4cb89")
keyB64 := base64.StdEncoding.EncodeToString(key)
// This is our KMS response
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewDecryptionClient(sess)
assert.NotNil(t, c)
input := &s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
req, out := c.GetObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
b, err := hex.DecodeString("fd0c71ecb7ed16a9bf42ea5f75501d416df608f190890c3b4d8897f24744cd7f9ea4a0b212e60634302450e1c5378f047ff753ccefe365d411c36339bf22e301fae4c3a6226719a4b93dc74c1af79d0296659b5d56c0892315f2c7cc30190220db1eaafae3920d6d9c65d0aa366499afc17af493454e141c6e0fbdeb6a990cb4")
assert.NoError(t, err)
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
http.CanonicalHeaderKey("x-amz-meta-x-amz-key-v2"): []string{"AQEDAHikdGvcj7Gil5VqAR/JWvvPp3ue26+t2vhWy4lL2hg4mAAAAH4wfAYJKoZIhvcNAQcGoG8wbQIBADBoBgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCcy43wCR0bSsnzTrAIBEIA7WdD2jxC3tCrK6TOdiEfbIN64m+UN7Velz4y0LRra5jn2U1CDClacwIpiBYuDp5ymPKO+ZqUGE0WEf20="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-iv"): []string{"EMMWJY8ZLcK/9FOj3iCpng=="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-matdesc"): []string{`{"kms_cmk_id":"arn:aws:kms:us-east-1:172259396726:key/a22a4b30-79f4-4b3d-bab4-a26d327a231b"}`},
http.CanonicalHeaderKey("x-amz-meta-x-amz-wrap-alg"): []string{s3crypto.KMSWrap},
http.CanonicalHeaderKey("x-amz-meta-x-amz-cek-alg"): []string{strings.Join([]string{s3crypto.AESCBC, s3crypto.AESCBCPadder.Name()}, "/")},
},
Body: ioutil.NopCloser(bytes.NewBuffer(b)),
}
fmt.Println("HEADER", r.HTTPResponse.Header)
out.Metadata = make(map[string]*string)
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
})
err := req.Send()
assert.NoError(t, err)
b, err := ioutil.ReadAll(out.Body)
assert.NoError(t, err)
expected, err := hex.DecodeString("a6ccd3482f5ce25c9ddeb69437cd0acbc0bdda2ef8696d90781de2b35704543529871b2032e68ef1c5baed1769aba8d420d1aca181341b49b8b3587a6580cdf1d809c68f06735f7735c16691f4b70c967d68fc08195b81ad71bcc4df452fd0a5799c1e1234f92f1cd929fc072167ccf9f2ac85b93170932b32")
assert.NoError(t, err)
assert.Equal(t, len(expected), len(b))
assert.Equal(t, expected, b)
}
func TestGetObjectWithContext(t *testing.T) {
c := s3crypto.NewDecryptionClient(unit.Session)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
input := s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
_, err := c.GetObjectWithContext(ctx, &input)
if err == nil {
t.Fatalf("expected error, did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expected error code %q, got %q", e, a)
}
if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}
+66
View File
@@ -0,0 +1,66 @@
/*
Package s3crypto provides encryption to S3 using KMS and AES GCM.
Keyproviders are interfaces that handle masterkeys. Masterkeys are used to encrypt and decrypt the randomly
generated cipher keys. The SDK currently uses KMS to do this. A user does not need to provide a master key
since all that information is hidden in KMS.
Modes are interfaces that handle content encryption and decryption. It is an abstraction layer that instantiates
the ciphers. If content is being encrypted we generate the key and iv of the cipher. For decryption, we use the
metadata stored either on the object or an instruction file object to decrypt the contents.
Ciphers are interfaces that handle encryption and decryption of data. This may be key wrap ciphers or content
ciphers.
Creating an S3 cryptography client
cmkID := "<some key ID>"
sess := session.New()
// Create the KeyProvider
handler := s3crypto.NewKMSKeyGenerator(kms.New(sess), cmkID)
// Create an encryption and decryption client
// We need to pass the session here so S3 can use it. In addition, any decryption that
// occurs will use the KMS client.
svc := s3crypto.NewEncryptionClient(sess, s3crypto.AESGCMContentCipherBuilder(handler))
svc := s3crypto.NewDecryptionClient(sess)
Configuration of the S3 cryptography client
cfg := s3crypto.EncryptionConfig{
// Save instruction files to separate objects
SaveStrategy: NewS3SaveStrategy(session.New(), ""),
// Change instruction file suffix to .example
InstructionFileSuffix: ".example",
// Set temp folder path
TempFolderPath: "/path/to/tmp/folder/",
// Any content less than the minimum file size will use memory
// instead of writing the contents to a temp file.
MinFileSize: int64(1024 * 1024 * 1024),
}
The default SaveStrategy is to the object's header.
The InstructionFileSuffix defaults to .instruction. Careful here though, if you do this, be sure you know
what that suffix is in grabbing data. All requests will look for fooKey.example instead of fooKey.instruction.
This suffix only affects gets and not puts. Put uses the keyprovider's suffix.
Registration of new wrap or cek algorithms are also supported by the SDK. Let's say we want to support `AES Wrap`
and `AES CTR`. Let's assume we have already defined the functionality.
svc := s3crypto.NewDecryptionClient(sess)
svc.WrapRegistry["AESWrap"] = NewAESWrap
svc.CEKRegistry["AES/CTR/NoPadding"] = NewAESCTR
We have now registered these new algorithms to the decryption client. When the client calls `GetObject` and sees
the wrap as `AESWrap` then it'll use that wrap algorithm. This is also true for `AES/CTR/NoPadding`.
For encryption adding a custom content cipher builder and key handler will allow for encryption of custom
defined ciphers.
// Our wrap algorithm, AESWrap
handler := NewAESWrap(key, iv)
// Our content cipher builder, AESCTRContentCipherBuilder
svc := s3crypto.NewEncryptionClient(sess, NewAESCTRContentCipherBuilder(handler))
*/
package s3crypto
@@ -0,0 +1,146 @@
package s3crypto
import (
"encoding/hex"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// DefaultMinFileSize is used to check whether we want to write to a temp file
// or store the data in memory.
const DefaultMinFileSize = 1024 * 512 * 5
// EncryptionClient is an S3 crypto client. By default the SDK will use Authentication mode which
// will use KMS for key wrapping and AES GCM for content encryption.
// AES GCM will load all data into memory. However, the rest of the content algorithms
// do not load the entire contents into memory.
type EncryptionClient struct {
S3Client s3iface.S3API
ContentCipherBuilder ContentCipherBuilder
// SaveStrategy will dictate where the envelope is saved.
//
// Defaults to the object's metadata
SaveStrategy SaveStrategy
// TempFolderPath is used to store temp files when calling PutObject.
// Temporary files are needed to compute the X-Amz-Content-Sha256 header.
TempFolderPath string
// MinFileSize is the minimum size for the content to write to a
// temporary file instead of using memory.
MinFileSize int64
}
// NewEncryptionClient instantiates a new S3 crypto client
//
// Example:
// cmkID := "arn:aws:kms:region:000000000000:key/00000000-0000-0000-0000-000000000000"
// sess := session.New()
// handler := s3crypto.NewKMSKeyGenerator(kms.New(sess), cmkID)
// svc := s3crypto.New(sess, s3crypto.AESGCMContentCipherBuilder(handler))
func NewEncryptionClient(prov client.ConfigProvider, builder ContentCipherBuilder, options ...func(*EncryptionClient)) *EncryptionClient {
client := &EncryptionClient{
S3Client: s3.New(prov),
ContentCipherBuilder: builder,
SaveStrategy: HeaderV2SaveStrategy{},
MinFileSize: DefaultMinFileSize,
}
for _, option := range options {
option(client)
}
return client
}
// PutObjectRequest creates a temp file to encrypt the contents into. It then streams
// that data to S3.
//
// Example:
// svc := s3crypto.New(session.New(), s3crypto.AESGCMContentCipherBuilder(handler))
// req, out := svc.PutObjectRequest(&s3.PutObjectInput {
// Key: aws.String("testKey"),
// Bucket: aws.String("testBucket"),
// Body: bytes.NewBuffer("test data"),
// })
// err := req.Send()
func (c *EncryptionClient) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
req, out := c.S3Client.PutObjectRequest(input)
// Get Size of file
n, err := input.Body.Seek(0, 2)
if err != nil {
req.Error = err
return req, out
}
input.Body.Seek(0, 0)
dst, err := getWriterStore(req, c.TempFolderPath, n >= c.MinFileSize)
if err != nil {
req.Error = err
return req, out
}
encryptor, err := c.ContentCipherBuilder.ContentCipher()
req.Handlers.Build.PushFront(func(r *request.Request) {
if err != nil {
r.Error = err
return
}
md5 := newMD5Reader(input.Body)
sha := newSHA256Writer(dst)
reader, err := encryptor.EncryptContents(md5)
if err != nil {
r.Error = err
return
}
_, err = io.Copy(sha, reader)
if err != nil {
r.Error = err
return
}
data := encryptor.GetCipherData()
env, err := encodeMeta(md5, data)
if err != nil {
r.Error = err
return
}
shaHex := hex.EncodeToString(sha.GetValue())
req.HTTPRequest.Header.Set("X-Amz-Content-Sha256", shaHex)
dst.Seek(0, 0)
input.Body = dst
err = c.SaveStrategy.Save(env, r)
r.Error = err
})
return req, out
}
// PutObject is a wrapper for PutObjectRequest
func (c *EncryptionClient) PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
req, out := c.PutObjectRequest(input)
return out, req.Send()
}
// PutObjectWithContext is a wrapper for PutObjectRequest with the additional
// context, and request options support.
//
// PutObjectWithContext is the same as PutObject with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the Context to add deadlining, timeouts, ect. In the future
// this may create sub-contexts for individual underlying requests.
func (c *EncryptionClient) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) {
req, out := c.PutObjectRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return out, req.Send()
}
@@ -0,0 +1,99 @@
package s3crypto_test
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestDefaultConfigValues(t *testing.T) {
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
svc := kms.New(sess)
handler := s3crypto.NewKMSKeyGenerator(svc, "testid")
c := s3crypto.NewEncryptionClient(sess, s3crypto.AESGCMContentCipherBuilder(handler))
assert.NotNil(t, c)
assert.NotNil(t, c.ContentCipherBuilder)
assert.NotNil(t, c.SaveStrategy)
}
func TestPutObject(t *testing.T) {
size := 1024 * 1024
data := make([]byte, size)
expected := bytes.Repeat([]byte{1}, size)
generator := mockGenerator{}
cb := mockCipherBuilder{generator}
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewEncryptionClient(sess, cb)
assert.NotNil(t, c)
input := &s3.PutObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
Body: bytes.NewReader(data),
}
req, _ := c.PutObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
r.Error = errors.New("stop")
r.HTTPResponse = &http.Response{
StatusCode: 200,
}
})
err := req.Send()
assert.Equal(t, "stop", err.Error())
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
assert.NoError(t, err)
assert.Equal(t, expected, b)
}
func TestPutObjectWithContext(t *testing.T) {
generator := mockGenerator{}
cb := mockCipherBuilder{generator}
c := s3crypto.NewEncryptionClient(unit.Session, cb)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
input := s3.PutObjectInput{
Bucket: aws.String("test"),
Key: aws.String("test"),
Body: bytes.NewReader([]byte{}),
}
_, err := c.PutObjectWithContext(ctx, &input)
if err == nil {
t.Fatalf("expected error, did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expected error code %q, got %q", e, a)
}
if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}
+37
View File
@@ -0,0 +1,37 @@
package s3crypto
// DefaultInstructionKeySuffix is appended to the end of the instruction file key when
// grabbing or saving to S3
const DefaultInstructionKeySuffix = ".instruction"
const (
metaHeader = "x-amz-meta"
keyV1Header = "x-amz-key"
keyV2Header = keyV1Header + "-v2"
ivHeader = "x-amz-iv"
matDescHeader = "x-amz-matdesc"
cekAlgorithmHeader = "x-amz-cek-alg"
wrapAlgorithmHeader = "x-amz-wrap-alg"
tagLengthHeader = "x-amz-tag-len"
unencryptedMD5Header = "x-amz-unencrypted-content-md5"
unencryptedContentLengthHeader = "x-amz-unencrypted-content-length"
)
// Envelope encryption starts off by generating a random symmetric key using
// AES GCM. The SDK generates a random IV based off the encryption cipher
// chosen. The master key that was provided, whether by the user or KMS, will be used
// to encrypt the randomly generated symmetric key and base64 encode the iv. This will
// allow for decryption of that same data later.
type Envelope struct {
// IV is the randomly generated IV base64 encoded.
IV string `json:"x-amz-iv"`
// CipherKey is the randomly generated cipher key.
CipherKey string `json:"x-amz-key-v2, x-amz-key"`
// MaterialDesc is a description to distinguish from other envelopes.
MatDesc string `json:"x-amz-matdesc"`
WrapAlg string `json:"x-amz-wrap-alg"`
CEKAlg string `json:"x-amz-cek-alg"`
TagLen string `json:"x-amz-tag-len"`
UnencryptedMD5 string `json:"x-amz-unencrypted-content-md5"`
UnencryptedContentLen string `json:"x-amz-unencrypted-content-length"`
}
+61
View File
@@ -0,0 +1,61 @@
package s3crypto
import (
"crypto/md5"
"crypto/sha256"
"hash"
"io"
)
// hashReader is used for calculating SHA256 when following the sigv4 specification.
// Additionally this used for calculating the unencrypted MD5.
type hashReader interface {
GetValue() []byte
GetContentLength() int64
}
type sha256Writer struct {
sha256 []byte
hash hash.Hash
out io.Writer
}
func newSHA256Writer(f io.Writer) *sha256Writer {
return &sha256Writer{hash: sha256.New(), out: f}
}
func (r *sha256Writer) Write(b []byte) (int, error) {
r.hash.Write(b)
return r.out.Write(b)
}
func (r *sha256Writer) GetValue() []byte {
return r.hash.Sum(nil)
}
type md5Reader struct {
contentLength int64
hash hash.Hash
body io.Reader
}
func newMD5Reader(body io.Reader) *md5Reader {
return &md5Reader{hash: md5.New(), body: body}
}
func (w *md5Reader) Read(b []byte) (int, error) {
n, err := w.body.Read(b)
if err != nil && err != io.EOF {
return n, err
}
w.contentLength += int64(n)
w.hash.Write(b[:n])
return n, err
}
func (w *md5Reader) GetValue() []byte {
return w.hash.Sum(nil)
}
func (w *md5Reader) GetContentLength() int64 {
return w.contentLength
}
+25
View File
@@ -0,0 +1,25 @@
package s3crypto
import (
"bytes"
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
)
// From Go stdlib encoding/sha256 test cases
func TestSHA256(t *testing.T) {
sha := newSHA256Writer(nil)
expected, _ := hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
b := sha.GetValue()
assert.Equal(t, expected, b)
}
func TestSHA256_Case2(t *testing.T) {
sha := newSHA256Writer(bytes.NewBuffer([]byte{}))
sha.Write([]byte("hello"))
expected, _ := hex.DecodeString("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824")
b := sha.GetValue()
assert.Equal(t, expected, b)
}
+70
View File
@@ -0,0 +1,70 @@
package s3crypto
import (
"errors"
"io"
"io/ioutil"
"os"
"path/filepath"
"github.com/aws/aws-sdk-go/aws/request"
)
func getWriterStore(req *request.Request, path string, useTempFile bool) (io.ReadWriteSeeker, error) {
if !useTempFile {
return &bytesReadWriteSeeker{}, nil
}
// Create temp file to be used later for calculating the SHA256 header
f, err := ioutil.TempFile(path, "")
if err != nil {
return nil, err
}
req.Handlers.Send.PushBack(func(r *request.Request) {
// Close the temp file and cleanup
f.Close()
fpath := filepath.Join(path, f.Name())
os.Remove(fpath)
})
return f, nil
}
type bytesReadWriteSeeker struct {
buf []byte
i int64
}
// Copied from Go stdlib bytes.Reader
func (ws *bytesReadWriteSeeker) Read(b []byte) (int, error) {
if ws.i >= int64(len(ws.buf)) {
return 0, io.EOF
}
n := copy(b, ws.buf[ws.i:])
ws.i += int64(n)
return n, nil
}
func (ws *bytesReadWriteSeeker) Write(b []byte) (int, error) {
ws.buf = append(ws.buf, b...)
return len(b), nil
}
// Copied from Go stdlib bytes.Reader
func (ws *bytesReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
var abs int64
switch whence {
case 0:
abs = offset
case 1:
abs = int64(ws.i) + offset
case 2:
abs = int64(len(ws.buf)) + offset
default:
return 0, errors.New("bytes.Reader.Seek: invalid whence")
}
if abs < 0 {
return 0, errors.New("bytes.Reader.Seek: negative position")
}
ws.i = abs
return abs, nil
}
+47
View File
@@ -0,0 +1,47 @@
package s3crypto
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBytesReadWriteSeeker_Read(t *testing.T) {
b := &bytesReadWriteSeeker{[]byte{1, 2, 3}, 0}
expected := []byte{1, 2, 3}
buf := make([]byte, 3)
n, err := b.Read(buf)
assert.NoError(t, err)
assert.Equal(t, 3, n)
assert.Equal(t, expected, buf)
}
func TestBytesReadWriteSeeker_Write(t *testing.T) {
b := &bytesReadWriteSeeker{}
expected := []byte{1, 2, 3}
buf := make([]byte, 3)
n, err := b.Write([]byte{1, 2, 3})
assert.NoError(t, err)
assert.Equal(t, 3, n)
n, err = b.Read(buf)
assert.NoError(t, err)
assert.Equal(t, 3, n)
assert.Equal(t, expected, buf)
}
func TestBytesReadWriteSeeker_Seek(t *testing.T) {
b := &bytesReadWriteSeeker{[]byte{1, 2, 3}, 0}
expected := []byte{2, 3}
m, err := b.Seek(1, 0)
assert.NoError(t, err)
assert.Equal(t, 1, int(m))
buf := make([]byte, 3)
n, err := b.Read(buf)
assert.NoError(t, err)
assert.Equal(t, 2, n)
assert.Equal(t, expected, buf[:n])
}
+21
View File
@@ -0,0 +1,21 @@
package s3crypto
import "crypto/rand"
// CipherDataGenerator handles generating proper key and IVs of proper size for the
// content cipher. CipherDataGenerator will also encrypt the key and store it in
// the CipherData.
type CipherDataGenerator interface {
GenerateCipherData(int, int) (CipherData, error)
}
// CipherDataDecrypter is a handler to decrypt keys from the envelope.
type CipherDataDecrypter interface {
DecryptKey([]byte) ([]byte, error)
}
func generateBytes(n int) []byte {
b := make([]byte, n)
rand.Read(b)
return b
}
@@ -0,0 +1,16 @@
package s3crypto
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGenerateBytes(t *testing.T) {
b := generateBytes(5)
assert.Equal(t, 5, len(b))
b = generateBytes(0)
assert.Equal(t, 0, len(b))
b = generateBytes(1024)
assert.Equal(t, 1024, len(b))
}
+114
View File
@@ -0,0 +1,114 @@
package s3crypto
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
)
const (
// KMSWrap is a constant used during decryption to build a KMS key handler.
KMSWrap = "kms"
)
// kmsKeyHandler will make calls to KMS to get the masterkey
type kmsKeyHandler struct {
kms kmsiface.KMSAPI
cmkID *string
CipherData
}
// NewKMSKeyGenerator builds a new KMS key provider using the customer key ID and material
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler := s3crypto.NewKMSKeyGenerator(kms.New(sess), cmkID)
func NewKMSKeyGenerator(kmsClient kmsiface.KMSAPI, cmkID string) CipherDataGenerator {
return NewKMSKeyGeneratorWithMatDesc(kmsClient, cmkID, MaterialDescription{})
}
// NewKMSKeyGeneratorWithMatDesc builds a new KMS key provider using the customer key ID and material
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler, err := s3crypto.NewKMSKeyGeneratorWithMatDesc(kms.New(sess), cmkID, matdesc)
func NewKMSKeyGeneratorWithMatDesc(kmsClient kmsiface.KMSAPI, cmkID string, matdesc MaterialDescription) CipherDataGenerator {
if matdesc == nil {
matdesc = MaterialDescription{}
}
matdesc["kms_cmk_id"] = &cmkID
// These values are read only making them thread safe
kp := &kmsKeyHandler{
kms: kmsClient,
cmkID: &cmkID,
}
// These values are read only making them thread safe
kp.CipherData.WrapAlgorithm = KMSWrap
kp.CipherData.MaterialDescription = matdesc
return kp
}
// decryptHandler initializes a KMS keyprovider with a material description. This
// is used with Decrypting kms content, due to the cmkID being in the material description.
func (kp kmsKeyHandler) decryptHandler(env Envelope) (CipherDataDecrypter, error) {
m := MaterialDescription{}
err := m.decodeDescription([]byte(env.MatDesc))
if err != nil {
return nil, err
}
cmkID, ok := m["kms_cmk_id"]
if !ok {
return nil, awserr.New("MissingCMKIDError", "Material description is missing CMK ID", nil)
}
kp.CipherData.MaterialDescription = m
kp.cmkID = cmkID
kp.WrapAlgorithm = KMSWrap
return &kp, nil
}
// DecryptKey makes a call to KMS to decrypt the key.
func (kp *kmsKeyHandler) DecryptKey(key []byte) ([]byte, error) {
out, err := kp.kms.Decrypt(&kms.DecryptInput{
EncryptionContext: map[string]*string(kp.CipherData.MaterialDescription),
CiphertextBlob: key,
GrantTokens: []*string{},
})
if err != nil {
return nil, err
}
return out.Plaintext, nil
}
// GenerateCipherData makes a call to KMS to generate a data key, Upon making
// the call, it also sets the encrypted key.
func (kp *kmsKeyHandler) GenerateCipherData(keySize, ivSize int) (CipherData, error) {
out, err := kp.kms.GenerateDataKey(&kms.GenerateDataKeyInput{
EncryptionContext: kp.CipherData.MaterialDescription,
KeyId: kp.cmkID,
KeySpec: aws.String("AES_256"),
})
if err != nil {
return CipherData{}, err
}
iv := generateBytes(ivSize)
cd := CipherData{
Key: out.Plaintext,
IV: iv,
WrapAlgorithm: KMSWrap,
MaterialDescription: kp.CipherData.MaterialDescription,
EncryptedKey: out.CiphertextBlob,
}
return cd, nil
}
@@ -0,0 +1,105 @@
package s3crypto
import (
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kms"
)
func TestBuildKMSEncryptHandler(t *testing.T) {
svc := kms.New(unit.Session)
handler := NewKMSKeyGenerator(svc, "testid")
assert.NotNil(t, handler)
}
func TestBuildKMSEncryptHandlerWithMatDesc(t *testing.T) {
svc := kms.New(unit.Session)
handler := NewKMSKeyGeneratorWithMatDesc(svc, "testid", MaterialDescription{
"Testing": aws.String("123"),
})
assert.NotNil(t, handler)
kmsHandler := handler.(*kmsKeyHandler)
expected := MaterialDescription{
"kms_cmk_id": aws.String("testid"),
"Testing": aws.String("123"),
}
assert.Equal(t, expected, kmsHandler.CipherData.MaterialDescription)
}
func TestKMSGenerateCipherData(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `{"CiphertextBlob":"AQEDAHhqBCCY1MSimw8gOGcUma79cn4ANvTtQyv9iuBdbcEF1QAAAH4wfAYJKoZIhvcNAQcGoG8wbQIBADBoBgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDJ6IcN5E4wVbk38MNAIBEIA7oF1E3lS7FY9DkoxPc/UmJsEwHzL82zMqoLwXIvi8LQHr8If4Lv6zKqY8u0+JRgSVoqCvZDx3p8Cn6nM=","KeyId":"arn:aws:kms:us-west-2:042062605278:key/c80a5cdb-8d09-4f9f-89ee-df01b2e3870a","Plaintext":"6tmyz9JLBE2yIuU7iXpArqpDVle172WSmxjcO6GNT7E="}`)
}))
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
svc := kms.New(sess)
handler := NewKMSKeyGenerator(svc, "testid")
keySize := 32
ivSize := 16
cd, err := handler.GenerateCipherData(keySize, ivSize)
assert.NoError(t, err)
assert.Equal(t, keySize, len(cd.Key))
assert.Equal(t, ivSize, len(cd.IV))
assert.NotEmpty(t, cd.Key)
assert.NotEmpty(t, cd.IV)
}
func TestKMSDecrypt(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
handler, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{MatDesc: `{"kms_cmk_id":"test"}`})
assert.NoError(t, err)
plaintextKey, err := handler.DecryptKey([]byte{1, 2, 3, 4})
assert.NoError(t, err)
assert.Equal(t, key, plaintextKey)
}
func TestKMSDecryptBadJSON(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
_, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{MatDesc: `{"kms_cmk_id":"test"`})
assert.Error(t, err)
}
+18
View File
@@ -0,0 +1,18 @@
package s3crypto
import (
"encoding/json"
)
// MaterialDescription is used to identify how and what master
// key has been used.
type MaterialDescription map[string]*string
func (md *MaterialDescription) encodeDescription() ([]byte, error) {
v, err := json.Marshal(&md)
return v, err
}
func (md *MaterialDescription) decodeDescription(b []byte) error {
return json.Unmarshal(b, &md)
}
+28
View File
@@ -0,0 +1,28 @@
package s3crypto
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
)
func TestEncodeMaterialDescription(t *testing.T) {
md := MaterialDescription{}
md["foo"] = aws.String("bar")
b, err := md.encodeDescription()
expected := `{"foo":"bar"}`
assert.NoError(t, err)
assert.Equal(t, expected, string(b))
}
func TestDecodeMaterialDescription(t *testing.T) {
md := MaterialDescription{}
json := `{"foo":"bar"}`
err := md.decodeDescription([]byte(json))
expected := MaterialDescription{
"foo": aws.String("bar"),
}
assert.NoError(t, err)
assert.Equal(t, expected, md)
}
+70
View File
@@ -0,0 +1,70 @@
package s3crypto_test
import (
"bytes"
"io"
"io/ioutil"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
type mockGenerator struct {
}
func (m mockGenerator) GenerateCipherData(keySize, ivSize int) (s3crypto.CipherData, error) {
cd := s3crypto.CipherData{
Key: make([]byte, keySize),
IV: make([]byte, ivSize),
}
return cd, nil
}
func (m mockGenerator) EncryptKey(key []byte) ([]byte, error) {
size := len(key)
b := bytes.Repeat([]byte{1}, size)
return b, nil
}
func (m mockGenerator) DecryptKey(key []byte) ([]byte, error) {
return make([]byte, 16), nil
}
type mockCipherBuilder struct {
generator s3crypto.CipherDataGenerator
}
func (builder mockCipherBuilder) ContentCipher() (s3crypto.ContentCipher, error) {
cd, err := builder.generator.GenerateCipherData(32, 16)
if err != nil {
return nil, err
}
return &mockContentCipher{cd}, nil
}
type mockContentCipher struct {
cd s3crypto.CipherData
}
func (cipher *mockContentCipher) GetCipherData() s3crypto.CipherData {
return cipher.cd
}
func (cipher *mockContentCipher) EncryptContents(src io.Reader) (io.Reader, error) {
b, err := ioutil.ReadAll(src)
if err != nil {
return nil, err
}
size := len(b)
b = bytes.Repeat([]byte{1}, size)
return bytes.NewReader(b), nil
}
func (cipher *mockContentCipher) DecryptContents(src io.ReadCloser) (io.ReadCloser, error) {
b, err := ioutil.ReadAll(src)
if err != nil {
return nil, err
}
size := len(b)
return ioutil.NopCloser(bytes.NewReader(make([]byte, size))), nil
}
+35
View File
@@ -0,0 +1,35 @@
package s3crypto
// Padder handles padding of crypto data
type Padder interface {
// Pad will pad the byte array.
// The second parameter is NOT how many
// bytes to pad by, but how many bytes
// have been read prior to the padding.
// This allows for streamable padding.
Pad([]byte, int) ([]byte, error)
// Unpad will unpad the byte bytes. Unpad
// methods must be constant time.
Unpad([]byte) ([]byte, error)
// Name returns the name of the padder.
// This is used when decrypting on
// instantiating new padders.
Name() string
}
// NoPadder does not pad anything
var NoPadder = Padder(noPadder{})
type noPadder struct{}
func (padder noPadder) Pad(b []byte, n int) ([]byte, error) {
return b, nil
}
func (padder noPadder) Unpad(b []byte) ([]byte, error) {
return b, nil
}
func (padder noPadder) Name() string {
return "NoPadding"
}
+80
View File
@@ -0,0 +1,80 @@
package s3crypto
// Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Portions Licensed under the MIT License. Copyright (c) 2016 Carl Jackson
import (
"bytes"
"crypto/subtle"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const (
pkcs7MaxPaddingSize = 255
)
type pkcs7Padder struct {
blockSize int
}
// NewPKCS7Padder follows the RFC 2315: https://www.ietf.org/rfc/rfc2315.txt
// PKCS7 padding is subject to side-channel attacks and timing attacks. For
// the most secure data, use an authenticated crypto algorithm.
func NewPKCS7Padder(blockSize int) Padder {
return pkcs7Padder{blockSize}
}
var errPKCS7Padding = awserr.New("InvalidPadding", "invalid padding", nil)
// Pad will pad the data relative to how many bytes have been read.
// Pad follows the PKCS7 standard.
func (padder pkcs7Padder) Pad(buf []byte, n int) ([]byte, error) {
if padder.blockSize < 1 || padder.blockSize > pkcs7MaxPaddingSize {
return nil, awserr.New("InvalidBlockSize", "block size must be between 1 and 255", nil)
}
size := padder.blockSize - (n % padder.blockSize)
pad := bytes.Repeat([]byte{byte(size)}, size)
buf = append(buf, pad...)
return buf, nil
}
// Unpad will unpad the correct amount of bytes based off
// of the PKCS7 standard
func (padder pkcs7Padder) Unpad(buf []byte) ([]byte, error) {
if len(buf) == 0 {
return nil, errPKCS7Padding
}
// Here be dragons. We're attempting to check the padding in constant
// time. The only piece of information here which is public is len(buf).
// This code is modeled loosely after tls1_cbc_remove_padding from
// OpenSSL.
padLen := buf[len(buf)-1]
toCheck := pkcs7MaxPaddingSize
good := 1
if toCheck > len(buf) {
toCheck = len(buf)
}
for i := 0; i < toCheck; i++ {
b := buf[len(buf)-1-i]
outOfRange := subtle.ConstantTimeLessOrEq(int(padLen), i)
equal := subtle.ConstantTimeByteEq(padLen, b)
good &= subtle.ConstantTimeSelect(outOfRange, 1, equal)
}
good &= subtle.ConstantTimeLessOrEq(1, int(padLen))
good &= subtle.ConstantTimeLessOrEq(int(padLen), len(buf))
if good != 1 {
return nil, errPKCS7Padding
}
return buf[:len(buf)-int(padLen)], nil
}
func (padder pkcs7Padder) Name() string {
return "PKCS7Padding"
}
@@ -0,0 +1,57 @@
package s3crypto_test
import (
"bytes"
"fmt"
"testing"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func padTest(size int, t *testing.T) {
padder := s3crypto.NewPKCS7Padder(size)
for i := 0; i < size; i++ {
input := make([]byte, i)
expected := append(input, bytes.Repeat([]byte{byte(size - i)}, size-i)...)
b, err := padder.Pad(input, len(input))
if err != nil {
t.Fatal("Expected error to be nil but received " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}
func unpadTest(size int, t *testing.T) {
padder := s3crypto.NewPKCS7Padder(size)
for i := 0; i < size; i++ {
expected := make([]byte, i)
input := append(expected, bytes.Repeat([]byte{byte(size - i)}, size-i)...)
b, err := padder.Unpad(input)
if err != nil {
t.Fatal("Error received, was expecting nil: " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}
func TestPKCS7Padding(t *testing.T) {
padTest(10, t)
padTest(16, t)
padTest(255, t)
}
func TestPKCS7Unpadding(t *testing.T) {
unpadTest(10, t)
unpadTest(16, t)
unpadTest(255, t)
}
+142
View File
@@ -0,0 +1,142 @@
package s3crypto
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
)
// SaveStrategy is how the data's metadata wants to be saved
type SaveStrategy interface {
Save(Envelope, *request.Request) error
}
// S3SaveStrategy will save the metadata to a separate instruction file in S3
type S3SaveStrategy struct {
Client *s3.S3
InstructionFileSuffix string
}
// Save will save the envelope contents to s3.
func (strat S3SaveStrategy) Save(env Envelope, req *request.Request) error {
input := req.Params.(*s3.PutObjectInput)
b, err := json.Marshal(env)
if err != nil {
return err
}
instInput := s3.PutObjectInput{
Bucket: input.Bucket,
Body: bytes.NewReader(b),
}
if strat.InstructionFileSuffix == "" {
instInput.Key = aws.String(*input.Key + DefaultInstructionKeySuffix)
} else {
instInput.Key = aws.String(*input.Key + strat.InstructionFileSuffix)
}
_, err = strat.Client.PutObject(&instInput)
return err
}
// HeaderV2SaveStrategy will save the metadata of the crypto contents to the header of
// the object.
type HeaderV2SaveStrategy struct{}
// Save will save the envelope to the request's header.
func (strat HeaderV2SaveStrategy) Save(env Envelope, req *request.Request) error {
input := req.Params.(*s3.PutObjectInput)
if input.Metadata == nil {
input.Metadata = map[string]*string{}
}
input.Metadata[http.CanonicalHeaderKey(keyV2Header)] = &env.CipherKey
input.Metadata[http.CanonicalHeaderKey(ivHeader)] = &env.IV
input.Metadata[http.CanonicalHeaderKey(matDescHeader)] = &env.MatDesc
input.Metadata[http.CanonicalHeaderKey(wrapAlgorithmHeader)] = &env.WrapAlg
input.Metadata[http.CanonicalHeaderKey(cekAlgorithmHeader)] = &env.CEKAlg
input.Metadata[http.CanonicalHeaderKey(tagLengthHeader)] = &env.TagLen
input.Metadata[http.CanonicalHeaderKey(unencryptedMD5Header)] = &env.UnencryptedMD5
input.Metadata[http.CanonicalHeaderKey(unencryptedContentLengthHeader)] = &env.UnencryptedContentLen
return nil
}
// LoadStrategy ...
type LoadStrategy interface {
Load(*request.Request) (Envelope, error)
}
// S3LoadStrategy will load the instruction file from s3
type S3LoadStrategy struct {
Client *s3.S3
InstructionFileSuffix string
}
// Load from a given instruction file suffix
func (load S3LoadStrategy) Load(req *request.Request) (Envelope, error) {
env := Envelope{}
if load.InstructionFileSuffix == "" {
load.InstructionFileSuffix = DefaultInstructionKeySuffix
}
input := req.Params.(*s3.GetObjectInput)
out, err := load.Client.GetObject(&s3.GetObjectInput{
Key: aws.String(strings.Join([]string{*input.Key, load.InstructionFileSuffix}, "")),
Bucket: input.Bucket,
})
if err != nil {
return env, err
}
b, err := ioutil.ReadAll(out.Body)
if err != nil {
return env, err
}
err = json.Unmarshal(b, &env)
return env, err
}
// HeaderV2LoadStrategy will load the envelope from the metadata
type HeaderV2LoadStrategy struct{}
// Load from a given object's header
func (load HeaderV2LoadStrategy) Load(req *request.Request) (Envelope, error) {
env := Envelope{}
env.CipherKey = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV2Header}, "-"))
env.IV = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, ivHeader}, "-"))
env.MatDesc = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, matDescHeader}, "-"))
env.WrapAlg = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, wrapAlgorithmHeader}, "-"))
env.CEKAlg = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, cekAlgorithmHeader}, "-"))
env.TagLen = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, tagLengthHeader}, "-"))
env.UnencryptedMD5 = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, unencryptedMD5Header}, "-"))
env.UnencryptedContentLen = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, unencryptedContentLengthHeader}, "-"))
return env, nil
}
type defaultV2LoadStrategy struct {
client *s3.S3
suffix string
}
func (load defaultV2LoadStrategy) Load(req *request.Request) (Envelope, error) {
if value := req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV2Header}, "-")); value != "" {
strat := HeaderV2LoadStrategy{}
return strat.Load(req)
} else if value = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV1Header}, "-")); value != "" {
return Envelope{}, awserr.New("V1NotSupportedError", "The AWS SDK for Go does not support version 1", nil)
}
strat := S3LoadStrategy{
Client: load.client,
InstructionFileSuffix: load.suffix,
}
return strat.Load(req)
}
+46
View File
@@ -0,0 +1,46 @@
package s3crypto_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestHeaderV2SaveStrategy(t *testing.T) {
env := s3crypto.Envelope{
CipherKey: "Foo",
IV: "Bar",
MatDesc: "{}",
WrapAlg: s3crypto.KMSWrap,
CEKAlg: s3crypto.AESGCMNoPadding,
TagLen: "128",
UnencryptedMD5: "hello",
UnencryptedContentLen: "0",
}
params := &s3.PutObjectInput{}
req := &request.Request{
Params: params,
}
strat := s3crypto.HeaderV2SaveStrategy{}
err := strat.Save(env, req)
assert.NoError(t, err)
expected := map[string]*string{
"X-Amz-Key-V2": aws.String("Foo"),
"X-Amz-Iv": aws.String("Bar"),
"X-Amz-Matdesc": aws.String("{}"),
"X-Amz-Wrap-Alg": aws.String(s3crypto.KMSWrap),
"X-Amz-Cek-Alg": aws.String(s3crypto.AESGCMNoPadding),
"X-Amz-Tag-Len": aws.String("128"),
"X-Amz-Unencrypted-Content-Md5": aws.String("hello"),
"X-Amz-Unencrypted-Content-Length": aws.String("0"),
}
assert.Equal(t, len(expected), len(params.Metadata))
assert.Equal(t, expected, params.Metadata)
}
+387
View File
@@ -0,0 +1,387 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
// Package s3iface provides an interface to enable mocking the Amazon Simple Storage Service service client
// for testing your code.
//
// It is important to note that this interface will have breaking changes
// when the service model is updated and adds new API operations, paginators,
// and waiters.
package s3iface
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
)
// S3API provides an interface to enable mocking the
// s3.S3 service client's API operation,
// paginators, and waiters. This make unit testing your code that calls out
// to the SDK's service client's calls easier.
//
// The best way to use this interface is so the SDK's service client's calls
// can be stubbed out for unit testing your code with the SDK without needing
// to inject custom request handlers into the the SDK's request pipeline.
//
// // myFunc uses an SDK service client to make a request to
// // Amazon Simple Storage Service.
// func myFunc(svc s3iface.S3API) bool {
// // Make svc.AbortMultipartUpload request
// }
//
// func main() {
// sess := session.New()
// svc := s3.New(sess)
//
// myFunc(svc)
// }
//
// In your _test.go file:
//
// // Define a mock struct to be used in your unit tests of myFunc.
// type mockS3Client struct {
// s3iface.S3API
// }
// func (m *mockS3Client) AbortMultipartUpload(input *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) {
// // mock response/functionality
// }
//
// func TestMyFunc(t *testing.T) {
// // Setup Test
// mockSvc := &mockS3Client{}
//
// myfunc(mockSvc)
//
// // Verify myFunc's functionality
// }
//
// It is important to note that this interface will have breaking changes
// when the service model is updated and adds new API operations, paginators,
// and waiters. Its suggested to use the pattern above for testing, or using
// tooling to generate mocks to satisfy the interfaces.
type S3API interface {
AbortMultipartUpload(*s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error)
AbortMultipartUploadWithContext(aws.Context, *s3.AbortMultipartUploadInput, ...request.Option) (*s3.AbortMultipartUploadOutput, error)
AbortMultipartUploadRequest(*s3.AbortMultipartUploadInput) (*request.Request, *s3.AbortMultipartUploadOutput)
CompleteMultipartUpload(*s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error)
CompleteMultipartUploadWithContext(aws.Context, *s3.CompleteMultipartUploadInput, ...request.Option) (*s3.CompleteMultipartUploadOutput, error)
CompleteMultipartUploadRequest(*s3.CompleteMultipartUploadInput) (*request.Request, *s3.CompleteMultipartUploadOutput)
CopyObject(*s3.CopyObjectInput) (*s3.CopyObjectOutput, error)
CopyObjectWithContext(aws.Context, *s3.CopyObjectInput, ...request.Option) (*s3.CopyObjectOutput, error)
CopyObjectRequest(*s3.CopyObjectInput) (*request.Request, *s3.CopyObjectOutput)
CreateBucket(*s3.CreateBucketInput) (*s3.CreateBucketOutput, error)
CreateBucketWithContext(aws.Context, *s3.CreateBucketInput, ...request.Option) (*s3.CreateBucketOutput, error)
CreateBucketRequest(*s3.CreateBucketInput) (*request.Request, *s3.CreateBucketOutput)
CreateMultipartUpload(*s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
CreateMultipartUploadWithContext(aws.Context, *s3.CreateMultipartUploadInput, ...request.Option) (*s3.CreateMultipartUploadOutput, error)
CreateMultipartUploadRequest(*s3.CreateMultipartUploadInput) (*request.Request, *s3.CreateMultipartUploadOutput)
DeleteBucket(*s3.DeleteBucketInput) (*s3.DeleteBucketOutput, error)
DeleteBucketWithContext(aws.Context, *s3.DeleteBucketInput, ...request.Option) (*s3.DeleteBucketOutput, error)
DeleteBucketRequest(*s3.DeleteBucketInput) (*request.Request, *s3.DeleteBucketOutput)
DeleteBucketAnalyticsConfiguration(*s3.DeleteBucketAnalyticsConfigurationInput) (*s3.DeleteBucketAnalyticsConfigurationOutput, error)
DeleteBucketAnalyticsConfigurationWithContext(aws.Context, *s3.DeleteBucketAnalyticsConfigurationInput, ...request.Option) (*s3.DeleteBucketAnalyticsConfigurationOutput, error)
DeleteBucketAnalyticsConfigurationRequest(*s3.DeleteBucketAnalyticsConfigurationInput) (*request.Request, *s3.DeleteBucketAnalyticsConfigurationOutput)
DeleteBucketCors(*s3.DeleteBucketCorsInput) (*s3.DeleteBucketCorsOutput, error)
DeleteBucketCorsWithContext(aws.Context, *s3.DeleteBucketCorsInput, ...request.Option) (*s3.DeleteBucketCorsOutput, error)
DeleteBucketCorsRequest(*s3.DeleteBucketCorsInput) (*request.Request, *s3.DeleteBucketCorsOutput)
DeleteBucketInventoryConfiguration(*s3.DeleteBucketInventoryConfigurationInput) (*s3.DeleteBucketInventoryConfigurationOutput, error)
DeleteBucketInventoryConfigurationWithContext(aws.Context, *s3.DeleteBucketInventoryConfigurationInput, ...request.Option) (*s3.DeleteBucketInventoryConfigurationOutput, error)
DeleteBucketInventoryConfigurationRequest(*s3.DeleteBucketInventoryConfigurationInput) (*request.Request, *s3.DeleteBucketInventoryConfigurationOutput)
DeleteBucketLifecycle(*s3.DeleteBucketLifecycleInput) (*s3.DeleteBucketLifecycleOutput, error)
DeleteBucketLifecycleWithContext(aws.Context, *s3.DeleteBucketLifecycleInput, ...request.Option) (*s3.DeleteBucketLifecycleOutput, error)
DeleteBucketLifecycleRequest(*s3.DeleteBucketLifecycleInput) (*request.Request, *s3.DeleteBucketLifecycleOutput)
DeleteBucketMetricsConfiguration(*s3.DeleteBucketMetricsConfigurationInput) (*s3.DeleteBucketMetricsConfigurationOutput, error)
DeleteBucketMetricsConfigurationWithContext(aws.Context, *s3.DeleteBucketMetricsConfigurationInput, ...request.Option) (*s3.DeleteBucketMetricsConfigurationOutput, error)
DeleteBucketMetricsConfigurationRequest(*s3.DeleteBucketMetricsConfigurationInput) (*request.Request, *s3.DeleteBucketMetricsConfigurationOutput)
DeleteBucketPolicy(*s3.DeleteBucketPolicyInput) (*s3.DeleteBucketPolicyOutput, error)
DeleteBucketPolicyWithContext(aws.Context, *s3.DeleteBucketPolicyInput, ...request.Option) (*s3.DeleteBucketPolicyOutput, error)
DeleteBucketPolicyRequest(*s3.DeleteBucketPolicyInput) (*request.Request, *s3.DeleteBucketPolicyOutput)
DeleteBucketReplication(*s3.DeleteBucketReplicationInput) (*s3.DeleteBucketReplicationOutput, error)
DeleteBucketReplicationWithContext(aws.Context, *s3.DeleteBucketReplicationInput, ...request.Option) (*s3.DeleteBucketReplicationOutput, error)
DeleteBucketReplicationRequest(*s3.DeleteBucketReplicationInput) (*request.Request, *s3.DeleteBucketReplicationOutput)
DeleteBucketTagging(*s3.DeleteBucketTaggingInput) (*s3.DeleteBucketTaggingOutput, error)
DeleteBucketTaggingWithContext(aws.Context, *s3.DeleteBucketTaggingInput, ...request.Option) (*s3.DeleteBucketTaggingOutput, error)
DeleteBucketTaggingRequest(*s3.DeleteBucketTaggingInput) (*request.Request, *s3.DeleteBucketTaggingOutput)
DeleteBucketWebsite(*s3.DeleteBucketWebsiteInput) (*s3.DeleteBucketWebsiteOutput, error)
DeleteBucketWebsiteWithContext(aws.Context, *s3.DeleteBucketWebsiteInput, ...request.Option) (*s3.DeleteBucketWebsiteOutput, error)
DeleteBucketWebsiteRequest(*s3.DeleteBucketWebsiteInput) (*request.Request, *s3.DeleteBucketWebsiteOutput)
DeleteObject(*s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error)
DeleteObjectWithContext(aws.Context, *s3.DeleteObjectInput, ...request.Option) (*s3.DeleteObjectOutput, error)
DeleteObjectRequest(*s3.DeleteObjectInput) (*request.Request, *s3.DeleteObjectOutput)
DeleteObjectTagging(*s3.DeleteObjectTaggingInput) (*s3.DeleteObjectTaggingOutput, error)
DeleteObjectTaggingWithContext(aws.Context, *s3.DeleteObjectTaggingInput, ...request.Option) (*s3.DeleteObjectTaggingOutput, error)
DeleteObjectTaggingRequest(*s3.DeleteObjectTaggingInput) (*request.Request, *s3.DeleteObjectTaggingOutput)
DeleteObjects(*s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error)
DeleteObjectsWithContext(aws.Context, *s3.DeleteObjectsInput, ...request.Option) (*s3.DeleteObjectsOutput, error)
DeleteObjectsRequest(*s3.DeleteObjectsInput) (*request.Request, *s3.DeleteObjectsOutput)
GetBucketAccelerateConfiguration(*s3.GetBucketAccelerateConfigurationInput) (*s3.GetBucketAccelerateConfigurationOutput, error)
GetBucketAccelerateConfigurationWithContext(aws.Context, *s3.GetBucketAccelerateConfigurationInput, ...request.Option) (*s3.GetBucketAccelerateConfigurationOutput, error)
GetBucketAccelerateConfigurationRequest(*s3.GetBucketAccelerateConfigurationInput) (*request.Request, *s3.GetBucketAccelerateConfigurationOutput)
GetBucketAcl(*s3.GetBucketAclInput) (*s3.GetBucketAclOutput, error)
GetBucketAclWithContext(aws.Context, *s3.GetBucketAclInput, ...request.Option) (*s3.GetBucketAclOutput, error)
GetBucketAclRequest(*s3.GetBucketAclInput) (*request.Request, *s3.GetBucketAclOutput)
GetBucketAnalyticsConfiguration(*s3.GetBucketAnalyticsConfigurationInput) (*s3.GetBucketAnalyticsConfigurationOutput, error)
GetBucketAnalyticsConfigurationWithContext(aws.Context, *s3.GetBucketAnalyticsConfigurationInput, ...request.Option) (*s3.GetBucketAnalyticsConfigurationOutput, error)
GetBucketAnalyticsConfigurationRequest(*s3.GetBucketAnalyticsConfigurationInput) (*request.Request, *s3.GetBucketAnalyticsConfigurationOutput)
GetBucketCors(*s3.GetBucketCorsInput) (*s3.GetBucketCorsOutput, error)
GetBucketCorsWithContext(aws.Context, *s3.GetBucketCorsInput, ...request.Option) (*s3.GetBucketCorsOutput, error)
GetBucketCorsRequest(*s3.GetBucketCorsInput) (*request.Request, *s3.GetBucketCorsOutput)
GetBucketInventoryConfiguration(*s3.GetBucketInventoryConfigurationInput) (*s3.GetBucketInventoryConfigurationOutput, error)
GetBucketInventoryConfigurationWithContext(aws.Context, *s3.GetBucketInventoryConfigurationInput, ...request.Option) (*s3.GetBucketInventoryConfigurationOutput, error)
GetBucketInventoryConfigurationRequest(*s3.GetBucketInventoryConfigurationInput) (*request.Request, *s3.GetBucketInventoryConfigurationOutput)
GetBucketLifecycle(*s3.GetBucketLifecycleInput) (*s3.GetBucketLifecycleOutput, error)
GetBucketLifecycleWithContext(aws.Context, *s3.GetBucketLifecycleInput, ...request.Option) (*s3.GetBucketLifecycleOutput, error)
GetBucketLifecycleRequest(*s3.GetBucketLifecycleInput) (*request.Request, *s3.GetBucketLifecycleOutput)
GetBucketLifecycleConfiguration(*s3.GetBucketLifecycleConfigurationInput) (*s3.GetBucketLifecycleConfigurationOutput, error)
GetBucketLifecycleConfigurationWithContext(aws.Context, *s3.GetBucketLifecycleConfigurationInput, ...request.Option) (*s3.GetBucketLifecycleConfigurationOutput, error)
GetBucketLifecycleConfigurationRequest(*s3.GetBucketLifecycleConfigurationInput) (*request.Request, *s3.GetBucketLifecycleConfigurationOutput)
GetBucketLocation(*s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error)
GetBucketLocationWithContext(aws.Context, *s3.GetBucketLocationInput, ...request.Option) (*s3.GetBucketLocationOutput, error)
GetBucketLocationRequest(*s3.GetBucketLocationInput) (*request.Request, *s3.GetBucketLocationOutput)
GetBucketLogging(*s3.GetBucketLoggingInput) (*s3.GetBucketLoggingOutput, error)
GetBucketLoggingWithContext(aws.Context, *s3.GetBucketLoggingInput, ...request.Option) (*s3.GetBucketLoggingOutput, error)
GetBucketLoggingRequest(*s3.GetBucketLoggingInput) (*request.Request, *s3.GetBucketLoggingOutput)
GetBucketMetricsConfiguration(*s3.GetBucketMetricsConfigurationInput) (*s3.GetBucketMetricsConfigurationOutput, error)
GetBucketMetricsConfigurationWithContext(aws.Context, *s3.GetBucketMetricsConfigurationInput, ...request.Option) (*s3.GetBucketMetricsConfigurationOutput, error)
GetBucketMetricsConfigurationRequest(*s3.GetBucketMetricsConfigurationInput) (*request.Request, *s3.GetBucketMetricsConfigurationOutput)
GetBucketNotification(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfigurationDeprecated, error)
GetBucketNotificationWithContext(aws.Context, *s3.GetBucketNotificationConfigurationRequest, ...request.Option) (*s3.NotificationConfigurationDeprecated, error)
GetBucketNotificationRequest(*s3.GetBucketNotificationConfigurationRequest) (*request.Request, *s3.NotificationConfigurationDeprecated)
GetBucketNotificationConfiguration(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfiguration, error)
GetBucketNotificationConfigurationWithContext(aws.Context, *s3.GetBucketNotificationConfigurationRequest, ...request.Option) (*s3.NotificationConfiguration, error)
GetBucketNotificationConfigurationRequest(*s3.GetBucketNotificationConfigurationRequest) (*request.Request, *s3.NotificationConfiguration)
GetBucketPolicy(*s3.GetBucketPolicyInput) (*s3.GetBucketPolicyOutput, error)
GetBucketPolicyWithContext(aws.Context, *s3.GetBucketPolicyInput, ...request.Option) (*s3.GetBucketPolicyOutput, error)
GetBucketPolicyRequest(*s3.GetBucketPolicyInput) (*request.Request, *s3.GetBucketPolicyOutput)
GetBucketReplication(*s3.GetBucketReplicationInput) (*s3.GetBucketReplicationOutput, error)
GetBucketReplicationWithContext(aws.Context, *s3.GetBucketReplicationInput, ...request.Option) (*s3.GetBucketReplicationOutput, error)
GetBucketReplicationRequest(*s3.GetBucketReplicationInput) (*request.Request, *s3.GetBucketReplicationOutput)
GetBucketRequestPayment(*s3.GetBucketRequestPaymentInput) (*s3.GetBucketRequestPaymentOutput, error)
GetBucketRequestPaymentWithContext(aws.Context, *s3.GetBucketRequestPaymentInput, ...request.Option) (*s3.GetBucketRequestPaymentOutput, error)
GetBucketRequestPaymentRequest(*s3.GetBucketRequestPaymentInput) (*request.Request, *s3.GetBucketRequestPaymentOutput)
GetBucketTagging(*s3.GetBucketTaggingInput) (*s3.GetBucketTaggingOutput, error)
GetBucketTaggingWithContext(aws.Context, *s3.GetBucketTaggingInput, ...request.Option) (*s3.GetBucketTaggingOutput, error)
GetBucketTaggingRequest(*s3.GetBucketTaggingInput) (*request.Request, *s3.GetBucketTaggingOutput)
GetBucketVersioning(*s3.GetBucketVersioningInput) (*s3.GetBucketVersioningOutput, error)
GetBucketVersioningWithContext(aws.Context, *s3.GetBucketVersioningInput, ...request.Option) (*s3.GetBucketVersioningOutput, error)
GetBucketVersioningRequest(*s3.GetBucketVersioningInput) (*request.Request, *s3.GetBucketVersioningOutput)
GetBucketWebsite(*s3.GetBucketWebsiteInput) (*s3.GetBucketWebsiteOutput, error)
GetBucketWebsiteWithContext(aws.Context, *s3.GetBucketWebsiteInput, ...request.Option) (*s3.GetBucketWebsiteOutput, error)
GetBucketWebsiteRequest(*s3.GetBucketWebsiteInput) (*request.Request, *s3.GetBucketWebsiteOutput)
GetObject(*s3.GetObjectInput) (*s3.GetObjectOutput, error)
GetObjectWithContext(aws.Context, *s3.GetObjectInput, ...request.Option) (*s3.GetObjectOutput, error)
GetObjectRequest(*s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput)
GetObjectAcl(*s3.GetObjectAclInput) (*s3.GetObjectAclOutput, error)
GetObjectAclWithContext(aws.Context, *s3.GetObjectAclInput, ...request.Option) (*s3.GetObjectAclOutput, error)
GetObjectAclRequest(*s3.GetObjectAclInput) (*request.Request, *s3.GetObjectAclOutput)
GetObjectTagging(*s3.GetObjectTaggingInput) (*s3.GetObjectTaggingOutput, error)
GetObjectTaggingWithContext(aws.Context, *s3.GetObjectTaggingInput, ...request.Option) (*s3.GetObjectTaggingOutput, error)
GetObjectTaggingRequest(*s3.GetObjectTaggingInput) (*request.Request, *s3.GetObjectTaggingOutput)
GetObjectTorrent(*s3.GetObjectTorrentInput) (*s3.GetObjectTorrentOutput, error)
GetObjectTorrentWithContext(aws.Context, *s3.GetObjectTorrentInput, ...request.Option) (*s3.GetObjectTorrentOutput, error)
GetObjectTorrentRequest(*s3.GetObjectTorrentInput) (*request.Request, *s3.GetObjectTorrentOutput)
HeadBucket(*s3.HeadBucketInput) (*s3.HeadBucketOutput, error)
HeadBucketWithContext(aws.Context, *s3.HeadBucketInput, ...request.Option) (*s3.HeadBucketOutput, error)
HeadBucketRequest(*s3.HeadBucketInput) (*request.Request, *s3.HeadBucketOutput)
HeadObject(*s3.HeadObjectInput) (*s3.HeadObjectOutput, error)
HeadObjectWithContext(aws.Context, *s3.HeadObjectInput, ...request.Option) (*s3.HeadObjectOutput, error)
HeadObjectRequest(*s3.HeadObjectInput) (*request.Request, *s3.HeadObjectOutput)
ListBucketAnalyticsConfigurations(*s3.ListBucketAnalyticsConfigurationsInput) (*s3.ListBucketAnalyticsConfigurationsOutput, error)
ListBucketAnalyticsConfigurationsWithContext(aws.Context, *s3.ListBucketAnalyticsConfigurationsInput, ...request.Option) (*s3.ListBucketAnalyticsConfigurationsOutput, error)
ListBucketAnalyticsConfigurationsRequest(*s3.ListBucketAnalyticsConfigurationsInput) (*request.Request, *s3.ListBucketAnalyticsConfigurationsOutput)
ListBucketInventoryConfigurations(*s3.ListBucketInventoryConfigurationsInput) (*s3.ListBucketInventoryConfigurationsOutput, error)
ListBucketInventoryConfigurationsWithContext(aws.Context, *s3.ListBucketInventoryConfigurationsInput, ...request.Option) (*s3.ListBucketInventoryConfigurationsOutput, error)
ListBucketInventoryConfigurationsRequest(*s3.ListBucketInventoryConfigurationsInput) (*request.Request, *s3.ListBucketInventoryConfigurationsOutput)
ListBucketMetricsConfigurations(*s3.ListBucketMetricsConfigurationsInput) (*s3.ListBucketMetricsConfigurationsOutput, error)
ListBucketMetricsConfigurationsWithContext(aws.Context, *s3.ListBucketMetricsConfigurationsInput, ...request.Option) (*s3.ListBucketMetricsConfigurationsOutput, error)
ListBucketMetricsConfigurationsRequest(*s3.ListBucketMetricsConfigurationsInput) (*request.Request, *s3.ListBucketMetricsConfigurationsOutput)
ListBuckets(*s3.ListBucketsInput) (*s3.ListBucketsOutput, error)
ListBucketsWithContext(aws.Context, *s3.ListBucketsInput, ...request.Option) (*s3.ListBucketsOutput, error)
ListBucketsRequest(*s3.ListBucketsInput) (*request.Request, *s3.ListBucketsOutput)
ListMultipartUploads(*s3.ListMultipartUploadsInput) (*s3.ListMultipartUploadsOutput, error)
ListMultipartUploadsWithContext(aws.Context, *s3.ListMultipartUploadsInput, ...request.Option) (*s3.ListMultipartUploadsOutput, error)
ListMultipartUploadsRequest(*s3.ListMultipartUploadsInput) (*request.Request, *s3.ListMultipartUploadsOutput)
ListMultipartUploadsPages(*s3.ListMultipartUploadsInput, func(*s3.ListMultipartUploadsOutput, bool) bool) error
ListMultipartUploadsPagesWithContext(aws.Context, *s3.ListMultipartUploadsInput, func(*s3.ListMultipartUploadsOutput, bool) bool, ...request.Option) error
ListObjectVersions(*s3.ListObjectVersionsInput) (*s3.ListObjectVersionsOutput, error)
ListObjectVersionsWithContext(aws.Context, *s3.ListObjectVersionsInput, ...request.Option) (*s3.ListObjectVersionsOutput, error)
ListObjectVersionsRequest(*s3.ListObjectVersionsInput) (*request.Request, *s3.ListObjectVersionsOutput)
ListObjectVersionsPages(*s3.ListObjectVersionsInput, func(*s3.ListObjectVersionsOutput, bool) bool) error
ListObjectVersionsPagesWithContext(aws.Context, *s3.ListObjectVersionsInput, func(*s3.ListObjectVersionsOutput, bool) bool, ...request.Option) error
ListObjects(*s3.ListObjectsInput) (*s3.ListObjectsOutput, error)
ListObjectsWithContext(aws.Context, *s3.ListObjectsInput, ...request.Option) (*s3.ListObjectsOutput, error)
ListObjectsRequest(*s3.ListObjectsInput) (*request.Request, *s3.ListObjectsOutput)
ListObjectsPages(*s3.ListObjectsInput, func(*s3.ListObjectsOutput, bool) bool) error
ListObjectsPagesWithContext(aws.Context, *s3.ListObjectsInput, func(*s3.ListObjectsOutput, bool) bool, ...request.Option) error
ListObjectsV2(*s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error)
ListObjectsV2WithContext(aws.Context, *s3.ListObjectsV2Input, ...request.Option) (*s3.ListObjectsV2Output, error)
ListObjectsV2Request(*s3.ListObjectsV2Input) (*request.Request, *s3.ListObjectsV2Output)
ListObjectsV2Pages(*s3.ListObjectsV2Input, func(*s3.ListObjectsV2Output, bool) bool) error
ListObjectsV2PagesWithContext(aws.Context, *s3.ListObjectsV2Input, func(*s3.ListObjectsV2Output, bool) bool, ...request.Option) error
ListParts(*s3.ListPartsInput) (*s3.ListPartsOutput, error)
ListPartsWithContext(aws.Context, *s3.ListPartsInput, ...request.Option) (*s3.ListPartsOutput, error)
ListPartsRequest(*s3.ListPartsInput) (*request.Request, *s3.ListPartsOutput)
ListPartsPages(*s3.ListPartsInput, func(*s3.ListPartsOutput, bool) bool) error
ListPartsPagesWithContext(aws.Context, *s3.ListPartsInput, func(*s3.ListPartsOutput, bool) bool, ...request.Option) error
PutBucketAccelerateConfiguration(*s3.PutBucketAccelerateConfigurationInput) (*s3.PutBucketAccelerateConfigurationOutput, error)
PutBucketAccelerateConfigurationWithContext(aws.Context, *s3.PutBucketAccelerateConfigurationInput, ...request.Option) (*s3.PutBucketAccelerateConfigurationOutput, error)
PutBucketAccelerateConfigurationRequest(*s3.PutBucketAccelerateConfigurationInput) (*request.Request, *s3.PutBucketAccelerateConfigurationOutput)
PutBucketAcl(*s3.PutBucketAclInput) (*s3.PutBucketAclOutput, error)
PutBucketAclWithContext(aws.Context, *s3.PutBucketAclInput, ...request.Option) (*s3.PutBucketAclOutput, error)
PutBucketAclRequest(*s3.PutBucketAclInput) (*request.Request, *s3.PutBucketAclOutput)
PutBucketAnalyticsConfiguration(*s3.PutBucketAnalyticsConfigurationInput) (*s3.PutBucketAnalyticsConfigurationOutput, error)
PutBucketAnalyticsConfigurationWithContext(aws.Context, *s3.PutBucketAnalyticsConfigurationInput, ...request.Option) (*s3.PutBucketAnalyticsConfigurationOutput, error)
PutBucketAnalyticsConfigurationRequest(*s3.PutBucketAnalyticsConfigurationInput) (*request.Request, *s3.PutBucketAnalyticsConfigurationOutput)
PutBucketCors(*s3.PutBucketCorsInput) (*s3.PutBucketCorsOutput, error)
PutBucketCorsWithContext(aws.Context, *s3.PutBucketCorsInput, ...request.Option) (*s3.PutBucketCorsOutput, error)
PutBucketCorsRequest(*s3.PutBucketCorsInput) (*request.Request, *s3.PutBucketCorsOutput)
PutBucketInventoryConfiguration(*s3.PutBucketInventoryConfigurationInput) (*s3.PutBucketInventoryConfigurationOutput, error)
PutBucketInventoryConfigurationWithContext(aws.Context, *s3.PutBucketInventoryConfigurationInput, ...request.Option) (*s3.PutBucketInventoryConfigurationOutput, error)
PutBucketInventoryConfigurationRequest(*s3.PutBucketInventoryConfigurationInput) (*request.Request, *s3.PutBucketInventoryConfigurationOutput)
PutBucketLifecycle(*s3.PutBucketLifecycleInput) (*s3.PutBucketLifecycleOutput, error)
PutBucketLifecycleWithContext(aws.Context, *s3.PutBucketLifecycleInput, ...request.Option) (*s3.PutBucketLifecycleOutput, error)
PutBucketLifecycleRequest(*s3.PutBucketLifecycleInput) (*request.Request, *s3.PutBucketLifecycleOutput)
PutBucketLifecycleConfiguration(*s3.PutBucketLifecycleConfigurationInput) (*s3.PutBucketLifecycleConfigurationOutput, error)
PutBucketLifecycleConfigurationWithContext(aws.Context, *s3.PutBucketLifecycleConfigurationInput, ...request.Option) (*s3.PutBucketLifecycleConfigurationOutput, error)
PutBucketLifecycleConfigurationRequest(*s3.PutBucketLifecycleConfigurationInput) (*request.Request, *s3.PutBucketLifecycleConfigurationOutput)
PutBucketLogging(*s3.PutBucketLoggingInput) (*s3.PutBucketLoggingOutput, error)
PutBucketLoggingWithContext(aws.Context, *s3.PutBucketLoggingInput, ...request.Option) (*s3.PutBucketLoggingOutput, error)
PutBucketLoggingRequest(*s3.PutBucketLoggingInput) (*request.Request, *s3.PutBucketLoggingOutput)
PutBucketMetricsConfiguration(*s3.PutBucketMetricsConfigurationInput) (*s3.PutBucketMetricsConfigurationOutput, error)
PutBucketMetricsConfigurationWithContext(aws.Context, *s3.PutBucketMetricsConfigurationInput, ...request.Option) (*s3.PutBucketMetricsConfigurationOutput, error)
PutBucketMetricsConfigurationRequest(*s3.PutBucketMetricsConfigurationInput) (*request.Request, *s3.PutBucketMetricsConfigurationOutput)
PutBucketNotification(*s3.PutBucketNotificationInput) (*s3.PutBucketNotificationOutput, error)
PutBucketNotificationWithContext(aws.Context, *s3.PutBucketNotificationInput, ...request.Option) (*s3.PutBucketNotificationOutput, error)
PutBucketNotificationRequest(*s3.PutBucketNotificationInput) (*request.Request, *s3.PutBucketNotificationOutput)
PutBucketNotificationConfiguration(*s3.PutBucketNotificationConfigurationInput) (*s3.PutBucketNotificationConfigurationOutput, error)
PutBucketNotificationConfigurationWithContext(aws.Context, *s3.PutBucketNotificationConfigurationInput, ...request.Option) (*s3.PutBucketNotificationConfigurationOutput, error)
PutBucketNotificationConfigurationRequest(*s3.PutBucketNotificationConfigurationInput) (*request.Request, *s3.PutBucketNotificationConfigurationOutput)
PutBucketPolicy(*s3.PutBucketPolicyInput) (*s3.PutBucketPolicyOutput, error)
PutBucketPolicyWithContext(aws.Context, *s3.PutBucketPolicyInput, ...request.Option) (*s3.PutBucketPolicyOutput, error)
PutBucketPolicyRequest(*s3.PutBucketPolicyInput) (*request.Request, *s3.PutBucketPolicyOutput)
PutBucketReplication(*s3.PutBucketReplicationInput) (*s3.PutBucketReplicationOutput, error)
PutBucketReplicationWithContext(aws.Context, *s3.PutBucketReplicationInput, ...request.Option) (*s3.PutBucketReplicationOutput, error)
PutBucketReplicationRequest(*s3.PutBucketReplicationInput) (*request.Request, *s3.PutBucketReplicationOutput)
PutBucketRequestPayment(*s3.PutBucketRequestPaymentInput) (*s3.PutBucketRequestPaymentOutput, error)
PutBucketRequestPaymentWithContext(aws.Context, *s3.PutBucketRequestPaymentInput, ...request.Option) (*s3.PutBucketRequestPaymentOutput, error)
PutBucketRequestPaymentRequest(*s3.PutBucketRequestPaymentInput) (*request.Request, *s3.PutBucketRequestPaymentOutput)
PutBucketTagging(*s3.PutBucketTaggingInput) (*s3.PutBucketTaggingOutput, error)
PutBucketTaggingWithContext(aws.Context, *s3.PutBucketTaggingInput, ...request.Option) (*s3.PutBucketTaggingOutput, error)
PutBucketTaggingRequest(*s3.PutBucketTaggingInput) (*request.Request, *s3.PutBucketTaggingOutput)
PutBucketVersioning(*s3.PutBucketVersioningInput) (*s3.PutBucketVersioningOutput, error)
PutBucketVersioningWithContext(aws.Context, *s3.PutBucketVersioningInput, ...request.Option) (*s3.PutBucketVersioningOutput, error)
PutBucketVersioningRequest(*s3.PutBucketVersioningInput) (*request.Request, *s3.PutBucketVersioningOutput)
PutBucketWebsite(*s3.PutBucketWebsiteInput) (*s3.PutBucketWebsiteOutput, error)
PutBucketWebsiteWithContext(aws.Context, *s3.PutBucketWebsiteInput, ...request.Option) (*s3.PutBucketWebsiteOutput, error)
PutBucketWebsiteRequest(*s3.PutBucketWebsiteInput) (*request.Request, *s3.PutBucketWebsiteOutput)
PutObject(*s3.PutObjectInput) (*s3.PutObjectOutput, error)
PutObjectWithContext(aws.Context, *s3.PutObjectInput, ...request.Option) (*s3.PutObjectOutput, error)
PutObjectRequest(*s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput)
PutObjectAcl(*s3.PutObjectAclInput) (*s3.PutObjectAclOutput, error)
PutObjectAclWithContext(aws.Context, *s3.PutObjectAclInput, ...request.Option) (*s3.PutObjectAclOutput, error)
PutObjectAclRequest(*s3.PutObjectAclInput) (*request.Request, *s3.PutObjectAclOutput)
PutObjectTagging(*s3.PutObjectTaggingInput) (*s3.PutObjectTaggingOutput, error)
PutObjectTaggingWithContext(aws.Context, *s3.PutObjectTaggingInput, ...request.Option) (*s3.PutObjectTaggingOutput, error)
PutObjectTaggingRequest(*s3.PutObjectTaggingInput) (*request.Request, *s3.PutObjectTaggingOutput)
RestoreObject(*s3.RestoreObjectInput) (*s3.RestoreObjectOutput, error)
RestoreObjectWithContext(aws.Context, *s3.RestoreObjectInput, ...request.Option) (*s3.RestoreObjectOutput, error)
RestoreObjectRequest(*s3.RestoreObjectInput) (*request.Request, *s3.RestoreObjectOutput)
UploadPart(*s3.UploadPartInput) (*s3.UploadPartOutput, error)
UploadPartWithContext(aws.Context, *s3.UploadPartInput, ...request.Option) (*s3.UploadPartOutput, error)
UploadPartRequest(*s3.UploadPartInput) (*request.Request, *s3.UploadPartOutput)
UploadPartCopy(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error)
UploadPartCopyWithContext(aws.Context, *s3.UploadPartCopyInput, ...request.Option) (*s3.UploadPartCopyOutput, error)
UploadPartCopyRequest(*s3.UploadPartCopyInput) (*request.Request, *s3.UploadPartCopyOutput)
WaitUntilBucketExists(*s3.HeadBucketInput) error
WaitUntilBucketExistsWithContext(aws.Context, *s3.HeadBucketInput, ...request.WaiterOption) error
WaitUntilBucketNotExists(*s3.HeadBucketInput) error
WaitUntilBucketNotExistsWithContext(aws.Context, *s3.HeadBucketInput, ...request.WaiterOption) error
WaitUntilObjectExists(*s3.HeadObjectInput) error
WaitUntilObjectExistsWithContext(aws.Context, *s3.HeadObjectInput, ...request.WaiterOption) error
WaitUntilObjectNotExists(*s3.HeadObjectInput) error
WaitUntilObjectNotExistsWithContext(aws.Context, *s3.HeadObjectInput, ...request.WaiterOption) error
}
var _ S3API = (*s3.S3)(nil)
+3
View File
@@ -0,0 +1,3 @@
// Package s3manager provides utilities to upload and download objects from
// S3 concurrently. Helpful for when working with large objects.
package s3manager
+429
View File
@@ -0,0 +1,429 @@
package s3manager
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// DefaultDownloadPartSize is the default range of bytes to get at a time when
// using Download().
const DefaultDownloadPartSize = 1024 * 1024 * 5
// DefaultDownloadConcurrency is the default number of goroutines to spin up
// when using Download().
const DefaultDownloadConcurrency = 5
// The Downloader structure that calls Download(). It is safe to call Download()
// on this structure for multiple objects and across concurrent goroutines.
// Mutating the Downloader's properties is not safe to be done concurrently.
type Downloader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultDownloadConcurrency value will be used.
Concurrency int
// An S3 client to use when performing downloads.
S3 s3iface.S3API
// List of request options that will be passed down to individual API
// operation requests made by the downloader.
RequestOptions []request.Option
}
// WithDownloaderRequestOptions appends to the Downloader's API request options.
func WithDownloaderRequestOptions(opts ...request.Option) func(*Downloader) {
return func(d *Downloader) {
d.RequestOptions = append(d.RequestOptions, opts...)
}
}
// NewDownloader creates a new Downloader instance to downloads objects from
// S3 in concurrent chunks. Pass in additional functional options to customize
// the downloader behavior. Requires a client.ConfigProvider in order to create
// a S3 service client. The session.Session satisfies the client.ConfigProvider
// interface.
//
// Example:
// // The session the S3 Downloader will use
// sess := session.Must(session.NewSession())
//
// // Create a downloader with the session and default options
// downloader := s3manager.NewDownloader(sess)
//
// // Create a downloader with the session and custom options
// downloader := s3manager.NewDownloader(sess, func(d *s3manager.Downloader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
d := &Downloader{
S3: s3.New(c),
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
for _, option := range options {
option(d)
}
return d
}
// NewDownloaderWithClient creates a new Downloader instance to downloads
// objects from S3 in concurrent chunks. Pass in additional functional
// options to customize the downloader behavior. Requires a S3 service client
// to make S3 API calls.
//
// Example:
// // The session the S3 Downloader will use
// sess := session.Must(session.NewSession())
//
// // The S3 client the S3 Downloader will use
// s3Svc := s3.new(sess)
//
// // Create a downloader with the s3 client and default options
// downloader := s3manager.NewDownloaderWithClient(s3Svc)
//
// // Create a downloader with the s3 client and custom options
// downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Downloader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
d := &Downloader{
S3: svc,
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
for _, option := range options {
option(d)
}
return d
}
type maxRetrier interface {
MaxRetries() int
}
// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests.
//
// Additional functional options can be provided to configure the individual
// download. These options are copies of the Downloader instance Download is called from.
// Modifying the options will not impact the original Downloader instance.
//
// It is safe to call this method concurrently across goroutines.
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
return d.DownloadWithContext(aws.BackgroundContext(), w, input, options...)
}
// DownloadWithContext downloads an object in S3 and writes the payload into w
// using concurrent GET requests.
//
// DownloadWithContext is the same as Download with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the Context to add deadlining, timeouts, ect. The
// DownloadWithContext may create sub-contexts for individual underlying
// requests.
//
// Additional functional options can be provided to configure the individual
// download. These options are copies of the Downloader instance Download is
// called from. Modifying the options will not impact the original Downloader
// instance. Use the WithDownloaderRequestOptions helper function to pass in request
// options that will be applied to all API operations made with this downloader.
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
//
// It is safe to call this method concurrently across goroutines.
func (d Downloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
impl := downloader{w: w, in: input, cfg: d, ctx: ctx}
for _, option := range options {
option(&impl.cfg)
}
impl.cfg.RequestOptions = append(impl.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager"))
if s, ok := d.S3.(maxRetrier); ok {
impl.partBodyMaxRetries = s.MaxRetries()
}
impl.totalBytes = -1
if impl.cfg.Concurrency == 0 {
impl.cfg.Concurrency = DefaultDownloadConcurrency
}
if impl.cfg.PartSize == 0 {
impl.cfg.PartSize = DefaultDownloadPartSize
}
return impl.download()
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
ctx aws.Context
cfg Downloader
in *s3.GetObjectInput
w io.WriterAt
wg sync.WaitGroup
m sync.Mutex
pos int64
totalBytes int64
written int64
err error
partBodyMaxRetries int
}
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
// Spin off first worker to check additional header information
d.getChunk()
if total := d.getTotalBytes(); total >= 0 {
// Spin up workers
ch := make(chan dlchunk, d.cfg.Concurrency)
for i := 0; i < d.cfg.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Assign work
for d.getErr() == nil {
if d.pos >= total {
break // We're finished queuing chunks
}
// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
d.pos += d.cfg.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
} else {
// Checking if we read anything new
for d.err == nil {
d.getChunk()
}
// We expect a 416 error letting us know we are done downloading the
// total bytes. Since we do not know the content's length, this will
// keep grabbing chunks of data until the range of bytes specified in
// the request is out of range of the content. Once, this happens, a
// 416 should occur.
e, ok := d.err.(awserr.RequestFailure)
if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
d.err = nil
}
}
// Return error
return d.written, d.err
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing a GetObject request on the data with a given byte range.
//
// If this is the first worker, this operation also resolves the total number
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok || d.getErr() != nil {
break
}
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
break
}
}
}
// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk() {
if d.getErr() != nil {
return
}
chunk := dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
d.pos += d.cfg.PartSize
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
}
// downloadChunk downloads the chunk froom s3
func (d *downloader) downloadChunk(chunk dlchunk) error {
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
// Get the next byte range of data
rng := fmt.Sprintf("bytes=%d-%d", chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
var n int64
var err error
for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
var resp *s3.GetObjectOutput
resp, err = d.cfg.S3.GetObjectWithContext(d.ctx, in, d.cfg.RequestOptions...)
if err != nil {
return err
}
d.setTotalBytes(resp) // Set total if not yet set.
n, err = io.Copy(&chunk, resp.Body)
resp.Body.Close()
if err == nil {
break
}
chunk.cur = 0
logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries,
fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d",
aws.StringValue(in.Key), err, retry))
}
d.incrWritten(n)
return err
}
func logMessage(svc s3iface.S3API, level aws.LogLevelType, msg string) {
s, ok := svc.(*s3.S3)
if !ok {
return
}
if s.Config.Logger == nil {
return
}
if s.Config.LogLevel.Matches(level) {
s.Config.Logger.Log(msg)
}
}
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
defer d.m.Unlock()
return d.totalBytes
}
// setTotalBytes is a thread-safe setter for setting the total byte status.
// Will extract the object's total bytes from the Content-Range if the file
// will be chunked, or Content-Length. Content-Length is used when the response
// does not include a Content-Range. Meaning the object was not chunked. This
// occurs when the full file fits within the PartSize directive.
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
d.m.Lock()
defer d.m.Unlock()
if d.totalBytes >= 0 {
return
}
if resp.ContentRange == nil {
// ContentRange is nil when the full file contents is provied, and
// is not chunked. Use ContentLength instead.
if resp.ContentLength != nil {
d.totalBytes = *resp.ContentLength
return
}
} else {
parts := strings.Split(*resp.ContentRange, "/")
total := int64(-1)
var err error
// Checking for whether or not a numbered total exists
// If one does not exist, we will assume the total to be -1, undefined,
// and sequentially download each chunk until hitting a 416 error
totalStr := parts[len(parts)-1]
if totalStr != "*" {
total, err = strconv.ParseInt(totalStr, 10, 64)
if err != nil {
d.err = err
return
}
}
d.totalBytes = total
}
}
func (d *downloader) incrWritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// getErr is a thread-safe getter for the error object
func (d *downloader) getErr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// setErr is a thread-safe setter for the error object
func (d *downloader) setErr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// dlchunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
w io.WriterAt
start int64
size int64
cur int64
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size {
return 0, io.EOF
}
n, err = c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)
return
}
+457
View File
@@ -0,0 +1,457 @@
package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++
if fin > int64(len(data)) {
fin = int64(len(data))
}
bodyBytes := data[start:fin]
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
start, fin-1, len(data)))
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
})
return svc, &names, &ranges
}
func dlLoggingSvcNoChunk(data []byte) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(data[:])),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(data)))
})
return svc, &names
}
func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
var index int = 0
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(data[:])),
Header: http.Header{},
}
index++
})
return svc, &names
}
func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
var index int = 0
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++
if fin >= int64(len(data)) {
fin = int64(len(data))
}
// Setting start and finish to 0 because this state of 1 is suppose to
// be an error state of 416
if index == len(states)-1 {
start = 0
fin = 0
}
bodyBytes := data[start:fin]
r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/*",
start, fin-1))
index++
})
return svc, &names
}
func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
var index int = 0
svc := s3.New(unit.Session, &aws.Config{
MaxRetries: aws.Int(len(cases) - 1),
})
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
c := cases[index]
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(&c),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range",
fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len))
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", c.Len))
index++
})
return svc, &names
}
func TestDownloadOrder(t *testing.T) {
s, names, ranges := dlLoggingSvc(buf12MB)
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf12MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}, *ranges)
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}
func TestDownloadZero(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{})
d := s3manager.NewDownloaderWithClient(s)
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879"}, *ranges)
}
func TestDownloadSetPartSize(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{1, 2, 3})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}, *ranges)
assert.Equal(t, []byte{1, 2, 3}, w.Bytes())
}
func TestDownloadError(t *testing.T) {
s, names, _ := dlLoggingSvc([]byte{1, 2, 3})
num := 0
s.Handlers.Send.PushBack(func(r *request.Request) {
num++
if num > 1 {
r.HTTPResponse.StatusCode = 400
r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
}
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.NotNil(t, err)
assert.Equal(t, int64(1), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
assert.Equal(t, []byte{1}, w.Bytes())
}
func TestDownloadNonChunk(t *testing.T) {
s, names := dlLoggingSvcNoChunk(buf2MB)
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject"}, *names)
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}
func TestDownloadNoContentRangeLength(t *testing.T) {
s, names := dlLoggingSvcNoContentRangeLength(buf2MB, []int{200, 416})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}
func TestDownloadContentRangeTotalAny(t *testing.T) {
s, names := dlLoggingSvcContentRangeTotalAny(buf2MB, []int{200, 416})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}
func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
{Buf: []byte("123"), Len: 3, Err: io.EOF},
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
assert.Equal(t, []byte("123"), w.Bytes())
}
func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("abc"), Len: 3, Err: io.EOF},
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []byte("abc"), w.Bytes())
}
func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Error(t, err)
assert.Equal(t, int64(2), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []byte("ab"), w.Bytes())
}
func TestDownloadWithContextCanceled(t *testing.T) {
d := s3manager.NewDownloader(unit.Session)
params := s3.GetObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
}
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
w := &aws.WriteAtBuffer{}
_, err := d.DownloadWithContext(ctx, w, &params)
if err == nil {
t.Fatalf("expected error, did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expected error code %q, got %q", e, a)
}
if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}
type testErrReader struct {
Buf []byte
Err error
Len int64
off int
}
func (r *testErrReader) Read(p []byte) (int, error) {
to := len(r.Buf) - r.off
n := copy(p, r.Buf[r.off:to])
r.off += n
if n < len(p) {
return n, r.Err
}
return n, nil
}
@@ -0,0 +1,23 @@
// Package s3manageriface provides an interface for the s3manager package
package s3manageriface
import (
"io"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
// DownloaderAPI is the interface type for s3manager.Downloader.
type DownloaderAPI interface {
Download(io.WriterAt, *s3.GetObjectInput, ...func(*s3manager.Downloader)) (int64, error)
}
var _ DownloaderAPI = (*s3manager.Downloader)(nil)
// UploaderAPI is the interface type for s3manager.Uploader.
type UploaderAPI interface {
Upload(*s3manager.UploadInput, ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
var _ UploaderAPI = (*s3manager.Uploader)(nil)
+4
View File
@@ -0,0 +1,4 @@
package s3manager_test
var buf12MB = make([]byte, 1024*1024*12)
var buf2MB = make([]byte, 1024*1024*2)
+729
View File
@@ -0,0 +1,729 @@
package s3manager
import (
"bytes"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// MaxUploadParts is the maximum allowed number of parts in a multi-part upload
// on Amazon S3.
const MaxUploadParts = 10000
// MinUploadPartSize is the minimum allowed part size when uploading a part to
// Amazon S3.
const MinUploadPartSize int64 = 1024 * 1024 * 5
// DefaultUploadPartSize is the default part size to buffer chunks of a
// payload into.
const DefaultUploadPartSize = MinUploadPartSize
// DefaultUploadConcurrency is the default number of goroutines to spin up when
// using Upload().
const DefaultUploadConcurrency = 5
// A MultiUploadFailure wraps a failed S3 multipart upload. An error returned
// will satisfy this interface when a multi part upload failed to upload all
// chucks to S3. In the case of a failure the UploadID is needed to operate on
// the chunks, if any, which were uploaded.
//
// Example:
//
// u := s3manager.NewUploader(opts)
// output, err := u.upload(input)
// if err != nil {
// if multierr, ok := err.(s3manager.MultiUploadFailure); ok {
// // Process error and its associated uploadID
// fmt.Println("Error:", multierr.Code(), multierr.Message(), multierr.UploadID())
// } else {
// // Process error generically
// fmt.Println("Error:", err.Error())
// }
// }
//
type MultiUploadFailure interface {
awserr.Error
// Returns the upload id for the S3 multipart upload that failed.
UploadID() string
}
// So that the Error interface type can be included as an anonymous field
// in the multiUploadError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A multiUploadError wraps the upload ID of a failed s3 multipart upload.
// Composed of BaseError for code, message, and original error
//
// Should be used for an error that occurred failing a S3 multipart upload,
// and a upload ID is available. If an uploadID is not available a more relevant
type multiUploadError struct {
awsError
// ID for multipart upload which failed.
uploadID string
}
// Error returns the string representation of the error.
//
// See apierr.BaseError ErrorWithExtra for output format
//
// Satisfies the error interface.
func (m multiUploadError) Error() string {
extra := fmt.Sprintf("upload id: %s", m.uploadID)
return awserr.SprintError(m.Code(), m.Message(), extra, m.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (m multiUploadError) String() string {
return m.Error()
}
// UploadID returns the id of the S3 upload which failed.
func (m multiUploadError) UploadID() string {
return m.uploadID
}
// UploadInput contains all input for upload requests to Amazon S3.
type UploadInput struct {
// The canned ACL to apply to the object.
ACL *string `location:"header" locationName:"x-amz-acl" type:"string"`
Bucket *string `location:"uri" locationName:"Bucket" type:"string" required:"true"`
// Specifies caching behavior along the request/reply chain.
CacheControl *string `location:"header" locationName:"Cache-Control" type:"string"`
// Specifies presentational information for the object.
ContentDisposition *string `location:"header" locationName:"Content-Disposition" type:"string"`
// Specifies what content encodings have been applied to the object and thus
// what decoding mechanisms must be applied to obtain the media-type referenced
// by the Content-Type header field.
ContentEncoding *string `location:"header" locationName:"Content-Encoding" type:"string"`
// The language the content is in.
ContentLanguage *string `location:"header" locationName:"Content-Language" type:"string"`
// A standard MIME type describing the format of the object data.
ContentType *string `location:"header" locationName:"Content-Type" type:"string"`
// The date and time at which the object is no longer cacheable.
Expires *time.Time `location:"header" locationName:"Expires" type:"timestamp" timestampFormat:"rfc822"`
// Gives the grantee READ, READ_ACP, and WRITE_ACP permissions on the object.
GrantFullControl *string `location:"header" locationName:"x-amz-grant-full-control" type:"string"`
// Allows grantee to read the object data and its metadata.
GrantRead *string `location:"header" locationName:"x-amz-grant-read" type:"string"`
// Allows grantee to read the object ACL.
GrantReadACP *string `location:"header" locationName:"x-amz-grant-read-acp" type:"string"`
// Allows grantee to write the ACL for the applicable object.
GrantWriteACP *string `location:"header" locationName:"x-amz-grant-write-acp" type:"string"`
Key *string `location:"uri" locationName:"Key" type:"string" required:"true"`
// A map of metadata to store with the object in S3.
Metadata map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
// Confirms that the requester knows that she or he will be charged for the
// request. Bucket owners need not specify this parameter in their requests.
// Documentation on downloading objects from requester pays buckets can be found
// at http://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html
RequestPayer *string `location:"header" locationName:"x-amz-request-payer" type:"string"`
// Specifies the algorithm to use to when encrypting the object (e.g., AES256,
// aws:kms).
SSECustomerAlgorithm *string `location:"header" locationName:"x-amz-server-side-encryption-customer-algorithm" type:"string"`
// Specifies the customer-provided encryption key for Amazon S3 to use in encrypting
// data. This value is used to store the object and then it is discarded; Amazon
// does not store the encryption key. The key must be appropriate for use with
// the algorithm specified in the x-amz-server-side-encryption-customer-algorithm
// header.
SSECustomerKey *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key" type:"string"`
// Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321.
// Amazon S3 uses this header for a message integrity check to ensure the encryption
// key was transmitted without error.
SSECustomerKeyMD5 *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key-MD5" type:"string"`
// Specifies the AWS KMS key ID to use for object encryption. All GET and PUT
// requests for an object protected by AWS KMS will fail if not made via SSL
// or using SigV4. Documentation on configuring any of the officially supported
// AWS SDKs and CLI can be found at http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingAWSSDK.html#specify-signature-version
SSEKMSKeyId *string `location:"header" locationName:"x-amz-server-side-encryption-aws-kms-key-id" type:"string"`
// The Server-side encryption algorithm used when storing this object in S3
// (e.g., AES256, aws:kms).
ServerSideEncryption *string `location:"header" locationName:"x-amz-server-side-encryption" type:"string"`
// The type of storage to use for the object. Defaults to 'STANDARD'.
StorageClass *string `location:"header" locationName:"x-amz-storage-class" type:"string"`
// The tag-set for the object. The tag-set must be encoded as URL Query parameters
Tagging *string `location:"header" locationName:"x-amz-tagging" type:"string"`
// If the bucket is configured as a website, redirects requests for this object
// to another object in the same bucket or to an external URL. Amazon S3 stores
// the value of this header in the object metadata.
WebsiteRedirectLocation *string `location:"header" locationName:"x-amz-website-redirect-location" type:"string"`
// The readable body payload to send to S3.
Body io.Reader
}
// UploadOutput represents a response from the Upload() call.
type UploadOutput struct {
// The URL where the object was uploaded to.
Location string
// The version of the object that was uploaded. Will only be populated if
// the S3 Bucket is versioned. If the bucket is not versioned this field
// will not be set.
VersionID *string
// The ID for a multipart upload to S3. In the case of an error the error
// can be cast to the MultiUploadFailure interface to extract the upload ID.
UploadID string
}
// WithUploaderRequestOptions appends to the Uploader's API request options.
func WithUploaderRequestOptions(opts ...request.Option) func(*Uploader) {
return func(u *Uploader) {
u.RequestOptions = append(u.RequestOptions, opts...)
}
}
// The Uploader structure that calls Upload(). It is safe to call Upload()
// on this structure for multiple objects and across concurrent goroutines.
// Mutating the Uploader's properties is not safe to be done concurrently.
type Uploader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultUploadConcurrency value will be used.
Concurrency int
// Setting this value to true will cause the SDK to avoid calling
// AbortMultipartUpload on a failure, leaving all successfully uploaded
// parts on S3 for manual recovery.
//
// Note that storing parts of an incomplete multipart upload counts towards
// space usage on S3 and will add additional costs if not cleaned up.
LeavePartsOnError bool
// MaxUploadParts is the max number of parts which will be uploaded to S3.
// Will be used to calculate the partsize of the object to be uploaded.
// E.g: 5GB file, with MaxUploadParts set to 100, will upload the file
// as 100, 50MB parts.
// With a limited of s3.MaxUploadParts (10,000 parts).
MaxUploadParts int
// The client to use when uploading to S3.
S3 s3iface.S3API
// List of request options that will be passed down to individual API
// operation requests made by the uploader.
RequestOptions []request.Option
}
// NewUploader creates a new Uploader instance to upload objects to S3. Pass In
// additional functional options to customize the uploader's behavior. Requires a
// client.ConfigProvider in order to create a S3 service client. The session.Session
// satisfies the client.ConfigProvider interface.
//
// Example:
// // The session the S3 Uploader will use
// sess := session.Must(session.NewSession())
//
// // Create an uploader with the session and default options
// uploader := s3manager.NewUploader(sess)
//
// // Create an uploader with the session and custom options
// uploader := s3manager.NewUploader(session, func(u *s3manager.Uploader) {
// u.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewUploader(c client.ConfigProvider, options ...func(*Uploader)) *Uploader {
u := &Uploader{
S3: s3.New(c),
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
MaxUploadParts: MaxUploadParts,
}
for _, option := range options {
option(u)
}
return u
}
// NewUploaderWithClient creates a new Uploader instance to upload objects to S3. Pass in
// additional functional options to customize the uploader's behavior. Requires
// a S3 service client to make S3 API calls.
//
// Example:
// // The session the S3 Uploader will use
// sess := session.Must(session.NewSession())
//
// // S3 service client the Upload manager will use.
// s3Svc := s3.New(sess)
//
// // Create an uploader with S3 client and default options
// uploader := s3manager.NewUploaderWithClient(s3Svc)
//
// // Create an uploader with S3 client and custom options
// uploader := s3manager.NewUploaderWithClient(s3Svc, func(u *s3manager.Uploader) {
// u.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewUploaderWithClient(svc s3iface.S3API, options ...func(*Uploader)) *Uploader {
u := &Uploader{
S3: svc,
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
MaxUploadParts: MaxUploadParts,
}
for _, option := range options {
option(u)
}
return u
}
// Upload uploads an object to S3, intelligently buffering large files into
// smaller chunks and sending them in parallel across multiple goroutines. You
// can configure the buffer size and concurrency through the Uploader's parameters.
//
// Additional functional options can be provided to configure the individual
// upload. These options are copies of the Uploader instance Upload is called from.
// Modifying the options will not impact the original Uploader instance.
//
// Use the WithUploaderRequestOptions helper function to pass in request
// options that will be applied to all API operations made with this uploader.
//
// It is safe to call this method concurrently across goroutines.
//
// Example:
// // Upload input parameters
// upParams := &s3manager.UploadInput{
// Bucket: &bucketName,
// Key: &keyName,
// Body: file,
// }
//
// // Perform an upload.
// result, err := uploader.Upload(upParams)
//
// // Perform upload with options different than the those in the Uploader.
// result, err := uploader.Upload(upParams, func(u *s3manager.Uploader) {
// u.PartSize = 10 * 1024 * 1024 // 10MB part size
// u.LeavePartsOnError = true // Don't delete the parts if the upload fails.
// })
func (u Uploader) Upload(input *UploadInput, options ...func(*Uploader)) (*UploadOutput, error) {
return u.UploadWithContext(aws.BackgroundContext(), input, options...)
}
// UploadWithContext uploads an object to S3, intelligently buffering large
// files into smaller chunks and sending them in parallel across multiple
// goroutines. You can configure the buffer size and concurrency through the
// Uploader's parameters.
//
// UploadWithContext is the same as Upload with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the context to add deadlining, timeouts, ect. The
// UploadWithContext may create sub-contexts for individual underlying requests.
//
// Additional functional options can be provided to configure the individual
// upload. These options are copies of the Uploader instance Upload is called from.
// Modifying the options will not impact the original Uploader instance.
//
// Use the WithUploaderRequestOptions helper function to pass in request
// options that will be applied to all API operations made with this uploader.
//
// It is safe to call this method concurrently across goroutines.
func (u Uploader) UploadWithContext(ctx aws.Context, input *UploadInput, opts ...func(*Uploader)) (*UploadOutput, error) {
i := uploader{in: input, cfg: u, ctx: ctx}
for _, opt := range opts {
opt(&i.cfg)
}
i.cfg.RequestOptions = append(i.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager"))
return i.upload()
}
// internal structure to manage an upload to S3.
type uploader struct {
ctx aws.Context
cfg Uploader
in *UploadInput
readerPos int64 // current reader position
totalSize int64 // set to -1 if the size is not known
}
// internal logic for deciding whether to upload a single part or use a
// multipart upload.
func (u *uploader) upload() (*UploadOutput, error) {
u.init()
if u.cfg.PartSize < MinUploadPartSize {
msg := fmt.Sprintf("part size must be at least %d bytes", MinUploadPartSize)
return nil, awserr.New("ConfigError", msg, nil)
}
// Do one read to determine if we have more than one part
reader, _, err := u.nextReader()
if err == io.EOF { // single part
return u.singlePart(reader)
} else if err != nil {
return nil, awserr.New("ReadRequestBody", "read upload data failed", err)
}
mu := multiuploader{uploader: u}
return mu.upload(reader)
}
// init will initialize all default options.
func (u *uploader) init() {
if u.cfg.Concurrency == 0 {
u.cfg.Concurrency = DefaultUploadConcurrency
}
if u.cfg.PartSize == 0 {
u.cfg.PartSize = DefaultUploadPartSize
}
// Try to get the total size for some optimizations
u.initSize()
}
// initSize tries to detect the total stream size, setting u.totalSize. If
// the size is not known, totalSize is set to -1.
func (u *uploader) initSize() {
u.totalSize = -1
switch r := u.in.Body.(type) {
case io.Seeker:
pos, _ := r.Seek(0, 1)
defer r.Seek(pos, 0)
n, err := r.Seek(0, 2)
if err != nil {
return
}
u.totalSize = n
// Try to adjust partSize if it is too small and account for
// integer division truncation.
if u.totalSize/u.cfg.PartSize >= int64(u.cfg.MaxUploadParts) {
// Add one to the part size to account for remainders
// during the size calculation. e.g odd number of bytes.
u.cfg.PartSize = (u.totalSize / int64(u.cfg.MaxUploadParts)) + 1
}
}
}
// nextReader returns a seekable reader representing the next packet of data.
// This operation increases the shared u.readerPos counter, but note that it
// does not need to be wrapped in a mutex because nextReader is only called
// from the main thread.
func (u *uploader) nextReader() (io.ReadSeeker, int, error) {
type readerAtSeeker interface {
io.ReaderAt
io.ReadSeeker
}
switch r := u.in.Body.(type) {
case readerAtSeeker:
var err error
n := u.cfg.PartSize
if u.totalSize >= 0 {
bytesLeft := u.totalSize - u.readerPos
if bytesLeft <= u.cfg.PartSize {
err = io.EOF
n = bytesLeft
}
}
reader := io.NewSectionReader(r, u.readerPos, n)
u.readerPos += n
return reader, int(n), err
default:
part := make([]byte, u.cfg.PartSize)
n, err := readFillBuf(r, part)
u.readerPos += int64(n)
return bytes.NewReader(part[0:n]), n, err
}
}
func readFillBuf(r io.Reader, b []byte) (offset int, err error) {
for offset < len(b) && err == nil {
var n int
n, err = r.Read(b[offset:])
offset += n
}
return offset, err
}
// singlePart contains upload logic for uploading a single chunk via
// a regular PutObject request. Multipart requests require at least two
// parts, or at least 5MB of data.
func (u *uploader) singlePart(buf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.PutObjectInput{}
awsutil.Copy(params, u.in)
params.Body = buf
// Need to use request form because URL generated in request is
// used in return.
req, out := u.cfg.S3.PutObjectRequest(params)
req.SetContext(u.ctx)
req.ApplyOptions(u.cfg.RequestOptions...)
if err := req.Send(); err != nil {
return nil, err
}
url := req.HTTPRequest.URL.String()
return &UploadOutput{
Location: url,
VersionID: out.VersionId,
}, nil
}
// internal structure to manage a specific multipart upload to S3.
type multiuploader struct {
*uploader
wg sync.WaitGroup
m sync.Mutex
err error
uploadID string
parts completedParts
}
// keeps track of a single chunk of data being sent to S3.
type chunk struct {
buf io.ReadSeeker
num int64
}
// completedParts is a wrapper to make parts sortable by their part number,
// since S3 required this list to be sent in sorted order.
type completedParts []*s3.CompletedPart
func (a completedParts) Len() int { return len(a) }
func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
// upload will perform a multipart upload using the firstBuf buffer containing
// the first chunk of data.
func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.CreateMultipartUploadInput{}
awsutil.Copy(params, u.in)
// Create the multipart
resp, err := u.cfg.S3.CreateMultipartUploadWithContext(u.ctx, params, u.cfg.RequestOptions...)
if err != nil {
return nil, err
}
u.uploadID = *resp.UploadId
// Create the workers
ch := make(chan chunk, u.cfg.Concurrency)
for i := 0; i < u.cfg.Concurrency; i++ {
u.wg.Add(1)
go u.readChunk(ch)
}
// Send part 1 to the workers
var num int64 = 1
ch <- chunk{buf: firstBuf, num: num}
// Read and queue the rest of the parts
for u.geterr() == nil && err == nil {
num++
// This upload exceeded maximum number of supported parts, error now.
if num > int64(u.cfg.MaxUploadParts) || num > int64(MaxUploadParts) {
var msg string
if num > int64(u.cfg.MaxUploadParts) {
msg = fmt.Sprintf("exceeded total allowed configured MaxUploadParts (%d). Adjust PartSize to fit in this limit",
u.cfg.MaxUploadParts)
} else {
msg = fmt.Sprintf("exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit",
MaxUploadParts)
}
u.seterr(awserr.New("TotalPartsExceeded", msg, nil))
break
}
var reader io.ReadSeeker
var nextChunkLen int
reader, nextChunkLen, err = u.nextReader()
if err != nil && err != io.EOF {
u.seterr(awserr.New(
"ReadRequestBody",
"read multipart upload data failed",
err))
break
}
if nextChunkLen == 0 {
// No need to upload empty part, if file was empty to start
// with empty single part would of been created and never
// started multipart upload.
break
}
ch <- chunk{buf: reader, num: num}
}
// Close the channel, wait for workers, and complete upload
close(ch)
u.wg.Wait()
complete := u.complete()
if err := u.geterr(); err != nil {
return nil, &multiUploadError{
awsError: awserr.New(
"MultipartUpload",
"upload multipart failed",
err),
uploadID: u.uploadID,
}
}
return &UploadOutput{
Location: aws.StringValue(complete.Location),
VersionID: complete.VersionId,
UploadID: u.uploadID,
}, nil
}
// readChunk runs in worker goroutines to pull chunks off of the ch channel
// and send() them as UploadPart requests.
func (u *multiuploader) readChunk(ch chan chunk) {
defer u.wg.Done()
for {
data, ok := <-ch
if !ok {
break
}
if u.geterr() == nil {
if err := u.send(data); err != nil {
u.seterr(err)
}
}
}
}
// send performs an UploadPart request and keeps track of the completed
// part information.
func (u *multiuploader) send(c chunk) error {
params := &s3.UploadPartInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
Body: c.buf,
UploadId: &u.uploadID,
SSECustomerAlgorithm: u.in.SSECustomerAlgorithm,
SSECustomerKey: u.in.SSECustomerKey,
PartNumber: &c.num,
}
resp, err := u.cfg.S3.UploadPartWithContext(u.ctx, params, u.cfg.RequestOptions...)
if err != nil {
return err
}
n := c.num
completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &n}
u.m.Lock()
u.parts = append(u.parts, completed)
u.m.Unlock()
return nil
}
// geterr is a thread-safe getter for the error object
func (u *multiuploader) geterr() error {
u.m.Lock()
defer u.m.Unlock()
return u.err
}
// seterr is a thread-safe setter for the error object
func (u *multiuploader) seterr(e error) {
u.m.Lock()
defer u.m.Unlock()
u.err = e
}
// fail will abort the multipart unless LeavePartsOnError is set to true.
func (u *multiuploader) fail() {
if u.cfg.LeavePartsOnError {
return
}
params := &s3.AbortMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadId: &u.uploadID,
}
_, err := u.cfg.S3.AbortMultipartUploadWithContext(u.ctx, params, u.cfg.RequestOptions...)
if err != nil {
logMessage(u.cfg.S3, aws.LogDebug, fmt.Sprintf("failed to abort multipart upload, %v", err))
}
}
// complete successfully completes a multipart upload and returns the response.
func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput {
if u.geterr() != nil {
u.fail()
return nil
}
// Parts must be sorted in PartNumber order.
sort.Sort(u.parts)
params := &s3.CompleteMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadId: &u.uploadID,
MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts},
}
resp, err := u.cfg.S3.CompleteMultipartUploadWithContext(u.ctx, params, u.cfg.RequestOptions...)
if err != nil {
u.seterr(err)
u.fail()
}
return resp
}
+755
View File
@@ -0,0 +1,755 @@
package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"sort"
"strings"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/assert"
)
var emptyList = []string{}
func val(i interface{}, s string) interface{} {
v, err := awsutil.ValuesAtPath(i, s)
if err != nil || len(v) == 0 {
return nil
}
if _, ok := v[0].(io.Reader); ok {
return v[0]
}
if rv := reflect.ValueOf(v[0]); rv.Kind() == reflect.Ptr {
return rv.Elem().Interface()
}
return v[0]
}
func contains(src []string, s string) bool {
for _, v := range src {
if s == v {
return true
}
}
return false
}
func loggingSvc(ignoreOps []string) (*s3.S3, *[]string, *[]interface{}) {
var m sync.Mutex
partNum := 0
names := []string{}
params := []interface{}{}
svc := s3.New(unit.Session)
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
if !contains(ignoreOps, r.Operation.Name) {
names = append(names, r.Operation.Name)
params = append(params, r.Params)
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
switch data := r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
data.UploadId = aws.String("UPLOAD-ID")
case *s3.UploadPartOutput:
partNum++
data.ETag = aws.String(fmt.Sprintf("ETAG%d", partNum))
case *s3.CompleteMultipartUploadOutput:
data.Location = aws.String("https://location")
data.VersionId = aws.String("VERSION-ID")
case *s3.PutObjectOutput:
data.VersionId = aws.String("VERSION-ID")
}
})
return svc, &names, &params
}
func buflen(i interface{}) int {
r := i.(io.Reader)
b, _ := ioutil.ReadAll(r)
return len(b)
}
func TestUploadOrderMulti(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
u := s3manager.NewUploaderWithClient(s)
resp, err := u.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
ServerSideEncryption: aws.String("aws:kms"),
SSEKMSKeyId: aws.String("KmsId"),
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
assert.Equal(t, "https://location", resp.Location)
assert.Equal(t, "UPLOAD-ID", resp.UploadID)
assert.Equal(t, aws.String("VERSION-ID"), resp.VersionID)
// Validate input values
// UploadPart
assert.Equal(t, "UPLOAD-ID", val((*args)[1], "UploadId"))
assert.Equal(t, "UPLOAD-ID", val((*args)[2], "UploadId"))
assert.Equal(t, "UPLOAD-ID", val((*args)[3], "UploadId"))
// CompleteMultipartUpload
assert.Equal(t, "UPLOAD-ID", val((*args)[4], "UploadId"))
assert.Equal(t, int64(1), val((*args)[4], "MultipartUpload.Parts[0].PartNumber"))
assert.Equal(t, int64(2), val((*args)[4], "MultipartUpload.Parts[1].PartNumber"))
assert.Equal(t, int64(3), val((*args)[4], "MultipartUpload.Parts[2].PartNumber"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[0].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[1].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[2].ETag"))
// Custom headers
assert.Equal(t, "aws:kms", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "KmsId", val((*args)[0], "SSEKMSKeyId"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
}
func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.PartSize = 1024 * 1024 * 7
u.Concurrency = 1
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
assert.Equal(t, 1024*1024*7, buflen(val((*args)[1], "Body")))
assert.Equal(t, 1024*1024*5, buflen(val((*args)[2], "Body")))
}
func TestUploadIncreasePartSize(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.Concurrency = 1
u.MaxUploadParts = 2
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, int64(s3manager.DefaultDownloadPartSize), mgr.PartSize)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
assert.Equal(t, (1024*1024*6)+1, buflen(val((*args)[1], "Body")))
assert.Equal(t, (1024*1024*6)-1, buflen(val((*args)[2], "Body")))
}
func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
mgr := s3manager.NewUploader(unit.Session, func(u *s3manager.Uploader) {
u.PartSize = 5
})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Nil(t, resp)
assert.NotNil(t, err)
aerr := err.(awserr.Error)
assert.Equal(t, "ConfigError", aerr.Code())
assert.Contains(t, aerr.Message(), "part size must be at least")
}
func TestUploadOrderSingle(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s)
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf2MB),
ServerSideEncryption: aws.String("aws:kms"),
SSEKMSKeyId: aws.String("KmsId"),
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, aws.String("VERSION-ID"), resp.VersionID)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, "aws:kms", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "KmsId", val((*args)[0], "SSEKMSKeyId"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
}
func TestUploadOrderSingleFailure(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse.StatusCode = 400
})
mgr := s3manager.NewUploaderWithClient(s)
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf2MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.Nil(t, resp)
}
func TestUploadOrderZero(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s)
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 0)),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, 0, buflen(val((*args)[0], "Body")))
}
func TestUploadOrderMultiFailure(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *request.Request) {
switch t := r.Data.(type) {
case *s3.UploadPartOutput:
if *t.ETag == "ETAG2" {
r.HTTPResponse.StatusCode = 400
}
}
})
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *request.Request) {
switch r.Data.(type) {
case *s3.CompleteMultipartUploadOutput:
r.HTTPResponse.StatusCode = 400
}
})
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload", "AbortMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *request.Request) {
switch r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
r.HTTPResponse.StatusCode = 400
}
})
mgr := s3manager.NewUploaderWithClient(s)
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *request.Request) {
switch data := r.Data.(type) {
case *s3.UploadPartOutput:
if *data.ETag == "ETAG2" {
r.HTTPResponse.StatusCode = 400
}
}
})
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.Concurrency = 1
u.LeavePartsOnError = true
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *ops)
}
type failreader struct {
times int
failCount int
}
func (f *failreader) Read(b []byte) (int, error) {
f.failCount++
if f.failCount >= f.times {
return 0, fmt.Errorf("random failure")
}
return len(b), nil
}
func TestUploadOrderReadFail1(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s)
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 1},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
assert.EqualError(t, err.(awserr.Error).OrigErr(), "random failure")
assert.Equal(t, []string{}, *ops)
}
func TestUploadOrderReadFail2(t *testing.T) {
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 2},
})
assert.Equal(t, "MultipartUpload", err.(awserr.Error).Code())
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).OrigErr().(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).OrigErr().Error(), "random failure")
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
}
type sizedReader struct {
size int
cur int
err error
}
func (s *sizedReader) Read(p []byte) (n int, err error) {
if s.cur >= s.size {
if s.err == nil {
s.err = io.EOF
}
return 0, s.err
}
n = len(p)
s.cur += len(p)
if s.cur > s.size {
n -= s.cur - s.size
}
return
}
func TestUploadOrderMultiBufferedReader(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s)
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
parts := []int{
buflen(val((*args)[1], "Body")),
buflen(val((*args)[2], "Body")),
buflen(val((*args)[3], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
}
func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s)
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12, err: io.EOF},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
parts := []int{
buflen(val((*args)[1], "Body")),
buflen(val((*args)[2], "Body")),
buflen(val((*args)[3], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
}
// TestUploadOrderMultiBufferedReaderEOF tests the edge case where the
// file size is the same as part size.
func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s)
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 10, err: io.EOF},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
parts := []int{
buflen(val((*args)[1], "Body")),
buflen(val((*args)[2], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
}
func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.Concurrency = 1
u.MaxUploadParts = 2
})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.Error(t, err)
assert.Nil(t, resp)
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
aerr := err.(awserr.Error)
assert.Equal(t, "MultipartUpload", aerr.Code())
assert.Equal(t, "TotalPartsExceeded", aerr.OrigErr().(awserr.Error).Code())
assert.Contains(t, aerr.Error(), "configured MaxUploadParts (2)")
}
func TestUploadOrderSingleBufferedReader(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s)
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 2},
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
}
func TestUploadZeroLenObject(t *testing.T) {
requestMade := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestMade = true
w.WriteHeader(http.StatusOK)
}))
mgr := s3manager.NewUploaderWithClient(s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
}))
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: strings.NewReader(""),
})
assert.NoError(t, err)
assert.True(t, requestMade)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
}
func TestUploadInputS3PutObjectInputPairity(t *testing.T) {
matchings := compareStructType(reflect.TypeOf(s3.PutObjectInput{}),
reflect.TypeOf(s3manager.UploadInput{}))
aOnly := []string{}
bOnly := []string{}
for k, c := range matchings {
if c == 1 && k != "ContentLength" {
aOnly = append(aOnly, k)
} else if c == 2 {
bOnly = append(bOnly, k)
}
}
assert.Empty(t, aOnly, "s3.PutObjectInput")
assert.Empty(t, bOnly, "s3Manager.UploadInput")
}
type testIncompleteReader struct {
Buf []byte
Count int
}
func (r *testIncompleteReader) Read(p []byte) (n int, err error) {
if r.Count < 0 {
return 0, io.ErrUnexpectedEOF
}
r.Count--
return copy(p, r.Buf), nil
}
func TestUploadUnexpectedEOF(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploaderWithClient(s, func(u *s3manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &testIncompleteReader{
Buf: make([]byte, 1024*1024*5),
Count: 1,
},
})
assert.Error(t, err)
assert.Equal(t, "CreateMultipartUpload", (*ops)[0])
assert.Equal(t, "UploadPart", (*ops)[1])
assert.Equal(t, "AbortMultipartUpload", (*ops)[len(*ops)-1])
// Part lengths
assert.Equal(t, 1024*1024*5, buflen(val((*args)[1], "Body")))
}
func compareStructType(a, b reflect.Type) map[string]int {
if a.Kind() != reflect.Struct || b.Kind() != reflect.Struct {
panic(fmt.Sprintf("types must both be structs, got %v and %v", a.Kind(), b.Kind()))
}
aFields := enumFields(a)
bFields := enumFields(b)
matchings := map[string]int{}
for i := 0; i < len(aFields) || i < len(bFields); i++ {
if i < len(aFields) {
c := matchings[aFields[i].Name]
matchings[aFields[i].Name] = c + 1
}
if i < len(bFields) {
c := matchings[bFields[i].Name]
matchings[bFields[i].Name] = c + 2
}
}
return matchings
}
func enumFields(v reflect.Type) []reflect.StructField {
fields := []reflect.StructField{}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
// Ignoreing anon fields
if field.PkgPath != "" {
// Ignore unexported fields
continue
}
fields = append(fields, field)
}
return fields
}
type fooReaderAt struct{}
func (r *fooReaderAt) Read(p []byte) (n int, err error) {
return 12, io.EOF
}
func (r *fooReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
return 12, io.EOF
}
func TestReaderAt(t *testing.T) {
svc := s3.New(unit.Session)
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.Send.Clear()
contentLen := ""
svc.Handlers.Send.PushBack(func(r *request.Request) {
contentLen = r.HTTPRequest.Header.Get("Content-Length")
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
})
mgr := s3manager.NewUploaderWithClient(svc, func(u *s3manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &fooReaderAt{},
})
assert.NoError(t, err)
assert.Equal(t, contentLen, "12")
}
func TestSSE(t *testing.T) {
svc := s3.New(unit.Session)
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.ValidateResponse.Clear()
svc.Handlers.Send.Clear()
partNum := 0
mutex := &sync.Mutex{}
svc.Handlers.Send.PushBack(func(r *request.Request) {
mutex.Lock()
defer mutex.Unlock()
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
switch data := r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
data.UploadId = aws.String("UPLOAD-ID")
case *s3.UploadPartOutput:
input := r.Params.(*s3.UploadPartInput)
if input.SSECustomerAlgorithm == nil {
t.Fatal("SSECustomerAlgoritm should not be nil")
}
if input.SSECustomerKey == nil {
t.Fatal("SSECustomerKey should not be nil")
}
partNum++
data.ETag = aws.String(fmt.Sprintf("ETAG%d", partNum))
case *s3.CompleteMultipartUploadOutput:
data.Location = aws.String("https://location")
data.VersionId = aws.String("VERSION-ID")
case *s3.PutObjectOutput:
data.VersionId = aws.String("VERSION-ID")
}
})
mgr := s3manager.NewUploaderWithClient(svc, func(u *s3manager.Uploader) {
u.Concurrency = 5
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
SSECustomerAlgorithm: aws.String("AES256"),
SSECustomerKey: aws.String("foo"),
Body: bytes.NewBuffer(make([]byte, 1024*1024*10)),
})
if err != nil {
t.Fatal("Expected no error, but received" + err.Error())
}
}
func TestUploadWithContextCanceled(t *testing.T) {
u := s3manager.NewUploader(unit.Session)
params := s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 0)),
}
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
_, err := u.UploadWithContext(ctx, &params)
if err == nil {
t.Fatalf("expected error, did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expected error code %q, got %q", e, a)
}
if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}
+91
View File
@@ -0,0 +1,91 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package s3
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol/restxml"
)
// S3 is a client for Amazon S3.
// The service client's operations are safe to be used concurrently.
// It is not safe to mutate any of the client's properties though.
// Please also see https://docs.aws.amazon.com/goto/WebAPI/s3-2006-03-01
type S3 struct {
*client.Client
}
// Used for custom client initialization logic
var initClient func(*client.Client)
// Used for custom request initialization logic
var initRequest func(*request.Request)
// Service information constants
const (
ServiceName = "s3" // Service endpoint prefix API calls made to.
EndpointsID = ServiceName // Service ID for Regions and Endpoints metadata.
)
// New creates a new instance of the S3 client with a session.
// If additional configuration is needed for the client instance use the optional
// aws.Config parameter to add your extra config.
//
// Example:
// // Create a S3 client from just a session.
// svc := s3.New(mySession)
//
// // Create a S3 client with additional configuration
// svc := s3.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *S3 {
c := p.ClientConfig(EndpointsID, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *S3 {
svc := &S3{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningName: signingName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2006-03-01",
},
handlers,
),
}
// Handlers
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(restxml.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(restxml.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(restxml.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(restxml.UnmarshalErrorHandler)
// Run custom client initialization if present
if initClient != nil {
initClient(svc.Client)
}
return svc
}
// newRequest creates a new request for a S3 operation and runs any
// custom request initialization.
func (c *S3) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
// Run custom request initialization if present
if initRequest != nil {
initRequest(req)
}
return req
}
+44
View File
@@ -0,0 +1,44 @@
package s3
import (
"crypto/md5"
"encoding/base64"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
var errSSERequiresSSL = awserr.New("ConfigError", "cannot send SSE keys over HTTP.", nil)
func validateSSERequiresSSL(r *request.Request) {
if r.HTTPRequest.URL.Scheme != "https" {
p, _ := awsutil.ValuesAtPath(r.Params, "SSECustomerKey||CopySourceSSECustomerKey")
if len(p) > 0 {
r.Error = errSSERequiresSSL
}
}
}
func computeSSEKeys(r *request.Request) {
headers := []string{
"x-amz-server-side-encryption-customer-key",
"x-amz-copy-source-server-side-encryption-customer-key",
}
for _, h := range headers {
md5h := h + "-md5"
if key := r.HTTPRequest.Header.Get(h); key != "" {
// Base64-encode the value
b64v := base64.StdEncoding.EncodeToString([]byte(key))
r.HTTPRequest.Header.Set(h, b64v)
// Add MD5 if it wasn't computed
if r.HTTPRequest.Header.Get(md5h) == "" {
sum := md5.Sum([]byte(key))
b64sum := base64.StdEncoding.EncodeToString(sum[:])
r.HTTPRequest.Header.Set(md5h, b64sum)
}
}
}
}
+79
View File
@@ -0,0 +1,79 @@
package s3_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
func TestSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{DisableSSL: aws.Bool(true)})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.Error(t, err)
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
}
func TestCopySourceSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{DisableSSL: aws.Bool(true)})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.Error(t, err)
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
}
func TestComputeSSEKeys(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
}
func TestComputeSSEKeysShortcircuit(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
SSECustomerKeyMD5: aws.String("MD5"),
CopySourceSSECustomerKeyMD5: aws.String("MD5"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
}
+35
View File
@@ -0,0 +1,35 @@
package s3
import (
"bytes"
"io/ioutil"
"net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
func copyMultipartStatusOKUnmarhsalError(r *request.Request) {
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to read response body", err)
return
}
body := bytes.NewReader(b)
r.HTTPResponse.Body = ioutil.NopCloser(body)
defer body.Seek(0, 0)
if body.Len() == 0 {
// If there is no body don't attempt to parse the body.
return
}
unmarshalError(r)
if err, ok := r.Error.(awserr.Error); ok && err != nil {
if err.Code() == "SerializationError" {
r.Error = nil
return
}
r.HTTPResponse.StatusCode = http.StatusServiceUnavailable
}
}
+130
View File
@@ -0,0 +1,130 @@
package s3_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
const errMsg = `<Error><Code>ErrorCode</Code><Message>message body</Message><RequestId>requestID</RequestId><HostId>hostID=</HostId></Error>`
var lastModifiedTime = time.Date(2009, 11, 23, 0, 0, 0, 0, time.UTC)
func TestCopyObjectNoError(t *testing.T) {
const successMsg = `
<?xml version="1.0" encoding="UTF-8"?>
<CopyObjectResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><LastModified>2009-11-23T0:00:00Z</LastModified><ETag>&quot;1da64c7f13d1e8dbeaea40b905fd586c&quot;</ETag></CopyObjectResult>`
res, err := newCopyTestSvc(successMsg).CopyObject(&s3.CopyObjectInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/exists.txt"),
Key: aws.String("destination.txt"),
})
require.NoError(t, err)
assert.Equal(t, fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyObjectResult.ETag)
assert.Equal(t, lastModifiedTime, *res.CopyObjectResult.LastModified)
}
func TestCopyObjectError(t *testing.T) {
_, err := newCopyTestSvc(errMsg).CopyObject(&s3.CopyObjectInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/doesnotexist.txt"),
Key: aws.String("destination.txt"),
})
require.Error(t, err)
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "message body", e.Message())
}
func TestUploadPartCopySuccess(t *testing.T) {
const successMsg = `
<?xml version="1.0" encoding="UTF-8"?>
<UploadPartCopyResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><LastModified>2009-11-23T0:00:00Z</LastModified><ETag>&quot;1da64c7f13d1e8dbeaea40b905fd586c&quot;</ETag></CopyObjectResult>`
res, err := newCopyTestSvc(successMsg).UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/doesnotexist.txt"),
Key: aws.String("destination.txt"),
PartNumber: aws.Int64(0),
UploadId: aws.String("uploadID"),
})
require.NoError(t, err)
assert.Equal(t, fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyPartResult.ETag)
assert.Equal(t, lastModifiedTime, *res.CopyPartResult.LastModified)
}
func TestUploadPartCopyError(t *testing.T) {
_, err := newCopyTestSvc(errMsg).UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/doesnotexist.txt"),
Key: aws.String("destination.txt"),
PartNumber: aws.Int64(0),
UploadId: aws.String("uploadID"),
})
require.Error(t, err)
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "message body", e.Message())
}
func TestCompleteMultipartUploadSuccess(t *testing.T) {
const successMsg = `
<?xml version="1.0" encoding="UTF-8"?>
<CompleteMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Location>locationName</Location><Bucket>bucketName</Bucket><Key>keyName</Key><ETag>"etagVal"</ETag></CompleteMultipartUploadResult>`
res, err := newCopyTestSvc(successMsg).CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucketname"),
Key: aws.String("key"),
UploadId: aws.String("uploadID"),
})
require.NoError(t, err)
assert.Equal(t, `"etagVal"`, *res.ETag)
assert.Equal(t, "bucketName", *res.Bucket)
assert.Equal(t, "keyName", *res.Key)
assert.Equal(t, "locationName", *res.Location)
}
func TestCompleteMultipartUploadError(t *testing.T) {
_, err := newCopyTestSvc(errMsg).CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucketname"),
Key: aws.String("key"),
UploadId: aws.String("uploadID"),
})
require.Error(t, err)
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "message body", e.Message())
}
func newCopyTestSvc(errMsg string) *s3.S3 {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, errMsg, http.StatusOK)
}))
return s3.New(unit.Session, aws.NewConfig().
WithEndpoint(server.URL).
WithDisableSSL(true).
WithMaxRetries(0).
WithS3ForcePathStyle(true))
}
+103
View File
@@ -0,0 +1,103 @@
package s3
import (
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
type xmlErrorResponse struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
defer io.Copy(ioutil.Discard, r.HTTPResponse.Body)
hostID := r.HTTPResponse.Header.Get("X-Amz-Id-2")
// Bucket exists in a different region, and request needs
// to be made to the correct region.
if r.HTTPResponse.StatusCode == http.StatusMovedPermanently {
r.Error = requestFailure{
RequestFailure: awserr.NewRequestFailure(
awserr.New("BucketRegionError",
fmt.Sprintf("incorrect region, the bucket is not in '%s' region",
aws.StringValue(r.Config.Region)),
nil),
r.HTTPResponse.StatusCode,
r.RequestID,
),
hostID: hostID,
}
return
}
var errCode, errMsg string
// Attempt to parse error from body if it is known
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
errCode = "SerializationError"
errMsg = "failed to decode S3 XML error response"
} else {
errCode = resp.Code
errMsg = resp.Message
err = nil
}
// Fallback to status code converted to message if still no error code
if len(errCode) == 0 {
statusText := http.StatusText(r.HTTPResponse.StatusCode)
errCode = strings.Replace(statusText, " ", "", -1)
errMsg = statusText
}
r.Error = requestFailure{
RequestFailure: awserr.NewRequestFailure(
awserr.New(errCode, errMsg, err),
r.HTTPResponse.StatusCode,
r.RequestID,
),
hostID: hostID,
}
}
// A RequestFailure provides access to the S3 Request ID and Host ID values
// returned from API operation errors. Getting the error as a string will
// return the formated error with the same information as awserr.RequestFailure,
// while also adding the HostID value from the response.
type RequestFailure interface {
awserr.RequestFailure
// Host ID is the S3 Host ID needed for debug, and contacting support
HostID() string
}
type requestFailure struct {
awserr.RequestFailure
hostID string
}
func (r requestFailure) Error() string {
extra := fmt.Sprintf("status code: %d, request id: %s, host id: %s",
r.StatusCode(), r.RequestID(), r.hostID)
return awserr.SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}
func (r requestFailure) String() string {
return r.Error()
}
func (r requestFailure) HostID() string {
return r.hostID
}
@@ -0,0 +1,33 @@
package s3
import (
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestUnmarhsalErrorLeak(t *testing.T) {
req := &request.Request{
HTTPRequest: &http.Request{
Header: make(http.Header),
Body: &awstesting.ReadCloser{Size: 2048},
},
}
req.HTTPResponse = &http.Response{
Body: &awstesting.ReadCloser{Size: 2048},
Header: http.Header{
"X-Amzn-Requestid": []string{"1"},
},
StatusCode: http.StatusOK,
}
reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
unmarshalError(req)
assert.NotNil(t, req.Error)
assert.Equal(t, reader.Closed, true)
assert.Equal(t, reader.Size, 0)
}
+253
View File
@@ -0,0 +1,253 @@
package s3_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
type testErrorCase struct {
RespFn func() *http.Response
ReqID, HostID string
Code, Msg string
WithoutStatusMsg bool
}
var testUnmarshalCases = []testErrorCase{
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 301,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: -1,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "BucketRegionError", Msg: "incorrect region, the bucket is not in 'mock-region' region",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 403,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: 0,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "Forbidden", Msg: "Forbidden",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 400,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: 0,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "BadRequest", Msg: "Bad Request",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 404,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: 0,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "NotFound", Msg: "Not Found",
},
{
// SDK only reads request ID and host ID from the header. The values
// in message body are ignored.
RespFn: func() *http.Response {
body := `<Error><Code>SomeException</Code><Message>Exception message</Message><RequestId>ignored-request-id</RequestId><HostId>ignored-host-id</HostId></Error>`
return &http.Response{
StatusCode: 500,
Header: http.Header{
"X-Amz-Request-Id": []string{"taken-request-id"},
"X-Amz-Id-2": []string{"taken-host-id"},
},
Body: ioutil.NopCloser(strings.NewReader(body)),
ContentLength: int64(len(body)),
}
},
ReqID: "taken-request-id",
HostID: "taken-host-id",
Code: "SomeException", Msg: "Exception message",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 404,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: -1,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "NotFound", Msg: "Not Found", WithoutStatusMsg: true,
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 404,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: -1,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "NotFound", Msg: "Not Found",
},
}
func TestUnmarshalError(t *testing.T) {
for i, c := range testUnmarshalCases {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = c.RespFn()
if !c.WithoutStatusMsg {
r.HTTPResponse.Status = fmt.Sprintf("%d%s",
r.HTTPResponse.StatusCode,
http.StatusText(r.HTTPResponse.StatusCode))
}
})
_, err := s.PutBucketAcl(&s3.PutBucketAclInput{
Bucket: aws.String("bucket"), ACL: aws.String("public-read"),
})
if err == nil {
t.Fatalf("%d, expected error, got nil", i)
}
if e, a := c.Code, err.(awserr.Error).Code(); e != a {
t.Errorf("%d, Code: expect %s, got %s", i, e, a)
}
if e, a := c.Msg, err.(awserr.Error).Message(); e != a {
t.Errorf("%d, Message: expect %s, got %s", i, e, a)
}
if e, a := c.ReqID, err.(awserr.RequestFailure).RequestID(); e != a {
t.Errorf("%d, RequestID: expect %s, got %s", i, e, a)
}
if e, a := c.HostID, err.(s3.RequestFailure).HostID(); e != a {
t.Errorf("%d, HostID: expect %s, got %s", i, e, a)
}
}
}
const completeMultiResp = `
163
<?xml version="1.0" encoding="UTF-8"?>
<CompleteMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Location>https://bucket.s3-us-west-2.amazonaws.com/key</Location><Bucket>bucket</Bucket><Key>key</Key><ETag>&quot;a7d414b9133d6483d9a1c4e04e856e3b-2&quot;</ETag></CompleteMultipartUploadResult>
0
`
func Test200NoErrorUnmarshalError(t *testing.T) {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(strings.NewReader(completeMultiResp)),
ContentLength: -1,
}
r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode)
})
_, err := s.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucket"), Key: aws.String("key"),
UploadId: aws.String("id"),
MultipartUpload: &s3.CompletedMultipartUpload{Parts: []*s3.CompletedPart{
{ETag: aws.String("etag"), PartNumber: aws.Int64(1)},
}},
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}
const completeMultiErrResp = `<Error><Code>SomeException</Code><Message>Exception message</Message></Error>`
func Test200WithErrorUnmarshalError(t *testing.T) {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(strings.NewReader(completeMultiErrResp)),
ContentLength: -1,
}
r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode)
})
_, err := s.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucket"), Key: aws.String("key"),
UploadId: aws.String("id"),
MultipartUpload: &s3.CompletedMultipartUpload{Parts: []*s3.CompletedPart{
{ETag: aws.String("etag"), PartNumber: aws.Int64(1)},
}},
})
if err == nil {
t.Fatalf("expected error, got nil")
}
if e, a := "SomeException", err.(awserr.Error).Code(); e != a {
t.Errorf("Code: expect %s, got %s", e, a)
}
if e, a := "Exception message", err.(awserr.Error).Message(); e != a {
t.Errorf("Message: expect %s, got %s", e, a)
}
if e, a := "abc123", err.(s3.RequestFailure).RequestID(); e != a {
t.Errorf("RequestID: expect %s, got %s", e, a)
}
if e, a := "321cba", err.(s3.RequestFailure).HostID(); e != a {
t.Errorf("HostID: expect %s, got %s", e, a)
}
}
+194
View File
@@ -0,0 +1,194 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package s3
import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// WaitUntilBucketExists uses the Amazon S3 API operation
// HeadBucket to wait for a condition to be met before returning.
// If the condition is not meet within the max attempt window an error will
// be returned.
func (c *S3) WaitUntilBucketExists(input *HeadBucketInput) error {
return c.WaitUntilBucketExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilBucketExistsWithContext is an extended version of WaitUntilBucketExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilBucketExistsWithContext(ctx aws.Context, input *HeadBucketInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilBucketExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 301,
},
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 403,
},
{
State: request.RetryWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
req, _ := c.HeadBucketRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}
// WaitUntilBucketNotExists uses the Amazon S3 API operation
// HeadBucket to wait for a condition to be met before returning.
// If the condition is not meet within the max attempt window an error will
// be returned.
func (c *S3) WaitUntilBucketNotExists(input *HeadBucketInput) error {
return c.WaitUntilBucketNotExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilBucketNotExistsWithContext is an extended version of WaitUntilBucketNotExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilBucketNotExistsWithContext(ctx aws.Context, input *HeadBucketInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilBucketNotExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
req, _ := c.HeadBucketRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}
// WaitUntilObjectExists uses the Amazon S3 API operation
// HeadObject to wait for a condition to be met before returning.
// If the condition is not meet within the max attempt window an error will
// be returned.
func (c *S3) WaitUntilObjectExists(input *HeadObjectInput) error {
return c.WaitUntilObjectExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilObjectExistsWithContext is an extended version of WaitUntilObjectExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilObjectExistsWithContext(ctx aws.Context, input *HeadObjectInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilObjectExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
{
State: request.RetryWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
req, _ := c.HeadObjectRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}
// WaitUntilObjectNotExists uses the Amazon S3 API operation
// HeadObject to wait for a condition to be met before returning.
// If the condition is not meet within the max attempt window an error will
// be returned.
func (c *S3) WaitUntilObjectNotExists(input *HeadObjectInput) error {
return c.WaitUntilObjectNotExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilObjectNotExistsWithContext is an extended version of WaitUntilObjectNotExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilObjectNotExistsWithContext(ctx aws.Context, input *HeadObjectInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilObjectNotExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
req, _ := c.HeadObjectRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}