mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-06-09 06:04:12 +00:00
Update vendored deps, including AWS SDK, openpgp, ftp, ...
This commit is contained in:
+133
-45
@@ -5,11 +5,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awsutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func ExampleCopy() {
|
||||
@@ -81,12 +81,24 @@ func TestCopy1(t *testing.T) {
|
||||
awsutil.Copy(&f2, f1)
|
||||
|
||||
// Values are equal
|
||||
assert.Equal(t, f2.A, f1.A)
|
||||
assert.Equal(t, f2.B, f1.B)
|
||||
assert.Equal(t, f2.C, f1.C)
|
||||
assert.Equal(t, f2.D, f1.D)
|
||||
assert.Equal(t, f2.E.B, f1.E.B)
|
||||
assert.Equal(t, f2.E.D, f1.E.D)
|
||||
if v1, v2 := f2.A, f1.A; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.D, f1.D; !v1.Equal(*v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.E.B, f1.E.B; !reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.E.D, f1.E.D; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
|
||||
// But pointers are not!
|
||||
str3 := "nothello"
|
||||
@@ -99,14 +111,30 @@ func TestCopy1(t *testing.T) {
|
||||
*f2.E.B = int3
|
||||
f2.E.c = 5
|
||||
f2.E.D = 5
|
||||
assert.NotEqual(t, f2.A, f1.A)
|
||||
assert.NotEqual(t, f2.B, f1.B)
|
||||
assert.NotEqual(t, f2.C, f1.C)
|
||||
assert.NotEqual(t, f2.D, f1.D)
|
||||
assert.NotEqual(t, f2.E.a, f1.E.a)
|
||||
assert.NotEqual(t, f2.E.B, f1.E.B)
|
||||
assert.NotEqual(t, f2.E.c, f1.E.c)
|
||||
assert.NotEqual(t, f2.E.D, f1.E.D)
|
||||
if v1, v2 := f2.A, f1.A; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.B, f1.B; reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.C, f1.C; reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.D, f1.D; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.E.a, f1.E.a; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.E.B, f1.E.B; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.E.c, f1.E.c; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.E.D, f1.E.D; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyNestedWithUnexported(t *testing.T) {
|
||||
@@ -125,10 +153,18 @@ func TestCopyNestedWithUnexported(t *testing.T) {
|
||||
awsutil.Copy(&f2, f1)
|
||||
|
||||
// Values match
|
||||
assert.Equal(t, f2.A, f1.A)
|
||||
assert.NotEqual(t, f2.B, f1.B)
|
||||
assert.NotEqual(t, f2.B.a, f1.B.a)
|
||||
assert.Equal(t, f2.B.B, f2.B.B)
|
||||
if v1, v2 := f2.A, f1.A; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.B, f1.B; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.B.a, f1.B.a; v1 == v2 {
|
||||
t.Errorf("expected values to be not equivalent, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := f2.B.B, f2.B.B; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyIgnoreNilMembers(t *testing.T) {
|
||||
@@ -139,34 +175,56 @@ func TestCopyIgnoreNilMembers(t *testing.T) {
|
||||
}
|
||||
|
||||
f := &Foo{}
|
||||
assert.Nil(t, f.A)
|
||||
assert.Nil(t, f.B)
|
||||
assert.Nil(t, f.C)
|
||||
if v1 := f.A; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
if v1 := f.B; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
if v1 := f.C; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
|
||||
var f2 Foo
|
||||
awsutil.Copy(&f2, f)
|
||||
assert.Nil(t, f2.A)
|
||||
assert.Nil(t, f2.B)
|
||||
assert.Nil(t, f2.C)
|
||||
if v1 := f2.A; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
if v1 := f2.B; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
if v1 := f2.C; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
|
||||
fcopy := awsutil.CopyOf(f)
|
||||
f3 := fcopy.(*Foo)
|
||||
assert.Nil(t, f3.A)
|
||||
assert.Nil(t, f3.B)
|
||||
assert.Nil(t, f3.C)
|
||||
if v1 := f3.A; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
if v1 := f3.B; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
if v1 := f3.C; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyPrimitive(t *testing.T) {
|
||||
str := "hello"
|
||||
var s string
|
||||
awsutil.Copy(&s, &str)
|
||||
assert.Equal(t, "hello", s)
|
||||
if v1, v2 := "hello", s; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyNil(t *testing.T) {
|
||||
var s string
|
||||
awsutil.Copy(&s, nil)
|
||||
assert.Equal(t, "", s)
|
||||
if v1, v2 := "", s; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyReader(t *testing.T) {
|
||||
@@ -174,13 +232,21 @@ func TestCopyReader(t *testing.T) {
|
||||
var r io.Reader
|
||||
awsutil.Copy(&r, buf)
|
||||
b, err := ioutil.ReadAll(r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello world"), b)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if v1, v2 := []byte("hello world"), b; !bytes.Equal(v1, v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
|
||||
// empty bytes because this is not a deep copy
|
||||
b, err = ioutil.ReadAll(buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), b)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if v1, v2 := []byte(""), b; !bytes.Equal(v1, v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyDifferentStructs(t *testing.T) {
|
||||
@@ -226,17 +292,39 @@ func TestCopyDifferentStructs(t *testing.T) {
|
||||
awsutil.Copy(&f2, f1)
|
||||
|
||||
// Values are equal
|
||||
assert.Equal(t, f2.A, f1.A)
|
||||
assert.Equal(t, f2.B, f1.B)
|
||||
assert.Equal(t, f2.C, f1.C)
|
||||
assert.Equal(t, "unique", f1.SrcUnique)
|
||||
assert.Equal(t, 1, f1.SameNameDiffType)
|
||||
assert.Equal(t, 0, f2.DstUnique)
|
||||
assert.Equal(t, "", f2.SameNameDiffType)
|
||||
assert.Equal(t, int1, *f1.unexportedPtr)
|
||||
assert.Nil(t, f2.unexportedPtr)
|
||||
assert.Equal(t, int2, *f1.ExportedPtr)
|
||||
assert.Equal(t, int2, *f2.ExportedPtr)
|
||||
if v1, v2 := f2.A, f1.A; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := "unique", f1.SrcUnique; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := 1, f1.SameNameDiffType; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := 0, f2.DstUnique; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := "", f2.SameNameDiffType; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := int1, *f1.unexportedPtr; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1 := f2.unexportedPtr; v1 != nil {
|
||||
t.Errorf("expected nil, but received %v", v1)
|
||||
}
|
||||
if v1, v2 := int2, *f1.ExportedPtr; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
if v1, v2 := int2, *f2.ExportedPtr; v1 != v2 {
|
||||
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleCopyOf() {
|
||||
|
||||
+3
-2
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awsutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeepEqual(t *testing.T) {
|
||||
@@ -24,6 +23,8 @@ func TestDeepEqual(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
assert.Equal(t, c.equal, awsutil.DeepEqual(c.a, c.b), "%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
|
||||
if awsutil.DeepEqual(c.a, c.b) != c.equal {
|
||||
t.Errorf("%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+61
-21
@@ -1,10 +1,10 @@
|
||||
package awsutil_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awsutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type Struct struct {
|
||||
@@ -50,8 +50,12 @@ func TestValueAtPathSuccess(t *testing.T) {
|
||||
}
|
||||
for i, c := range testCases {
|
||||
v, err := awsutil.ValuesAtPath(c.data, c.path)
|
||||
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
|
||||
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
|
||||
if err != nil {
|
||||
t.Errorf("case %v, expected no error, %v", i, c.path)
|
||||
}
|
||||
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
|
||||
t.Errorf("case %v, %v", i, c.path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,12 +82,18 @@ func TestValueAtPathFailure(t *testing.T) {
|
||||
for i, c := range testCases {
|
||||
v, err := awsutil.ValuesAtPath(c.data, c.path)
|
||||
if c.errContains != "" {
|
||||
assert.Contains(t, err.Error(), c.errContains, "case %d, expected error, %s", i, c.path)
|
||||
if !strings.Contains(err.Error(), c.errContains) {
|
||||
t.Errorf("case %v, expected error, %v", i, c.path)
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
|
||||
if err != nil {
|
||||
t.Errorf("case %v, expected no error, %v", i, c.path)
|
||||
}
|
||||
}
|
||||
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
|
||||
t.Errorf("case %v, %v", i, c.path)
|
||||
}
|
||||
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,51 +102,81 @@ func TestSetValueAtPathSuccess(t *testing.T) {
|
||||
awsutil.SetValueAtPath(&s, "C", "test1")
|
||||
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
|
||||
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
|
||||
assert.Equal(t, "test1", s.C)
|
||||
assert.Equal(t, "test2", s.B.B.C)
|
||||
assert.Equal(t, "test3", s.B.D.C)
|
||||
if e, a := "test1", s.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
if e, a := "test2", s.B.B.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
if e, a := "test3", s.B.D.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
|
||||
assert.Equal(t, "test0", s.B.B.C)
|
||||
assert.Equal(t, "test0", s.B.D.C)
|
||||
if e, a := "test0", s.B.B.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
if e, a := "test0", s.B.D.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
var s2 Struct
|
||||
awsutil.SetValueAtPath(&s2, "b.b.c", "test0")
|
||||
assert.Equal(t, "test0", s2.B.B.C)
|
||||
if e, a := "test0", s2.B.B.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
|
||||
assert.Equal(t, []Struct{{}}, s2.A)
|
||||
if e, a := []Struct{{}}, s2.A; !awsutil.DeepEqual(e, a) {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
str := "foo"
|
||||
|
||||
s3 := Struct{}
|
||||
awsutil.SetValueAtPath(&s3, "b.b.c", str)
|
||||
assert.Equal(t, "foo", s3.B.B.C)
|
||||
if e, a := "foo", s3.B.B.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
s3 = Struct{B: &Struct{B: &Struct{C: str}}}
|
||||
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
|
||||
assert.Equal(t, "", s3.B.B.C)
|
||||
if e, a := "", s3.B.B.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
s3 = Struct{}
|
||||
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
|
||||
assert.Equal(t, "", s3.B.B.C)
|
||||
if e, a := "", s3.B.B.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
s3 = Struct{}
|
||||
awsutil.SetValueAtPath(&s3, "b.b.c", &str)
|
||||
assert.Equal(t, "foo", s3.B.B.C)
|
||||
if e, a := "foo", s3.B.B.C; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
var s4 struct{ Name *string }
|
||||
awsutil.SetValueAtPath(&s4, "Name", str)
|
||||
assert.Equal(t, str, *s4.Name)
|
||||
if e, a := str, *s4.Name; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
s4 = struct{ Name *string }{}
|
||||
awsutil.SetValueAtPath(&s4, "Name", nil)
|
||||
assert.Equal(t, (*string)(nil), s4.Name)
|
||||
if e, a := (*string)(nil), s4.Name; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
s4 = struct{ Name *string }{Name: &str}
|
||||
awsutil.SetValueAtPath(&s4, "Name", nil)
|
||||
assert.Equal(t, (*string)(nil), s4.Name)
|
||||
if e, a := (*string)(nil), s4.Name; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
|
||||
s4 = struct{ Name *string }{}
|
||||
awsutil.SetValueAtPath(&s4, "Name", &str)
|
||||
assert.Equal(t, str, *s4.Name)
|
||||
if e, a := str, *s4.Name; e != a {
|
||||
t.Errorf("expected %v, but received %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+6
@@ -15,6 +15,12 @@ type Config struct {
|
||||
Endpoint string
|
||||
SigningRegion string
|
||||
SigningName string
|
||||
|
||||
// States that the signing name did not come from a modeled source but
|
||||
// was derived based on other data. Used by service client constructors
|
||||
// to determine if the signin name can be overriden based on metadata the
|
||||
// service has.
|
||||
SigningNameDerived bool
|
||||
}
|
||||
|
||||
// ConfigProvider provides a generic way for a service client to receive
|
||||
|
||||
+48
-28
@@ -1,11 +1,11 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkrand"
|
||||
)
|
||||
|
||||
// DefaultRetryer implements basic retry logic using exponential backoff for
|
||||
@@ -30,25 +30,27 @@ func (d DefaultRetryer) MaxRetries() int {
|
||||
return d.NumMaxRetries
|
||||
}
|
||||
|
||||
var seededRand = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())})
|
||||
|
||||
// RetryRules returns the delay duration before retrying this request again
|
||||
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
|
||||
// Set the upper limit of delay in retrying at ~five minutes
|
||||
minTime := 30
|
||||
throttle := d.shouldThrottle(r)
|
||||
if throttle {
|
||||
if delay, ok := getRetryDelay(r); ok {
|
||||
return delay
|
||||
}
|
||||
|
||||
minTime = 500
|
||||
}
|
||||
|
||||
retryCount := r.RetryCount
|
||||
if retryCount > 13 {
|
||||
retryCount = 13
|
||||
} else if throttle && retryCount > 8 {
|
||||
if throttle && retryCount > 8 {
|
||||
retryCount = 8
|
||||
} else if retryCount > 13 {
|
||||
retryCount = 13
|
||||
}
|
||||
|
||||
delay := (1 << uint(retryCount)) * (seededRand.Intn(minTime) + minTime)
|
||||
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime)
|
||||
return time.Duration(delay) * time.Millisecond
|
||||
}
|
||||
|
||||
@@ -60,7 +62,7 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
|
||||
return *r.Retryable
|
||||
}
|
||||
|
||||
if r.HTTPResponse.StatusCode >= 500 {
|
||||
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
|
||||
return true
|
||||
}
|
||||
return r.IsErrorRetryable() || d.shouldThrottle(r)
|
||||
@@ -68,29 +70,47 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
|
||||
|
||||
// ShouldThrottle returns true if the request should be throttled.
|
||||
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
|
||||
if r.HTTPResponse.StatusCode == 502 ||
|
||||
r.HTTPResponse.StatusCode == 503 ||
|
||||
r.HTTPResponse.StatusCode == 504 {
|
||||
return true
|
||||
switch r.HTTPResponse.StatusCode {
|
||||
case 429:
|
||||
case 502:
|
||||
case 503:
|
||||
case 504:
|
||||
default:
|
||||
return r.IsErrorThrottle()
|
||||
}
|
||||
return r.IsErrorThrottle()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// lockedSource is a thread-safe implementation of rand.Source
|
||||
type lockedSource struct {
|
||||
lk sync.Mutex
|
||||
src rand.Source
|
||||
// This will look in the Retry-After header, RFC 7231, for how long
|
||||
// it will wait before attempting another request
|
||||
func getRetryDelay(r *request.Request) (time.Duration, bool) {
|
||||
if !canUseRetryAfterHeader(r) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
delayStr := r.HTTPResponse.Header.Get("Retry-After")
|
||||
if len(delayStr) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
delay, err := strconv.Atoi(delayStr)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return time.Duration(delay) * time.Second, true
|
||||
}
|
||||
|
||||
func (r *lockedSource) Int63() (n int64) {
|
||||
r.lk.Lock()
|
||||
n = r.src.Int63()
|
||||
r.lk.Unlock()
|
||||
return
|
||||
}
|
||||
// Will look at the status code to see if the retry header pertains to
|
||||
// the status code.
|
||||
func canUseRetryAfterHeader(r *request.Request) bool {
|
||||
switch r.HTTPResponse.StatusCode {
|
||||
case 429:
|
||||
case 503:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *lockedSource) Seed(seed int64) {
|
||||
r.lk.Lock()
|
||||
r.src.Seed(seed)
|
||||
r.lk.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
+189
@@ -0,0 +1,189 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
func TestRetryThrottleStatusCodes(t *testing.T) {
|
||||
cases := []struct {
|
||||
expectThrottle bool
|
||||
expectRetry bool
|
||||
r request.Request
|
||||
}{
|
||||
{
|
||||
false,
|
||||
false,
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 200},
|
||||
},
|
||||
},
|
||||
{
|
||||
true,
|
||||
true,
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 429},
|
||||
},
|
||||
},
|
||||
{
|
||||
true,
|
||||
true,
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 502},
|
||||
},
|
||||
},
|
||||
{
|
||||
true,
|
||||
true,
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 503},
|
||||
},
|
||||
},
|
||||
{
|
||||
true,
|
||||
true,
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 504},
|
||||
},
|
||||
},
|
||||
{
|
||||
false,
|
||||
true,
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 500},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
d := DefaultRetryer{NumMaxRetries: 10}
|
||||
for i, c := range cases {
|
||||
throttle := d.shouldThrottle(&c.r)
|
||||
retry := d.ShouldRetry(&c.r)
|
||||
|
||||
if e, a := c.expectThrottle, throttle; e != a {
|
||||
t.Errorf("%d: expected %v, but received %v", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := c.expectRetry, retry; e != a {
|
||||
t.Errorf("%d: expected %v, but received %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanUseRetryAfter(t *testing.T) {
|
||||
cases := []struct {
|
||||
r request.Request
|
||||
e bool
|
||||
}{
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 200},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 500},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 429},
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 503},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
a := canUseRetryAfterHeader(&c.r)
|
||||
if c.e != a {
|
||||
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRetryDelay(t *testing.T) {
|
||||
cases := []struct {
|
||||
r request.Request
|
||||
e time.Duration
|
||||
equal bool
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 429, Header: http.Header{"Retry-After": []string{"3600"}}},
|
||||
},
|
||||
3600 * time.Second,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
|
||||
},
|
||||
120 * time.Second,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
|
||||
},
|
||||
1 * time.Second,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
request.Request{
|
||||
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}},
|
||||
},
|
||||
0 * time.Second,
|
||||
true,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
a, ok := getRetryDelay(&c.r)
|
||||
if c.ok != ok {
|
||||
t.Errorf("%d: expected %v, but received %v", i, c.ok, ok)
|
||||
}
|
||||
|
||||
if (c.e != a) == c.equal {
|
||||
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryDelay(t *testing.T) {
|
||||
r := request.Request{}
|
||||
for i := 0; i < 100; i++ {
|
||||
rTemp := r
|
||||
rTemp.HTTPResponse = &http.Response{StatusCode: 500, Header: http.Header{"Retry-After": []string{""}}}
|
||||
rTemp.RetryCount = i
|
||||
a, _ := getRetryDelay(&rTemp)
|
||||
if a > 5*time.Minute {
|
||||
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
rTemp := r
|
||||
rTemp.RetryCount = i
|
||||
rTemp.HTTPResponse = &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}}
|
||||
a, _ := getRetryDelay(&rTemp)
|
||||
if a > 5*time.Minute {
|
||||
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
|
||||
}
|
||||
}
|
||||
}
|
||||
+4
@@ -46,6 +46,7 @@ func (reader *teeReaderCloser) Close() error {
|
||||
|
||||
func logRequest(r *request.Request) {
|
||||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
|
||||
bodySeekable := aws.IsReaderSeekable(r.Body)
|
||||
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
|
||||
if err != nil {
|
||||
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
|
||||
@@ -53,6 +54,9 @@ func logRequest(r *request.Request) {
|
||||
}
|
||||
|
||||
if logBody {
|
||||
if !bodySeekable {
|
||||
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
|
||||
}
|
||||
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
|
||||
// Body as a NoOpCloser and will not be reset after read by the HTTP
|
||||
// client reader.
|
||||
|
||||
+87
@@ -2,8 +2,17 @@ package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
type mockCloser struct {
|
||||
@@ -55,3 +64,81 @@ func TestLogWriter(t *testing.T) {
|
||||
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogRequest(t *testing.T) {
|
||||
cases := []struct {
|
||||
Body io.ReadSeeker
|
||||
ExpectBody []byte
|
||||
LogLevel aws.LogLevelType
|
||||
}{
|
||||
{
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
|
||||
ExpectBody: []byte("body content"),
|
||||
},
|
||||
{
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
|
||||
LogLevel: aws.LogDebugWithHTTPBody,
|
||||
ExpectBody: []byte("body content"),
|
||||
},
|
||||
{
|
||||
Body: bytes.NewReader([]byte("body content")),
|
||||
ExpectBody: []byte("body content"),
|
||||
},
|
||||
{
|
||||
Body: bytes.NewReader([]byte("body content")),
|
||||
LogLevel: aws.LogDebugWithHTTPBody,
|
||||
ExpectBody: []byte("body content"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
logW := bytes.NewBuffer(nil)
|
||||
req := request.New(
|
||||
aws.Config{
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
Logger: &bufLogger{w: logW},
|
||||
LogLevel: aws.LogLevel(c.LogLevel),
|
||||
},
|
||||
metadata.ClientInfo{
|
||||
Endpoint: "https://mock-service.mock-region.amazonaws.com",
|
||||
},
|
||||
testHandlers(),
|
||||
nil,
|
||||
&request.Operation{
|
||||
Name: "APIName",
|
||||
HTTPMethod: "POST",
|
||||
HTTPPath: "/",
|
||||
},
|
||||
struct{}{}, nil,
|
||||
)
|
||||
req.SetReaderBody(c.Body)
|
||||
req.Build()
|
||||
|
||||
logRequest(req)
|
||||
|
||||
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("%d, expect to read SDK request Body", i)
|
||||
}
|
||||
|
||||
if e, a := c.ExpectBody, b; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("%d, expect %v body, got %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type bufLogger struct {
|
||||
w *bytes.Buffer
|
||||
}
|
||||
|
||||
func (l *bufLogger) Log(args ...interface{}) {
|
||||
fmt.Fprintln(l.w, args...)
|
||||
}
|
||||
|
||||
func testHandlers() request.Handlers {
|
||||
var handlers request.Handlers
|
||||
|
||||
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
|
||||
|
||||
return handlers
|
||||
}
|
||||
|
||||
+23
-1
@@ -151,6 +151,15 @@ type Config struct {
|
||||
// with accelerate.
|
||||
S3UseAccelerate *bool
|
||||
|
||||
// S3DisableContentMD5Validation config option is temporarily disabled,
|
||||
// For S3 GetObject API calls, #1837.
|
||||
//
|
||||
// Set this to `true` to disable the S3 service client from automatically
|
||||
// adding the ContentMD5 to S3 Object Put and Upload API calls. This option
|
||||
// will also disable the SDK from performing object ContentMD5 validation
|
||||
// on GetObject API calls.
|
||||
S3DisableContentMD5Validation *bool
|
||||
|
||||
// Set this to `true` to disable the EC2Metadata client from overriding the
|
||||
// default http.Client's Timeout. This is helpful if you do not want the
|
||||
// EC2Metadata client to create a new http.Client. This options is only
|
||||
@@ -168,7 +177,7 @@ type Config struct {
|
||||
//
|
||||
EC2MetadataDisableTimeoutOverride *bool
|
||||
|
||||
// Instructs the endpiont to be generated for a service client to
|
||||
// Instructs the endpoint to be generated for a service client to
|
||||
// be the dual stack endpoint. The dual stack endpoint will support
|
||||
// both IPv4 and IPv6 addressing.
|
||||
//
|
||||
@@ -336,6 +345,15 @@ func (c *Config) WithS3Disable100Continue(disable bool) *Config {
|
||||
func (c *Config) WithS3UseAccelerate(enable bool) *Config {
|
||||
c.S3UseAccelerate = &enable
|
||||
return c
|
||||
|
||||
}
|
||||
|
||||
// WithS3DisableContentMD5Validation sets a config
|
||||
// S3DisableContentMD5Validation value returning a Config pointer for chaining.
|
||||
func (c *Config) WithS3DisableContentMD5Validation(enable bool) *Config {
|
||||
c.S3DisableContentMD5Validation = &enable
|
||||
return c
|
||||
|
||||
}
|
||||
|
||||
// WithUseDualStack sets a config UseDualStack value returning a Config
|
||||
@@ -435,6 +453,10 @@ func mergeInConfig(dst *Config, other *Config) {
|
||||
dst.S3UseAccelerate = other.S3UseAccelerate
|
||||
}
|
||||
|
||||
if other.S3DisableContentMD5Validation != nil {
|
||||
dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation
|
||||
}
|
||||
|
||||
if other.UseDualStack != nil {
|
||||
dst.UseDualStack = other.UseDualStack
|
||||
}
|
||||
|
||||
+259
-88
@@ -1,10 +1,9 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testCasesStringSlice = [][]string{
|
||||
@@ -18,14 +17,22 @@ func TestStringSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := StringSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := StringValueSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,22 +46,34 @@ func TestStringValueSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := StringValueSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
|
||||
if out[i] != "" {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := *(in[i]), out[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out2 := StringSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out2 {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
|
||||
if *(out2[i]) != "" {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := *in[i], *out2[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,14 +89,22 @@ func TestStringMap(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := StringMap(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := StringValueMap(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,14 +118,22 @@ func TestBoolSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := BoolSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := BoolValueSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,22 +145,34 @@ func TestBoolValueSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := BoolValueSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
|
||||
if out[i] {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := *(in[i]), out[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out2 := BoolSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out2 {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
|
||||
if *(out2[i]) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], out2[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -141,14 +188,22 @@ func TestBoolMap(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := BoolMap(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := BoolValueMap(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,14 +217,22 @@ func TestIntSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := IntSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := IntValueSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,22 +244,34 @@ func TestIntValueSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := IntValueSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
|
||||
if out[i] != 0 {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := *(in[i]), out[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out2 := IntSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out2 {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
|
||||
if *(out2[i]) != 0 {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], out2[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -212,14 +287,22 @@ func TestIntMap(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := IntMap(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := IntValueMap(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,14 +316,22 @@ func TestInt64Slice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := Int64Slice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := Int64ValueSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,22 +343,34 @@ func TestInt64ValueSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := Int64ValueSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
|
||||
if out[i] != 0 {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := *(in[i]), out[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out2 := Int64Slice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out2 {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
|
||||
if *(out2[i]) != 0 {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], out2[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -283,14 +386,22 @@ func TestInt64Map(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := Int64Map(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := Int64ValueMap(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,14 +415,22 @@ func TestFloat64Slice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := Float64Slice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := Float64ValueSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,22 +442,34 @@ func TestFloat64ValueSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := Float64ValueSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
|
||||
if out[i] != 0 {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := *(in[i]), out[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out2 := Float64Slice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out2 {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
|
||||
if *(out2[i]) != 0 {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], out2[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -354,14 +485,22 @@ func TestFloat64Map(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := Float64Map(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := Float64ValueMap(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,14 +514,22 @@ func TestTimeSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := TimeSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := TimeValueSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,22 +541,34 @@ func TestTimeValueSlice(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := TimeValueSlice(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
|
||||
if !out[i].IsZero() {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := *(in[i]), out[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out2 := TimeSlice(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out2 {
|
||||
if in[i] == nil {
|
||||
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
|
||||
if !(*(out2[i])).IsZero() {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], out2[i]; e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -425,14 +584,22 @@ func TestTimeMap(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
out := TimeMap(in)
|
||||
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
|
||||
if e, a := len(out), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
for i := range out {
|
||||
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
|
||||
if e, a := in[i], *(out[i]); e != a {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
out2 := TimeValueMap(out)
|
||||
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
if e, a := len(out2), len(in); e != a {
|
||||
t.Errorf("Unexpected len at idx %d", idx)
|
||||
}
|
||||
if e, a := in, out2; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,13 +625,17 @@ var testCasesTimeValue = []TimeValueTestCase{
|
||||
func TestSecondsTimeValue(t *testing.T) {
|
||||
for idx, testCase := range testCasesTimeValue {
|
||||
out := SecondsTimeValue(&testCase.in)
|
||||
assert.Equal(t, testCase.outSecs, out, "Unexpected value for time value at %d", idx)
|
||||
if e, a := testCase.outSecs, out; e != a {
|
||||
t.Errorf("Unexpected value for time value at %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMillisecondsTimeValue(t *testing.T) {
|
||||
for idx, testCase := range testCasesTimeValue {
|
||||
out := MillisecondsTimeValue(&testCase.in)
|
||||
assert.Equal(t, testCase.outMillis, out, "Unexpected value for time value at %d", idx)
|
||||
if e, a := testCase.outMillis, out; e != a {
|
||||
t.Errorf("Unexpected value for time value at %d", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+7
-21
@@ -3,12 +3,10 @@ package corehandlers
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -36,18 +34,13 @@ var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLen
|
||||
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
|
||||
length, _ = strconv.ParseInt(slength, 10, 64)
|
||||
} else {
|
||||
switch body := r.Body.(type) {
|
||||
case nil:
|
||||
length = 0
|
||||
case lener:
|
||||
length = int64(body.Len())
|
||||
case io.Seeker:
|
||||
r.BodyStart, _ = body.Seek(0, 1)
|
||||
end, _ := body.Seek(0, 2)
|
||||
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
|
||||
length = end - r.BodyStart
|
||||
default:
|
||||
panic("Cannot get length of body, must provide `ContentLength`")
|
||||
if r.Body != nil {
|
||||
var err error
|
||||
length, err = aws.SeekerLen(r.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New(request.ErrCodeSerialization, "failed to get request body's length", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,13 +53,6 @@ var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLen
|
||||
}
|
||||
}}
|
||||
|
||||
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent.
|
||||
var SDKVersionUserAgentHandler = request.NamedHandler{
|
||||
Name: "core.SDKVersionUserAgentHandler",
|
||||
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
|
||||
runtime.Version(), runtime.GOOS, runtime.GOARCH),
|
||||
}
|
||||
|
||||
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
|
||||
|
||||
// ValidateReqSigHandler is a request handler to ensure that the request's
|
||||
|
||||
+67
-24
@@ -7,11 +7,10 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||
@@ -32,7 +31,9 @@ func TestValidateEndpointHandler(t *testing.T) {
|
||||
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
|
||||
err := req.Build()
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
|
||||
@@ -45,8 +46,12 @@ func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
|
||||
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
|
||||
err := req.Build()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, aws.ErrMissingRegion, err)
|
||||
if err == nil {
|
||||
t.Errorf("expect error, got none")
|
||||
}
|
||||
if e, a := aws.ErrMissingRegion, err; e != a {
|
||||
t.Errorf("expect %v to be %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
type mockCredsProvider struct {
|
||||
@@ -82,18 +87,30 @@ func TestAfterRetryRefreshCreds(t *testing.T) {
|
||||
})
|
||||
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
|
||||
|
||||
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
|
||||
assert.False(t, credProvider.retrieveCalled)
|
||||
if !svc.Config.Credentials.IsExpired() {
|
||||
t.Errorf("Expect to start out expired")
|
||||
}
|
||||
if credProvider.retrieveCalled {
|
||||
t.Errorf("expect not called")
|
||||
}
|
||||
|
||||
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
|
||||
req.Send()
|
||||
|
||||
assert.True(t, svc.Config.Credentials.IsExpired())
|
||||
assert.False(t, credProvider.retrieveCalled)
|
||||
if !svc.Config.Credentials.IsExpired() {
|
||||
t.Errorf("Expect to start out expired")
|
||||
}
|
||||
if credProvider.retrieveCalled {
|
||||
t.Errorf("expect not called")
|
||||
}
|
||||
|
||||
_, err := svc.Config.Credentials.Get()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, credProvider.retrieveCalled)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if !credProvider.retrieveCalled {
|
||||
t.Errorf("expect not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAfterRetryWithContextCanceled(t *testing.T) {
|
||||
@@ -202,8 +219,12 @@ func TestSendHandlerError(t *testing.T) {
|
||||
|
||||
r.Send()
|
||||
|
||||
assert.Error(t, r.Error)
|
||||
assert.NotNil(t, r.HTTPResponse)
|
||||
if r.Error == nil {
|
||||
t.Errorf("expect error, got none")
|
||||
}
|
||||
if r.HTTPResponse == nil {
|
||||
t.Errorf("expect response, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithoutFollowRedirects(t *testing.T) {
|
||||
@@ -273,31 +294,47 @@ func TestValidateReqSigHandler(t *testing.T) {
|
||||
|
||||
corehandlers.ValidateReqSigHandler.Fn(c.Req)
|
||||
|
||||
assert.NoError(t, c.Req.Error, "%d, expect no error", i)
|
||||
assert.Equal(t, c.Resign, resigned, "%d, expected resigning to match", i)
|
||||
if c.Req.Error != nil {
|
||||
t.Errorf("expect no error, got %v", c.Req.Error)
|
||||
}
|
||||
if e, a := c.Resign, resigned; e != a {
|
||||
t.Errorf("%d, expect %v to be %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := r.Header["Content-Length"]
|
||||
assert.Equal(t, hasContentLength, ok, "expect content length to be set, %t", hasContentLength)
|
||||
if e, a := hasContentLength, ok; e != a {
|
||||
t.Errorf("expect %v to be %v", e, a)
|
||||
}
|
||||
if hasContentLength {
|
||||
assert.Equal(t, contentLength, r.ContentLength)
|
||||
if e, a := contentLength, r.ContentLength; e != a {
|
||||
t.Errorf("expect %v to be %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
r.Body.Close()
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if hasContentLength {
|
||||
assert.Contains(t, authHeader, "content-length")
|
||||
if e, a := "content-length", authHeader; !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v to be in %v", e, a)
|
||||
}
|
||||
} else {
|
||||
assert.NotContains(t, authHeader, "content-length")
|
||||
if e, a := "content-length", authHeader; strings.Contains(a, e) {
|
||||
t.Errorf("expect %v to not be in %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, contentLength, int64(len(b)))
|
||||
if e, a := contentLength, int64(len(b)); e != a {
|
||||
t.Errorf("expect %v to be %v", e, a)
|
||||
}
|
||||
}))
|
||||
|
||||
return server
|
||||
@@ -316,7 +353,9 @@ func TestBuildContentLength_ZeroBody(t *testing.T) {
|
||||
Key: aws.String("keyname"),
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildContentLength_NegativeBody(t *testing.T) {
|
||||
@@ -334,7 +373,9 @@ func TestBuildContentLength_NegativeBody(t *testing.T) {
|
||||
|
||||
req.HTTPRequest.Header.Set("Content-Length", "-1")
|
||||
|
||||
assert.NoError(t, req.Send())
|
||||
if req.Error != nil {
|
||||
t.Errorf("expect no error, got %v", req.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildContentLength_WithBody(t *testing.T) {
|
||||
@@ -351,5 +392,7 @@ func TestBuildContentLength_WithBody(t *testing.T) {
|
||||
Body: bytes.NewReader(make([]byte, 1024)),
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+52
-20
@@ -3,8 +3,7 @@ package corehandlers_test
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"reflect"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/kinesis"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testSvc = func() *client.Client {
|
||||
@@ -113,7 +111,9 @@ func TestNoErrors(t *testing.T) {
|
||||
|
||||
req := testSvc.NewRequest(&request.Operation{}, input, nil)
|
||||
corehandlers.ValidateParametersHandler.Fn(req)
|
||||
require.NoError(t, req.Error)
|
||||
if req.Error != nil {
|
||||
t.Fatalf("expect no error, got %v", req.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingRequiredParameters(t *testing.T) {
|
||||
@@ -121,17 +121,33 @@ func TestMissingRequiredParameters(t *testing.T) {
|
||||
req := testSvc.NewRequest(&request.Operation{}, input, nil)
|
||||
corehandlers.ValidateParametersHandler.Fn(req)
|
||||
|
||||
require.Error(t, req.Error)
|
||||
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
|
||||
assert.Equal(t, "3 validation error(s) found.", req.Error.(awserr.Error).Message())
|
||||
if req.Error == nil {
|
||||
t.Fatalf("expect error")
|
||||
}
|
||||
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
|
||||
assert.Len(t, errs, 3)
|
||||
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error())
|
||||
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error())
|
||||
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error())
|
||||
if e, a := 3, len(errs); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
assert.Equal(t, "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error())
|
||||
if e, a := "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedMissingRequiredParameters(t *testing.T) {
|
||||
@@ -148,15 +164,29 @@ func TestNestedMissingRequiredParameters(t *testing.T) {
|
||||
req := testSvc.NewRequest(&request.Operation{}, input, nil)
|
||||
corehandlers.ValidateParametersHandler.Fn(req)
|
||||
|
||||
require.Error(t, req.Error)
|
||||
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
|
||||
assert.Equal(t, "3 validation error(s) found.", req.Error.(awserr.Error).Message())
|
||||
if req.Error == nil {
|
||||
t.Fatalf("expect error")
|
||||
}
|
||||
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
|
||||
assert.Len(t, errs, 3)
|
||||
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error())
|
||||
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error())
|
||||
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error())
|
||||
if e, a := 3, len(errs); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
type testInput struct {
|
||||
@@ -226,7 +256,9 @@ func TestValidateFieldMinParameter(t *testing.T) {
|
||||
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
|
||||
corehandlers.ValidateParametersHandler.Fn(req)
|
||||
|
||||
assert.Equal(t, c.err, req.Error, "%d case failed", i)
|
||||
if e, a := c.err, req.Error; !reflect.DeepEqual(e,a) {
|
||||
t.Errorf("%d, expect %v, got %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+37
@@ -0,0 +1,37 @@
|
||||
package corehandlers
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version
|
||||
// to the user agent.
|
||||
var SDKVersionUserAgentHandler = request.NamedHandler{
|
||||
Name: "core.SDKVersionUserAgentHandler",
|
||||
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
|
||||
runtime.Version(), runtime.GOOS, runtime.GOARCH),
|
||||
}
|
||||
|
||||
const execEnvVar = `AWS_EXECUTION_ENV`
|
||||
const execEnvUAKey = `exec_env`
|
||||
|
||||
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
|
||||
// execution environment to the user agent.
|
||||
//
|
||||
// If the environment variable AWS_EXECUTION_ENV is set, its value will be
|
||||
// appended to the user agent string.
|
||||
var AddHostExecEnvUserAgentHander = request.NamedHandler{
|
||||
Name: "core.AddHostExecEnvUserAgentHander",
|
||||
Fn: func(r *request.Request) {
|
||||
v := os.Getenv(execEnvVar)
|
||||
if len(v) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
request.AddToUserAgent(r, execEnvUAKey+"/"+v)
|
||||
},
|
||||
}
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
package corehandlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
func TestAddHostExecEnvUserAgentHander(t *testing.T) {
|
||||
cases := []struct {
|
||||
ExecEnv string
|
||||
Expect string
|
||||
}{
|
||||
{ExecEnv: "Lambda", Expect: "exec_env/Lambda"},
|
||||
{ExecEnv: "", Expect: ""},
|
||||
{ExecEnv: "someThingCool", Expect: "exec_env/someThingCool"},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
os.Clearenv()
|
||||
os.Setenv(execEnvVar, c.ExecEnv)
|
||||
|
||||
req := &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
}
|
||||
AddHostExecEnvUserAgentHander.Fn(req)
|
||||
|
||||
if err := req.Error; err != nil {
|
||||
t.Fatalf("%d, expect no error, got %v", i, err)
|
||||
}
|
||||
|
||||
if e, a := c.Expect, req.HTTPRequest.Header.Get("User-Agent"); e != a {
|
||||
t.Errorf("%d, expect %v user agent, got %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
+33
-2
@@ -9,6 +9,7 @@ package defaults
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -72,6 +73,7 @@ func Handlers() request.Handlers {
|
||||
handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
|
||||
handlers.Validate.AfterEachFn = request.HandlerListStopOnError
|
||||
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
|
||||
handlers.Build.PushBackNamed(corehandlers.AddHostExecEnvUserAgentHander)
|
||||
handlers.Build.AfterEachFn = request.HandlerListStopOnError
|
||||
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
|
||||
handlers.Send.PushBackNamed(corehandlers.ValidateReqSigHandler)
|
||||
@@ -118,14 +120,43 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
|
||||
return ec2RoleProvider(cfg, handlers)
|
||||
}
|
||||
|
||||
var lookupHostFn = net.LookupHost
|
||||
|
||||
func isLoopbackHost(host string) (bool, error) {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
return ip.IsLoopback(), nil
|
||||
}
|
||||
|
||||
// Host is not an ip, perform lookup
|
||||
addrs, err := lookupHostFn(host)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if !net.ParseIP(addr).IsLoopback() {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
|
||||
var errMsg string
|
||||
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
errMsg = fmt.Sprintf("invalid URL, %v", err)
|
||||
} else if host := aws.URLHostname(parsed); !(host == "localhost" || host == "127.0.0.1") {
|
||||
errMsg = fmt.Sprintf("invalid host address, %q, only localhost and 127.0.0.1 are valid.", host)
|
||||
} else {
|
||||
host := aws.URLHostname(parsed)
|
||||
if len(host) == 0 {
|
||||
errMsg = "unable to parse host from local HTTP cred provider URL"
|
||||
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
|
||||
errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, loopbackErr)
|
||||
} else if !isLoopback {
|
||||
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback hosts are allowed.", host)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errMsg) > 0 {
|
||||
|
||||
+30
-2
@@ -13,12 +13,40 @@ import (
|
||||
)
|
||||
|
||||
func TestHTTPCredProvider(t *testing.T) {
|
||||
origFn := lookupHostFn
|
||||
defer func() { lookupHostFn = origFn }()
|
||||
|
||||
lookupHostFn = func(host string) ([]string, error) {
|
||||
m := map[string]struct {
|
||||
Addrs []string
|
||||
Err error
|
||||
}{
|
||||
"localhost": {Addrs: []string{"::1", "127.0.0.1"}},
|
||||
"actuallylocal": {Addrs: []string{"127.0.0.2"}},
|
||||
"notlocal": {Addrs: []string{"::1", "127.0.0.1", "192.168.1.10"}},
|
||||
"www.example.com": {Addrs: []string{"10.10.10.10"}},
|
||||
}
|
||||
|
||||
h, ok := m[host]
|
||||
if !ok {
|
||||
t.Fatalf("unknown host in test, %v", host)
|
||||
return nil, fmt.Errorf("unknown host")
|
||||
}
|
||||
|
||||
return h.Addrs, h.Err
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Host string
|
||||
Fail bool
|
||||
}{
|
||||
{"localhost", false}, {"127.0.0.1", false},
|
||||
{"www.example.com", true}, {"169.254.170.2", true},
|
||||
{"localhost", false},
|
||||
{"actuallylocal", false},
|
||||
{"127.0.0.1", false},
|
||||
{"127.1.1.1", false},
|
||||
{"[::1]", false},
|
||||
{"www.example.com", true},
|
||||
{"169.254.170.2", true},
|
||||
}
|
||||
|
||||
defer os.Clearenv()
|
||||
|
||||
+82
-35
@@ -11,8 +11,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
||||
@@ -71,8 +69,12 @@ func TestEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
req := c.NewRequest(op, nil, nil)
|
||||
assert.Equal(t, "http://169.254.169.254/latest", req.ClientInfo.Endpoint)
|
||||
assert.Equal(t, "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String())
|
||||
if e, a := "http://169.254.169.254/latest", req.ClientInfo.Endpoint; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetadata(t *testing.T) {
|
||||
@@ -85,8 +87,12 @@ func TestGetMetadata(t *testing.T) {
|
||||
|
||||
resp, err := c.GetMetadata("some/path")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "success", resp)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "success", resp; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserData(t *testing.T) {
|
||||
@@ -99,8 +105,12 @@ func TestGetUserData(t *testing.T) {
|
||||
|
||||
resp, err := c.GetUserData()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "success", resp)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "success", resp; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserData_Error(t *testing.T) {
|
||||
@@ -126,12 +136,17 @@ func TestGetUserData_Error(t *testing.T) {
|
||||
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
|
||||
|
||||
resp, err := c.GetUserData()
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, resp)
|
||||
if err == nil {
|
||||
t.Errorf("expect error")
|
||||
}
|
||||
if len(resp) != 0 {
|
||||
t.Errorf("expect empty, got %v", resp)
|
||||
}
|
||||
|
||||
aerr, ok := err.(awserr.Error)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "NotFoundError", aerr.Code())
|
||||
aerr := err.(awserr.Error)
|
||||
if e, a := "NotFoundError", aerr.Code(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRegion(t *testing.T) {
|
||||
@@ -144,8 +159,12 @@ func TestGetRegion(t *testing.T) {
|
||||
|
||||
region, err := c.Region()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "us-west-2", region)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "us-west-2", region; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataAvailable(t *testing.T) {
|
||||
@@ -156,9 +175,9 @@ func TestMetadataAvailable(t *testing.T) {
|
||||
defer server.Close()
|
||||
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
|
||||
|
||||
available := c.Available()
|
||||
|
||||
assert.True(t, available)
|
||||
if !c.Available() {
|
||||
t.Errorf("expect available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataIAMInfo_success(t *testing.T) {
|
||||
@@ -170,10 +189,18 @@ func TestMetadataIAMInfo_success(t *testing.T) {
|
||||
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
|
||||
|
||||
iamInfo, err := c.IAMInfo()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Success", iamInfo.Code)
|
||||
assert.Equal(t, "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn)
|
||||
assert.Equal(t, "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "Success", iamInfo.Code; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataIAMInfo_failure(t *testing.T) {
|
||||
@@ -185,10 +212,18 @@ func TestMetadataIAMInfo_failure(t *testing.T) {
|
||||
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
|
||||
|
||||
iamInfo, err := c.IAMInfo()
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "", iamInfo.Code)
|
||||
assert.Equal(t, "", iamInfo.InstanceProfileArn)
|
||||
assert.Equal(t, "", iamInfo.InstanceProfileID)
|
||||
if err == nil {
|
||||
t.Errorf("expect error")
|
||||
}
|
||||
if e, a := "", iamInfo.Code; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "", iamInfo.InstanceProfileArn; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "", iamInfo.InstanceProfileID; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataNotAvailable(t *testing.T) {
|
||||
@@ -204,9 +239,9 @@ func TestMetadataNotAvailable(t *testing.T) {
|
||||
r.Retryable = aws.Bool(true) // network errors are retryable
|
||||
})
|
||||
|
||||
available := c.Available()
|
||||
|
||||
assert.False(t, available)
|
||||
if c.Available() {
|
||||
t.Errorf("expect not available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataErrorResponse(t *testing.T) {
|
||||
@@ -222,8 +257,12 @@ func TestMetadataErrorResponse(t *testing.T) {
|
||||
})
|
||||
|
||||
data, err := c.GetMetadata("uri/path")
|
||||
assert.Empty(t, data)
|
||||
assert.Contains(t, err.Error(), "error message text")
|
||||
if len(data) != 0 {
|
||||
t.Errorf("expect empty, got %v", data)
|
||||
}
|
||||
if e, a := "error message text", err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v to be in %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEC2RoleProviderInstanceIdentity(t *testing.T) {
|
||||
@@ -235,8 +274,16 @@ func TestEC2RoleProviderInstanceIdentity(t *testing.T) {
|
||||
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
|
||||
|
||||
doc, err := c.GetInstanceIdentityDocument()
|
||||
assert.Nil(t, err, "Expect no error, %v", err)
|
||||
assert.Equal(t, doc.AccountID, "123456789012")
|
||||
assert.Equal(t, doc.AvailabilityZone, "us-east-1d")
|
||||
assert.Equal(t, doc.Region, "us-east-1")
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := doc.AccountID, "123456789012"; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := doc.AvailabilityZone, "us-east-1d"; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := doc.Region, "us-east-1"; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+24
@@ -1,5 +1,10 @@
|
||||
// Package ec2metadata provides the client for making API calls to the
|
||||
// EC2 Metadata service.
|
||||
//
|
||||
// This package's client can be disabled completely by setting the environment
|
||||
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
|
||||
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
|
||||
// be used while the environemnt variable is set to true, (case insensitive).
|
||||
package ec2metadata
|
||||
|
||||
import (
|
||||
@@ -7,17 +12,21 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// ServiceName is the name of the service.
|
||||
const ServiceName = "ec2metadata"
|
||||
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
|
||||
|
||||
// A EC2Metadata is an EC2 Metadata service Client.
|
||||
type EC2Metadata struct {
|
||||
@@ -75,6 +84,21 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
|
||||
svc.Handlers.Validate.Clear()
|
||||
svc.Handlers.Validate.PushBack(validateEndpointHandler)
|
||||
|
||||
// Disable the EC2 Metadata service if the environment variable is set.
|
||||
// This shortcirctes the service's functionality to always fail to send
|
||||
// requests.
|
||||
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
|
||||
svc.Handlers.Send.SwapNamed(request.NamedHandler{
|
||||
Name: corehandlers.SendHandler.Name,
|
||||
Fn: func(r *request.Request) {
|
||||
r.Error = awserr.New(
|
||||
request.CanceledErrorCode,
|
||||
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
|
||||
nil)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Add additional options to the service config
|
||||
for _, option := range opts {
|
||||
option(svc.Client)
|
||||
|
||||
+52
-10
@@ -3,21 +3,30 @@ package ec2metadata_test
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
|
||||
svc := ec2metadata.New(unit.Session)
|
||||
|
||||
assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient)
|
||||
assert.Equal(t, 5*time.Second, svc.Config.HTTPClient.Timeout)
|
||||
if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a {
|
||||
t.Errorf("expect %v, not to equal %v", e, a)
|
||||
}
|
||||
|
||||
if e, a := 5*time.Second, svc.Config.HTTPClient.Timeout; e != a {
|
||||
t.Errorf("expect %v to be %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
|
||||
@@ -28,18 +37,25 @@ func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
|
||||
|
||||
svc := ec2metadata.New(unit.Session)
|
||||
|
||||
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
|
||||
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, tr)
|
||||
assert.Nil(t, tr.Dial)
|
||||
tr := svc.Config.HTTPClient.Transport.(*http.Transport)
|
||||
if tr == nil {
|
||||
t.Fatalf("expect transport not to be nil")
|
||||
}
|
||||
if tr.Dial != nil {
|
||||
t.Errorf("expect dial to be nil, was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
|
||||
svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))
|
||||
|
||||
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
|
||||
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
|
||||
@@ -63,6 +79,30 @@ func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
|
||||
runEC2MetadataClients(t, cfg, 100)
|
||||
}
|
||||
|
||||
func TestClientDisableIMDS(t *testing.T) {
|
||||
env := awstesting.StashEnv()
|
||||
defer awstesting.PopEnv(env)
|
||||
|
||||
os.Setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
|
||||
svc := ec2metadata.New(unit.Session)
|
||||
resp, err := svc.Region()
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
if len(resp) != 0 {
|
||||
t.Errorf("expect no response, got %v", resp)
|
||||
}
|
||||
|
||||
aerr := err.(awserr.Error)
|
||||
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
|
||||
t.Errorf("expect %v error code, got %v", e, a)
|
||||
}
|
||||
if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v in error message, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(atOnce)
|
||||
@@ -70,7 +110,9 @@ func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
|
||||
go func() {
|
||||
svc := ec2metadata.New(unit.Session, cfg)
|
||||
_, err := svc.Region()
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
+648
-57
File diff suppressed because it is too large
Load Diff
+4
@@ -347,6 +347,10 @@ type ResolvedEndpoint struct {
|
||||
// The service name that should be used for signing requests.
|
||||
SigningName string
|
||||
|
||||
// States that the signing name for this endpoint was derived from metadata
|
||||
// passed in, but was not explicitly modeled.
|
||||
SigningNameDerived bool
|
||||
|
||||
// The signing method that should be used for signing requests.
|
||||
SigningMethod string
|
||||
}
|
||||
|
||||
+8
-4
@@ -226,16 +226,20 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
|
||||
if len(signingRegion) == 0 {
|
||||
signingRegion = region
|
||||
}
|
||||
|
||||
signingName := e.CredentialScope.Service
|
||||
var signingNameDerived bool
|
||||
if len(signingName) == 0 {
|
||||
signingName = service
|
||||
signingNameDerived = true
|
||||
}
|
||||
|
||||
return ResolvedEndpoint{
|
||||
URL: u,
|
||||
SigningRegion: signingRegion,
|
||||
SigningName: signingName,
|
||||
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
|
||||
URL: u,
|
||||
SigningRegion: signingRegion,
|
||||
SigningName: signingName,
|
||||
SigningNameDerived: signingNameDerived,
|
||||
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+265
-78
@@ -2,10 +2,9 @@ package endpoints
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnmarshalRegionRegex(t *testing.T) {
|
||||
@@ -16,12 +15,18 @@ func TestUnmarshalRegionRegex(t *testing.T) {
|
||||
|
||||
p := partition{}
|
||||
err := json.Unmarshal(input, &p)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
expectRegexp, err := regexp.Compile(`^(us|eu|ap|sa|ca)\-\w+\-\d+$`)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, expectRegexp.String(), p.RegionRegex.Regexp.String())
|
||||
if e, a := expectRegexp.String(), p.RegionRegex.Regexp.String(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalRegion(t *testing.T) {
|
||||
@@ -37,16 +42,28 @@ func TestUnmarshalRegion(t *testing.T) {
|
||||
|
||||
rs := regions{}
|
||||
err := json.Unmarshal(input, &rs)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, rs, 2)
|
||||
if e, a := 2, len(rs); e != a {
|
||||
t.Errorf("expect %v len, got %v", e, a)
|
||||
}
|
||||
r, ok := rs["aws-global"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "AWS partition-global endpoint", r.Description)
|
||||
if !ok {
|
||||
t.Errorf("expect found, was not")
|
||||
}
|
||||
if e, a := "AWS partition-global endpoint", r.Description; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
r, ok = rs["us-east-1"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "US East (N. Virginia)", r.Description)
|
||||
if !ok {
|
||||
t.Errorf("expect found, was not")
|
||||
}
|
||||
if e, a := "US East (N. Virginia)", r.Description; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalServices(t *testing.T) {
|
||||
@@ -75,23 +92,45 @@ func TestUnmarshalServices(t *testing.T) {
|
||||
|
||||
ss := services{}
|
||||
err := json.Unmarshal(input, &ss)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, ss, 3)
|
||||
if e, a := 3, len(ss); e != a {
|
||||
t.Errorf("expect %v len, got %v", e, a)
|
||||
}
|
||||
s, ok := ss["acm"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, s.Endpoints, 1)
|
||||
assert.Equal(t, boxedBoolUnset, s.IsRegionalized)
|
||||
if !ok {
|
||||
t.Errorf("expect found, was not")
|
||||
}
|
||||
if e, a := 1, len(s.Endpoints); e != a {
|
||||
t.Errorf("expect %v len, got %v", e, a)
|
||||
}
|
||||
if e, a := boxedBoolUnset, s.IsRegionalized; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
s, ok = ss["apigateway"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, s.Endpoints, 2)
|
||||
assert.Equal(t, boxedTrue, s.IsRegionalized)
|
||||
if !ok {
|
||||
t.Errorf("expect found, was not")
|
||||
}
|
||||
if e, a := 2, len(s.Endpoints); e != a {
|
||||
t.Errorf("expect %v len, got %v", e, a)
|
||||
}
|
||||
if e, a := boxedTrue, s.IsRegionalized; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
s, ok = ss["notRegionalized"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, s.Endpoints, 2)
|
||||
assert.Equal(t, boxedFalse, s.IsRegionalized)
|
||||
if !ok {
|
||||
t.Errorf("expect found, was not")
|
||||
}
|
||||
if e, a := 2, len(s.Endpoints); e != a {
|
||||
t.Errorf("expect %v len, got %v", e, a)
|
||||
}
|
||||
if e, a := boxedFalse, s.IsRegionalized; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEndpoints(t *testing.T) {
|
||||
@@ -115,16 +154,32 @@ func TestUnmarshalEndpoints(t *testing.T) {
|
||||
|
||||
es := endpoints{}
|
||||
err := json.Unmarshal(inputs, &es)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, es, 2)
|
||||
if e, a := 2, len(es); e != a {
|
||||
t.Errorf("expect %v len, got %v", e, a)
|
||||
}
|
||||
s, ok := es["aws-global"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cloudfront.amazonaws.com", s.Hostname)
|
||||
assert.Equal(t, []string{"http", "https"}, s.Protocols)
|
||||
assert.Equal(t, []string{"v4"}, s.SignatureVersions)
|
||||
assert.Equal(t, credentialScope{"us-east-1", "serviceName"}, s.CredentialScope)
|
||||
assert.Equal(t, "commonName", s.SSLCommonName)
|
||||
if !ok {
|
||||
t.Errorf("expect found, was not")
|
||||
}
|
||||
if e, a := "cloudfront.amazonaws.com", s.Hostname; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := []string{"http", "https"}, s.Protocols; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := []string{"v4"}, s.SignatureVersions; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := (credentialScope{"us-east-1", "serviceName"}), s.CredentialScope; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "commonName", s.SSLCommonName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpointResolve(t *testing.T) {
|
||||
@@ -155,10 +210,18 @@ func TestEndpointResolve(t *testing.T) {
|
||||
defs, Options{},
|
||||
)
|
||||
|
||||
assert.Equal(t, "https://service.region.dnsSuffix", resolved.URL)
|
||||
assert.Equal(t, "signing_service", resolved.SigningName)
|
||||
assert.Equal(t, "signing_region", resolved.SigningRegion)
|
||||
assert.Equal(t, "v4", resolved.SigningMethod)
|
||||
if e, a := "https://service.region.dnsSuffix", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "signing_service", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "signing_region", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "v4", resolved.SigningMethod; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpointMergeIn(t *testing.T) {
|
||||
@@ -185,7 +248,9 @@ func TestEndpointMergeIn(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
assert.Equal(t, expected, actual)
|
||||
if e, a := expected, actual; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
var testPartitions = partitions{
|
||||
@@ -213,6 +278,11 @@ var testPartitions = partitions{
|
||||
Services: services{
|
||||
"s3": service{},
|
||||
"service1": service{
|
||||
Defaults: endpoint{
|
||||
CredentialScope: credentialScope{
|
||||
Service: "service1",
|
||||
},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"us-east-1": {},
|
||||
"us-west-2": {
|
||||
@@ -221,7 +291,13 @@ var testPartitions = partitions{
|
||||
},
|
||||
},
|
||||
},
|
||||
"service2": service{},
|
||||
"service2": service{
|
||||
Defaults: endpoint{
|
||||
CredentialScope: credentialScope{
|
||||
Service: "service2",
|
||||
},
|
||||
},
|
||||
},
|
||||
"httpService": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http"},
|
||||
@@ -246,109 +322,220 @@ var testPartitions = partitions{
|
||||
func TestResolveEndpoint(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("service2", "us-west-2")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://service2.us-west-2.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-west-2", resolved.SigningRegion)
|
||||
assert.Equal(t, "service2", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "https://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-west-2", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "service2", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name not to be derived, but was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_DisableSSL(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("service2", "us-west-2", DisableSSLOption)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "http://service2.us-west-2.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-west-2", resolved.SigningRegion)
|
||||
assert.Equal(t, "service2", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "http://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-west-2", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "service2", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name not to be derived, but was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_UseDualStack(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("service1", "us-west-2", UseDualStackOption)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://service1.dualstack.us-west-2.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-west-2", resolved.SigningRegion)
|
||||
assert.Equal(t, "service1", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "https://service1.dualstack.us-west-2.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-west-2", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "service1", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name not to be derived, but was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_HTTPProtocol(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("httpService", "us-west-2")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "http://httpService.us-west-2.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-west-2", resolved.SigningRegion)
|
||||
assert.Equal(t, "httpService", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "http://httpService.us-west-2.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-west-2", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "httpService", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if !resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name to be derived")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_UnknownService(t *testing.T) {
|
||||
_, err := testPartitions.EndpointFor("unknownservice", "us-west-2")
|
||||
|
||||
assert.Error(t, err)
|
||||
if err == nil {
|
||||
t.Errorf("expect error, got none")
|
||||
}
|
||||
|
||||
_, ok := err.(UnknownServiceError)
|
||||
assert.True(t, ok, "expect error to be UnknownServiceError")
|
||||
if !ok {
|
||||
t.Errorf("expect error to be UnknownServiceError")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_ResolveUnknownService(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("unknown-service", "us-region-1",
|
||||
ResolveUnknownServiceOption)
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "https://unknown-service.us-region-1.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-region-1", resolved.SigningRegion)
|
||||
assert.Equal(t, "unknown-service", resolved.SigningName)
|
||||
if e, a := "https://unknown-service.us-region-1.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-region-1", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "unknown-service", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if !resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name to be derived")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_UnknownMatchedRegion(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("service2", "us-region-1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://service2.us-region-1.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-region-1", resolved.SigningRegion)
|
||||
assert.Equal(t, "service2", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "https://service2.us-region-1.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-region-1", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "service2", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name not to be derived, but was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_UnknownRegion(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("service2", "unknownregion")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://service2.unknownregion.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "unknownregion", resolved.SigningRegion)
|
||||
assert.Equal(t, "service2", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "https://service2.unknownregion.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "unknownregion", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "service2", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name not to be derived, but was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_StrictPartitionUnknownEndpoint(t *testing.T) {
|
||||
_, err := testPartitions[0].EndpointFor("service2", "unknownregion", StrictMatchingOption)
|
||||
|
||||
assert.Error(t, err)
|
||||
if err == nil {
|
||||
t.Errorf("expect error, got none")
|
||||
}
|
||||
|
||||
_, ok := err.(UnknownEndpointError)
|
||||
assert.True(t, ok, "expect error to be UnknownEndpointError")
|
||||
if !ok {
|
||||
t.Errorf("expect error to be UnknownEndpointError")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_StrictPartitionsUnknownEndpoint(t *testing.T) {
|
||||
_, err := testPartitions.EndpointFor("service2", "us-region-1", StrictMatchingOption)
|
||||
|
||||
assert.Error(t, err)
|
||||
if err == nil {
|
||||
t.Errorf("expect error, got none")
|
||||
}
|
||||
|
||||
_, ok := err.(UnknownEndpointError)
|
||||
assert.True(t, ok, "expect error to be UnknownEndpointError")
|
||||
if !ok {
|
||||
t.Errorf("expect error to be UnknownEndpointError")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_NotRegionalized(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("globalService", "us-west-2")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://globalService.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-east-1", resolved.SigningRegion)
|
||||
assert.Equal(t, "globalService", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-east-1", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "globalService", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if !resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name to be derived")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEndpoint_AwsGlobal(t *testing.T) {
|
||||
resolved, err := testPartitions.EndpointFor("globalService", "aws-global")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://globalService.amazonaws.com", resolved.URL)
|
||||
assert.Equal(t, "us-east-1", resolved.SigningRegion)
|
||||
assert.Equal(t, "globalService", resolved.SigningName)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "us-east-1", resolved.SigningRegion; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "globalService", resolved.SigningName; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if !resolved.SigningNameDerived {
|
||||
t.Errorf("expect the signing name to be derived")
|
||||
}
|
||||
}
|
||||
|
||||
+3
-1
@@ -3,6 +3,8 @@ package request
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
// offsetReader is a thread-safe io.ReadCloser to prevent racing
|
||||
@@ -15,7 +17,7 @@ type offsetReader struct {
|
||||
|
||||
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader {
|
||||
reader := &offsetReader{}
|
||||
buf.Seek(offset, 0)
|
||||
buf.Seek(offset, sdkio.SeekStart)
|
||||
|
||||
reader.buf = buf
|
||||
return reader
|
||||
|
||||
+4
-3
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -28,15 +29,15 @@ func TestOffsetReaderSeek(t *testing.T) {
|
||||
buf := []byte("testData")
|
||||
reader := newOffsetReader(bytes.NewReader(buf), 0)
|
||||
|
||||
orig, err := reader.Seek(0, 1)
|
||||
orig, err := reader.Seek(0, sdkio.SeekCurrent)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), orig)
|
||||
|
||||
n, err := reader.Seek(0, 2)
|
||||
n, err := reader.Seek(0, sdkio.SeekEnd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(len(buf)), n)
|
||||
|
||||
n, err = reader.Seek(orig, 0)
|
||||
n, err = reader.Seek(orig, sdkio.SeekStart)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
}
|
||||
|
||||
+133
-60
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,6 +29,10 @@ const (
|
||||
// during body reads.
|
||||
ErrCodeResponseTimeout = "ResponseTimeout"
|
||||
|
||||
// ErrCodeInvalidPresignExpire is returned when the expire time provided to
|
||||
// presign is invalid
|
||||
ErrCodeInvalidPresignExpire = "InvalidPresignExpireError"
|
||||
|
||||
// CanceledErrorCode is the error code that will be returned by an
|
||||
// API request that was canceled. Requests given a aws.Context may
|
||||
// return this error when canceled.
|
||||
@@ -42,7 +47,6 @@ type Request struct {
|
||||
|
||||
Retryer
|
||||
Time time.Time
|
||||
ExpireTime time.Duration
|
||||
Operation *Operation
|
||||
HTTPRequest *http.Request
|
||||
HTTPResponse *http.Response
|
||||
@@ -60,6 +64,11 @@ type Request struct {
|
||||
LastSignedAt time.Time
|
||||
DisableFollowRedirects bool
|
||||
|
||||
// A value greater than 0 instructs the request to be signed as Presigned URL
|
||||
// You should not set this field directly. Instead use Request's
|
||||
// Presign or PresignRequest methods.
|
||||
ExpireTime time.Duration
|
||||
|
||||
context aws.Context
|
||||
|
||||
built bool
|
||||
@@ -104,6 +113,8 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
|
||||
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err)
|
||||
}
|
||||
|
||||
SanitizeHostForHeader(httpReq)
|
||||
|
||||
r := &Request{
|
||||
Config: cfg,
|
||||
ClientInfo: clientInfo,
|
||||
@@ -214,6 +225,9 @@ func (r *Request) SetContext(ctx aws.Context) {
|
||||
|
||||
// WillRetry returns if the request's can be retried.
|
||||
func (r *Request) WillRetry() bool {
|
||||
if !aws.IsReaderSeekable(r.Body) && r.HTTPRequest.Body != NoBody {
|
||||
return false
|
||||
}
|
||||
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
|
||||
}
|
||||
|
||||
@@ -245,45 +259,70 @@ func (r *Request) SetStringBody(s string) {
|
||||
// SetReaderBody will set the request's body reader.
|
||||
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
|
||||
r.Body = reader
|
||||
r.BodyStart, _ = reader.Seek(0, sdkio.SeekCurrent) // Get the Bodies current offset.
|
||||
r.ResetBody()
|
||||
}
|
||||
|
||||
// Presign returns the request's signed URL. Error will be returned
|
||||
// if the signing fails.
|
||||
func (r *Request) Presign(expireTime time.Duration) (string, error) {
|
||||
r.ExpireTime = expireTime
|
||||
//
|
||||
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
||||
// error is returned if expire duration is 0 or less.
|
||||
func (r *Request) Presign(expire time.Duration) (string, error) {
|
||||
r = r.copy()
|
||||
|
||||
// Presign requires all headers be hoisted. There is no way to retrieve
|
||||
// the signed headers not hoisted without this. Making the presigned URL
|
||||
// useless.
|
||||
r.NotHoist = false
|
||||
|
||||
if r.Operation.BeforePresignFn != nil {
|
||||
r = r.copy()
|
||||
err := r.Operation.BeforePresignFn(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
r.Sign()
|
||||
if r.Error != nil {
|
||||
return "", r.Error
|
||||
}
|
||||
return r.HTTPRequest.URL.String(), nil
|
||||
u, _, err := getPresignedURL(r, expire)
|
||||
return u, err
|
||||
}
|
||||
|
||||
// PresignRequest behaves just like presign, with the addition of returning a
|
||||
// set of headers that were signed.
|
||||
//
|
||||
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
||||
// error is returned if expire duration is 0 or less.
|
||||
//
|
||||
// Returns the URL string for the API operation with signature in the query string,
|
||||
// and the HTTP headers that were included in the signature. These headers must
|
||||
// be included in any HTTP request made with the presigned URL.
|
||||
//
|
||||
// To prevent hoisting any headers to the query string set NotHoist to true on
|
||||
// this Request value prior to calling PresignRequest.
|
||||
func (r *Request) PresignRequest(expireTime time.Duration) (string, http.Header, error) {
|
||||
r.ExpireTime = expireTime
|
||||
r.Sign()
|
||||
if r.Error != nil {
|
||||
return "", nil, r.Error
|
||||
func (r *Request) PresignRequest(expire time.Duration) (string, http.Header, error) {
|
||||
r = r.copy()
|
||||
return getPresignedURL(r, expire)
|
||||
}
|
||||
|
||||
// IsPresigned returns true if the request represents a presigned API url.
|
||||
func (r *Request) IsPresigned() bool {
|
||||
return r.ExpireTime != 0
|
||||
}
|
||||
|
||||
func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, error) {
|
||||
if expire <= 0 {
|
||||
return "", nil, awserr.New(
|
||||
ErrCodeInvalidPresignExpire,
|
||||
"presigned URL requires an expire duration greater than 0",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
r.ExpireTime = expire
|
||||
|
||||
if r.Operation.BeforePresignFn != nil {
|
||||
if err := r.Operation.BeforePresignFn(r); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.Sign(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
|
||||
}
|
||||
|
||||
@@ -303,7 +342,7 @@ func debugLogReqError(r *Request, stage string, retrying bool, err error) {
|
||||
|
||||
// Build will build the request's object so it can be signed and sent
|
||||
// to the service. Build will also validate all the request's parameters.
|
||||
// Anny additional build Handlers set on this request will be run
|
||||
// Any additional build Handlers set on this request will be run
|
||||
// in the order they were set.
|
||||
//
|
||||
// The request will only be built once. Multiple calls to build will have
|
||||
@@ -364,7 +403,7 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
|
||||
// of the SDK if they used that field.
|
||||
//
|
||||
// Related golang/go#18257
|
||||
l, err := computeBodyLength(r.Body)
|
||||
l, err := aws.SeekerLen(r.Body)
|
||||
if err != nil {
|
||||
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
|
||||
}
|
||||
@@ -382,7 +421,8 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
|
||||
// Transfer-Encoding: chunked bodies for these methods.
|
||||
//
|
||||
// This would only happen if a aws.ReaderSeekerCloser was used with
|
||||
// a io.Reader that was not also an io.Seeker.
|
||||
// a io.Reader that was not also an io.Seeker, or did not implement
|
||||
// Len() method.
|
||||
switch r.Operation.HTTPMethod {
|
||||
case "GET", "HEAD", "DELETE":
|
||||
body = NoBody
|
||||
@@ -394,42 +434,6 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Attempts to compute the length of the body of the reader using the
|
||||
// io.Seeker interface. If the value is not seekable because of being
|
||||
// a ReaderSeekerCloser without an unerlying Seeker -1 will be returned.
|
||||
// If no error occurs the length of the body will be returned.
|
||||
func computeBodyLength(r io.ReadSeeker) (int64, error) {
|
||||
seekable := true
|
||||
// Determine if the seeker is actually seekable. ReaderSeekerCloser
|
||||
// hides the fact that a io.Readers might not actually be seekable.
|
||||
switch v := r.(type) {
|
||||
case aws.ReaderSeekerCloser:
|
||||
seekable = v.IsSeeker()
|
||||
case *aws.ReaderSeekerCloser:
|
||||
seekable = v.IsSeeker()
|
||||
}
|
||||
if !seekable {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
curOffset, err := r.Seek(0, 1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
endOffset, err := r.Seek(0, 2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = r.Seek(curOffset, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return endOffset - curOffset, nil
|
||||
}
|
||||
|
||||
// GetBody will return an io.ReadSeeker of the Request's underlying
|
||||
// input body with a concurrency safe wrapper.
|
||||
func (r *Request) GetBody() io.ReadSeeker {
|
||||
@@ -579,3 +583,72 @@ func shouldRetryCancel(r *Request) bool {
|
||||
errStr != "net/http: request canceled while waiting for connection")
|
||||
|
||||
}
|
||||
|
||||
// SanitizeHostForHeader removes default port from host and updates request.Host
|
||||
func SanitizeHostForHeader(r *http.Request) {
|
||||
host := getHost(r)
|
||||
port := portOnly(host)
|
||||
if port != "" && isDefaultPort(r.URL.Scheme, port) {
|
||||
r.Host = stripPort(host)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns host from request
|
||||
func getHost(r *http.Request) string {
|
||||
if r.Host != "" {
|
||||
return r.Host
|
||||
}
|
||||
|
||||
return r.URL.Host
|
||||
}
|
||||
|
||||
// Hostname returns u.Host, without any port number.
|
||||
//
|
||||
// If Host is an IPv6 literal with a port number, Hostname returns the
|
||||
// IPv6 literal without the square brackets. IPv6 literals may include
|
||||
// a zone identifier.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func stripPort(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return hostport
|
||||
}
|
||||
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||
return strings.TrimPrefix(hostport[:i], "[")
|
||||
}
|
||||
return hostport[:colon]
|
||||
}
|
||||
|
||||
// Port returns the port part of u.Host, without the leading colon.
|
||||
// If u.Host doesn't contain a port, Port returns an empty string.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func portOnly(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return ""
|
||||
}
|
||||
if i := strings.Index(hostport, "]:"); i != -1 {
|
||||
return hostport[i+len("]:"):]
|
||||
}
|
||||
if strings.Contains(hostport, "]") {
|
||||
return ""
|
||||
}
|
||||
return hostport[colon+len(":"):]
|
||||
}
|
||||
|
||||
// Returns true if the specified URI is using the standard port
|
||||
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
|
||||
func isDefaultPort(scheme, port string) bool {
|
||||
if port == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
lowerCaseScheme := strings.ToLower(scheme)
|
||||
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
+20
-5
@@ -142,13 +142,28 @@ func (r *Request) nextPageTokens() []interface{} {
|
||||
tokens := []interface{}{}
|
||||
tokenAdded := false
|
||||
for _, outToken := range r.Operation.OutputTokens {
|
||||
v, _ := awsutil.ValuesAtPath(r.Data, outToken)
|
||||
if len(v) > 0 {
|
||||
tokens = append(tokens, v[0])
|
||||
tokenAdded = true
|
||||
} else {
|
||||
vs, _ := awsutil.ValuesAtPath(r.Data, outToken)
|
||||
if len(vs) == 0 {
|
||||
tokens = append(tokens, nil)
|
||||
continue
|
||||
}
|
||||
v := vs[0]
|
||||
|
||||
switch tv := v.(type) {
|
||||
case *string:
|
||||
if len(aws.StringValue(tv)) == 0 {
|
||||
tokens = append(tokens, nil)
|
||||
continue
|
||||
}
|
||||
case string:
|
||||
if len(tv) == 0 {
|
||||
tokens = append(tokens, nil)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tokenAdded = true
|
||||
tokens = append(tokens, v)
|
||||
}
|
||||
if !tokenAdded {
|
||||
return nil
|
||||
|
||||
+78
-63
@@ -454,78 +454,93 @@ func TestPaginationWithContextNilInput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type testPageInput struct {
|
||||
NextToken string
|
||||
}
|
||||
type testPageOutput struct {
|
||||
Value string
|
||||
NextToken *string
|
||||
}
|
||||
|
||||
func TestPagination_Standalone(t *testing.T) {
|
||||
expect := []struct {
|
||||
Value, PrevToken, NextToken string
|
||||
}{
|
||||
{"FirstValue", "InitalToken", "FirstToken"},
|
||||
{"SecondValue", "FirstToken", "SecondToken"},
|
||||
{"ThirdValue", "SecondToken", ""},
|
||||
type testPageInput struct {
|
||||
NextToken *string
|
||||
}
|
||||
input := testPageInput{
|
||||
NextToken: expect[0].PrevToken,
|
||||
type testPageOutput struct {
|
||||
Value *string
|
||||
NextToken *string
|
||||
}
|
||||
type testCase struct {
|
||||
Value, PrevToken, NextToken *string
|
||||
}
|
||||
|
||||
c := awstesting.NewClient()
|
||||
i := 0
|
||||
p := request.Pagination{
|
||||
NewRequest: func() (*request.Request, error) {
|
||||
r := c.NewRequest(
|
||||
&request.Operation{
|
||||
Name: "Operation",
|
||||
Paginator: &request.Paginator{
|
||||
InputTokens: []string{"NextToken"},
|
||||
OutputTokens: []string{"NextToken"},
|
||||
},
|
||||
},
|
||||
&input, &testPageOutput{},
|
||||
)
|
||||
// Setup handlers for testing
|
||||
r.Handlers.Clear()
|
||||
r.Handlers.Build.PushBack(func(req *request.Request) {
|
||||
in := req.Params.(*testPageInput)
|
||||
if e, a := expect[i].PrevToken, in.NextToken; e != a {
|
||||
t.Errorf("%d, expect NextToken input %q, got %q", i, e, a)
|
||||
}
|
||||
})
|
||||
r.Handlers.Unmarshal.PushBack(func(req *request.Request) {
|
||||
out := &testPageOutput{
|
||||
Value: expect[i].Value,
|
||||
}
|
||||
if len(expect[i].NextToken) > 0 {
|
||||
out.NextToken = aws.String(expect[i].NextToken)
|
||||
}
|
||||
req.Data = out
|
||||
})
|
||||
return r, nil
|
||||
cases := [][]testCase{
|
||||
{
|
||||
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
|
||||
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
|
||||
testCase{aws.String("ThirdValue"), aws.String("SecondToken"), nil},
|
||||
},
|
||||
{
|
||||
testCase{aws.String("FirstValue"), aws.String("InitalToken"), aws.String("FirstToken")},
|
||||
testCase{aws.String("SecondValue"), aws.String("FirstToken"), aws.String("SecondToken")},
|
||||
testCase{aws.String("ThirdValue"), aws.String("SecondToken"), aws.String("")},
|
||||
},
|
||||
}
|
||||
|
||||
for p.Next() {
|
||||
data := p.Page().(*testPageOutput)
|
||||
|
||||
if e, a := expect[i].Value, data.Value; e != a {
|
||||
t.Errorf("%d, expect Value to be %q, got %q", i, e, a)
|
||||
}
|
||||
if e, a := expect[i].NextToken, aws.StringValue(data.NextToken); e != a {
|
||||
t.Errorf("%d, expect NextToken to be %q, got %q", i, e, a)
|
||||
for _, c := range cases {
|
||||
input := testPageInput{
|
||||
NextToken: c[0].PrevToken,
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
if e, a := len(expect), i; e != a {
|
||||
t.Errorf("expected to process %d pages, did %d", e, a)
|
||||
}
|
||||
if err := p.Err(); err != nil {
|
||||
t.Fatalf("%d, expected no error, got %v", i, err)
|
||||
svc := awstesting.NewClient()
|
||||
i := 0
|
||||
p := request.Pagination{
|
||||
NewRequest: func() (*request.Request, error) {
|
||||
r := svc.NewRequest(
|
||||
&request.Operation{
|
||||
Name: "Operation",
|
||||
Paginator: &request.Paginator{
|
||||
InputTokens: []string{"NextToken"},
|
||||
OutputTokens: []string{"NextToken"},
|
||||
},
|
||||
},
|
||||
&input, &testPageOutput{},
|
||||
)
|
||||
// Setup handlers for testing
|
||||
r.Handlers.Clear()
|
||||
r.Handlers.Build.PushBack(func(req *request.Request) {
|
||||
if e, a := len(c), i+1; a > e {
|
||||
t.Fatalf("expect no more than %d requests, got %d", e, a)
|
||||
}
|
||||
in := req.Params.(*testPageInput)
|
||||
if e, a := aws.StringValue(c[i].PrevToken), aws.StringValue(in.NextToken); e != a {
|
||||
t.Errorf("%d, expect NextToken input %q, got %q", i, e, a)
|
||||
}
|
||||
})
|
||||
r.Handlers.Unmarshal.PushBack(func(req *request.Request) {
|
||||
out := &testPageOutput{
|
||||
Value: c[i].Value,
|
||||
}
|
||||
if c[i].NextToken != nil {
|
||||
next := *c[i].NextToken
|
||||
out.NextToken = aws.String(next)
|
||||
}
|
||||
req.Data = out
|
||||
})
|
||||
return r, nil
|
||||
},
|
||||
}
|
||||
|
||||
for p.Next() {
|
||||
data := p.Page().(*testPageOutput)
|
||||
|
||||
if e, a := aws.StringValue(c[i].Value), aws.StringValue(data.Value); e != a {
|
||||
t.Errorf("%d, expect Value to be %q, got %q", i, e, a)
|
||||
}
|
||||
if e, a := aws.StringValue(c[i].NextToken), aws.StringValue(data.NextToken); e != a {
|
||||
t.Errorf("%d, expect NextToken to be %q, got %q", i, e, a)
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
if e, a := len(c), i; e != a {
|
||||
t.Errorf("expected to process %d pages, did %d", e, a)
|
||||
}
|
||||
if err := p.Err(); err != nil {
|
||||
t.Fatalf("%d, expected no error, got %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+60
-11
@@ -2,6 +2,7 @@ package request
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -25,21 +26,70 @@ func TestResetBody_WithBodyContents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) {
|
||||
type mockReader struct{}
|
||||
|
||||
func (mockReader) Read([]byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func TestResetBody_ExcludeEmptyUnseekableBodyByMethod(t *testing.T) {
|
||||
cases := []struct {
|
||||
Method string
|
||||
Body io.ReadSeeker
|
||||
IsNoBody bool
|
||||
}{
|
||||
{"GET", true},
|
||||
{"HEAD", true},
|
||||
{"DELETE", true},
|
||||
{"PUT", false},
|
||||
{"PATCH", false},
|
||||
{"POST", false},
|
||||
{
|
||||
Method: "GET",
|
||||
IsNoBody: true,
|
||||
Body: aws.ReadSeekCloser(mockReader{}),
|
||||
},
|
||||
{
|
||||
Method: "HEAD",
|
||||
IsNoBody: true,
|
||||
Body: aws.ReadSeekCloser(mockReader{}),
|
||||
},
|
||||
{
|
||||
Method: "DELETE",
|
||||
IsNoBody: true,
|
||||
Body: aws.ReadSeekCloser(mockReader{}),
|
||||
},
|
||||
{
|
||||
Method: "PUT",
|
||||
IsNoBody: false,
|
||||
Body: aws.ReadSeekCloser(mockReader{}),
|
||||
},
|
||||
{
|
||||
Method: "PATCH",
|
||||
IsNoBody: false,
|
||||
Body: aws.ReadSeekCloser(mockReader{}),
|
||||
},
|
||||
{
|
||||
Method: "POST",
|
||||
IsNoBody: false,
|
||||
Body: aws.ReadSeekCloser(mockReader{}),
|
||||
},
|
||||
{
|
||||
Method: "GET",
|
||||
IsNoBody: false,
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
|
||||
},
|
||||
{
|
||||
Method: "GET",
|
||||
IsNoBody: true,
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
|
||||
},
|
||||
{
|
||||
Method: "POST",
|
||||
IsNoBody: false,
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc"))),
|
||||
},
|
||||
{
|
||||
Method: "POST",
|
||||
IsNoBody: true,
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
reader := aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc")))
|
||||
|
||||
for i, c := range cases {
|
||||
r := Request{
|
||||
HTTPRequest: &http.Request{},
|
||||
@@ -47,8 +97,7 @@ func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) {
|
||||
HTTPMethod: c.Method,
|
||||
},
|
||||
}
|
||||
|
||||
r.SetReaderBody(reader)
|
||||
r.SetReaderBody(c.Body)
|
||||
|
||||
if a, e := r.HTTPRequest.Body == NoBody, c.IsNoBody; a != e {
|
||||
t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a)
|
||||
|
||||
+280
-3
@@ -8,9 +8,11 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +22,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
@@ -80,7 +83,7 @@ func TestRequestRecoverRetry5xx(t *testing.T) {
|
||||
reqNum := 0
|
||||
reqs := []http.Response{
|
||||
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
|
||||
{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
|
||||
{StatusCode: 502, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
|
||||
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
|
||||
}
|
||||
|
||||
@@ -112,7 +115,8 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
|
||||
reqNum := 0
|
||||
reqs := []http.Response{
|
||||
{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
|
||||
{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
|
||||
{StatusCode: 400, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
|
||||
{StatusCode: 429, Body: body(`{"__type":"FooException","message":"Rate exceeded."}`)},
|
||||
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
|
||||
}
|
||||
|
||||
@@ -131,7 +135,7 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
if e, a := 2, int(r.RetryCount); e != a {
|
||||
if e, a := 3, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
if e, a := "valid", out.Data; e != a {
|
||||
@@ -842,3 +846,276 @@ func TestRequest_TemporaryRetry(t *testing.T) {
|
||||
t.Errorf("expect temporary error, was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequest_Presign(t *testing.T) {
|
||||
presign := func(r *request.Request, expire time.Duration) (string, http.Header, error) {
|
||||
u, err := r.Presign(expire)
|
||||
return u, nil, err
|
||||
}
|
||||
presignRequest := func(r *request.Request, expire time.Duration) (string, http.Header, error) {
|
||||
return r.PresignRequest(expire)
|
||||
}
|
||||
mustParseURL := func(v string) *url.URL {
|
||||
u, err := url.Parse(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Expire time.Duration
|
||||
PresignFn func(*request.Request, time.Duration) (string, http.Header, error)
|
||||
SignerFn func(*request.Request)
|
||||
URL string
|
||||
Header http.Header
|
||||
Err string
|
||||
}{
|
||||
{
|
||||
PresignFn: presign,
|
||||
Err: request.ErrCodeInvalidPresignExpire,
|
||||
},
|
||||
{
|
||||
PresignFn: presignRequest,
|
||||
Err: request.ErrCodeInvalidPresignExpire,
|
||||
},
|
||||
{
|
||||
Expire: -1,
|
||||
PresignFn: presign,
|
||||
Err: request.ErrCodeInvalidPresignExpire,
|
||||
},
|
||||
{
|
||||
// Presign clear NotHoist
|
||||
Expire: 1 * time.Minute,
|
||||
PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) {
|
||||
r.NotHoist = true
|
||||
return presign(r, dur)
|
||||
},
|
||||
SignerFn: func(r *request.Request) {
|
||||
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
|
||||
if r.NotHoist {
|
||||
r.Error = fmt.Errorf("expect NotHoist to be cleared")
|
||||
}
|
||||
},
|
||||
URL: "https://endpoint/presignedURL",
|
||||
},
|
||||
{
|
||||
// PresignRequest does not clear NotHoist
|
||||
Expire: 1 * time.Minute,
|
||||
PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) {
|
||||
r.NotHoist = true
|
||||
return presignRequest(r, dur)
|
||||
},
|
||||
SignerFn: func(r *request.Request) {
|
||||
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
|
||||
if !r.NotHoist {
|
||||
r.Error = fmt.Errorf("expect NotHoist not to be cleared")
|
||||
}
|
||||
},
|
||||
URL: "https://endpoint/presignedURL",
|
||||
},
|
||||
{
|
||||
// PresignRequest returns signed headers
|
||||
Expire: 1 * time.Minute,
|
||||
PresignFn: presignRequest,
|
||||
SignerFn: func(r *request.Request) {
|
||||
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
|
||||
r.HTTPRequest.Header.Set("UnsigndHeader", "abc")
|
||||
r.SignedHeaderVals = http.Header{
|
||||
"X-Amzn-Header": []string{"abc", "123"},
|
||||
"X-Amzn-Header2": []string{"efg", "456"},
|
||||
}
|
||||
},
|
||||
URL: "https://endpoint/presignedURL",
|
||||
Header: http.Header{
|
||||
"X-Amzn-Header": []string{"abc", "123"},
|
||||
"X-Amzn-Header2": []string{"efg", "456"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := awstesting.NewClient()
|
||||
svc.Handlers.Clear()
|
||||
for i, c := range cases {
|
||||
req := svc.NewRequest(&request.Operation{
|
||||
Name: "name", HTTPMethod: "GET", HTTPPath: "/path",
|
||||
}, &struct{}{}, &struct{}{})
|
||||
req.Handlers.Sign.PushBack(c.SignerFn)
|
||||
|
||||
u, h, err := c.PresignFn(req, c.Expire)
|
||||
if len(c.Err) != 0 {
|
||||
if e, a := c.Err, err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%d, expect %v to be in %v", i, e, a)
|
||||
}
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Errorf("%d, expect no error, got %v", i, err)
|
||||
continue
|
||||
}
|
||||
if e, a := c.URL, u; e != a {
|
||||
t.Errorf("%d, expect %v URL, got %v", i, e, a)
|
||||
}
|
||||
if e, a := c.Header, h; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("%d, expect %v header got %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_EndpointWithDefaultPort(t *testing.T) {
|
||||
endpoint := "https://estest.us-east-1.es.amazonaws.com:443"
|
||||
expectedRequestHost := "estest.us-east-1.es.amazonaws.com"
|
||||
|
||||
r := request.New(
|
||||
aws.Config{},
|
||||
metadata.ClientInfo{Endpoint: endpoint},
|
||||
defaults.Handlers(),
|
||||
client.DefaultRetryer{},
|
||||
&request.Operation{},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
if h := r.HTTPRequest.Host; h != expectedRequestHost {
|
||||
t.Errorf("expect %v host, got %q", expectedRequestHost, h)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeHostForHeader(t *testing.T) {
|
||||
cases := []struct {
|
||||
url string
|
||||
expectedRequestHost string
|
||||
}{
|
||||
{"https://estest.us-east-1.es.amazonaws.com:443", "estest.us-east-1.es.amazonaws.com"},
|
||||
{"https://estest.us-east-1.es.amazonaws.com", "estest.us-east-1.es.amazonaws.com"},
|
||||
{"https://localhost:9200", "localhost:9200"},
|
||||
{"http://localhost:80", "localhost"},
|
||||
{"http://localhost:8080", "localhost:8080"},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
r, _ := http.NewRequest("GET", c.url, nil)
|
||||
request.SanitizeHostForHeader(r)
|
||||
|
||||
if h := r.Host; h != c.expectedRequestHost {
|
||||
t.Errorf("expect %v host, got %q", c.expectedRequestHost, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestWillRetry_ByBody(t *testing.T) {
|
||||
svc := awstesting.NewClient()
|
||||
|
||||
cases := []struct {
|
||||
WillRetry bool
|
||||
HTTPMethod string
|
||||
Body io.ReadSeeker
|
||||
IsReqNoBody bool
|
||||
}{
|
||||
{
|
||||
WillRetry: true,
|
||||
HTTPMethod: "GET",
|
||||
Body: bytes.NewReader([]byte{}),
|
||||
IsReqNoBody: true,
|
||||
},
|
||||
{
|
||||
WillRetry: true,
|
||||
HTTPMethod: "GET",
|
||||
Body: bytes.NewReader(nil),
|
||||
IsReqNoBody: true,
|
||||
},
|
||||
{
|
||||
WillRetry: true,
|
||||
HTTPMethod: "POST",
|
||||
Body: bytes.NewReader([]byte("abc123")),
|
||||
},
|
||||
{
|
||||
WillRetry: true,
|
||||
HTTPMethod: "POST",
|
||||
Body: aws.ReadSeekCloser(bytes.NewReader([]byte("abc123"))),
|
||||
},
|
||||
{
|
||||
WillRetry: true,
|
||||
HTTPMethod: "GET",
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
|
||||
IsReqNoBody: true,
|
||||
},
|
||||
{
|
||||
WillRetry: true,
|
||||
HTTPMethod: "POST",
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)),
|
||||
IsReqNoBody: true,
|
||||
},
|
||||
{
|
||||
WillRetry: false,
|
||||
HTTPMethod: "POST",
|
||||
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc123"))),
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
req := svc.NewRequest(&request.Operation{
|
||||
Name: "Operation",
|
||||
HTTPMethod: c.HTTPMethod,
|
||||
HTTPPath: "/",
|
||||
}, nil, nil)
|
||||
req.SetReaderBody(c.Body)
|
||||
req.Build()
|
||||
|
||||
req.Error = fmt.Errorf("some error")
|
||||
req.Retryable = aws.Bool(true)
|
||||
req.HTTPResponse = &http.Response{
|
||||
StatusCode: 500,
|
||||
}
|
||||
|
||||
if e, a := c.IsReqNoBody, request.NoBody == req.HTTPRequest.Body; e != a {
|
||||
t.Errorf("%d, expect request to be no body, %t, got %t, %T", i, e, a, req.HTTPRequest.Body)
|
||||
}
|
||||
|
||||
if e, a := c.WillRetry, req.WillRetry(); e != a {
|
||||
t.Errorf("%d, expect %t willRetry, got %t", i, e, a)
|
||||
}
|
||||
|
||||
if req.Error == nil {
|
||||
t.Fatalf("%d, expect error, got none", i)
|
||||
}
|
||||
if e, a := "some error", req.Error.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%d, expect %q error in %q", i, e, a)
|
||||
}
|
||||
if e, a := 0, req.RetryCount; e != a {
|
||||
t.Errorf("%d, expect retry count to be %d, got %d", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test501NotRetrying(t *testing.T) {
|
||||
reqNum := 0
|
||||
reqs := []http.Response{
|
||||
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
|
||||
{StatusCode: 501, Body: body(`{"__type":"NotImplemented","message":"An error occurred."}`)},
|
||||
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
|
||||
}
|
||||
|
||||
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
|
||||
s.Handlers.Validate.Clear()
|
||||
s.Handlers.Unmarshal.PushBack(unmarshal)
|
||||
s.Handlers.UnmarshalError.PushBack(unmarshalError)
|
||||
s.Handlers.Send.Clear() // mock sending
|
||||
s.Handlers.Send.PushBack(func(r *request.Request) {
|
||||
r.HTTPResponse = &reqs[reqNum]
|
||||
reqNum++
|
||||
})
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
if err == nil {
|
||||
t.Fatal("expect error, but got none")
|
||||
}
|
||||
|
||||
aerr := err.(awserr.Error)
|
||||
if e, a := "NotImplemented", aerr.Code(); e != a {
|
||||
t.Errorf("expected error code %q, but received %q", e, a)
|
||||
}
|
||||
if e, a := 1, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+8
@@ -5,6 +5,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||
)
|
||||
|
||||
// EnvProviderName provides a name of the provider when config is loaded from environment.
|
||||
@@ -176,6 +177,13 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
|
||||
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
|
||||
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
|
||||
|
||||
if len(cfg.SharedCredentialsFile) == 0 {
|
||||
cfg.SharedCredentialsFile = defaults.SharedCredentialsFilename()
|
||||
}
|
||||
if len(cfg.SharedConfigFile) == 0 {
|
||||
cfg.SharedConfigFile = defaults.SharedConfigFilename()
|
||||
}
|
||||
|
||||
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
|
||||
|
||||
return cfg
|
||||
|
||||
+37
-10
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||
)
|
||||
|
||||
func TestLoadEnvConfig_Creds(t *testing.T) {
|
||||
@@ -105,6 +106,8 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -116,6 +119,8 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -128,7 +133,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -136,6 +143,10 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_DEFAULT_REGION": "default_region",
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
},
|
||||
Config: envConfig{
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
@@ -145,7 +156,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
EnableSharedConfig: true,
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -155,7 +168,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
@@ -168,7 +183,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
@@ -182,7 +199,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
@@ -193,7 +212,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
EnableSharedConfig: true,
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
@@ -205,7 +226,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
},
|
||||
Config: envConfig{
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
EnableSharedConfig: true,
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
@@ -214,7 +237,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_CA_BUNDLE": "custom_ca_bundle",
|
||||
},
|
||||
Config: envConfig{
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -222,8 +247,10 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_CA_BUNDLE": "custom_ca_bundle",
|
||||
},
|
||||
Config: envConfig{
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
EnableSharedConfig: true,
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
EnableSharedConfig: true,
|
||||
SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(),
|
||||
SharedConfigFile: shareddefaults.SharedConfigFilename(),
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
|
||||
+19
-19
@@ -26,7 +26,7 @@ import (
|
||||
// Sessions are safe to create service clients concurrently, but it is not safe
|
||||
// to mutate the Session concurrently.
|
||||
//
|
||||
// The Session satisfies the service client's client.ClientConfigProvider.
|
||||
// The Session satisfies the service client's client.ConfigProvider.
|
||||
type Session struct {
|
||||
Config *aws.Config
|
||||
Handlers request.Handlers
|
||||
@@ -58,7 +58,12 @@ func New(cfgs ...*aws.Config) *Session {
|
||||
envCfg := loadEnvConfig()
|
||||
|
||||
if envCfg.EnableSharedConfig {
|
||||
s, err := newSession(Options{}, envCfg, cfgs...)
|
||||
var cfg aws.Config
|
||||
cfg.MergeIn(cfgs...)
|
||||
s, err := NewSessionWithOptions(Options{
|
||||
Config: cfg,
|
||||
SharedConfigState: SharedConfigEnable,
|
||||
})
|
||||
if err != nil {
|
||||
// Old session.New expected all errors to be discovered when
|
||||
// a request is made, and would report the errors then. This
|
||||
@@ -243,13 +248,6 @@ func NewSessionWithOptions(opts Options) (*Session, error) {
|
||||
envCfg.EnableSharedConfig = true
|
||||
}
|
||||
|
||||
if len(envCfg.SharedCredentialsFile) == 0 {
|
||||
envCfg.SharedCredentialsFile = defaults.SharedCredentialsFilename()
|
||||
}
|
||||
if len(envCfg.SharedConfigFile) == 0 {
|
||||
envCfg.SharedConfigFile = defaults.SharedConfigFilename()
|
||||
}
|
||||
|
||||
// Only use AWS_CA_BUNDLE if session option is not provided.
|
||||
if len(envCfg.CustomCABundle) != 0 && opts.CustomCABundle == nil {
|
||||
f, err := os.Open(envCfg.CustomCABundle)
|
||||
@@ -573,11 +571,12 @@ func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (
|
||||
}
|
||||
|
||||
return client.Config{
|
||||
Config: s.Config,
|
||||
Handlers: s.Handlers,
|
||||
Endpoint: resolved.URL,
|
||||
SigningRegion: resolved.SigningRegion,
|
||||
SigningName: resolved.SigningName,
|
||||
Config: s.Config,
|
||||
Handlers: s.Handlers,
|
||||
Endpoint: resolved.URL,
|
||||
SigningRegion: resolved.SigningRegion,
|
||||
SigningNameDerived: resolved.SigningNameDerived,
|
||||
SigningName: resolved.SigningName,
|
||||
}, err
|
||||
}
|
||||
|
||||
@@ -597,10 +596,11 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
|
||||
}
|
||||
|
||||
return client.Config{
|
||||
Config: s.Config,
|
||||
Handlers: s.Handlers,
|
||||
Endpoint: resolved.URL,
|
||||
SigningRegion: resolved.SigningRegion,
|
||||
SigningName: resolved.SigningName,
|
||||
Config: s.Config,
|
||||
Handlers: s.Handlers,
|
||||
Endpoint: resolved.URL,
|
||||
SigningRegion: resolved.SigningRegion,
|
||||
SigningNameDerived: resolved.SigningNameDerived,
|
||||
SigningName: resolved.SigningName,
|
||||
}
|
||||
}
|
||||
|
||||
+23
-5
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
@@ -89,14 +90,31 @@ func TestSessionCopy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionClientConfig(t *testing.T) {
|
||||
s, err := NewSession(&aws.Config{Region: aws.String("orig_region")})
|
||||
s, err := NewSession(&aws.Config{
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
Region: aws.String("orig_region"),
|
||||
EndpointResolver: endpoints.ResolverFunc(
|
||||
func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
|
||||
if e, a := "mock-service", service; e != a {
|
||||
t.Errorf("expect %q service, got %q", e, a)
|
||||
}
|
||||
if e, a := "other-region", region; e != a {
|
||||
t.Errorf("expect %q region, got %q", e, a)
|
||||
}
|
||||
return endpoints.ResolvedEndpoint{
|
||||
URL: "https://" + service + "." + region + ".amazonaws.com",
|
||||
SigningRegion: region,
|
||||
}, nil
|
||||
},
|
||||
),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
cfg := s.ClientConfig("s3", &aws.Config{Region: aws.String("us-west-2")})
|
||||
cfg := s.ClientConfig("mock-service", &aws.Config{Region: aws.String("other-region")})
|
||||
|
||||
assert.Equal(t, "https://s3-us-west-2.amazonaws.com", cfg.Endpoint)
|
||||
assert.Equal(t, "us-west-2", cfg.SigningRegion)
|
||||
assert.Equal(t, "us-west-2", *cfg.Config.Region)
|
||||
assert.Equal(t, "https://mock-service.other-region.amazonaws.com", cfg.Endpoint)
|
||||
assert.Equal(t, "other-region", cfg.SigningRegion)
|
||||
assert.Equal(t, "other-region", *cfg.Config.Region)
|
||||
}
|
||||
|
||||
func TestNewSession_NoCredentials(t *testing.T) {
|
||||
|
||||
+30
-11
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStandaloneSign(t *testing.T) {
|
||||
@@ -22,7 +21,9 @@ func TestStandaloneSign(t *testing.T) {
|
||||
c.SubDomain, c.Region, c.Service)
|
||||
|
||||
req, err := http.NewRequest("GET", host, nil)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
// URL.EscapedPath() will be used by the signer to get the
|
||||
// escaped form of the request's URI path.
|
||||
@@ -30,12 +31,20 @@ func TestStandaloneSign(t *testing.T) {
|
||||
req.URL.RawQuery = c.OrigQuery
|
||||
|
||||
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
actual := req.Header.Get("Authorization")
|
||||
assert.Equal(t, c.ExpSig, actual)
|
||||
assert.Equal(t, c.OrigURI, req.URL.Path)
|
||||
assert.Equal(t, c.EscapedURI, req.URL.EscapedPath())
|
||||
if e, a := c.ExpSig, actual; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.OrigURI, req.URL.Path; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +57,9 @@ func TestStandaloneSign_RawPath(t *testing.T) {
|
||||
c.SubDomain, c.Region, c.Service)
|
||||
|
||||
req, err := http.NewRequest("GET", host, nil)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
// URL.EscapedPath() will be used by the signer to get the
|
||||
// escaped form of the request's URI path.
|
||||
@@ -57,11 +68,19 @@ func TestStandaloneSign_RawPath(t *testing.T) {
|
||||
req.URL.RawQuery = c.OrigQuery
|
||||
|
||||
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
actual := req.Header.Get("Authorization")
|
||||
assert.Equal(t, c.ExpSig, actual)
|
||||
assert.Equal(t, c.OrigURI, req.URL.Path)
|
||||
assert.Equal(t, c.EscapedURI, req.URL.EscapedPath())
|
||||
if e, a := c.ExpSig, actual; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.OrigURI, req.URL.Path; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+88
@@ -164,3 +164,91 @@ func TestStandaloneSign_CustomURIEscape(t *testing.T) {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandaloneSign_WithPort(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
description string
|
||||
url string
|
||||
expectedSig string
|
||||
}{
|
||||
{
|
||||
"default HTTPS port",
|
||||
"https://estest.us-east-1.es.amazonaws.com:443/_search",
|
||||
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=e573fc9aa3a156b720976419319be98fb2824a3abc2ddd895ecb1d1611c6a82d",
|
||||
},
|
||||
{
|
||||
"default HTTP port",
|
||||
"http://example.com:80/_search",
|
||||
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=54ebe60c4ae03a40948b849e13c333523235f38002e2807059c64a9a8c7cb951",
|
||||
},
|
||||
{
|
||||
"non-standard HTTP port",
|
||||
"http://example.com:9200/_search",
|
||||
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=cd9d926a460f8d3b58b57beadbd87666dc667e014c0afaa4cea37b2867f51b4f",
|
||||
},
|
||||
{
|
||||
"non-standard HTTPS port",
|
||||
"https://example.com:9200/_search",
|
||||
"AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=cd9d926a460f8d3b58b57beadbd87666dc667e014c0afaa4cea37b2867f51b4f",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
signer := v4.NewSigner(unit.Session.Config.Credentials)
|
||||
req, _ := http.NewRequest("GET", c.url, nil)
|
||||
_, err := signer.Sign(req, nil, "es", "us-east-1", time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
actual := req.Header.Get("Authorization")
|
||||
if e, a := c.expectedSig, actual; e != a {
|
||||
t.Errorf("%s, expect %v, got %v", c.description, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandalonePresign_WithPort(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
description string
|
||||
url string
|
||||
expectedSig string
|
||||
}{
|
||||
{
|
||||
"default HTTPS port",
|
||||
"https://estest.us-east-1.es.amazonaws.com:443/_search",
|
||||
"0abcf61a351063441296febf4b485734d780634fba8cf1e7d9769315c35255d6",
|
||||
},
|
||||
{
|
||||
"default HTTP port",
|
||||
"http://example.com:80/_search",
|
||||
"fce9976dd6c849c21adfa6d3f3e9eefc651d0e4a2ccd740d43efddcccfdc8179",
|
||||
},
|
||||
{
|
||||
"non-standard HTTP port",
|
||||
"http://example.com:9200/_search",
|
||||
"f33c25a81c735e42bef35ed5e9f720c43940562e3e616ff0777bf6dde75249b0",
|
||||
},
|
||||
{
|
||||
"non-standard HTTPS port",
|
||||
"https://example.com:9200/_search",
|
||||
"f33c25a81c735e42bef35ed5e9f720c43940562e3e616ff0777bf6dde75249b0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
signer := v4.NewSigner(unit.Session.Config.Credentials)
|
||||
req, _ := http.NewRequest("GET", c.url, nil)
|
||||
_, err := signer.Presign(req, nil, "es", "us-east-1", 5 * time.Minute, time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
actual := req.URL.Query().Get("X-Amz-Signature")
|
||||
if e, a := c.expectedSig, actual; e != a {
|
||||
t.Errorf("%s, expect %v, got %v", c.description, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+33
-13
@@ -2,8 +2,6 @@ package v4
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRuleCheckWhitelist(t *testing.T) {
|
||||
@@ -13,8 +11,12 @@ func TestRuleCheckWhitelist(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, w.IsValid("Cache-Control"))
|
||||
assert.False(t, w.IsValid("Cache-"))
|
||||
if !w.IsValid("Cache-Control") {
|
||||
t.Error("expected true value")
|
||||
}
|
||||
if w.IsValid("Cache-") {
|
||||
t.Error("expected false value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleCheckBlacklist(t *testing.T) {
|
||||
@@ -24,16 +26,26 @@ func TestRuleCheckBlacklist(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
assert.False(t, b.IsValid("Cache-Control"))
|
||||
assert.True(t, b.IsValid("Cache-"))
|
||||
if b.IsValid("Cache-Control") {
|
||||
t.Error("expected false value")
|
||||
}
|
||||
if !b.IsValid("Cache-") {
|
||||
t.Error("expected true value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleCheckPattern(t *testing.T) {
|
||||
p := patterns{"X-Amz-Meta-"}
|
||||
|
||||
assert.True(t, p.IsValid("X-Amz-Meta-"))
|
||||
assert.True(t, p.IsValid("X-Amz-Meta-Star"))
|
||||
assert.False(t, p.IsValid("Cache-"))
|
||||
if !p.IsValid("X-Amz-Meta-") {
|
||||
t.Error("expected true value")
|
||||
}
|
||||
if !p.IsValid("X-Amz-Meta-Star") {
|
||||
t.Error("expected true value")
|
||||
}
|
||||
if p.IsValid("Cache-") {
|
||||
t.Error("expected false value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleComplexWhitelist(t *testing.T) {
|
||||
@@ -50,8 +62,16 @@ func TestRuleComplexWhitelist(t *testing.T) {
|
||||
inclusiveRules{patterns{"X-Amz-"}, blacklist{w}},
|
||||
}
|
||||
|
||||
assert.True(t, r.IsValid("X-Amz-Blah"))
|
||||
assert.False(t, r.IsValid("X-Amz-Meta-"))
|
||||
assert.False(t, r.IsValid("X-Amz-Meta-Star"))
|
||||
assert.False(t, r.IsValid("Cache-Control"))
|
||||
if !r.IsValid("X-Amz-Blah") {
|
||||
t.Error("expected true value")
|
||||
}
|
||||
if r.IsValid("X-Amz-Meta-") {
|
||||
t.Error("expected false value")
|
||||
}
|
||||
if r.IsValid("X-Amz-Meta-Star") {
|
||||
t.Error("expected false value")
|
||||
}
|
||||
if r.IsValid("Cache-Control") {
|
||||
t.Error("expected false value")
|
||||
}
|
||||
}
|
||||
|
||||
+28
-11
@@ -71,6 +71,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
)
|
||||
|
||||
@@ -268,7 +269,7 @@ type signingCtx struct {
|
||||
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
|
||||
// only compute the hash if the request header value is empty.
|
||||
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
|
||||
return v4.signWithBody(r, body, service, region, 0, signTime)
|
||||
return v4.signWithBody(r, body, service, region, 0, false, signTime)
|
||||
}
|
||||
|
||||
// Presign signs AWS v4 requests with the provided body, service name, region
|
||||
@@ -302,10 +303,10 @@ func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region strin
|
||||
// presigned request's signature you can set the "X-Amz-Content-Sha256"
|
||||
// HTTP header and that will be included in the request's signature.
|
||||
func (v4 Signer) Presign(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
|
||||
return v4.signWithBody(r, body, service, region, exp, signTime)
|
||||
return v4.signWithBody(r, body, service, region, exp, true, signTime)
|
||||
}
|
||||
|
||||
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
|
||||
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, isPresign bool, signTime time.Time) (http.Header, error) {
|
||||
currentTimeFn := v4.currentTimeFn
|
||||
if currentTimeFn == nil {
|
||||
currentTimeFn = time.Now
|
||||
@@ -317,7 +318,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
|
||||
Query: r.URL.Query(),
|
||||
Time: signTime,
|
||||
ExpireTime: exp,
|
||||
isPresign: exp != 0,
|
||||
isPresign: isPresign,
|
||||
ServiceName: service,
|
||||
Region: region,
|
||||
DisableURIPathEscaping: v4.DisableURIPathEscaping,
|
||||
@@ -339,8 +340,11 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
|
||||
return http.Header{}, err
|
||||
}
|
||||
|
||||
ctx.sanitizeHostForHeader()
|
||||
ctx.assignAmzQueryValues()
|
||||
ctx.build(v4.DisableHeaderHoisting)
|
||||
if err := ctx.build(v4.DisableHeaderHoisting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the request is not presigned the body should be attached to it. This
|
||||
// prevents the confusion of wanting to send a signed request without
|
||||
@@ -363,6 +367,10 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
|
||||
return ctx.SignedHeaderVals, nil
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) sanitizeHostForHeader() {
|
||||
request.SanitizeHostForHeader(ctx.Request)
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) handlePresignRemoval() {
|
||||
if !ctx.isPresign {
|
||||
return
|
||||
@@ -467,7 +475,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
|
||||
}
|
||||
|
||||
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
|
||||
name, region, req.ExpireTime, signingTime,
|
||||
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime,
|
||||
)
|
||||
if err != nil {
|
||||
req.Error = err
|
||||
@@ -498,11 +506,13 @@ func (v4 *Signer) logSigningInfo(ctx *signingCtx) {
|
||||
v4.Logger.Log(msg)
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) build(disableHeaderHoisting bool) {
|
||||
func (ctx *signingCtx) build(disableHeaderHoisting bool) error {
|
||||
ctx.buildTime() // no depends
|
||||
ctx.buildCredentialString() // no depends
|
||||
|
||||
ctx.buildBodyDigest()
|
||||
if err := ctx.buildBodyDigest(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unsignedHeaders := ctx.Request.Header
|
||||
if ctx.isPresign {
|
||||
@@ -530,6 +540,8 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) {
|
||||
}
|
||||
ctx.Request.Header.Set("Authorization", strings.Join(parts, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildTime() {
|
||||
@@ -656,7 +668,7 @@ func (ctx *signingCtx) buildSignature() {
|
||||
ctx.signature = hex.EncodeToString(signature)
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildBodyDigest() {
|
||||
func (ctx *signingCtx) buildBodyDigest() error {
|
||||
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
|
||||
if hash == "" {
|
||||
if ctx.unsignedPayload || (ctx.isPresign && ctx.ServiceName == "s3") {
|
||||
@@ -664,6 +676,9 @@ func (ctx *signingCtx) buildBodyDigest() {
|
||||
} else if ctx.Body == nil {
|
||||
hash = emptyStringSHA256
|
||||
} else {
|
||||
if !aws.IsReaderSeekable(ctx.Body) {
|
||||
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
|
||||
}
|
||||
hash = hex.EncodeToString(makeSha256Reader(ctx.Body))
|
||||
}
|
||||
if ctx.unsignedPayload || ctx.ServiceName == "s3" || ctx.ServiceName == "glacier" {
|
||||
@@ -671,6 +686,8 @@ func (ctx *signingCtx) buildBodyDigest() {
|
||||
}
|
||||
}
|
||||
ctx.bodyDigest = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRequestSigned returns if the request is currently signed or presigned
|
||||
@@ -710,8 +727,8 @@ func makeSha256(data []byte) []byte {
|
||||
|
||||
func makeSha256Reader(reader io.ReadSeeker) []byte {
|
||||
hash := sha256.New()
|
||||
start, _ := reader.Seek(0, 1)
|
||||
defer reader.Seek(start, 0)
|
||||
start, _ := reader.Seek(0, sdkio.SeekCurrent)
|
||||
defer reader.Seek(start, sdkio.SeekStart)
|
||||
|
||||
io.Copy(hash, reader)
|
||||
return hash.Sum(nil)
|
||||
|
||||
+85
-12
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -61,17 +62,42 @@ func TestStripExcessHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
func buildRequest(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
|
||||
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||
reader := strings.NewReader(body)
|
||||
req, _ := http.NewRequest("POST", endpoint, reader)
|
||||
return buildRequestWithBodyReader(serviceName, region, reader)
|
||||
}
|
||||
|
||||
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) {
|
||||
var bodyLen int
|
||||
|
||||
type lenner interface {
|
||||
Len() int
|
||||
}
|
||||
if lr, ok := body.(lenner); ok {
|
||||
bodyLen = lr.Len()
|
||||
}
|
||||
|
||||
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||
req, _ := http.NewRequest("POST", endpoint, body)
|
||||
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
|
||||
req.Header.Add("X-Amz-Target", "prefix.Operation")
|
||||
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Add("Content-Length", string(len(body)))
|
||||
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
|
||||
req.Header.Set("X-Amz-Target", "prefix.Operation")
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
|
||||
if bodyLen > 0 {
|
||||
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
|
||||
}
|
||||
|
||||
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
|
||||
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||
return req, reader
|
||||
|
||||
var seeker io.ReadSeeker
|
||||
if sr, ok := body.(io.ReadSeeker); ok {
|
||||
seeker = sr
|
||||
} else {
|
||||
seeker = aws.ReadSeekCloser(body)
|
||||
}
|
||||
|
||||
return req, seeker
|
||||
}
|
||||
|
||||
func buildSigner() Signer {
|
||||
@@ -101,7 +127,7 @@ func TestPresignRequest(t *testing.T) {
|
||||
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
|
||||
expectedSig := "ea7856749041f727690c580569738282e99c79355fe0d8f125d3b5535d2ece83"
|
||||
expectedSig := "122f0b9e091e4ba84286097e2b3404a1f1f4c4aad479adda95b7dff0ccbe5581"
|
||||
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
|
||||
expectedTarget := "prefix.Operation"
|
||||
|
||||
@@ -135,7 +161,7 @@ func TestPresignBodyWithArrayRequest(t *testing.T) {
|
||||
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
|
||||
expectedSig := "fef6002062400bbf526d70f1a6456abc0fb2e213fe1416012737eebd42a62924"
|
||||
expectedSig := "e3ac55addee8711b76c6d608d762cff285fe8b627a057f8b5ec9268cf82c08b1"
|
||||
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
|
||||
expectedTarget := "prefix.Operation"
|
||||
|
||||
@@ -166,14 +192,14 @@ func TestSignRequest(t *testing.T) {
|
||||
signer.Sign(req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
|
||||
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=ea766cabd2ec977d955a3c2bae1ae54f4515d70752f2207618396f20aa85bd21"
|
||||
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
|
||||
|
||||
q := req.Header
|
||||
if e, a := expectedSig, q.Get("Authorization"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,6 +233,53 @@ func TestPresignEmptyBodyS3(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignUnseekableBody(t *testing.T) {
|
||||
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
|
||||
signer := buildSigner()
|
||||
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
|
||||
if err == nil {
|
||||
t.Fatalf("expect error signing request")
|
||||
}
|
||||
|
||||
if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %q to be in %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignUnsignedPayloadUnseekableBody(t *testing.T) {
|
||||
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
|
||||
|
||||
signer := buildSigner()
|
||||
signer.UnsignedPayload = true
|
||||
|
||||
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
if e, a := "UNSIGNED-PAYLOAD", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignPreComputedHashUnseekableBody(t *testing.T) {
|
||||
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
|
||||
|
||||
signer := buildSigner()
|
||||
|
||||
req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
|
||||
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
if e, a := "some-content-sha256", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignPrecomputedBodyChecksum(t *testing.T) {
|
||||
req, body := buildRequest("dynamodb", "us-east-1", "hello")
|
||||
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
|
||||
|
||||
+83
@@ -3,6 +3,8 @@ package aws
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should
|
||||
@@ -22,6 +24,22 @@ type ReaderSeekerCloser struct {
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
// IsReaderSeekable returns if the underlying reader type can be seeked. A
|
||||
// io.Reader might not actually be seekable if it is the ReaderSeekerCloser
|
||||
// type.
|
||||
func IsReaderSeekable(r io.Reader) bool {
|
||||
switch v := r.(type) {
|
||||
case ReaderSeekerCloser:
|
||||
return v.IsSeeker()
|
||||
case *ReaderSeekerCloser:
|
||||
return v.IsSeeker()
|
||||
case io.ReadSeeker:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads from the reader up to size of p. The number of bytes read, and
|
||||
// error if it occurred will be returned.
|
||||
//
|
||||
@@ -56,6 +74,71 @@ func (r ReaderSeekerCloser) IsSeeker() bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// HasLen returns the length of the underlying reader if the value implements
|
||||
// the Len() int method.
|
||||
func (r ReaderSeekerCloser) HasLen() (int, bool) {
|
||||
type lenner interface {
|
||||
Len() int
|
||||
}
|
||||
|
||||
if lr, ok := r.r.(lenner); ok {
|
||||
return lr.Len(), true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetLen returns the length of the bytes remaining in the underlying reader.
|
||||
// Checks first for Len(), then io.Seeker to determine the size of the
|
||||
// underlying reader.
|
||||
//
|
||||
// Will return -1 if the length cannot be determined.
|
||||
func (r ReaderSeekerCloser) GetLen() (int64, error) {
|
||||
if l, ok := r.HasLen(); ok {
|
||||
return int64(l), nil
|
||||
}
|
||||
|
||||
if s, ok := r.r.(io.Seeker); ok {
|
||||
return seekerLen(s)
|
||||
}
|
||||
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// SeekerLen attempts to get the number of bytes remaining at the seeker's
|
||||
// current position. Returns the number of bytes remaining or error.
|
||||
func SeekerLen(s io.Seeker) (int64, error) {
|
||||
// Determine if the seeker is actually seekable. ReaderSeekerCloser
|
||||
// hides the fact that a io.Readers might not actually be seekable.
|
||||
switch v := s.(type) {
|
||||
case ReaderSeekerCloser:
|
||||
return v.GetLen()
|
||||
case *ReaderSeekerCloser:
|
||||
return v.GetLen()
|
||||
}
|
||||
|
||||
return seekerLen(s)
|
||||
}
|
||||
|
||||
func seekerLen(s io.Seeker) (int64, error) {
|
||||
curOffset, err := s.Seek(0, sdkio.SeekCurrent)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
endOffset, err := s.Seek(0, sdkio.SeekEnd)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = s.Seek(curOffset, sdkio.SeekStart)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return endOffset - curOffset, nil
|
||||
}
|
||||
|
||||
// Close closes the ReaderSeekerCloser.
|
||||
//
|
||||
// If the ReaderSeekerCloser is not an io.Closer nothing will be done.
|
||||
|
||||
+28
-11
@@ -1,32 +1,49 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWriteAtBuffer(t *testing.T) {
|
||||
b := &WriteAtBuffer{}
|
||||
|
||||
n, err := b.WriteAt([]byte{1}, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, n)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := 1, n; e != a {
|
||||
t.Errorf("expected %d, but recieved %d", e, a)
|
||||
}
|
||||
|
||||
n, err = b.WriteAt([]byte{1, 1, 1}, 5)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, n)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := 3, n; e != a {
|
||||
t.Errorf("expected %d, but recieved %d", e, a)
|
||||
}
|
||||
|
||||
n, err = b.WriteAt([]byte{2}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, n)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := 1, n; e != a {
|
||||
t.Errorf("expected %d, but recieved %d", e, a)
|
||||
}
|
||||
|
||||
n, err = b.WriteAt([]byte{3}, 2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, n)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
if e, a := 1, n; e != a {
|
||||
t.Errorf("expected %d, but received %d", e, a)
|
||||
}
|
||||
|
||||
assert.Equal(t, []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
|
||||
if !bytes.Equal([]byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes()) {
|
||||
t.Errorf("expected %v, but received %v", []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteAtBuffer(b *testing.B) {
|
||||
|
||||
+1
-1
@@ -5,4 +5,4 @@ package aws
|
||||
const SDKName = "aws-sdk-go"
|
||||
|
||||
// SDKVersion is the version of this SDK
|
||||
const SDKVersion = "1.12.1"
|
||||
const SDKVersion = "1.13.31"
|
||||
|
||||
Reference in New Issue
Block a user