Update vendored deps, including AWS SDK, openpgp, ftp, ...

This commit is contained in:
Andrey Smirnov
2018-04-05 17:46:45 +03:00
parent cef4fefc40
commit 0e6ee35942
1497 changed files with 450721 additions and 68034 deletions
+133 -45
View File
@@ -5,11 +5,11 @@ import (
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
@@ -81,12 +81,24 @@ func TestCopy1(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, f2.D, f1.D)
assert.Equal(t, f2.E.B, f1.E.B)
assert.Equal(t, f2.E.D, f1.E.D)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.D, f1.D; !v1.Equal(*v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.B, f1.E.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.D, f1.E.D; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// But pointers are not!
str3 := "nothello"
@@ -99,14 +111,30 @@ func TestCopy1(t *testing.T) {
*f2.E.B = int3
f2.E.c = 5
f2.E.D = 5
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
assert.NotEqual(t, f2.D, f1.D)
assert.NotEqual(t, f2.E.a, f1.E.a)
assert.NotEqual(t, f2.E.B, f1.E.B)
assert.NotEqual(t, f2.E.c, f1.E.c)
assert.NotEqual(t, f2.E.D, f1.E.D)
if v1, v2 := f2.A, f1.A; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B, f1.B; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.C, f1.C; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.D, f1.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.a, f1.E.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.B, f1.E.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.c, f1.E.c; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.D, f1.E.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
}
func TestCopyNestedWithUnexported(t *testing.T) {
@@ -125,10 +153,18 @@ func TestCopyNestedWithUnexported(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values match
assert.Equal(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.B.a, f1.B.a)
assert.Equal(t, f2.B.B, f2.B.B)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.a, f1.B.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.B, f2.B.B; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyIgnoreNilMembers(t *testing.T) {
@@ -139,34 +175,56 @@ func TestCopyIgnoreNilMembers(t *testing.T) {
}
f := &Foo{}
assert.Nil(t, f.A)
assert.Nil(t, f.B)
assert.Nil(t, f.C)
if v1 := f.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
var f2 Foo
awsutil.Copy(&f2, f)
assert.Nil(t, f2.A)
assert.Nil(t, f2.B)
assert.Nil(t, f2.C)
if v1 := f2.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
assert.Nil(t, f3.A)
assert.Nil(t, f3.B)
assert.Nil(t, f3.C)
if v1 := f3.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
if v1, v2 := "hello", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
if v1, v2 := "", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyReader(t *testing.T) {
@@ -174,13 +232,21 @@ func TestCopyReader(t *testing.T) {
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte("hello world"), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte(""), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyDifferentStructs(t *testing.T) {
@@ -226,17 +292,39 @@ func TestCopyDifferentStructs(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, "unique", f1.SrcUnique)
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
assert.Equal(t, int1, *f1.unexportedPtr)
assert.Nil(t, f2.unexportedPtr)
assert.Equal(t, int2, *f1.ExportedPtr)
assert.Equal(t, int2, *f2.ExportedPtr)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "unique", f1.SrcUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 1, f1.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 0, f2.DstUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "", f2.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int1, *f1.unexportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1 := f2.unexportedPtr; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1, v2 := int2, *f1.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int2, *f2.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func ExampleCopyOf() {
+3 -2
View File
@@ -5,7 +5,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func TestDeepEqual(t *testing.T) {
@@ -24,6 +23,8 @@ func TestDeepEqual(t *testing.T) {
}
for i, c := range cases {
assert.Equal(t, c.equal, awsutil.DeepEqual(c.a, c.b), "%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
if awsutil.DeepEqual(c.a, c.b) != c.equal {
t.Errorf("%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
}
}
}
+61 -21
View File
@@ -1,10 +1,10 @@
package awsutil_test
import (
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
@@ -50,8 +50,12 @@ func TestValueAtPathSuccess(t *testing.T) {
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
}
}
@@ -78,12 +82,18 @@ func TestValueAtPathFailure(t *testing.T) {
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
if c.errContains != "" {
assert.Contains(t, err.Error(), c.errContains, "case %d, expected error, %s", i, c.path)
if !strings.Contains(err.Error(), c.errContains) {
t.Errorf("case %v, expected error, %v", i, c.path)
}
continue
} else {
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
}
}
@@ -92,51 +102,81 @@ func TestSetValueAtPathSuccess(t *testing.T) {
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
if e, a := "test1", s.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test2", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test3", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
if e, a := "test0", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test0", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s2 Struct
awsutil.SetValueAtPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
if e, a := "test0", s2.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A)
if e, a := []Struct{{}}, s2.A; !awsutil.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
str := "foo"
s3 := Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", str)
assert.Equal(t, "foo", s3.B.B.C)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{B: &Struct{B: &Struct{C: str}}}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", &str)
assert.Equal(t, "foo", s3.B.B.C)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s4 struct{ Name *string }
awsutil.SetValueAtPath(&s4, "Name", str)
assert.Equal(t, str, *s4.Name)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{Name: &str}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", &str)
assert.Equal(t, str, *s4.Name)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
}
+6
View File
@@ -15,6 +15,12 @@ type Config struct {
Endpoint string
SigningRegion string
SigningName string
// States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors
// to determine if the signin name can be overriden based on metadata the
// service has.
SigningNameDerived bool
}
// ConfigProvider provides a generic way for a service client to receive
+48 -28
View File
@@ -1,11 +1,11 @@
package client
import (
"math/rand"
"sync"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
@@ -30,25 +30,27 @@ func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries
}
var seededRand = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())})
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30
throttle := d.shouldThrottle(r)
if throttle {
if delay, ok := getRetryDelay(r); ok {
return delay
}
minTime = 500
}
retryCount := r.RetryCount
if retryCount > 13 {
retryCount = 13
} else if throttle && retryCount > 8 {
if throttle && retryCount > 8 {
retryCount = 8
} else if retryCount > 13 {
retryCount = 13
}
delay := (1 << uint(retryCount)) * (seededRand.Intn(minTime) + minTime)
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime)
return time.Duration(delay) * time.Millisecond
}
@@ -60,7 +62,7 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 {
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
return true
}
return r.IsErrorRetryable() || d.shouldThrottle(r)
@@ -68,29 +70,47 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
if r.HTTPResponse.StatusCode == 502 ||
r.HTTPResponse.StatusCode == 503 ||
r.HTTPResponse.StatusCode == 504 {
return true
switch r.HTTPResponse.StatusCode {
case 429:
case 502:
case 503:
case 504:
default:
return r.IsErrorThrottle()
}
return r.IsErrorThrottle()
return true
}
// lockedSource is a thread-safe implementation of rand.Source
type lockedSource struct {
lk sync.Mutex
src rand.Source
// This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request
func getRetryDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) {
return 0, false
}
delayStr := r.HTTPResponse.Header.Get("Retry-After")
if len(delayStr) == 0 {
return 0, false
}
delay, err := strconv.Atoi(delayStr)
if err != nil {
return 0, false
}
return time.Duration(delay) * time.Second, true
}
func (r *lockedSource) Int63() (n int64) {
r.lk.Lock()
n = r.src.Int63()
r.lk.Unlock()
return
}
// Will look at the status code to see if the retry header pertains to
// the status code.
func canUseRetryAfterHeader(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 503:
default:
return false
}
func (r *lockedSource) Seed(seed int64) {
r.lk.Lock()
r.src.Seed(seed)
r.lk.Unlock()
return true
}
+189
View File
@@ -0,0 +1,189 @@
package client
import (
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestRetryThrottleStatusCodes(t *testing.T) {
cases := []struct {
expectThrottle bool
expectRetry bool
r request.Request
}{
{
false,
false,
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 502},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 504},
},
},
{
false,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
},
}
d := DefaultRetryer{NumMaxRetries: 10}
for i, c := range cases {
throttle := d.shouldThrottle(&c.r)
retry := d.ShouldRetry(&c.r)
if e, a := c.expectThrottle, throttle; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
if e, a := c.expectRetry, retry; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
}
}
func TestCanUseRetryAfter(t *testing.T) {
cases := []struct {
r request.Request
e bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
true,
},
}
for i, c := range cases {
a := canUseRetryAfterHeader(&c.r)
if c.e != a {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestGetRetryDelay(t *testing.T) {
cases := []struct {
r request.Request
e time.Duration
equal bool
ok bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429, Header: http.Header{"Retry-After": []string{"3600"}}},
},
3600 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
120 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
1 * time.Second,
false,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}},
},
0 * time.Second,
true,
false,
},
}
for i, c := range cases {
a, ok := getRetryDelay(&c.r)
if c.ok != ok {
t.Errorf("%d: expected %v, but received %v", i, c.ok, ok)
}
if (c.e != a) == c.equal {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestRetryDelay(t *testing.T) {
r := request.Request{}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.HTTPResponse = &http.Response{StatusCode: 500, Header: http.Header{"Retry-After": []string{""}}}
rTemp.RetryCount = i
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.RetryCount = i
rTemp.HTTPResponse = &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}}
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
}
+4
View File
@@ -46,6 +46,7 @@ func (reader *teeReaderCloser) Close() error {
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
bodySeekable := aws.IsReaderSeekable(r.Body)
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
@@ -53,6 +54,9 @@ func logRequest(r *request.Request) {
}
if logBody {
if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
}
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
+87
View File
@@ -2,8 +2,17 @@ package client
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
type mockCloser struct {
@@ -55,3 +64,81 @@ func TestLogWriter(t *testing.T) {
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
}
}
func TestLogRequest(t *testing.T) {
cases := []struct {
Body io.ReadSeeker
ExpectBody []byte
LogLevel aws.LogLevelType
}{
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
ExpectBody: []byte("body content"),
},
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
logW := bytes.NewBuffer(nil)
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.SetReaderBody(c.Body)
req.Build()
logRequest(req)
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
type bufLogger struct {
w *bytes.Buffer
}
func (l *bufLogger) Log(args ...interface{}) {
fmt.Fprintln(l.w, args...)
}
func testHandlers() request.Handlers {
var handlers request.Handlers
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
return handlers
}
+23 -1
View File
@@ -151,6 +151,15 @@ type Config struct {
// with accelerate.
S3UseAccelerate *bool
// S3DisableContentMD5Validation config option is temporarily disabled,
// For S3 GetObject API calls, #1837.
//
// Set this to `true` to disable the S3 service client from automatically
// adding the ContentMD5 to S3 Object Put and Upload API calls. This option
// will also disable the SDK from performing object ContentMD5 validation
// on GetObject API calls.
S3DisableContentMD5Validation *bool
// Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only
@@ -168,7 +177,7 @@ type Config struct {
//
EC2MetadataDisableTimeoutOverride *bool
// Instructs the endpiont to be generated for a service client to
// Instructs the endpoint to be generated for a service client to
// be the dual stack endpoint. The dual stack endpoint will support
// both IPv4 and IPv6 addressing.
//
@@ -336,6 +345,15 @@ func (c *Config) WithS3Disable100Continue(disable bool) *Config {
func (c *Config) WithS3UseAccelerate(enable bool) *Config {
c.S3UseAccelerate = &enable
return c
}
// WithS3DisableContentMD5Validation sets a config
// S3DisableContentMD5Validation value returning a Config pointer for chaining.
func (c *Config) WithS3DisableContentMD5Validation(enable bool) *Config {
c.S3DisableContentMD5Validation = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config
@@ -435,6 +453,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.S3UseAccelerate = other.S3UseAccelerate
}
if other.S3DisableContentMD5Validation != nil {
dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation
}
if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack
}
+259 -88
View File
@@ -1,10 +1,9 @@
package aws
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var testCasesStringSlice = [][]string{
@@ -18,14 +17,22 @@ func TestStringSlice(t *testing.T) {
continue
}
out := StringSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := StringValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -39,22 +46,34 @@ func TestStringValueSlice(t *testing.T) {
continue
}
out := StringValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != "" {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := StringSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != "" {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := *in[i], *out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -70,14 +89,22 @@ func TestStringMap(t *testing.T) {
continue
}
out := StringMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := StringValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -91,14 +118,22 @@ func TestBoolSlice(t *testing.T) {
continue
}
out := BoolSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := BoolValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -110,22 +145,34 @@ func TestBoolValueSlice(t *testing.T) {
continue
}
out := BoolValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := BoolSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -141,14 +188,22 @@ func TestBoolMap(t *testing.T) {
continue
}
out := BoolMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := BoolValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -162,14 +217,22 @@ func TestIntSlice(t *testing.T) {
continue
}
out := IntSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := IntValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -181,22 +244,34 @@ func TestIntValueSlice(t *testing.T) {
continue
}
out := IntValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := IntSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -212,14 +287,22 @@ func TestIntMap(t *testing.T) {
continue
}
out := IntMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := IntValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -233,14 +316,22 @@ func TestInt64Slice(t *testing.T) {
continue
}
out := Int64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Int64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -252,22 +343,34 @@ func TestInt64ValueSlice(t *testing.T) {
continue
}
out := Int64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := Int64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -283,14 +386,22 @@ func TestInt64Map(t *testing.T) {
continue
}
out := Int64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Int64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -304,14 +415,22 @@ func TestFloat64Slice(t *testing.T) {
continue
}
out := Float64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Float64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -323,22 +442,34 @@ func TestFloat64ValueSlice(t *testing.T) {
continue
}
out := Float64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := Float64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -354,14 +485,22 @@ func TestFloat64Map(t *testing.T) {
continue
}
out := Float64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Float64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -375,14 +514,22 @@ func TestTimeSlice(t *testing.T) {
continue
}
out := TimeSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := TimeValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -394,22 +541,34 @@ func TestTimeValueSlice(t *testing.T) {
continue
}
out := TimeValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if !out[i].IsZero() {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := TimeSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if !(*(out2[i])).IsZero() {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -425,14 +584,22 @@ func TestTimeMap(t *testing.T) {
continue
}
out := TimeMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := TimeValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -458,13 +625,17 @@ var testCasesTimeValue = []TimeValueTestCase{
func TestSecondsTimeValue(t *testing.T) {
for idx, testCase := range testCasesTimeValue {
out := SecondsTimeValue(&testCase.in)
assert.Equal(t, testCase.outSecs, out, "Unexpected value for time value at %d", idx)
if e, a := testCase.outSecs, out; e != a {
t.Errorf("Unexpected value for time value at %d", idx)
}
}
}
func TestMillisecondsTimeValue(t *testing.T) {
for idx, testCase := range testCasesTimeValue {
out := MillisecondsTimeValue(&testCase.in)
assert.Equal(t, testCase.outMillis, out, "Unexpected value for time value at %d", idx)
if e, a := testCase.outMillis, out; e != a {
t.Errorf("Unexpected value for time value at %d", idx)
}
}
}
+7 -21
View File
@@ -3,12 +3,10 @@ package corehandlers
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"runtime"
"strconv"
"time"
@@ -36,18 +34,13 @@ var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLen
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ = strconv.ParseInt(slength, 10, 64)
} else {
switch body := r.Body.(type) {
case nil:
length = 0
case lener:
length = int64(body.Len())
case io.Seeker:
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
if r.Body != nil {
var err error
length, err = aws.SeekerLen(r.Body)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to get request body's length", err)
return
}
}
}
@@ -60,13 +53,6 @@ var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLen
}
}}
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent.
var SDKVersionUserAgentHandler = request.NamedHandler{
Name: "core.SDKVersionUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// ValidateReqSigHandler is a request handler to ensure that the request's
+67 -24
View File
@@ -7,11 +7,10 @@ import (
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
@@ -32,7 +31,9 @@ func TestValidateEndpointHandler(t *testing.T) {
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
@@ -45,8 +46,12 @@ func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, aws.ErrMissingRegion, err)
if err == nil {
t.Errorf("expect error, got none")
}
if e, a := aws.ErrMissingRegion, err; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
type mockCredsProvider struct {
@@ -82,18 +87,30 @@ func TestAfterRetryRefreshCreds(t *testing.T) {
})
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
if !svc.Config.Credentials.IsExpired() {
t.Errorf("Expect to start out expired")
}
if credProvider.retrieveCalled {
t.Errorf("expect not called")
}
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
if !svc.Config.Credentials.IsExpired() {
t.Errorf("Expect to start out expired")
}
if credProvider.retrieveCalled {
t.Errorf("expect not called")
}
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if !credProvider.retrieveCalled {
t.Errorf("expect not called")
}
}
func TestAfterRetryWithContextCanceled(t *testing.T) {
@@ -202,8 +219,12 @@ func TestSendHandlerError(t *testing.T) {
r.Send()
assert.Error(t, r.Error)
assert.NotNil(t, r.HTTPResponse)
if r.Error == nil {
t.Errorf("expect error, got none")
}
if r.HTTPResponse == nil {
t.Errorf("expect response, got none")
}
}
func TestSendWithoutFollowRedirects(t *testing.T) {
@@ -273,31 +294,47 @@ func TestValidateReqSigHandler(t *testing.T) {
corehandlers.ValidateReqSigHandler.Fn(c.Req)
assert.NoError(t, c.Req.Error, "%d, expect no error", i)
assert.Equal(t, c.Resign, resigned, "%d, expected resigning to match", i)
if c.Req.Error != nil {
t.Errorf("expect no error, got %v", c.Req.Error)
}
if e, a := c.Resign, resigned; e != a {
t.Errorf("%d, expect %v to be %v", i, e, a)
}
}
}
func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Header["Content-Length"]
assert.Equal(t, hasContentLength, ok, "expect content length to be set, %t", hasContentLength)
if e, a := hasContentLength, ok; e != a {
t.Errorf("expect %v to be %v", e, a)
}
if hasContentLength {
assert.Equal(t, contentLength, r.ContentLength)
if e, a := contentLength, r.ContentLength; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
b, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
r.Body.Close()
authHeader := r.Header.Get("Authorization")
if hasContentLength {
assert.Contains(t, authHeader, "content-length")
if e, a := "content-length", authHeader; !strings.Contains(a, e) {
t.Errorf("expect %v to be in %v", e, a)
}
} else {
assert.NotContains(t, authHeader, "content-length")
if e, a := "content-length", authHeader; strings.Contains(a, e) {
t.Errorf("expect %v to not be in %v", e, a)
}
}
assert.Equal(t, contentLength, int64(len(b)))
if e, a := contentLength, int64(len(b)); e != a {
t.Errorf("expect %v to be %v", e, a)
}
}))
return server
@@ -316,7 +353,9 @@ func TestBuildContentLength_ZeroBody(t *testing.T) {
Key: aws.String("keyname"),
})
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}
func TestBuildContentLength_NegativeBody(t *testing.T) {
@@ -334,7 +373,9 @@ func TestBuildContentLength_NegativeBody(t *testing.T) {
req.HTTPRequest.Header.Set("Content-Length", "-1")
assert.NoError(t, req.Send())
if req.Error != nil {
t.Errorf("expect no error, got %v", req.Error)
}
}
func TestBuildContentLength_WithBody(t *testing.T) {
@@ -351,5 +392,7 @@ func TestBuildContentLength_WithBody(t *testing.T) {
Body: bytes.NewReader(make([]byte, 1024)),
})
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}
+52 -20
View File
@@ -3,8 +3,7 @@ package corehandlers_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"reflect"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@@ -14,7 +13,6 @@ import (
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/stretchr/testify/require"
)
var testSvc = func() *client.Client {
@@ -113,7 +111,9 @@ func TestNoErrors(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.NoError(t, req.Error)
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
}
func TestMissingRequiredParameters(t *testing.T) {
@@ -121,17 +121,33 @@ func TestMissingRequiredParameters(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation error(s) found.", req.Error.(awserr.Error).Message())
if req.Error == nil {
t.Fatalf("expect error")
}
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
assert.Len(t, errs, 3)
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error())
if e, a := 3, len(errs); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
assert.Equal(t, "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error())
if e, a := "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestNestedMissingRequiredParameters(t *testing.T) {
@@ -148,15 +164,29 @@ func TestNestedMissingRequiredParameters(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation error(s) found.", req.Error.(awserr.Error).Message())
if req.Error == nil {
t.Fatalf("expect error")
}
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
assert.Len(t, errs, 3)
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error())
if e, a := 3, len(errs); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
type testInput struct {
@@ -226,7 +256,9 @@ func TestValidateFieldMinParameter(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
assert.Equal(t, c.err, req.Error, "%d case failed", i)
if e, a := c.err, req.Error; !reflect.DeepEqual(e,a) {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
}
+37
View File
@@ -0,0 +1,37 @@
package corehandlers
import (
"os"
"runtime"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version
// to the user agent.
var SDKVersionUserAgentHandler = request.NamedHandler{
Name: "core.SDKVersionUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
const execEnvVar = `AWS_EXECUTION_ENV`
const execEnvUAKey = `exec_env`
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
// execution environment to the user agent.
//
// If the environment variable AWS_EXECUTION_ENV is set, its value will be
// appended to the user agent string.
var AddHostExecEnvUserAgentHander = request.NamedHandler{
Name: "core.AddHostExecEnvUserAgentHander",
Fn: func(r *request.Request) {
v := os.Getenv(execEnvVar)
if len(v) == 0 {
return
}
request.AddToUserAgent(r, execEnvUAKey+"/"+v)
},
}
+40
View File
@@ -0,0 +1,40 @@
package corehandlers
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestAddHostExecEnvUserAgentHander(t *testing.T) {
cases := []struct {
ExecEnv string
Expect string
}{
{ExecEnv: "Lambda", Expect: "exec_env/Lambda"},
{ExecEnv: "", Expect: ""},
{ExecEnv: "someThingCool", Expect: "exec_env/someThingCool"},
}
for i, c := range cases {
os.Clearenv()
os.Setenv(execEnvVar, c.ExecEnv)
req := &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
}
AddHostExecEnvUserAgentHander.Fn(req)
if err := req.Error; err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.Expect, req.HTTPRequest.Header.Get("User-Agent"); e != a {
t.Errorf("%d, expect %v user agent, got %v", i, e, a)
}
}
}
+33 -2
View File
@@ -9,6 +9,7 @@ package defaults
import (
"fmt"
"net"
"net/http"
"net/url"
"os"
@@ -72,6 +73,7 @@ func Handlers() request.Handlers {
handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
handlers.Validate.AfterEachFn = request.HandlerListStopOnError
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
handlers.Build.PushBackNamed(corehandlers.AddHostExecEnvUserAgentHander)
handlers.Build.AfterEachFn = request.HandlerListStopOnError
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
handlers.Send.PushBackNamed(corehandlers.ValidateReqSigHandler)
@@ -118,14 +120,43 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
return ec2RoleProvider(cfg, handlers)
}
var lookupHostFn = net.LookupHost
func isLoopbackHost(host string) (bool, error) {
ip := net.ParseIP(host)
if ip != nil {
return ip.IsLoopback(), nil
}
// Host is not an ip, perform lookup
addrs, err := lookupHostFn(host)
if err != nil {
return false, err
}
for _, addr := range addrs {
if !net.ParseIP(addr).IsLoopback() {
return false, nil
}
}
return true, nil
}
func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
var errMsg string
parsed, err := url.Parse(u)
if err != nil {
errMsg = fmt.Sprintf("invalid URL, %v", err)
} else if host := aws.URLHostname(parsed); !(host == "localhost" || host == "127.0.0.1") {
errMsg = fmt.Sprintf("invalid host address, %q, only localhost and 127.0.0.1 are valid.", host)
} else {
host := aws.URLHostname(parsed)
if len(host) == 0 {
errMsg = "unable to parse host from local HTTP cred provider URL"
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, loopbackErr)
} else if !isLoopback {
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback hosts are allowed.", host)
}
}
if len(errMsg) > 0 {
+30 -2
View File
@@ -13,12 +13,40 @@ import (
)
func TestHTTPCredProvider(t *testing.T) {
origFn := lookupHostFn
defer func() { lookupHostFn = origFn }()
lookupHostFn = func(host string) ([]string, error) {
m := map[string]struct {
Addrs []string
Err error
}{
"localhost": {Addrs: []string{"::1", "127.0.0.1"}},
"actuallylocal": {Addrs: []string{"127.0.0.2"}},
"notlocal": {Addrs: []string{"::1", "127.0.0.1", "192.168.1.10"}},
"www.example.com": {Addrs: []string{"10.10.10.10"}},
}
h, ok := m[host]
if !ok {
t.Fatalf("unknown host in test, %v", host)
return nil, fmt.Errorf("unknown host")
}
return h.Addrs, h.Err
}
cases := []struct {
Host string
Fail bool
}{
{"localhost", false}, {"127.0.0.1", false},
{"www.example.com", true}, {"169.254.170.2", true},
{"localhost", false},
{"actuallylocal", false},
{"127.0.0.1", false},
{"127.1.1.1", false},
{"[::1]", false},
{"www.example.com", true},
{"169.254.170.2", true},
}
defer os.Clearenv()
+82 -35
View File
@@ -11,8 +11,6 @@ import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
@@ -71,8 +69,12 @@ func TestEndpoint(t *testing.T) {
}
req := c.NewRequest(op, nil, nil)
assert.Equal(t, "http://169.254.169.254/latest", req.ClientInfo.Endpoint)
assert.Equal(t, "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String())
if e, a := "http://169.254.169.254/latest", req.ClientInfo.Endpoint; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetMetadata(t *testing.T) {
@@ -85,8 +87,12 @@ func TestGetMetadata(t *testing.T) {
resp, err := c.GetMetadata("some/path")
assert.NoError(t, err)
assert.Equal(t, "success", resp)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "success", resp; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetUserData(t *testing.T) {
@@ -99,8 +105,12 @@ func TestGetUserData(t *testing.T) {
resp, err := c.GetUserData()
assert.NoError(t, err)
assert.Equal(t, "success", resp)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "success", resp; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetUserData_Error(t *testing.T) {
@@ -126,12 +136,17 @@ func TestGetUserData_Error(t *testing.T) {
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetUserData()
assert.Error(t, err)
assert.Empty(t, resp)
if err == nil {
t.Errorf("expect error")
}
if len(resp) != 0 {
t.Errorf("expect empty, got %v", resp)
}
aerr, ok := err.(awserr.Error)
assert.True(t, ok)
assert.Equal(t, "NotFoundError", aerr.Code())
aerr := err.(awserr.Error)
if e, a := "NotFoundError", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetRegion(t *testing.T) {
@@ -144,8 +159,12 @@ func TestGetRegion(t *testing.T) {
region, err := c.Region()
assert.NoError(t, err)
assert.Equal(t, "us-west-2", region)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "us-west-2", region; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataAvailable(t *testing.T) {
@@ -156,9 +175,9 @@ func TestMetadataAvailable(t *testing.T) {
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
available := c.Available()
assert.True(t, available)
if !c.Available() {
t.Errorf("expect available")
}
}
func TestMetadataIAMInfo_success(t *testing.T) {
@@ -170,10 +189,18 @@ func TestMetadataIAMInfo_success(t *testing.T) {
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
assert.NoError(t, err)
assert.Equal(t, "Success", iamInfo.Code)
assert.Equal(t, "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn)
assert.Equal(t, "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "Success", iamInfo.Code; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataIAMInfo_failure(t *testing.T) {
@@ -185,10 +212,18 @@ func TestMetadataIAMInfo_failure(t *testing.T) {
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
assert.NotNil(t, err)
assert.Equal(t, "", iamInfo.Code)
assert.Equal(t, "", iamInfo.InstanceProfileArn)
assert.Equal(t, "", iamInfo.InstanceProfileID)
if err == nil {
t.Errorf("expect error")
}
if e, a := "", iamInfo.Code; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "", iamInfo.InstanceProfileArn; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "", iamInfo.InstanceProfileID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataNotAvailable(t *testing.T) {
@@ -204,9 +239,9 @@ func TestMetadataNotAvailable(t *testing.T) {
r.Retryable = aws.Bool(true) // network errors are retryable
})
available := c.Available()
assert.False(t, available)
if c.Available() {
t.Errorf("expect not available")
}
}
func TestMetadataErrorResponse(t *testing.T) {
@@ -222,8 +257,12 @@ func TestMetadataErrorResponse(t *testing.T) {
})
data, err := c.GetMetadata("uri/path")
assert.Empty(t, data)
assert.Contains(t, err.Error(), "error message text")
if len(data) != 0 {
t.Errorf("expect empty, got %v", data)
}
if e, a := "error message text", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v to be in %v", e, a)
}
}
func TestEC2RoleProviderInstanceIdentity(t *testing.T) {
@@ -235,8 +274,16 @@ func TestEC2RoleProviderInstanceIdentity(t *testing.T) {
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
doc, err := c.GetInstanceIdentityDocument()
assert.Nil(t, err, "Expect no error, %v", err)
assert.Equal(t, doc.AccountID, "123456789012")
assert.Equal(t, doc.AvailabilityZone, "us-east-1d")
assert.Equal(t, doc.Region, "us-east-1")
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := doc.AccountID, "123456789012"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := doc.AvailabilityZone, "us-east-1d"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := doc.Region, "us-east-1"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
+24
View File
@@ -1,5 +1,10 @@
// Package ec2metadata provides the client for making API calls to the
// EC2 Metadata service.
//
// This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive).
package ec2metadata
import (
@@ -7,17 +12,21 @@ import (
"errors"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
)
// ServiceName is the name of the service.
const ServiceName = "ec2metadata"
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
// A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct {
@@ -75,6 +84,21 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)
// Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send
// requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) {
r.Error = awserr.New(
request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
nil)
},
})
}
// Add additional options to the service config
for _, option := range opts {
option(svc.Client)
+52 -10
View File
@@ -3,21 +3,30 @@ package ec2metadata_test
import (
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)
assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient)
assert.Equal(t, 5*time.Second, svc.Config.HTTPClient.Timeout)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a {
t.Errorf("expect %v, not to equal %v", e, a)
}
if e, a := 5*time.Second, svc.Config.HTTPClient.Timeout; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
@@ -28,18 +37,25 @@ func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, tr)
assert.Nil(t, tr.Dial)
tr := svc.Config.HTTPClient.Transport.(*http.Transport)
if tr == nil {
t.Fatalf("expect transport not to be nil")
}
if tr.Dial != nil {
t.Errorf("expect dial to be nil, was not")
}
}
func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
@@ -63,6 +79,30 @@ func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
runEC2MetadataClients(t, cfg, 100)
}
func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("AWS_EC2_METADATA_DISABLED", "true")
svc := ec2metadata.New(unit.Session)
resp, err := svc.Region()
if err == nil {
t.Fatalf("expect error, got none")
}
if len(resp) != 0 {
t.Errorf("expect no response, got %v", resp)
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %v error code, got %v", e, a)
}
if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expect %v in error message, got %v", e, a)
}
}
func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
var wg sync.WaitGroup
wg.Add(atOnce)
@@ -70,7 +110,9 @@ func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
go func() {
svc := ec2metadata.New(unit.Session, cfg)
_, err := svc.Region()
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
wg.Done()
}()
}
File diff suppressed because it is too large Load Diff
+4
View File
@@ -347,6 +347,10 @@ type ResolvedEndpoint struct {
// The service name that should be used for signing requests.
SigningName string
// States that the signing name for this endpoint was derived from metadata
// passed in, but was not explicitly modeled.
SigningNameDerived bool
// The signing method that should be used for signing requests.
SigningMethod string
}
+8 -4
View File
@@ -226,16 +226,20 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
if len(signingRegion) == 0 {
signingRegion = region
}
signingName := e.CredentialScope.Service
var signingNameDerived bool
if len(signingName) == 0 {
signingName = service
signingNameDerived = true
}
return ResolvedEndpoint{
URL: u,
SigningRegion: signingRegion,
SigningName: signingName,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
URL: u,
SigningRegion: signingRegion,
SigningName: signingName,
SigningNameDerived: signingNameDerived,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
}
}
+265 -78
View File
@@ -2,10 +2,9 @@ package endpoints
import (
"encoding/json"
"reflect"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUnmarshalRegionRegex(t *testing.T) {
@@ -16,12 +15,18 @@ func TestUnmarshalRegionRegex(t *testing.T) {
p := partition{}
err := json.Unmarshal(input, &p)
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectRegexp, err := regexp.Compile(`^(us|eu|ap|sa|ca)\-\w+\-\d+$`)
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
assert.Equal(t, expectRegexp.String(), p.RegionRegex.Regexp.String())
if e, a := expectRegexp.String(), p.RegionRegex.Regexp.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalRegion(t *testing.T) {
@@ -37,16 +42,28 @@ func TestUnmarshalRegion(t *testing.T) {
rs := regions{}
err := json.Unmarshal(input, &rs)
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
assert.Len(t, rs, 2)
if e, a := 2, len(rs); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
r, ok := rs["aws-global"]
assert.True(t, ok)
assert.Equal(t, "AWS partition-global endpoint", r.Description)
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "AWS partition-global endpoint", r.Description; e != a {
t.Errorf("expect %v, got %v", e, a)
}
r, ok = rs["us-east-1"]
assert.True(t, ok)
assert.Equal(t, "US East (N. Virginia)", r.Description)
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "US East (N. Virginia)", r.Description; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalServices(t *testing.T) {
@@ -75,23 +92,45 @@ func TestUnmarshalServices(t *testing.T) {
ss := services{}
err := json.Unmarshal(input, &ss)
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
assert.Len(t, ss, 3)
if e, a := 3, len(ss); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
s, ok := ss["acm"]
assert.True(t, ok)
assert.Len(t, s.Endpoints, 1)
assert.Equal(t, boxedBoolUnset, s.IsRegionalized)
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 1, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedBoolUnset, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
s, ok = ss["apigateway"]
assert.True(t, ok)
assert.Len(t, s.Endpoints, 2)
assert.Equal(t, boxedTrue, s.IsRegionalized)
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 2, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedTrue, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
s, ok = ss["notRegionalized"]
assert.True(t, ok)
assert.Len(t, s.Endpoints, 2)
assert.Equal(t, boxedFalse, s.IsRegionalized)
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 2, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedFalse, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalEndpoints(t *testing.T) {
@@ -115,16 +154,32 @@ func TestUnmarshalEndpoints(t *testing.T) {
es := endpoints{}
err := json.Unmarshal(inputs, &es)
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
assert.Len(t, es, 2)
if e, a := 2, len(es); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
s, ok := es["aws-global"]
assert.True(t, ok)
assert.Equal(t, "cloudfront.amazonaws.com", s.Hostname)
assert.Equal(t, []string{"http", "https"}, s.Protocols)
assert.Equal(t, []string{"v4"}, s.SignatureVersions)
assert.Equal(t, credentialScope{"us-east-1", "serviceName"}, s.CredentialScope)
assert.Equal(t, "commonName", s.SSLCommonName)
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "cloudfront.amazonaws.com", s.Hostname; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := []string{"http", "https"}, s.Protocols; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := []string{"v4"}, s.SignatureVersions; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := (credentialScope{"us-east-1", "serviceName"}), s.CredentialScope; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "commonName", s.SSLCommonName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEndpointResolve(t *testing.T) {
@@ -155,10 +210,18 @@ func TestEndpointResolve(t *testing.T) {
defs, Options{},
)
assert.Equal(t, "https://service.region.dnsSuffix", resolved.URL)
assert.Equal(t, "signing_service", resolved.SigningName)
assert.Equal(t, "signing_region", resolved.SigningRegion)
assert.Equal(t, "v4", resolved.SigningMethod)
if e, a := "https://service.region.dnsSuffix", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "signing_service", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "signing_region", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "v4", resolved.SigningMethod; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEndpointMergeIn(t *testing.T) {
@@ -185,7 +248,9 @@ func TestEndpointMergeIn(t *testing.T) {
},
})
assert.Equal(t, expected, actual)
if e, a := expected, actual; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
}
var testPartitions = partitions{
@@ -213,6 +278,11 @@ var testPartitions = partitions{
Services: services{
"s3": service{},
"service1": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "service1",
},
},
Endpoints: endpoints{
"us-east-1": {},
"us-west-2": {
@@ -221,7 +291,13 @@ var testPartitions = partitions{
},
},
},
"service2": service{},
"service2": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "service2",
},
},
},
"httpService": service{
Defaults: endpoint{
Protocols: []string{"http"},
@@ -246,109 +322,220 @@ var testPartitions = partitions{
func TestResolveEndpoint(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2")
assert.NoError(t, err)
assert.Equal(t, "https://service2.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_DisableSSL(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2", DisableSSLOption)
assert.NoError(t, err)
assert.Equal(t, "http://service2.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "http://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_UseDualStack(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service1", "us-west-2", UseDualStackOption)
assert.NoError(t, err)
assert.Equal(t, "https://service1.dualstack.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "service1", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service1.dualstack.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service1", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_HTTPProtocol(t *testing.T) {
resolved, err := testPartitions.EndpointFor("httpService", "us-west-2")
assert.NoError(t, err)
assert.Equal(t, "http://httpService.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "httpService", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "http://httpService.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "httpService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_UnknownService(t *testing.T) {
_, err := testPartitions.EndpointFor("unknownservice", "us-west-2")
assert.Error(t, err)
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownServiceError)
assert.True(t, ok, "expect error to be UnknownServiceError")
if !ok {
t.Errorf("expect error to be UnknownServiceError")
}
}
func TestResolveEndpoint_ResolveUnknownService(t *testing.T) {
resolved, err := testPartitions.EndpointFor("unknown-service", "us-region-1",
ResolveUnknownServiceOption)
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
assert.Equal(t, "https://unknown-service.us-region-1.amazonaws.com", resolved.URL)
assert.Equal(t, "us-region-1", resolved.SigningRegion)
assert.Equal(t, "unknown-service", resolved.SigningName)
if e, a := "https://unknown-service.us-region-1.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-region-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "unknown-service", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_UnknownMatchedRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-region-1")
assert.NoError(t, err)
assert.Equal(t, "https://service2.us-region-1.amazonaws.com", resolved.URL)
assert.Equal(t, "us-region-1", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.us-region-1.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-region-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_UnknownRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "unknownregion")
assert.NoError(t, err)
assert.Equal(t, "https://service2.unknownregion.amazonaws.com", resolved.URL)
assert.Equal(t, "unknownregion", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.unknownregion.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "unknownregion", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_StrictPartitionUnknownEndpoint(t *testing.T) {
_, err := testPartitions[0].EndpointFor("service2", "unknownregion", StrictMatchingOption)
assert.Error(t, err)
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownEndpointError)
assert.True(t, ok, "expect error to be UnknownEndpointError")
if !ok {
t.Errorf("expect error to be UnknownEndpointError")
}
}
func TestResolveEndpoint_StrictPartitionsUnknownEndpoint(t *testing.T) {
_, err := testPartitions.EndpointFor("service2", "us-region-1", StrictMatchingOption)
assert.Error(t, err)
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownEndpointError)
assert.True(t, ok, "expect error to be UnknownEndpointError")
if !ok {
t.Errorf("expect error to be UnknownEndpointError")
}
}
func TestResolveEndpoint_NotRegionalized(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "us-west-2")
assert.NoError(t, err)
assert.Equal(t, "https://globalService.amazonaws.com", resolved.URL)
assert.Equal(t, "us-east-1", resolved.SigningRegion)
assert.Equal(t, "globalService", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "globalService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_AwsGlobal(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "aws-global")
assert.NoError(t, err)
assert.Equal(t, "https://globalService.amazonaws.com", resolved.URL)
assert.Equal(t, "us-east-1", resolved.SigningRegion)
assert.Equal(t, "globalService", resolved.SigningName)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "globalService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
+3 -1
View File
@@ -3,6 +3,8 @@ package request
import (
"io"
"sync"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
// offsetReader is a thread-safe io.ReadCloser to prevent racing
@@ -15,7 +17,7 @@ type offsetReader struct {
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader {
reader := &offsetReader{}
buf.Seek(offset, 0)
buf.Seek(offset, sdkio.SeekStart)
reader.buf = buf
return reader
+4 -3
View File
@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/aws/aws-sdk-go/internal/sdkio"
"github.com/stretchr/testify/assert"
)
@@ -28,15 +29,15 @@ func TestOffsetReaderSeek(t *testing.T) {
buf := []byte("testData")
reader := newOffsetReader(bytes.NewReader(buf), 0)
orig, err := reader.Seek(0, 1)
orig, err := reader.Seek(0, sdkio.SeekCurrent)
assert.NoError(t, err)
assert.Equal(t, int64(0), orig)
n, err := reader.Seek(0, 2)
n, err := reader.Seek(0, sdkio.SeekEnd)
assert.NoError(t, err)
assert.Equal(t, int64(len(buf)), n)
n, err = reader.Seek(orig, 0)
n, err = reader.Seek(orig, sdkio.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
}
+133 -60
View File
@@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
const (
@@ -28,6 +29,10 @@ const (
// during body reads.
ErrCodeResponseTimeout = "ResponseTimeout"
// ErrCodeInvalidPresignExpire is returned when the expire time provided to
// presign is invalid
ErrCodeInvalidPresignExpire = "InvalidPresignExpireError"
// CanceledErrorCode is the error code that will be returned by an
// API request that was canceled. Requests given a aws.Context may
// return this error when canceled.
@@ -42,7 +47,6 @@ type Request struct {
Retryer
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
@@ -60,6 +64,11 @@ type Request struct {
LastSignedAt time.Time
DisableFollowRedirects bool
// A value greater than 0 instructs the request to be signed as Presigned URL
// You should not set this field directly. Instead use Request's
// Presign or PresignRequest methods.
ExpireTime time.Duration
context aws.Context
built bool
@@ -104,6 +113,8 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err)
}
SanitizeHostForHeader(httpReq)
r := &Request{
Config: cfg,
ClientInfo: clientInfo,
@@ -214,6 +225,9 @@ func (r *Request) SetContext(ctx aws.Context) {
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
if !aws.IsReaderSeekable(r.Body) && r.HTTPRequest.Body != NoBody {
return false
}
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}
@@ -245,45 +259,70 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.Body = reader
r.BodyStart, _ = reader.Seek(0, sdkio.SeekCurrent) // Get the Bodies current offset.
r.ResetBody()
}
// Presign returns the request's signed URL. Error will be returned
// if the signing fails.
func (r *Request) Presign(expireTime time.Duration) (string, error) {
r.ExpireTime = expireTime
//
// It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less.
func (r *Request) Presign(expire time.Duration) (string, error) {
r = r.copy()
// Presign requires all headers be hoisted. There is no way to retrieve
// the signed headers not hoisted without this. Making the presigned URL
// useless.
r.NotHoist = false
if r.Operation.BeforePresignFn != nil {
r = r.copy()
err := r.Operation.BeforePresignFn(r)
if err != nil {
return "", err
}
}
r.Sign()
if r.Error != nil {
return "", r.Error
}
return r.HTTPRequest.URL.String(), nil
u, _, err := getPresignedURL(r, expire)
return u, err
}
// PresignRequest behaves just like presign, with the addition of returning a
// set of headers that were signed.
//
// It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less.
//
// Returns the URL string for the API operation with signature in the query string,
// and the HTTP headers that were included in the signature. These headers must
// be included in any HTTP request made with the presigned URL.
//
// To prevent hoisting any headers to the query string set NotHoist to true on
// this Request value prior to calling PresignRequest.
func (r *Request) PresignRequest(expireTime time.Duration) (string, http.Header, error) {
r.ExpireTime = expireTime
r.Sign()
if r.Error != nil {
return "", nil, r.Error
func (r *Request) PresignRequest(expire time.Duration) (string, http.Header, error) {
r = r.copy()
return getPresignedURL(r, expire)
}
// IsPresigned returns true if the request represents a presigned API url.
func (r *Request) IsPresigned() bool {
return r.ExpireTime != 0
}
func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, error) {
if expire <= 0 {
return "", nil, awserr.New(
ErrCodeInvalidPresignExpire,
"presigned URL requires an expire duration greater than 0",
nil,
)
}
r.ExpireTime = expire
if r.Operation.BeforePresignFn != nil {
if err := r.Operation.BeforePresignFn(r); err != nil {
return "", nil, err
}
}
if err := r.Sign(); err != nil {
return "", nil, err
}
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
}
@@ -303,7 +342,7 @@ func debugLogReqError(r *Request, stage string, retrying bool, err error) {
// Build will build the request's object so it can be signed and sent
// to the service. Build will also validate all the request's parameters.
// Anny additional build Handlers set on this request will be run
// Any additional build Handlers set on this request will be run
// in the order they were set.
//
// The request will only be built once. Multiple calls to build will have
@@ -364,7 +403,7 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// of the SDK if they used that field.
//
// Related golang/go#18257
l, err := computeBodyLength(r.Body)
l, err := aws.SeekerLen(r.Body)
if err != nil {
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
}
@@ -382,7 +421,8 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// Transfer-Encoding: chunked bodies for these methods.
//
// This would only happen if a aws.ReaderSeekerCloser was used with
// a io.Reader that was not also an io.Seeker.
// a io.Reader that was not also an io.Seeker, or did not implement
// Len() method.
switch r.Operation.HTTPMethod {
case "GET", "HEAD", "DELETE":
body = NoBody
@@ -394,42 +434,6 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
return body, nil
}
// Attempts to compute the length of the body of the reader using the
// io.Seeker interface. If the value is not seekable because of being
// a ReaderSeekerCloser without an unerlying Seeker -1 will be returned.
// If no error occurs the length of the body will be returned.
func computeBodyLength(r io.ReadSeeker) (int64, error) {
seekable := true
// Determine if the seeker is actually seekable. ReaderSeekerCloser
// hides the fact that a io.Readers might not actually be seekable.
switch v := r.(type) {
case aws.ReaderSeekerCloser:
seekable = v.IsSeeker()
case *aws.ReaderSeekerCloser:
seekable = v.IsSeeker()
}
if !seekable {
return -1, nil
}
curOffset, err := r.Seek(0, 1)
if err != nil {
return 0, err
}
endOffset, err := r.Seek(0, 2)
if err != nil {
return 0, err
}
_, err = r.Seek(curOffset, 0)
if err != nil {
return 0, err
}
return endOffset - curOffset, nil
}
// GetBody will return an io.ReadSeeker of the Request's underlying
// input body with a concurrency safe wrapper.
func (r *Request) GetBody() io.ReadSeeker {
@@ -579,3 +583,72 @@ func shouldRetryCancel(r *Request) bool {
errStr != "net/http: request canceled while waiting for connection")
}
// SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) {
host := getHost(r)
port := portOnly(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
r.Host = stripPort(host)
}
}
// Returns host from request
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
return r.URL.Host
}
// Hostname returns u.Host, without any port number.
//
// If Host is an IPv6 literal with a port number, Hostname returns the
// IPv6 literal without the square brackets. IPv6 literals may include
// a zone identifier.
//
// Copied from the Go 1.8 standard library (net/url)
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// Port returns the port part of u.Host, without the leading colon.
// If u.Host doesn't contain a port, Port returns an empty string.
//
// Copied from the Go 1.8 standard library (net/url)
func portOnly(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return ""
}
if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[i+len("]:"):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+len(":"):]
}
// Returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
func isDefaultPort(scheme, port string) bool {
if port == "" {
return true
}
lowerCaseScheme := strings.ToLower(scheme)
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
return true
}
return false
}
+20 -5
View File
@@ -142,13 +142,28 @@ func (r *Request) nextPageTokens() []interface{} {
tokens := []interface{}{}
tokenAdded := false
for _, outToken := range r.Operation.OutputTokens {
v, _ := awsutil.ValuesAtPath(r.Data, outToken)
if len(v) > 0 {
tokens = append(tokens, v[0])
tokenAdded = true
} else {
vs, _ := awsutil.ValuesAtPath(r.Data, outToken)
if len(vs) == 0 {
tokens = append(tokens, nil)
continue
}
v := vs[0]
switch tv := v.(type) {
case *string:
if len(aws.StringValue(tv)) == 0 {
tokens = append(tokens, nil)
continue
}
case string:
if len(tv) == 0 {
tokens = append(tokens, nil)
continue
}
}
tokenAdded = true
tokens = append(tokens, v)
}
if !tokenAdded {
return nil
+78 -63
View File
@@ -454,78 +454,93 @@ func TestPaginationWithContextNilInput(t *testing.T) {
}
}
type testPageInput struct {
NextToken string
}
type testPageOutput struct {
Value string
NextToken *string
}
func TestPagination_Standalone(t *testing.T) {
expect := []struct {
Value, PrevToken, NextToken string
}{
{"FirstValue", "InitalToken", "FirstToken"},
{"SecondValue", "FirstToken", "SecondToken"},
{"ThirdValue", "SecondToken", ""},
type testPageInput struct {
NextToken *string
}
input := testPageInput{
NextToken: expect[0].PrevToken,
type testPageOutput struct {
Value *string
NextToken *string
}
type testCase struct {
Value, PrevToken, NextToken *string
}
c := awstesting.NewClient()
i := 0
p := request.Pagination{
NewRequest: func() (*request.Request, error) {
r := c.NewRequest(
&request.Operation{
Name: "Operation",
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
},
},
&input, &testPageOutput{},
)
// Setup handlers for testing
r.Handlers.Clear()
r.Handlers.Build.PushBack(func(req *request.Request) {
in := req.Params.(*testPageInput)
if e, a := expect[i].PrevToken, in.NextToken; e != a {
t.Errorf("%d, expect NextToken input %q, got %q", i, e, a)
}
})
r.Handlers.Unmarshal.PushBack(func(req *request.Request) {
out := &testPageOutput{
Value: expect[i].Value,
}
if len(expect[i].NextToken) > 0 {
out.NextToken = aws.String(expect[i].NextToken)
}
req.Data = out
})
return r, nil
cases := [][]testCase{
{
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
testCase{aws.String("ThirdValue"), aws.String("SecondToken"), nil},
},
{
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
testCase{aws.String("ThirdValue"), aws.String("SecondToken"), aws.String("")},
},
}
for p.Next() {
data := p.Page().(*testPageOutput)
if e, a := expect[i].Value, data.Value; e != a {
t.Errorf("%d, expect Value to be %q, got %q", i, e, a)
}
if e, a := expect[i].NextToken, aws.StringValue(data.NextToken); e != a {
t.Errorf("%d, expect NextToken to be %q, got %q", i, e, a)
for _, c := range cases {
input := testPageInput{
NextToken: c[0].PrevToken,
}
i++
}
if e, a := len(expect), i; e != a {
t.Errorf("expected to process %d pages, did %d", e, a)
}
if err := p.Err(); err != nil {
t.Fatalf("%d, expected no error, got %v", i, err)
svc := awstesting.NewClient()
i := 0
p := request.Pagination{
NewRequest: func() (*request.Request, error) {
r := svc.NewRequest(
&request.Operation{
Name: "Operation",
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
},
},
&input, &testPageOutput{},
)
// Setup handlers for testing
r.Handlers.Clear()
r.Handlers.Build.PushBack(func(req *request.Request) {
if e, a := len(c), i+1; a > e {
t.Fatalf("expect no more than %d requests, got %d", e, a)
}
in := req.Params.(*testPageInput)
if e, a := aws.StringValue(c[i].PrevToken), aws.StringValue(in.NextToken); e != a {
t.Errorf("%d, expect NextToken input %q, got %q", i, e, a)
}
})
r.Handlers.Unmarshal.PushBack(func(req *request.Request) {
out := &testPageOutput{
Value: c[i].Value,
}
if c[i].NextToken != nil {
next := *c[i].NextToken
out.NextToken = aws.String(next)
}
req.Data = out
})
return r, nil
},
}
for p.Next() {
data := p.Page().(*testPageOutput)
if e, a := aws.StringValue(c[i].Value), aws.StringValue(data.Value); e != a {
t.Errorf("%d, expect Value to be %q, got %q", i, e, a)
}
if e, a := aws.StringValue(c[i].NextToken), aws.StringValue(data.NextToken); e != a {
t.Errorf("%d, expect NextToken to be %q, got %q", i, e, a)
}
i++
}
if e, a := len(c), i; e != a {
t.Errorf("expected to process %d pages, did %d", e, a)
}
if err := p.Err(); err != nil {
t.Fatalf("%d, expected no error, got %v", i, err)
}
}
}
+60 -11
View File
@@ -2,6 +2,7 @@ package request
import (
"bytes"
"io"
"net/http"
"strings"
"testing"
@@ -25,21 +26,70 @@ func TestResetBody_WithBodyContents(t *testing.T) {
}
}
func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) {
type mockReader struct{}
func (mockReader) Read([]byte) (int, error) {
return 0, io.EOF
}
func TestResetBody_ExcludeEmptyUnseekableBodyByMethod(t *testing.T) {
cases := []struct {
Method string
Body io.ReadSeeker
IsNoBody bool
}{
{"GET", true},
{"HEAD", true},
{"DELETE", true},
{"PUT", false},
{"PATCH", false},
{"POST", false},
{
Method: "GET",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "HEAD",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "DELETE",
IsNoBody: true,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "PUT",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "PATCH",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "POST",
IsNoBody: false,
Body: aws.ReadSeekCloser(mockReader{}),
},
{
Method: "GET",
IsNoBody: false,
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
},
{
Method: "GET",
IsNoBody: true,
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
},
{
Method: "POST",
IsNoBody: false,
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
},
{
Method: "POST",
IsNoBody: true,
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
},
}
reader := aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc")))
for i, c := range cases {
r := Request{
HTTPRequest: &http.Request{},
@@ -47,8 +97,7 @@ func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) {
HTTPMethod: c.Method,
},
}
r.SetReaderBody(reader)
r.SetReaderBody(c.Body)
if a, e := r.HTTPRequest.Body == NoBody, c.IsNoBody; a != e {
t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a)
+280 -3
View File
@@ -8,9 +8,11 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"runtime"
"strconv"
"strings"
"testing"
"time"
@@ -20,6 +22,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting"
@@ -80,7 +83,7 @@ func TestRequestRecoverRetry5xx(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 502, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
@@ -112,7 +115,8 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
{StatusCode: 400, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
{StatusCode: 429, Body: body(`{"__type":"FooException","message":"Rate exceeded."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
@@ -131,7 +135,7 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 2, int(r.RetryCount); e != a {
if e, a := 3, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
if e, a := "valid", out.Data; e != a {
@@ -842,3 +846,276 @@ func TestRequest_TemporaryRetry(t *testing.T) {
t.Errorf("expect temporary error, was not")
}
}
func TestRequest_Presign(t *testing.T) {
presign := func(r *request.Request, expire time.Duration) (string, http.Header, error) {
u, err := r.Presign(expire)
return u, nil, err
}
presignRequest := func(r *request.Request, expire time.Duration) (string, http.Header, error) {
return r.PresignRequest(expire)
}
mustParseURL := func(v string) *url.URL {
u, err := url.Parse(v)
if err != nil {
panic(err)
}
return u
}
cases := []struct {
Expire time.Duration
PresignFn func(*request.Request, time.Duration) (string, http.Header, error)
SignerFn func(*request.Request)
URL string
Header http.Header
Err string
}{
{
PresignFn: presign,
Err: request.ErrCodeInvalidPresignExpire,
},
{
PresignFn: presignRequest,
Err: request.ErrCodeInvalidPresignExpire,
},
{
Expire: -1,
PresignFn: presign,
Err: request.ErrCodeInvalidPresignExpire,
},
{
// Presign clear NotHoist
Expire: 1 * time.Minute,
PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) {
r.NotHoist = true
return presign(r, dur)
},
SignerFn: func(r *request.Request) {
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
if r.NotHoist {
r.Error = fmt.Errorf("expect NotHoist to be cleared")
}
},
URL: "https://endpoint/presignedURL",
},
{
// PresignRequest does not clear NotHoist
Expire: 1 * time.Minute,
PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) {
r.NotHoist = true
return presignRequest(r, dur)
},
SignerFn: func(r *request.Request) {
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
if !r.NotHoist {
r.Error = fmt.Errorf("expect NotHoist not to be cleared")
}
},
URL: "https://endpoint/presignedURL",
},
{
// PresignRequest returns signed headers
Expire: 1 * time.Minute,
PresignFn: presignRequest,
SignerFn: func(r *request.Request) {
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
r.HTTPRequest.Header.Set("UnsigndHeader", "abc")
r.SignedHeaderVals = http.Header{
"X-Amzn-Header": []string{"abc", "123"},
"X-Amzn-Header2": []string{"efg", "456"},
}
},
URL: "https://endpoint/presignedURL",
Header: http.Header{
"X-Amzn-Header": []string{"abc", "123"},
"X-Amzn-Header2": []string{"efg", "456"},
},
},
}
svc := awstesting.NewClient()
svc.Handlers.Clear()
for i, c := range cases {
req := svc.NewRequest(&request.Operation{
Name: "name", HTTPMethod: "GET", HTTPPath: "/path",
}, &struct{}{}, &struct{}{})
req.Handlers.Sign.PushBack(c.SignerFn)
u, h, err := c.PresignFn(req, c.Expire)
if len(c.Err) != 0 {
if e, a := c.Err, err.Error(); !strings.Contains(a, e) {
t.Errorf("%d, expect %v to be in %v", i, e, a)
}
continue
} else if err != nil {
t.Errorf("%d, expect no error, got %v", i, err)
continue
}
if e, a := c.URL, u; e != a {
t.Errorf("%d, expect %v URL, got %v", i, e, a)
}
if e, a := c.Header, h; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v header got %v", i, e, a)
}
}
}
func TestNew_EndpointWithDefaultPort(t *testing.T) {
endpoint := "https://estest.us-east-1.es.amazonaws.com:443"
expectedRequestHost := "estest.us-east-1.es.amazonaws.com"
r := request.New(
aws.Config{},
metadata.ClientInfo{Endpoint: endpoint},
defaults.Handlers(),
client.DefaultRetryer{},
&request.Operation{},
nil,
nil,
)
if h := r.HTTPRequest.Host; h != expectedRequestHost {
t.Errorf("expect %v host, got %q", expectedRequestHost, h)
}
}
func TestSanitizeHostForHeader(t *testing.T) {
cases := []struct {
url string
expectedRequestHost string
}{
{"https://estest.us-east-1.es.amazonaws.com:443", "estest.us-east-1.es.amazonaws.com"},
{"https://estest.us-east-1.es.amazonaws.com", "estest.us-east-1.es.amazonaws.com"},
{"https://localhost:9200", "localhost:9200"},
{"http://localhost:80", "localhost"},
{"http://localhost:8080", "localhost:8080"},
}
for _, c := range cases {
r, _ := http.NewRequest("GET", c.url, nil)
request.SanitizeHostForHeader(r)
if h := r.Host; h != c.expectedRequestHost {
t.Errorf("expect %v host, got %q", c.expectedRequestHost, h)
}
}
}
func TestRequestWillRetry_ByBody(t *testing.T) {
svc := awstesting.NewClient()
cases := []struct {
WillRetry bool
HTTPMethod string
Body io.ReadSeeker
IsReqNoBody bool
}{
{
WillRetry: true,
HTTPMethod: "GET",
Body: bytes.NewReader([]byte{}),
IsReqNoBody: true,
},
{
WillRetry: true,
HTTPMethod: "GET",
Body: bytes.NewReader(nil),
IsReqNoBody: true,
},
{
WillRetry: true,
HTTPMethod: "POST",
Body: bytes.NewReader([]byte("abc123")),
},
{
WillRetry: true,
HTTPMethod: "POST",
Body: aws.ReadSeekCloser(bytes.NewReader([]byte("abc123"))),
},
{
WillRetry: true,
HTTPMethod: "GET",
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
IsReqNoBody: true,
},
{
WillRetry: true,
HTTPMethod: "POST",
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
IsReqNoBody: true,
},
{
WillRetry: false,
HTTPMethod: "POST",
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc123"))),
},
}
for i, c := range cases {
req := svc.NewRequest(&request.Operation{
Name: "Operation",
HTTPMethod: c.HTTPMethod,
HTTPPath: "/",
}, nil, nil)
req.SetReaderBody(c.Body)
req.Build()
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
if e, a := c.IsReqNoBody, request.NoBody == req.HTTPRequest.Body; e != a {
t.Errorf("%d, expect request to be no body, %t, got %t, %T", i, e, a, req.HTTPRequest.Body)
}
if e, a := c.WillRetry, req.WillRetry(); e != a {
t.Errorf("%d, expect %t willRetry, got %t", i, e, a)
}
if req.Error == nil {
t.Fatalf("%d, expect error, got none", i)
}
if e, a := "some error", req.Error.Error(); !strings.Contains(a, e) {
t.Errorf("%d, expect %q error in %q", i, e, a)
}
if e, a := 0, req.RetryCount; e != a {
t.Errorf("%d, expect retry count to be %d, got %d", i, e, a)
}
}
}
func Test501NotRetrying(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 501, Body: body(`{"__type":"NotImplemented","message":"An error occurred."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err == nil {
t.Fatal("expect error, but got none")
}
aerr := err.(awserr.Error)
if e, a := "NotImplemented", aerr.Code(); e != a {
t.Errorf("expected error code %q, but received %q", e, a)
}
if e, a := 1, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
}
+8
View File
@@ -5,6 +5,7 @@ import (
"strconv"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
)
// EnvProviderName provides a name of the provider when config is loaded from environment.
@@ -176,6 +177,13 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
if len(cfg.SharedCredentialsFile) == 0 {
cfg.SharedCredentialsFile = defaults.SharedCredentialsFilename()
}
if len(cfg.SharedConfigFile) == 0 {
cfg.SharedConfigFile = defaults.SharedConfigFilename()
}
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
return cfg
+37 -10
View File
@@ -7,6 +7,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func TestLoadEnvConfig_Creds(t *testing.T) {
@@ -105,6 +106,8 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "region", Profile: "profile",
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
@@ -116,6 +119,8 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "region", Profile: "profile",
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
@@ -128,7 +133,9 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
@@ -136,6 +143,10 @@ func TestLoadEnvConfig(t *testing.T) {
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Config: envConfig{
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
Env: map[string]string{
@@ -145,7 +156,9 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "default_region", Profile: "default_profile",
EnableSharedConfig: true,
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
@@ -155,7 +168,9 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
@@ -168,7 +183,9 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
@@ -182,7 +199,9 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "region", Profile: "profile",
EnableSharedConfig: true,
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
@@ -193,7 +212,9 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "default_region", Profile: "default_profile",
EnableSharedConfig: true,
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
@@ -205,7 +226,9 @@ func TestLoadEnvConfig(t *testing.T) {
},
Config: envConfig{
Region: "default_region", Profile: "default_profile",
EnableSharedConfig: true,
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
@@ -214,7 +237,9 @@ func TestLoadEnvConfig(t *testing.T) {
"AWS_CA_BUNDLE": "custom_ca_bundle",
},
Config: envConfig{
CustomCABundle: "custom_ca_bundle",
CustomCABundle: "custom_ca_bundle",
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
},
{
@@ -222,8 +247,10 @@ func TestLoadEnvConfig(t *testing.T) {
"AWS_CA_BUNDLE": "custom_ca_bundle",
},
Config: envConfig{
CustomCABundle: "custom_ca_bundle",
EnableSharedConfig: true,
CustomCABundle: "custom_ca_bundle",
EnableSharedConfig: true,
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
SharedConfigFile: shareddefaults.SharedConfigFilename(),
},
UseSharedConfigCall: true,
},
+19 -19
View File
@@ -26,7 +26,7 @@ import (
// Sessions are safe to create service clients concurrently, but it is not safe
// to mutate the Session concurrently.
//
// The Session satisfies the service client's client.ClientConfigProvider.
// The Session satisfies the service client's client.ConfigProvider.
type Session struct {
Config *aws.Config
Handlers request.Handlers
@@ -58,7 +58,12 @@ func New(cfgs ...*aws.Config) *Session {
envCfg := loadEnvConfig()
if envCfg.EnableSharedConfig {
s, err := newSession(Options{}, envCfg, cfgs...)
var cfg aws.Config
cfg.MergeIn(cfgs...)
s, err := NewSessionWithOptions(Options{
Config: cfg,
SharedConfigState: SharedConfigEnable,
})
if err != nil {
// Old session.New expected all errors to be discovered when
// a request is made, and would report the errors then. This
@@ -243,13 +248,6 @@ func NewSessionWithOptions(opts Options) (*Session, error) {
envCfg.EnableSharedConfig = true
}
if len(envCfg.SharedCredentialsFile) == 0 {
envCfg.SharedCredentialsFile = defaults.SharedCredentialsFilename()
}
if len(envCfg.SharedConfigFile) == 0 {
envCfg.SharedConfigFile = defaults.SharedConfigFilename()
}
// Only use AWS_CA_BUNDLE if session option is not provided.
if len(envCfg.CustomCABundle) != 0 && opts.CustomCABundle == nil {
f, err := os.Open(envCfg.CustomCABundle)
@@ -573,11 +571,12 @@ func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (
}
return client.Config{
Config: s.Config,
Handlers: s.Handlers,
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningName: resolved.SigningName,
Config: s.Config,
Handlers: s.Handlers,
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName,
}, err
}
@@ -597,10 +596,11 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
}
return client.Config{
Config: s.Config,
Handlers: s.Handlers,
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningName: resolved.SigningName,
Config: s.Config,
Handlers: s.Handlers,
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName,
}
}
+23 -5
View File
@@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/service/s3"
)
@@ -89,14 +90,31 @@ func TestSessionCopy(t *testing.T) {
}
func TestSessionClientConfig(t *testing.T) {
s, err := NewSession(&aws.Config{Region: aws.String("orig_region")})
s, err := NewSession(&aws.Config{
Credentials: credentials.AnonymousCredentials,
Region: aws.String("orig_region"),
EndpointResolver: endpoints.ResolverFunc(
func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if e, a := "mock-service", service; e != a {
t.Errorf("expect %q service, got %q", e, a)
}
if e, a := "other-region", region; e != a {
t.Errorf("expect %q region, got %q", e, a)
}
return endpoints.ResolvedEndpoint{
URL: "https://" + service + "." + region + ".amazonaws.com",
SigningRegion: region,
}, nil
},
),
})
assert.NoError(t, err)
cfg := s.ClientConfig("s3", &aws.Config{Region: aws.String("us-west-2")})
cfg := s.ClientConfig("mock-service", &aws.Config{Region: aws.String("other-region")})
assert.Equal(t, "https://s3-us-west-2.amazonaws.com", cfg.Endpoint)
assert.Equal(t, "us-west-2", cfg.SigningRegion)
assert.Equal(t, "us-west-2", *cfg.Config.Region)
assert.Equal(t, "https://mock-service.other-region.amazonaws.com", cfg.Endpoint)
assert.Equal(t, "other-region", cfg.SigningRegion)
assert.Equal(t, "other-region", *cfg.Config.Region)
}
func TestNewSession_NoCredentials(t *testing.T) {
+30 -11
View File
@@ -10,7 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestStandaloneSign(t *testing.T) {
@@ -22,7 +21,9 @@ func TestStandaloneSign(t *testing.T) {
c.SubDomain, c.Region, c.Service)
req, err := http.NewRequest("GET", host, nil)
assert.NoError(t, err)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
// URL.EscapedPath() will be used by the signer to get the
// escaped form of the request's URI path.
@@ -30,12 +31,20 @@ func TestStandaloneSign(t *testing.T) {
req.URL.RawQuery = c.OrigQuery
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
assert.NoError(t, err)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
actual := req.Header.Get("Authorization")
assert.Equal(t, c.ExpSig, actual)
assert.Equal(t, c.OrigURI, req.URL.Path)
assert.Equal(t, c.EscapedURI, req.URL.EscapedPath())
if e, a := c.ExpSig, actual; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.OrigURI, req.URL.Path; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
}
}
@@ -48,7 +57,9 @@ func TestStandaloneSign_RawPath(t *testing.T) {
c.SubDomain, c.Region, c.Service)
req, err := http.NewRequest("GET", host, nil)
assert.NoError(t, err)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
// URL.EscapedPath() will be used by the signer to get the
// escaped form of the request's URI path.
@@ -57,11 +68,19 @@ func TestStandaloneSign_RawPath(t *testing.T) {
req.URL.RawQuery = c.OrigQuery
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
assert.NoError(t, err)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
actual := req.Header.Get("Authorization")
assert.Equal(t, c.ExpSig, actual)
assert.Equal(t, c.OrigURI, req.URL.Path)
assert.Equal(t, c.EscapedURI, req.URL.EscapedPath())
if e, a := c.ExpSig, actual; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.OrigURI, req.URL.Path; e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
t.Errorf("expected %v, but recieved %v", e, a)
}
}
}
+88
View File
@@ -164,3 +164,91 @@ func TestStandaloneSign_CustomURIEscape(t *testing.T) {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestStandaloneSign_WithPort(t *testing.T) {
cases := []struct {
description string
url string
expectedSig string
}{
{
"default HTTPS port",
"https://estest.us-east-1.es.amazonaws.com:443/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=e573fc9aa3a156b720976419319be98fb2824a3abc2ddd895ecb1d1611c6a82d",
},
{
"default HTTP port",
"http://example.com:80/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=54ebe60c4ae03a40948b849e13c333523235f38002e2807059c64a9a8c7cb951",
},
{
"non-standard HTTP port",
"http://example.com:9200/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=cd9d926a460f8d3b58b57beadbd87666dc667e014c0afaa4cea37b2867f51b4f",
},
{
"non-standard HTTPS port",
"https://example.com:9200/_search",
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=cd9d926a460f8d3b58b57beadbd87666dc667e014c0afaa4cea37b2867f51b4f",
},
}
for _, c := range cases {
signer := v4.NewSigner(unit.Session.Config.Credentials)
req, _ := http.NewRequest("GET", c.url, nil)
_, err := signer.Sign(req, nil, "es", "us-east-1", time.Unix(0, 0))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
actual := req.Header.Get("Authorization")
if e, a := c.expectedSig, actual; e != a {
t.Errorf("%s, expect %v, got %v", c.description, e, a)
}
}
}
func TestStandalonePresign_WithPort(t *testing.T) {
cases := []struct {
description string
url string
expectedSig string
}{
{
"default HTTPS port",
"https://estest.us-east-1.es.amazonaws.com:443/_search",
"0abcf61a351063441296febf4b485734d780634fba8cf1e7d9769315c35255d6",
},
{
"default HTTP port",
"http://example.com:80/_search",
"fce9976dd6c849c21adfa6d3f3e9eefc651d0e4a2ccd740d43efddcccfdc8179",
},
{
"non-standard HTTP port",
"http://example.com:9200/_search",
"f33c25a81c735e42bef35ed5e9f720c43940562e3e616ff0777bf6dde75249b0",
},
{
"non-standard HTTPS port",
"https://example.com:9200/_search",
"f33c25a81c735e42bef35ed5e9f720c43940562e3e616ff0777bf6dde75249b0",
},
}
for _, c := range cases {
signer := v4.NewSigner(unit.Session.Config.Credentials)
req, _ := http.NewRequest("GET", c.url, nil)
_, err := signer.Presign(req, nil, "es", "us-east-1", 5 * time.Minute, time.Unix(0, 0))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
actual := req.URL.Query().Get("X-Amz-Signature")
if e, a := c.expectedSig, actual; e != a {
t.Errorf("%s, expect %v, got %v", c.description, e, a)
}
}
}
+33 -13
View File
@@ -2,8 +2,6 @@ package v4
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRuleCheckWhitelist(t *testing.T) {
@@ -13,8 +11,12 @@ func TestRuleCheckWhitelist(t *testing.T) {
},
}
assert.True(t, w.IsValid("Cache-Control"))
assert.False(t, w.IsValid("Cache-"))
if !w.IsValid("Cache-Control") {
t.Error("expected true value")
}
if w.IsValid("Cache-") {
t.Error("expected false value")
}
}
func TestRuleCheckBlacklist(t *testing.T) {
@@ -24,16 +26,26 @@ func TestRuleCheckBlacklist(t *testing.T) {
},
}
assert.False(t, b.IsValid("Cache-Control"))
assert.True(t, b.IsValid("Cache-"))
if b.IsValid("Cache-Control") {
t.Error("expected false value")
}
if !b.IsValid("Cache-") {
t.Error("expected true value")
}
}
func TestRuleCheckPattern(t *testing.T) {
p := patterns{"X-Amz-Meta-"}
assert.True(t, p.IsValid("X-Amz-Meta-"))
assert.True(t, p.IsValid("X-Amz-Meta-Star"))
assert.False(t, p.IsValid("Cache-"))
if !p.IsValid("X-Amz-Meta-") {
t.Error("expected true value")
}
if !p.IsValid("X-Amz-Meta-Star") {
t.Error("expected true value")
}
if p.IsValid("Cache-") {
t.Error("expected false value")
}
}
func TestRuleComplexWhitelist(t *testing.T) {
@@ -50,8 +62,16 @@ func TestRuleComplexWhitelist(t *testing.T) {
inclusiveRules{patterns{"X-Amz-"}, blacklist{w}},
}
assert.True(t, r.IsValid("X-Amz-Blah"))
assert.False(t, r.IsValid("X-Amz-Meta-"))
assert.False(t, r.IsValid("X-Amz-Meta-Star"))
assert.False(t, r.IsValid("Cache-Control"))
if !r.IsValid("X-Amz-Blah") {
t.Error("expected true value")
}
if r.IsValid("X-Amz-Meta-") {
t.Error("expected false value")
}
if r.IsValid("X-Amz-Meta-Star") {
t.Error("expected false value")
}
if r.IsValid("Cache-Control") {
t.Error("expected false value")
}
}
+28 -11
View File
@@ -71,6 +71,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkio"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
@@ -268,7 +269,7 @@ type signingCtx struct {
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
// only compute the hash if the request header value is empty.
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, 0, signTime)
return v4.signWithBody(r, body, service, region, 0, false, signTime)
}
// Presign signs AWS v4 requests with the provided body, service name, region
@@ -302,10 +303,10 @@ func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region strin
// presigned request's signature you can set the "X-Amz-Content-Sha256"
// HTTP header and that will be included in the request's signature.
func (v4 Signer) Presign(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, exp, signTime)
return v4.signWithBody(r, body, service, region, exp, true, signTime)
}
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, isPresign bool, signTime time.Time) (http.Header, error) {
currentTimeFn := v4.currentTimeFn
if currentTimeFn == nil {
currentTimeFn = time.Now
@@ -317,7 +318,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
Query: r.URL.Query(),
Time: signTime,
ExpireTime: exp,
isPresign: exp != 0,
isPresign: isPresign,
ServiceName: service,
Region: region,
DisableURIPathEscaping: v4.DisableURIPathEscaping,
@@ -339,8 +340,11 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
return http.Header{}, err
}
ctx.sanitizeHostForHeader()
ctx.assignAmzQueryValues()
ctx.build(v4.DisableHeaderHoisting)
if err := ctx.build(v4.DisableHeaderHoisting); err != nil {
return nil, err
}
// If the request is not presigned the body should be attached to it. This
// prevents the confusion of wanting to send a signed request without
@@ -363,6 +367,10 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
return ctx.SignedHeaderVals, nil
}
func (ctx *signingCtx) sanitizeHostForHeader() {
request.SanitizeHostForHeader(ctx.Request)
}
func (ctx *signingCtx) handlePresignRemoval() {
if !ctx.isPresign {
return
@@ -467,7 +475,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
}
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, signingTime,
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime,
)
if err != nil {
req.Error = err
@@ -498,11 +506,13 @@ func (v4 *Signer) logSigningInfo(ctx *signingCtx) {
v4.Logger.Log(msg)
}
func (ctx *signingCtx) build(disableHeaderHoisting bool) {
func (ctx *signingCtx) build(disableHeaderHoisting bool) error {
ctx.buildTime() // no depends
ctx.buildCredentialString() // no depends
ctx.buildBodyDigest()
if err := ctx.buildBodyDigest(); err != nil {
return err
}
unsignedHeaders := ctx.Request.Header
if ctx.isPresign {
@@ -530,6 +540,8 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) {
}
ctx.Request.Header.Set("Authorization", strings.Join(parts, ", "))
}
return nil
}
func (ctx *signingCtx) buildTime() {
@@ -656,7 +668,7 @@ func (ctx *signingCtx) buildSignature() {
ctx.signature = hex.EncodeToString(signature)
}
func (ctx *signingCtx) buildBodyDigest() {
func (ctx *signingCtx) buildBodyDigest() error {
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if ctx.unsignedPayload || (ctx.isPresign && ctx.ServiceName == "s3") {
@@ -664,6 +676,9 @@ func (ctx *signingCtx) buildBodyDigest() {
} else if ctx.Body == nil {
hash = emptyStringSHA256
} else {
if !aws.IsReaderSeekable(ctx.Body) {
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
}
hash = hex.EncodeToString(makeSha256Reader(ctx.Body))
}
if ctx.unsignedPayload || ctx.ServiceName == "s3" || ctx.ServiceName == "glacier" {
@@ -671,6 +686,8 @@ func (ctx *signingCtx) buildBodyDigest() {
}
}
ctx.bodyDigest = hash
return nil
}
// isRequestSigned returns if the request is currently signed or presigned
@@ -710,8 +727,8 @@ func makeSha256(data []byte) []byte {
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
start, _ := reader.Seek(0, sdkio.SeekCurrent)
defer reader.Seek(start, sdkio.SeekStart)
io.Copy(hash, reader)
return hash.Sum(nil)
+85 -12
View File
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"strings"
"testing"
"time"
@@ -61,17 +62,42 @@ func TestStripExcessHeaders(t *testing.T) {
}
func buildRequest(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
return buildRequestWithBodyReader(serviceName, region, reader)
}
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) {
var bodyLen int
type lenner interface {
Len() int
}
if lr, ok := body.(lenner); ok {
bodyLen = lr.Len()
}
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, _ := http.NewRequest("POST", endpoint, body)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
req.Header.Set("X-Amz-Target", "prefix.Operation")
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
if bodyLen > 0 {
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
}
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
return req, reader
var seeker io.ReadSeeker
if sr, ok := body.(io.ReadSeeker); ok {
seeker = sr
} else {
seeker = aws.ReadSeekCloser(body)
}
return req, seeker
}
func buildSigner() Signer {
@@ -101,7 +127,7 @@ func TestPresignRequest(t *testing.T) {
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "ea7856749041f727690c580569738282e99c79355fe0d8f125d3b5535d2ece83"
expectedSig := "122f0b9e091e4ba84286097e2b3404a1f1f4c4aad479adda95b7dff0ccbe5581"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
@@ -135,7 +161,7 @@ func TestPresignBodyWithArrayRequest(t *testing.T) {
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "fef6002062400bbf526d70f1a6456abc0fb2e213fe1416012737eebd42a62924"
expectedSig := "e3ac55addee8711b76c6d608d762cff285fe8b627a057f8b5ec9268cf82c08b1"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
@@ -166,14 +192,14 @@ func TestSignRequest(t *testing.T) {
signer.Sign(req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=ea766cabd2ec977d955a3c2bae1ae54f4515d70752f2207618396f20aa85bd21"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
q := req.Header
if e, a := expectedSig, q.Get("Authorization"); e != a {
t.Errorf("expect %v, got %v", e, a)
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
}
}
@@ -207,6 +233,53 @@ func TestPresignEmptyBodyS3(t *testing.T) {
}
}
func TestSignUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err == nil {
t.Fatalf("expect error signing request")
}
if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q to be in %q", e, a)
}
}
func TestSignUnsignedPayloadUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
signer.UnsignedPayload = true
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "UNSIGNED-PAYLOAD", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignPreComputedHashUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "some-content-sha256", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "hello")
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
+83
View File
@@ -3,6 +3,8 @@ package aws
import (
"io"
"sync"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should
@@ -22,6 +24,22 @@ type ReaderSeekerCloser struct {
r io.Reader
}
// IsReaderSeekable returns if the underlying reader type can be seeked. A
// io.Reader might not actually be seekable if it is the ReaderSeekerCloser
// type.
func IsReaderSeekable(r io.Reader) bool {
switch v := r.(type) {
case ReaderSeekerCloser:
return v.IsSeeker()
case *ReaderSeekerCloser:
return v.IsSeeker()
case io.ReadSeeker:
return true
default:
return false
}
}
// Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned.
//
@@ -56,6 +74,71 @@ func (r ReaderSeekerCloser) IsSeeker() bool {
return ok
}
// HasLen returns the length of the underlying reader if the value implements
// the Len() int method.
func (r ReaderSeekerCloser) HasLen() (int, bool) {
type lenner interface {
Len() int
}
if lr, ok := r.r.(lenner); ok {
return lr.Len(), true
}
return 0, false
}
// GetLen returns the length of the bytes remaining in the underlying reader.
// Checks first for Len(), then io.Seeker to determine the size of the
// underlying reader.
//
// Will return -1 if the length cannot be determined.
func (r ReaderSeekerCloser) GetLen() (int64, error) {
if l, ok := r.HasLen(); ok {
return int64(l), nil
}
if s, ok := r.r.(io.Seeker); ok {
return seekerLen(s)
}
return -1, nil
}
// SeekerLen attempts to get the number of bytes remaining at the seeker's
// current position. Returns the number of bytes remaining or error.
func SeekerLen(s io.Seeker) (int64, error) {
// Determine if the seeker is actually seekable. ReaderSeekerCloser
// hides the fact that a io.Readers might not actually be seekable.
switch v := s.(type) {
case ReaderSeekerCloser:
return v.GetLen()
case *ReaderSeekerCloser:
return v.GetLen()
}
return seekerLen(s)
}
func seekerLen(s io.Seeker) (int64, error) {
curOffset, err := s.Seek(0, sdkio.SeekCurrent)
if err != nil {
return 0, err
}
endOffset, err := s.Seek(0, sdkio.SeekEnd)
if err != nil {
return 0, err
}
_, err = s.Seek(curOffset, sdkio.SeekStart)
if err != nil {
return 0, err
}
return endOffset - curOffset, nil
}
// Close closes the ReaderSeekerCloser.
//
// If the ReaderSeekerCloser is not an io.Closer nothing will be done.
+28 -11
View File
@@ -1,32 +1,49 @@
package aws
import (
"bytes"
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWriteAtBuffer(t *testing.T) {
b := &WriteAtBuffer{}
n, err := b.WriteAt([]byte{1}, 0)
assert.NoError(t, err)
assert.Equal(t, 1, n)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 1, n; e != a {
t.Errorf("expected %d, but recieved %d", e, a)
}
n, err = b.WriteAt([]byte{1, 1, 1}, 5)
assert.NoError(t, err)
assert.Equal(t, 3, n)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 3, n; e != a {
t.Errorf("expected %d, but recieved %d", e, a)
}
n, err = b.WriteAt([]byte{2}, 1)
assert.NoError(t, err)
assert.Equal(t, 1, n)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 1, n; e != a {
t.Errorf("expected %d, but recieved %d", e, a)
}
n, err = b.WriteAt([]byte{3}, 2)
assert.NoError(t, err)
assert.Equal(t, 1, n)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 1, n; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
assert.Equal(t, []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
if !bytes.Equal([]byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes()) {
t.Errorf("expected %v, but received %v", []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
}
}
func BenchmarkWriteAtBuffer(b *testing.B) {
+1 -1
View File
@@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "1.12.1"
const SDKVersion = "1.13.31"