mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-06-13 06:40:41 +00:00
Upgrade AWS SDK to the latest version
This commit is contained in:
+86
@@ -0,0 +1,86 @@
|
||||
// Package arn provides a parser for interacting with Amazon Resource Names.
|
||||
package arn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
arnDelimiter = ":"
|
||||
arnSections = 6
|
||||
arnPrefix = "arn:"
|
||||
|
||||
// zero-indexed
|
||||
sectionPartition = 1
|
||||
sectionService = 2
|
||||
sectionRegion = 3
|
||||
sectionAccountID = 4
|
||||
sectionResource = 5
|
||||
|
||||
// errors
|
||||
invalidPrefix = "arn: invalid prefix"
|
||||
invalidSections = "arn: not enough sections"
|
||||
)
|
||||
|
||||
// ARN captures the individual fields of an Amazon Resource Name.
|
||||
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information.
|
||||
type ARN struct {
|
||||
// The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in
|
||||
// other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China
|
||||
// (Beijing) region is "aws-cn".
|
||||
Partition string
|
||||
|
||||
// The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of
|
||||
// namespaces, see
|
||||
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces.
|
||||
Service string
|
||||
|
||||
// The region the resource resides in. Note that the ARNs for some resources do not require a region, so this
|
||||
// component might be omitted.
|
||||
Region string
|
||||
|
||||
// The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the
|
||||
// ARNs for some resources don't require an account number, so this component might be omitted.
|
||||
AccountID string
|
||||
|
||||
// The content of this part of the ARN varies by service. It often includes an indicator of the type of resource —
|
||||
// for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the
|
||||
// resource name itself. Some services allows paths for resource names, as described in
|
||||
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths.
|
||||
Resource string
|
||||
}
|
||||
|
||||
// Parse parses an ARN into its constituent parts.
|
||||
//
|
||||
// Some example ARNs:
|
||||
// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment
|
||||
// arn:aws:iam::123456789012:user/David
|
||||
// arn:aws:rds:eu-west-1:123456789012:db:mysql-db
|
||||
// arn:aws:s3:::my_corporate_bucket/exampleobject.png
|
||||
func Parse(arn string) (ARN, error) {
|
||||
if !strings.HasPrefix(arn, arnPrefix) {
|
||||
return ARN{}, errors.New(invalidPrefix)
|
||||
}
|
||||
sections := strings.SplitN(arn, arnDelimiter, arnSections)
|
||||
if len(sections) != arnSections {
|
||||
return ARN{}, errors.New(invalidSections)
|
||||
}
|
||||
return ARN{
|
||||
Partition: sections[sectionPartition],
|
||||
Service: sections[sectionService],
|
||||
Region: sections[sectionRegion],
|
||||
AccountID: sections[sectionAccountID],
|
||||
Resource: sections[sectionResource],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// String returns the canonical representation of the ARN
|
||||
func (arn ARN) String() string {
|
||||
return arnPrefix +
|
||||
arn.Partition + arnDelimiter +
|
||||
arn.Service + arnDelimiter +
|
||||
arn.Region + arnDelimiter +
|
||||
arn.AccountID + arnDelimiter +
|
||||
arn.Resource
|
||||
}
|
||||
+90
@@ -0,0 +1,90 @@
|
||||
// +build go1.7
|
||||
|
||||
package arn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseARN(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
arn ARN
|
||||
err error
|
||||
}{
|
||||
{
|
||||
input: "invalid",
|
||||
err: errors.New(invalidPrefix),
|
||||
},
|
||||
{
|
||||
input: "arn:nope",
|
||||
err: errors.New(invalidSections),
|
||||
},
|
||||
{
|
||||
input: "arn:aws:ecr:us-west-2:123456789012:repository/foo/bar",
|
||||
arn: ARN{
|
||||
Partition: "aws",
|
||||
Service: "ecr",
|
||||
Region: "us-west-2",
|
||||
AccountID: "123456789012",
|
||||
Resource: "repository/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment",
|
||||
arn: ARN{
|
||||
Partition: "aws",
|
||||
Service: "elasticbeanstalk",
|
||||
Region: "us-east-1",
|
||||
AccountID: "123456789012",
|
||||
Resource: "environment/My App/MyEnvironment",
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "arn:aws:iam::123456789012:user/David",
|
||||
arn: ARN{
|
||||
Partition: "aws",
|
||||
Service: "iam",
|
||||
Region: "",
|
||||
AccountID: "123456789012",
|
||||
Resource: "user/David",
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "arn:aws:rds:eu-west-1:123456789012:db:mysql-db",
|
||||
arn: ARN{
|
||||
Partition: "aws",
|
||||
Service: "rds",
|
||||
Region: "eu-west-1",
|
||||
AccountID: "123456789012",
|
||||
Resource: "db:mysql-db",
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "arn:aws:s3:::my_corporate_bucket/exampleobject.png",
|
||||
arn: ARN{
|
||||
Partition: "aws",
|
||||
Service: "s3",
|
||||
Region: "",
|
||||
AccountID: "",
|
||||
Resource: "my_corporate_bucket/exampleobject.png",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
spec, err := Parse(tc.input)
|
||||
if tc.arn != spec {
|
||||
t.Errorf("Expected %q to parse as %v, but got %v", tc.input, tc.arn, spec)
|
||||
}
|
||||
if err == nil && tc.err != nil {
|
||||
t.Errorf("Expected err to be %v, but got nil", tc.err)
|
||||
} else if err != nil && tc.err == nil {
|
||||
t.Errorf("Expected err to be nil, but got %v", err)
|
||||
} else if err != nil && tc.err != nil && err.Error() != tc.err.Error() {
|
||||
t.Errorf("Expected err to be %v, but got %v", tc.err, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+3
-59
@@ -2,7 +2,6 @@ package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
@@ -46,7 +45,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
|
||||
svc := &Client{
|
||||
Config: cfg,
|
||||
ClientInfo: info,
|
||||
Handlers: handlers,
|
||||
Handlers: handlers.Copy(),
|
||||
}
|
||||
|
||||
switch retryer, ok := cfg.Retryer.(request.Retryer); {
|
||||
@@ -86,61 +85,6 @@ func (c *Client) AddDebugHandlers() {
|
||||
return
|
||||
}
|
||||
|
||||
c.Handlers.Send.PushFront(logRequest)
|
||||
c.Handlers.Send.PushBack(logResponse)
|
||||
}
|
||||
|
||||
const logReqMsg = `DEBUG: Request %s/%s Details:
|
||||
---[ REQUEST POST-SIGN ]-----------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
const logReqErrMsg = `DEBUG ERROR: Request %s/%s:
|
||||
---[ REQUEST DUMP ERROR ]-----------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
func logRequest(r *request.Request) {
|
||||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
|
||||
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
|
||||
if err != nil {
|
||||
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
if logBody {
|
||||
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
|
||||
// Body as a NoOpCloser and will not be reset after read by the HTTP
|
||||
// client reader.
|
||||
r.ResetBody()
|
||||
}
|
||||
|
||||
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
|
||||
}
|
||||
|
||||
const logRespMsg = `DEBUG: Response %s/%s Details:
|
||||
---[ RESPONSE ]--------------------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
|
||||
---[ RESPONSE DUMP ERROR ]-----------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
func logResponse(r *request.Request) {
|
||||
var msg = "no response data"
|
||||
if r.HTTPResponse != nil {
|
||||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
|
||||
dumpedBody, err := httputil.DumpResponse(r.HTTPResponse, logBody)
|
||||
if err != nil {
|
||||
r.Config.Logger.Log(fmt.Sprintf(logRespErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
msg = string(dumpedBody)
|
||||
} else if r.Error != nil {
|
||||
msg = r.Error.Error()
|
||||
}
|
||||
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ClientInfo.ServiceName, r.Operation.Name, msg))
|
||||
c.Handlers.Send.PushFrontNamed(request.NamedHandler{Name: "awssdk.client.LogRequest", Fn: logRequest})
|
||||
c.Handlers.Send.PushBackNamed(request.NamedHandler{Name: "awssdk.client.LogResponse", Fn: logResponse})
|
||||
}
|
||||
|
||||
+78
@@ -0,0 +1,78 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
func pushBackTestHandler(name string, list *request.HandlerList) *bool {
|
||||
called := false
|
||||
(*list).PushBackNamed(request.NamedHandler{
|
||||
Name: name,
|
||||
Fn: func(r *request.Request) {
|
||||
called = true
|
||||
},
|
||||
})
|
||||
|
||||
return &called
|
||||
}
|
||||
|
||||
func pushFrontTestHandler(name string, list *request.HandlerList) *bool {
|
||||
called := false
|
||||
(*list).PushFrontNamed(request.NamedHandler{
|
||||
Name: name,
|
||||
Fn: func(r *request.Request) {
|
||||
called = true
|
||||
},
|
||||
})
|
||||
|
||||
return &called
|
||||
}
|
||||
|
||||
func TestNewClient_CopyHandlers(t *testing.T) {
|
||||
handlers := request.Handlers{}
|
||||
firstCalled := pushBackTestHandler("first", &handlers.Send)
|
||||
secondCalled := pushBackTestHandler("second", &handlers.Send)
|
||||
|
||||
var clientHandlerCalled *bool
|
||||
c := New(aws.Config{}, metadata.ClientInfo{}, handlers,
|
||||
func(c *Client) {
|
||||
clientHandlerCalled = pushFrontTestHandler("client handler", &c.Handlers.Send)
|
||||
},
|
||||
)
|
||||
|
||||
if e, a := 2, handlers.Send.Len(); e != a {
|
||||
t.Errorf("expect %d original handlers, got %d", e, a)
|
||||
}
|
||||
if e, a := 3, c.Handlers.Send.Len(); e != a {
|
||||
t.Errorf("expect %d client handlers, got %d", e, a)
|
||||
}
|
||||
|
||||
handlers.Send.Run(nil)
|
||||
if !*firstCalled {
|
||||
t.Errorf("expect first handler to of been called")
|
||||
}
|
||||
*firstCalled = false
|
||||
if !*secondCalled {
|
||||
t.Errorf("expect second handler to of been called")
|
||||
}
|
||||
*secondCalled = false
|
||||
if *clientHandlerCalled {
|
||||
t.Errorf("expect client handler to not of been called, but was")
|
||||
}
|
||||
|
||||
c.Handlers.Send.Run(nil)
|
||||
if !*firstCalled {
|
||||
t.Errorf("expect client's first handler to of been called")
|
||||
}
|
||||
if !*secondCalled {
|
||||
t.Errorf("expect client's second handler to of been called")
|
||||
}
|
||||
if !*clientHandlerCalled {
|
||||
t.Errorf("expect client's client handler to of been called")
|
||||
}
|
||||
|
||||
}
|
||||
+8
-2
@@ -15,11 +15,11 @@ import (
|
||||
// the MaxRetries method:
|
||||
//
|
||||
// type retryer struct {
|
||||
// service.DefaultRetryer
|
||||
// client.DefaultRetryer
|
||||
// }
|
||||
//
|
||||
// // This implementation always has 100 max retries
|
||||
// func (d retryer) MaxRetries() uint { return 100 }
|
||||
// func (d retryer) MaxRetries() int { return 100 }
|
||||
type DefaultRetryer struct {
|
||||
NumMaxRetries int
|
||||
}
|
||||
@@ -54,6 +54,12 @@ func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
|
||||
|
||||
// ShouldRetry returns true if the request should be retried.
|
||||
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
|
||||
// If one of the other handlers already set the retry state
|
||||
// we don't want to override it based on the service's state
|
||||
if r.Retryable != nil {
|
||||
return *r.Retryable
|
||||
}
|
||||
|
||||
if r.HTTPResponse.StatusCode >= 500 {
|
||||
return true
|
||||
}
|
||||
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
const logReqMsg = `DEBUG: Request %s/%s Details:
|
||||
---[ REQUEST POST-SIGN ]-----------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
const logReqErrMsg = `DEBUG ERROR: Request %s/%s:
|
||||
---[ REQUEST DUMP ERROR ]-----------------------------
|
||||
%s
|
||||
------------------------------------------------------`
|
||||
|
||||
type logWriter struct {
|
||||
// Logger is what we will use to log the payload of a response.
|
||||
Logger aws.Logger
|
||||
// buf stores the contents of what has been read
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func (logger *logWriter) Write(b []byte) (int, error) {
|
||||
return logger.buf.Write(b)
|
||||
}
|
||||
|
||||
type teeReaderCloser struct {
|
||||
// io.Reader will be a tee reader that is used during logging.
|
||||
// This structure will read from a body and write the contents to a logger.
|
||||
io.Reader
|
||||
// Source is used just to close when we are done reading.
|
||||
Source io.ReadCloser
|
||||
}
|
||||
|
||||
func (reader *teeReaderCloser) Close() error {
|
||||
return reader.Source.Close()
|
||||
}
|
||||
|
||||
func logRequest(r *request.Request) {
|
||||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
|
||||
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
|
||||
if err != nil {
|
||||
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
if logBody {
|
||||
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
|
||||
// Body as a NoOpCloser and will not be reset after read by the HTTP
|
||||
// client reader.
|
||||
r.ResetBody()
|
||||
}
|
||||
|
||||
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
|
||||
}
|
||||
|
||||
const logRespMsg = `DEBUG: Response %s/%s Details:
|
||||
---[ RESPONSE ]--------------------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
|
||||
---[ RESPONSE DUMP ERROR ]-----------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
func logResponse(r *request.Request) {
|
||||
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
|
||||
r.HTTPResponse.Body = &teeReaderCloser{
|
||||
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
|
||||
Source: r.HTTPResponse.Body,
|
||||
}
|
||||
|
||||
handlerFn := func(req *request.Request) {
|
||||
body, err := httputil.DumpResponse(req.HTTPResponse, false)
|
||||
if err != nil {
|
||||
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(lw.buf)
|
||||
if err != nil {
|
||||
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
|
||||
return
|
||||
}
|
||||
lw.Logger.Log(fmt.Sprintf(logRespMsg, req.ClientInfo.ServiceName, req.Operation.Name, string(body)))
|
||||
if req.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) {
|
||||
lw.Logger.Log(string(b))
|
||||
}
|
||||
}
|
||||
|
||||
const handlerName = "awsdk.client.LogResponse.ResponseBody"
|
||||
|
||||
r.Handlers.Unmarshal.SetBackNamed(request.NamedHandler{
|
||||
Name: handlerName, Fn: handlerFn,
|
||||
})
|
||||
r.Handlers.UnmarshalError.SetBackNamed(request.NamedHandler{
|
||||
Name: handlerName, Fn: handlerFn,
|
||||
})
|
||||
}
|
||||
+57
@@ -0,0 +1,57 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockCloser struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (closer *mockCloser) Read(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (closer *mockCloser) Close() error {
|
||||
closer.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTeeReaderCloser(t *testing.T) {
|
||||
expected := "FOO"
|
||||
buf := bytes.NewBuffer([]byte(expected))
|
||||
lw := bytes.NewBuffer(nil)
|
||||
c := &mockCloser{}
|
||||
closer := teeReaderCloser{
|
||||
io.TeeReader(buf, lw),
|
||||
c,
|
||||
}
|
||||
|
||||
b := make([]byte, len(expected))
|
||||
_, err := closer.Read(b)
|
||||
closer.Close()
|
||||
|
||||
if expected != lw.String() {
|
||||
t.Errorf("Expected %q, but received %q", expected, lw.String())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected 'nil', but received %v", err)
|
||||
}
|
||||
|
||||
if !c.closed {
|
||||
t.Error("Expected 'true', but received 'false'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter(t *testing.T) {
|
||||
expected := "FOO"
|
||||
lw := &logWriter{nil, bytes.NewBuffer(nil)}
|
||||
lw.Write([]byte(expected))
|
||||
|
||||
if expected != lw.buf.String() {
|
||||
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
|
||||
}
|
||||
}
|
||||
+12
-1
@@ -53,6 +53,13 @@ type Config struct {
|
||||
// to use based on region.
|
||||
EndpointResolver endpoints.Resolver
|
||||
|
||||
// EnforceShouldRetryCheck is used in the AfterRetryHandler to always call
|
||||
// ShouldRetry regardless of whether or not if request.Retryable is set.
|
||||
// This will utilize ShouldRetry method of custom retryers. If EnforceShouldRetryCheck
|
||||
// is not set, then ShouldRetry will only be called if request.Retryable is nil.
|
||||
// Proper handling of the request.Retryable field is important when setting this field.
|
||||
EnforceShouldRetryCheck *bool
|
||||
|
||||
// The region to send requests to. This parameter is required and must
|
||||
// be configured globally or on a per-client basis unless otherwise
|
||||
// noted. A full list of regions is found in the "Regions and Endpoints"
|
||||
@@ -88,7 +95,7 @@ type Config struct {
|
||||
// recoverable failures.
|
||||
//
|
||||
// When nil or the value does not implement the request.Retryer interface,
|
||||
// the request.DefaultRetryer will be used.
|
||||
// the client.DefaultRetryer will be used.
|
||||
//
|
||||
// When both Retryer and MaxRetries are non-nil, the former is used and
|
||||
// the latter ignored.
|
||||
@@ -443,6 +450,10 @@ func mergeInConfig(dst *Config, other *Config) {
|
||||
if other.DisableRestProtocolURICleaning != nil {
|
||||
dst.DisableRestProtocolURICleaning = other.DisableRestProtocolURICleaning
|
||||
}
|
||||
|
||||
if other.EnforceShouldRetryCheck != nil {
|
||||
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
|
||||
}
|
||||
}
|
||||
|
||||
// Copy will return a shallow copy of the Config object. If any additional
|
||||
|
||||
+3
-3
@@ -4,9 +4,9 @@ package aws
|
||||
|
||||
import "time"
|
||||
|
||||
// An emptyCtx is a copy of the the Go 1.7 context.emptyCtx type. This
|
||||
// is copied to provide a 1.6 and 1.5 safe version of context that is compatible
|
||||
// with Go 1.7's Context.
|
||||
// An emptyCtx is a copy of the Go 1.7 context.emptyCtx type. This is copied to
|
||||
// provide a 1.6 and 1.5 safe version of context that is compatible with Go
|
||||
// 1.7's Context.
|
||||
//
|
||||
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
|
||||
// struct{}, since vars of this type must have distinct addresses.
|
||||
|
||||
+18
@@ -311,6 +311,24 @@ func TimeValue(v *time.Time) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// SecondsTimeValue converts an int64 pointer to a time.Time value
|
||||
// representing seconds since Epoch or time.Time{} if the pointer is nil.
|
||||
func SecondsTimeValue(v *int64) time.Time {
|
||||
if v != nil {
|
||||
return time.Unix((*v / 1000), 0)
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// MillisecondsTimeValue converts an int64 pointer to a time.Time value
|
||||
// representing milliseconds sinch Epoch or time.Time{} if the pointer is nil.
|
||||
func MillisecondsTimeValue(v *int64) time.Time {
|
||||
if v != nil {
|
||||
return time.Unix(0, (*v * 1000000))
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// TimeUnixMilli returns a Unix timestamp in milliseconds from "January 1, 1970 UTC".
|
||||
// The result is undefined if the Unix time cannot be represented by an int64.
|
||||
// Which includes calling TimeUnixMilli on a zero Time is undefined.
|
||||
|
||||
+33
@@ -435,3 +435,36 @@ func TestTimeMap(t *testing.T) {
|
||||
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
type TimeValueTestCase struct {
|
||||
in int64
|
||||
outSecs time.Time
|
||||
outMillis time.Time
|
||||
}
|
||||
|
||||
var testCasesTimeValue = []TimeValueTestCase{
|
||||
{
|
||||
in: int64(1501558289000),
|
||||
outSecs: time.Unix(1501558289, 0),
|
||||
outMillis: time.Unix(1501558289, 0),
|
||||
},
|
||||
{
|
||||
in: int64(1501558289001),
|
||||
outSecs: time.Unix(1501558289, 0),
|
||||
outMillis: time.Unix(1501558289, 1*1000000),
|
||||
},
|
||||
}
|
||||
|
||||
func TestSecondsTimeValue(t *testing.T) {
|
||||
for idx, testCase := range testCasesTimeValue {
|
||||
out := SecondsTimeValue(&testCase.in)
|
||||
assert.Equal(t, testCase.outSecs, out, "Unexpected value for time value at %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMillisecondsTimeValue(t *testing.T) {
|
||||
for idx, testCase := range testCasesTimeValue {
|
||||
out := MillisecondsTimeValue(&testCase.in)
|
||||
assert.Equal(t, testCase.outMillis, out, "Unexpected value for time value at %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
+86
-45
@@ -27,7 +27,7 @@ type lener interface {
|
||||
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
|
||||
// to determine request body length and no "Content-Length" was specified it will panic.
|
||||
//
|
||||
// The Content-Length will only be aded to the request if the length of the body
|
||||
// The Content-Length will only be added to the request if the length of the body
|
||||
// is greater than 0. If the body is empty or the current `Content-Length`
|
||||
// header is <= 0, the header will also be stripped.
|
||||
var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) {
|
||||
@@ -71,8 +71,8 @@ var reStatusCode = regexp.MustCompile(`^(\d{3})`)
|
||||
|
||||
// ValidateReqSigHandler is a request handler to ensure that the request's
|
||||
// signature doesn't expire before it is sent. This can happen when a request
|
||||
// is built and signed signficantly before it is sent. Or significant delays
|
||||
// occur whne retrying requests that would cause the signature to expire.
|
||||
// is built and signed significantly before it is sent. Or significant delays
|
||||
// occur when retrying requests that would cause the signature to expire.
|
||||
var ValidateReqSigHandler = request.NamedHandler{
|
||||
Name: "core.ValidateReqSigHandler",
|
||||
Fn: func(r *request.Request) {
|
||||
@@ -98,54 +98,95 @@ var ValidateReqSigHandler = request.NamedHandler{
|
||||
}
|
||||
|
||||
// SendHandler is a request handler to send service request using HTTP client.
|
||||
var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) {
|
||||
var err error
|
||||
r.HTTPResponse, err = r.Config.HTTPClient.Do(r.HTTPRequest)
|
||||
if err != nil {
|
||||
// Prevent leaking if an HTTPResponse was returned. Clean up
|
||||
// the body.
|
||||
if r.HTTPResponse != nil {
|
||||
r.HTTPResponse.Body.Close()
|
||||
var SendHandler = request.NamedHandler{
|
||||
Name: "core.SendHandler",
|
||||
Fn: func(r *request.Request) {
|
||||
sender := sendFollowRedirects
|
||||
if r.DisableFollowRedirects {
|
||||
sender = sendWithoutFollowRedirects
|
||||
}
|
||||
// Capture the case where url.Error is returned for error processing
|
||||
// response. e.g. 301 without location header comes back as string
|
||||
// error and r.HTTPResponse is nil. Other url redirect errors will
|
||||
// comeback in a similar method.
|
||||
if e, ok := err.(*url.Error); ok && e.Err != nil {
|
||||
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
|
||||
code, _ := strconv.ParseInt(s[1], 10, 64)
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: int(code),
|
||||
Status: http.StatusText(int(code)),
|
||||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if request.NoBody == r.HTTPRequest.Body {
|
||||
// Strip off the request body if the NoBody reader was used as a
|
||||
// place holder for a request body. This prevents the SDK from
|
||||
// making requests with a request body when it would be invalid
|
||||
// to do so.
|
||||
//
|
||||
// Use a shallow copy of the http.Request to ensure the race condition
|
||||
// of transport on Body will not trigger
|
||||
reqOrig, reqCopy := r.HTTPRequest, *r.HTTPRequest
|
||||
reqCopy.Body = nil
|
||||
r.HTTPRequest = &reqCopy
|
||||
defer func() {
|
||||
r.HTTPRequest = reqOrig
|
||||
}()
|
||||
}
|
||||
if r.HTTPResponse == nil {
|
||||
// Add a dummy request response object to ensure the HTTPResponse
|
||||
// value is consistent.
|
||||
|
||||
var err error
|
||||
r.HTTPResponse, err = sender(r)
|
||||
if err != nil {
|
||||
handleSendError(r, err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func sendFollowRedirects(r *request.Request) (*http.Response, error) {
|
||||
return r.Config.HTTPClient.Do(r.HTTPRequest)
|
||||
}
|
||||
|
||||
func sendWithoutFollowRedirects(r *request.Request) (*http.Response, error) {
|
||||
transport := r.Config.HTTPClient.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
return transport.RoundTrip(r.HTTPRequest)
|
||||
}
|
||||
|
||||
func handleSendError(r *request.Request, err error) {
|
||||
// Prevent leaking if an HTTPResponse was returned. Clean up
|
||||
// the body.
|
||||
if r.HTTPResponse != nil {
|
||||
r.HTTPResponse.Body.Close()
|
||||
}
|
||||
// Capture the case where url.Error is returned for error processing
|
||||
// response. e.g. 301 without location header comes back as string
|
||||
// error and r.HTTPResponse is nil. Other URL redirect errors will
|
||||
// comeback in a similar method.
|
||||
if e, ok := err.(*url.Error); ok && e.Err != nil {
|
||||
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
|
||||
code, _ := strconv.ParseInt(s[1], 10, 64)
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: int(0),
|
||||
Status: http.StatusText(int(0)),
|
||||
StatusCode: int(code),
|
||||
Status: http.StatusText(int(code)),
|
||||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
|
||||
}
|
||||
}
|
||||
// Catch all other request errors.
|
||||
r.Error = awserr.New("RequestError", "send request failed", err)
|
||||
r.Retryable = aws.Bool(true) // network errors are retryable
|
||||
|
||||
// Override the error with a context canceled error, if that was canceled.
|
||||
ctx := r.Context()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.Error = awserr.New(request.CanceledErrorCode,
|
||||
"request context canceled", ctx.Err())
|
||||
r.Retryable = aws.Bool(false)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}}
|
||||
if r.HTTPResponse == nil {
|
||||
// Add a dummy request response object to ensure the HTTPResponse
|
||||
// value is consistent.
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: int(0),
|
||||
Status: http.StatusText(int(0)),
|
||||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
|
||||
}
|
||||
}
|
||||
// Catch all other request errors.
|
||||
r.Error = awserr.New("RequestError", "send request failed", err)
|
||||
r.Retryable = aws.Bool(true) // network errors are retryable
|
||||
|
||||
// Override the error with a context canceled error, if that was canceled.
|
||||
ctx := r.Context()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.Error = awserr.New(request.CanceledErrorCode,
|
||||
"request context canceled", ctx.Err())
|
||||
r.Retryable = aws.Bool(false)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateResponseHandler is a request handler to validate service response.
|
||||
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
|
||||
@@ -160,7 +201,7 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH
|
||||
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) {
|
||||
// If one of the other handlers already set the retry state
|
||||
// we don't want to override it based on the service's state
|
||||
if r.Retryable == nil {
|
||||
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
|
||||
r.Retryable = aws.Bool(r.ShouldRetry(r))
|
||||
}
|
||||
|
||||
|
||||
+64
@@ -0,0 +1,64 @@
|
||||
// +build go1.8
|
||||
|
||||
package corehandlers_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func TestSendHandler_HEADNoBody(t *testing.T) {
|
||||
TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile, err := awstesting.CreateTLSBundleFiles()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer awstesting.CleanupTLSBundleFiles(TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile)
|
||||
|
||||
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
transport := http.DefaultTransport.(*http.Transport)
|
||||
// test server's certificate is self-signed certificate
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
http2.ConfigureTransport(transport)
|
||||
|
||||
sess, err := session.NewSessionWithOptions(session.Options{
|
||||
Config: aws.Config{
|
||||
HTTPClient: &http.Client{},
|
||||
Endpoint: aws.String(endpoint),
|
||||
Region: aws.String("mock-region"),
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
S3ForcePathStyle: aws.Bool(true),
|
||||
},
|
||||
})
|
||||
|
||||
svc := s3.New(sess)
|
||||
|
||||
req, _ := svc.HeadObjectRequest(&s3.HeadObjectInput{
|
||||
Bucket: aws.String("bucketname"),
|
||||
Key: aws.String("keyname"),
|
||||
})
|
||||
|
||||
if e, a := request.NoBody, req.HTTPRequest.Body; e != a {
|
||||
t.Fatalf("expect %T request body, got %T", e, a)
|
||||
}
|
||||
|
||||
err = req.Send()
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, req.HTTPResponse.StatusCode; e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
}
|
||||
+33
@@ -206,6 +206,39 @@ func TestSendHandlerError(t *testing.T) {
|
||||
assert.NotNil(t, r.HTTPResponse)
|
||||
}
|
||||
|
||||
func TestSendWithoutFollowRedirects(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/original":
|
||||
w.Header().Set("Location", "/redirected")
|
||||
w.WriteHeader(301)
|
||||
case "/redirected":
|
||||
t.Fatalf("expect not to redirect, but was")
|
||||
}
|
||||
}))
|
||||
|
||||
svc := awstesting.NewClient(&aws.Config{
|
||||
DisableSSL: aws.Bool(true),
|
||||
Endpoint: aws.String(server.URL),
|
||||
})
|
||||
svc.Handlers.Clear()
|
||||
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
|
||||
|
||||
r := svc.NewRequest(&request.Operation{
|
||||
Name: "Operation",
|
||||
HTTPPath: "/original",
|
||||
}, nil, nil)
|
||||
r.DisableFollowRedirects = true
|
||||
|
||||
err := r.Send()
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := 301, r.HTTPResponse.StatusCode; e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReqSigHandler(t *testing.T) {
|
||||
cases := []struct {
|
||||
Req *request.Request
|
||||
|
||||
+8
-6
@@ -13,7 +13,7 @@ var (
|
||||
//
|
||||
// @readonly
|
||||
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
|
||||
`no valid providers in chain. Deprecated.
|
||||
`no valid providers in chain. Deprecated.
|
||||
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,
|
||||
nil)
|
||||
)
|
||||
@@ -39,16 +39,18 @@ var (
|
||||
// does not return any credentials ChainProvider will return the error
|
||||
// ErrNoValidProvidersFoundInChain
|
||||
//
|
||||
// creds := NewChainCredentials(
|
||||
// []Provider{
|
||||
// &EnvProvider{},
|
||||
// &EC2RoleProvider{
|
||||
// creds := credentials.NewChainCredentials(
|
||||
// []credentials.Provider{
|
||||
// &credentials.EnvProvider{},
|
||||
// &ec2rolecreds.EC2RoleProvider{
|
||||
// Client: ec2metadata.New(sess),
|
||||
// },
|
||||
// })
|
||||
//
|
||||
// // Usage of ChainCredentials with aws.Config
|
||||
// svc := ec2.New(&aws.Config{Credentials: creds})
|
||||
// svc := ec2.New(session.Must(session.NewSession(&aws.Config{
|
||||
// Credentials: creds,
|
||||
// })))
|
||||
//
|
||||
type ChainProvider struct {
|
||||
Providers []Provider
|
||||
|
||||
+28
-5
@@ -14,7 +14,7 @@
|
||||
//
|
||||
// Example of using the environment variable credentials.
|
||||
//
|
||||
// creds := NewEnvCredentials()
|
||||
// creds := credentials.NewEnvCredentials()
|
||||
//
|
||||
// // Retrieve the credentials value
|
||||
// credValue, err := creds.Get()
|
||||
@@ -26,7 +26,7 @@
|
||||
// This may be helpful to proactively expire credentials and refresh them sooner
|
||||
// than they would naturally expire on their own.
|
||||
//
|
||||
// creds := NewCredentials(&EC2RoleProvider{})
|
||||
// creds := credentials.NewCredentials(&ec2rolecreds.EC2RoleProvider{})
|
||||
// creds.Expire()
|
||||
// credsValue, err := creds.Get()
|
||||
// // New credentials will be retrieved instead of from cache.
|
||||
@@ -43,7 +43,7 @@
|
||||
// func (m *MyProvider) Retrieve() (Value, error) {...}
|
||||
// func (m *MyProvider) IsExpired() bool {...}
|
||||
//
|
||||
// creds := NewCredentials(&MyProvider{})
|
||||
// creds := credentials.NewCredentials(&MyProvider{})
|
||||
// credValue, err := creds.Get()
|
||||
//
|
||||
package credentials
|
||||
@@ -60,7 +60,9 @@ import (
|
||||
// when making service API calls. For example, when accessing public
|
||||
// s3 buckets.
|
||||
//
|
||||
// svc := s3.New(&aws.Config{Credentials: AnonymousCredentials})
|
||||
// svc := s3.New(session.Must(session.NewSession(&aws.Config{
|
||||
// Credentials: credentials.AnonymousCredentials,
|
||||
// })))
|
||||
// // Access public S3 buckets.
|
||||
//
|
||||
// @readonly
|
||||
@@ -88,7 +90,7 @@ type Value struct {
|
||||
// The Provider should not need to implement its own mutexes, because
|
||||
// that will be managed by Credentials.
|
||||
type Provider interface {
|
||||
// Refresh returns nil if it successfully retrieved the value.
|
||||
// Retrieve returns nil if it successfully retrieved the value.
|
||||
// Error is returned if the value were not obtainable, or empty.
|
||||
Retrieve() (Value, error)
|
||||
|
||||
@@ -97,6 +99,27 @@ type Provider interface {
|
||||
IsExpired() bool
|
||||
}
|
||||
|
||||
// An ErrorProvider is a stub credentials provider that always returns an error
|
||||
// this is used by the SDK when construction a known provider is not possible
|
||||
// due to an error.
|
||||
type ErrorProvider struct {
|
||||
// The error to be returned from Retrieve
|
||||
Err error
|
||||
|
||||
// The provider name to set on the Retrieved returned Value
|
||||
ProviderName string
|
||||
}
|
||||
|
||||
// Retrieve will always return the error that the ErrorProvider was created with.
|
||||
func (p ErrorProvider) Retrieve() (Value, error) {
|
||||
return Value{ProviderName: p.ProviderName}, p.Err
|
||||
}
|
||||
|
||||
// IsExpired will always return not expired.
|
||||
func (p ErrorProvider) IsExpired() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// A Expiry provides shared expiration logic to be used by credentials
|
||||
// providers to implement expiry functionality.
|
||||
//
|
||||
|
||||
+1
@@ -29,6 +29,7 @@ var (
|
||||
// Environment variables used:
|
||||
//
|
||||
// * Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
|
||||
//
|
||||
// * Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
|
||||
type EnvProvider struct {
|
||||
retrieved bool
|
||||
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
// +build !go1.8
|
||||
|
||||
// Package plugincreds provides usage of Go plugins for providing credentials
|
||||
// to the SDK. Only available with Go 1.8 and above.
|
||||
package plugincreds
|
||||
+211
@@ -0,0 +1,211 @@
|
||||
// +build go1.8
|
||||
|
||||
// Package plugincreds implements a credentials provider sourced from a Go
|
||||
// plugin. This package allows you to use a Go plugin to retrieve AWS credentials
|
||||
// for the SDK to use for service API calls.
|
||||
//
|
||||
// As of Go 1.8 plugins are only supported on the Linux platform.
|
||||
//
|
||||
// Plugin Symbol Name
|
||||
//
|
||||
// The "GetAWSSDKCredentialProvider" is the symbol name that will be used to
|
||||
// lookup the credentials provider getter from the plugin. If you want to use a
|
||||
// custom symbol name you should use GetPluginProviderFnsByName to lookup the
|
||||
// symbol by a custom name.
|
||||
//
|
||||
// This symbol is a function that returns two additional functions. One to
|
||||
// retrieve the credentials, and another to determine if the credentials have
|
||||
// expired.
|
||||
//
|
||||
// Plugin Symbol Signature
|
||||
//
|
||||
// The plugin credential provider requires the symbol to match the
|
||||
// following signature.
|
||||
//
|
||||
// func() (RetrieveFn func() (key, secret, token string, err error), IsExpiredFn func() bool)
|
||||
//
|
||||
// Plugin Implementation Exmaple
|
||||
//
|
||||
// The following is an example implementation of a SDK credential provider using
|
||||
// the plugin provider in this package. See the SDK's example/aws/credential/plugincreds/plugin
|
||||
// folder for a runnable example of this.
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// func main() {}
|
||||
//
|
||||
// var myCredProvider provider
|
||||
//
|
||||
// // Build: go build -o plugin.so -buildmode=plugin plugin.go
|
||||
// func init() {
|
||||
// // Initialize a mock credential provider with stubs
|
||||
// myCredProvider = provider{"a","b","c"}
|
||||
// }
|
||||
//
|
||||
// // GetAWSSDKCredentialProvider is the symbol SDK will lookup and use to
|
||||
// // get the credential provider's retrieve and isExpired functions.
|
||||
// func GetAWSSDKCredentialProvider() (func() (key, secret, token string, err error), func() bool) {
|
||||
// return myCredProvider.Retrieve, myCredProvider.IsExpired
|
||||
// }
|
||||
//
|
||||
// // mock implementation of a type that returns retrieves credentials and
|
||||
// // returns if they have expired.
|
||||
// type provider struct {
|
||||
// key, secret, token string
|
||||
// }
|
||||
//
|
||||
// func (p provider) Retrieve() (key, secret, token string, err error) {
|
||||
// return p.key, p.secret, p.token, nil
|
||||
// }
|
||||
//
|
||||
// func (p *provider) IsExpired() bool {
|
||||
// return false;
|
||||
// }
|
||||
//
|
||||
// Configuring SDK for Plugin Credentials
|
||||
//
|
||||
// To configure the SDK to use a plugin's credential provider you'll need to first
|
||||
// open the plugin file using the plugin standard library package. Once you have
|
||||
// a handle to the plugin you can use the NewCredentials function of this package
|
||||
// to create a new credentials.Credentials value that can be set as the
|
||||
// credentials loader of a Session or Config. See the SDK's example/aws/credential/plugincreds
|
||||
// folder for a runnable example of this.
|
||||
//
|
||||
// // Open plugin, and load it into the process.
|
||||
// p, err := plugin.Open("somefile.so")
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // Create a new Credentials value which will source the provider's Retrieve
|
||||
// // and IsExpired functions from the plugin.
|
||||
// creds, err := plugincreds.NewCredentials(p)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // Example to configure a Session with the newly created credentials that
|
||||
// // will be sourced using the plugin's functionality.
|
||||
// sess := session.Must(session.NewSession(&aws.Config{
|
||||
// Credentials: creds,
|
||||
// }))
|
||||
package plugincreds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"plugin"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
)
|
||||
|
||||
// ProviderSymbolName the symbol name the SDK will use to lookup the plugin
|
||||
// provider value from.
|
||||
const ProviderSymbolName = `GetAWSSDKCredentialProvider`
|
||||
|
||||
// ProviderName is the name this credentials provider will label any returned
|
||||
// credentials Value with.
|
||||
const ProviderName = `PluginCredentialsProvider`
|
||||
|
||||
const (
|
||||
// ErrCodeLookupSymbolError failed to lookup symbol
|
||||
ErrCodeLookupSymbolError = "LookupSymbolError"
|
||||
|
||||
// ErrCodeInvalidSymbolError symbol invalid
|
||||
ErrCodeInvalidSymbolError = "InvalidSymbolError"
|
||||
|
||||
// ErrCodePluginRetrieveNil Retrieve function was nil
|
||||
ErrCodePluginRetrieveNil = "PluginRetrieveNilError"
|
||||
|
||||
// ErrCodePluginIsExpiredNil IsExpired Function was nil
|
||||
ErrCodePluginIsExpiredNil = "PluginIsExpiredNilError"
|
||||
|
||||
// ErrCodePluginProviderRetrieve plugin provider's retrieve returned error
|
||||
ErrCodePluginProviderRetrieve = "PluginProviderRetrieveError"
|
||||
)
|
||||
|
||||
// Provider is the credentials provider that will use the plugin provided
|
||||
// Retrieve and IsExpired functions to retrieve credentials.
|
||||
type Provider struct {
|
||||
RetrieveFn func() (key, secret, token string, err error)
|
||||
IsExpiredFn func() bool
|
||||
}
|
||||
|
||||
// NewCredentials returns a new Credentials loader using the plugin provider.
|
||||
// If the symbol isn't found or is invalid in the plugin an error will be
|
||||
// returned.
|
||||
func NewCredentials(p *plugin.Plugin) (*credentials.Credentials, error) {
|
||||
retrieve, isExpired, err := GetPluginProviderFns(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return credentials.NewCredentials(Provider{
|
||||
RetrieveFn: retrieve,
|
||||
IsExpiredFn: isExpired,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// Retrieve will return the credentials Value if they were successfully retrieved
|
||||
// from the underlying plugin provider. An error will be returned otherwise.
|
||||
func (p Provider) Retrieve() (credentials.Value, error) {
|
||||
creds := credentials.Value{
|
||||
ProviderName: ProviderName,
|
||||
}
|
||||
|
||||
k, s, t, err := p.RetrieveFn()
|
||||
if err != nil {
|
||||
return creds, awserr.New(ErrCodePluginProviderRetrieve,
|
||||
"failed to retrieve credentials with plugin provider", err)
|
||||
}
|
||||
|
||||
creds.AccessKeyID = k
|
||||
creds.SecretAccessKey = s
|
||||
creds.SessionToken = t
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// IsExpired will return the expired state of the underlying plugin provider.
|
||||
func (p Provider) IsExpired() bool {
|
||||
return p.IsExpiredFn()
|
||||
}
|
||||
|
||||
// GetPluginProviderFns returns the plugin's Retrieve and IsExpired functions
|
||||
// returned by the plugin's credential provider getter.
|
||||
//
|
||||
// Uses ProviderSymbolName as the symbol name when lookup up the symbol. If you
|
||||
// want to use a different symbol name, use GetPluginProviderFnsByName.
|
||||
func GetPluginProviderFns(p *plugin.Plugin) (func() (key, secret, token string, err error), func() bool, error) {
|
||||
return GetPluginProviderFnsByName(p, ProviderSymbolName)
|
||||
}
|
||||
|
||||
// GetPluginProviderFnsByName returns the plugin's Retrieve and IsExpired functions
|
||||
// returned by the plugin's credential provider getter.
|
||||
//
|
||||
// Same as GetPluginProviderFns, but takes a custom symbolName to lookup with.
|
||||
func GetPluginProviderFnsByName(p *plugin.Plugin, symbolName string) (func() (key, secret, token string, err error), func() bool, error) {
|
||||
sym, err := p.Lookup(symbolName)
|
||||
if err != nil {
|
||||
return nil, nil, awserr.New(ErrCodeLookupSymbolError,
|
||||
fmt.Sprintf("failed to lookup %s plugin provider symbol", symbolName), err)
|
||||
}
|
||||
|
||||
fn, ok := sym.(func() (func() (key, secret, token string, err error), func() bool))
|
||||
if !ok {
|
||||
return nil, nil, awserr.New(ErrCodeInvalidSymbolError,
|
||||
fmt.Sprintf("symbol %T, does not match the 'func() (func() (key, secret, token string, err error), func() bool)' type", sym), nil)
|
||||
}
|
||||
|
||||
retrieveFn, isExpiredFn := fn()
|
||||
if retrieveFn == nil {
|
||||
return nil, nil, awserr.New(ErrCodePluginRetrieveNil,
|
||||
"the plugin provider retrieve function cannot be nil", nil)
|
||||
}
|
||||
if isExpiredFn == nil {
|
||||
return nil, nil, awserr.New(ErrCodePluginIsExpiredNil,
|
||||
"the plugin provider isExpired function cannot be nil", nil)
|
||||
}
|
||||
|
||||
return retrieveFn, isExpiredFn, nil
|
||||
}
|
||||
+71
@@ -0,0 +1,71 @@
|
||||
// +build go1.8,awsinclude
|
||||
|
||||
package plugincreds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
)
|
||||
|
||||
func TestProvider_Passthrough(t *testing.T) {
|
||||
p := Provider{
|
||||
RetrieveFn: func() (string, string, string, error) {
|
||||
return "key", "secret", "token", nil
|
||||
},
|
||||
IsExpiredFn: func() bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
actual, err := p.Retrieve()
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
expect := credentials.Value{
|
||||
AccessKeyID: "key",
|
||||
SecretAccessKey: "secret",
|
||||
SessionToken: "token",
|
||||
ProviderName: ProviderName,
|
||||
}
|
||||
if expect != actual {
|
||||
t.Errorf("expect %+v credentials, got %+v", expect, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Error(t *testing.T) {
|
||||
expectErr := fmt.Errorf("expect error")
|
||||
|
||||
p := Provider{
|
||||
RetrieveFn: func() (string, string, string, error) {
|
||||
return "", "", "", expectErr
|
||||
},
|
||||
IsExpiredFn: func() bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
actual, err := p.Retrieve()
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
|
||||
aerr := err.(awserr.Error)
|
||||
if e, a := ErrCodePluginProviderRetrieve, aerr.Code(); e != a {
|
||||
t.Errorf("expect %s error code, got %s", e, a)
|
||||
}
|
||||
|
||||
if e, a := expectErr, aerr.OrigErr(); e != a {
|
||||
t.Errorf("expect %v cause error, got %v", e, a)
|
||||
}
|
||||
|
||||
expect := credentials.Value{
|
||||
ProviderName: ProviderName,
|
||||
}
|
||||
if expect != actual {
|
||||
t.Errorf("expect %+v credentials, got %+v", expect, actual)
|
||||
}
|
||||
}
|
||||
+16
-17
@@ -3,11 +3,11 @@ package credentials
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-ini/ini"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||
)
|
||||
|
||||
// SharedCredsProviderName provides a name of SharedCreds provider
|
||||
@@ -15,8 +15,6 @@ const SharedCredsProviderName = "SharedCredentialsProvider"
|
||||
|
||||
var (
|
||||
// ErrSharedCredentialsHomeNotFound is emitted when the user directory cannot be found.
|
||||
//
|
||||
// @readonly
|
||||
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil)
|
||||
)
|
||||
|
||||
@@ -117,22 +115,23 @@ func loadProfile(filename, profile string) (Value, error) {
|
||||
//
|
||||
// Will return an error if the user's home directory path cannot be found.
|
||||
func (p *SharedCredentialsProvider) filename() (string, error) {
|
||||
if p.Filename == "" {
|
||||
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); p.Filename != "" {
|
||||
return p.Filename, nil
|
||||
}
|
||||
|
||||
homeDir := os.Getenv("HOME") // *nix
|
||||
if homeDir == "" { // Windows
|
||||
homeDir = os.Getenv("USERPROFILE")
|
||||
}
|
||||
if homeDir == "" {
|
||||
return "", ErrSharedCredentialsHomeNotFound
|
||||
}
|
||||
|
||||
p.Filename = filepath.Join(homeDir, ".aws", "credentials")
|
||||
if len(p.Filename) != 0 {
|
||||
return p.Filename, nil
|
||||
}
|
||||
|
||||
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); len(p.Filename) != 0 {
|
||||
return p.Filename, nil
|
||||
}
|
||||
|
||||
if home := shareddefaults.UserHomeDir(); len(home) == 0 {
|
||||
// Backwards compatibility of home directly not found error being returned.
|
||||
// This error is too verbose, failure when opening the file would of been
|
||||
// a better error to return.
|
||||
return "", ErrSharedCredentialsHomeNotFound
|
||||
}
|
||||
|
||||
p.Filename = shareddefaults.SharedCredentialsFilename()
|
||||
|
||||
return p.Filename, nil
|
||||
}
|
||||
|
||||
|
||||
Generated
Vendored
+20
@@ -5,6 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -97,6 +98,25 @@ func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
|
||||
assert.Empty(t, creds.SessionToken, "Expect no token")
|
||||
}
|
||||
|
||||
func TestSharedCredentialsProvider_DefaultFilename(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("USERPROFILE", "profile_dir")
|
||||
os.Setenv("HOME", "home_dir")
|
||||
|
||||
// default filename and profile
|
||||
p := SharedCredentialsProvider{}
|
||||
|
||||
filename, err := p.filename()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if e, a := shareddefaults.SharedCredentialsFilename(), filename; e != a {
|
||||
t.Errorf("expect %q filename, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSharedCredentialsProvider(b *testing.B) {
|
||||
os.Clearenv()
|
||||
|
||||
|
||||
Generated
Vendored
+1
-1
@@ -12,7 +12,7 @@ between multiple Credentials, Sessions or service clients.
|
||||
Assume Role
|
||||
|
||||
To assume an IAM role using STS with the SDK you can create a new Credentials
|
||||
with the SDKs's stscreds package.
|
||||
with the SDKs's stscreds package.
|
||||
|
||||
// Initial credentials loaded from SDK's default credential chain. Such as
|
||||
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
|
||||
|
||||
+38
-8
@@ -10,10 +10,12 @@ package defaults
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
|
||||
@@ -96,23 +98,51 @@ func CredChain(cfg *aws.Config, handlers request.Handlers) *credentials.Credenti
|
||||
})
|
||||
}
|
||||
|
||||
// RemoteCredProvider returns a credenitials provider for the default remote
|
||||
const (
|
||||
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
|
||||
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
|
||||
)
|
||||
|
||||
// RemoteCredProvider returns a credentials provider for the default remote
|
||||
// endpoints such as EC2 or ECS Roles.
|
||||
func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider {
|
||||
ecsCredURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
|
||||
if u := os.Getenv(httpProviderEnvVar); len(u) > 0 {
|
||||
return localHTTPCredProvider(cfg, handlers, u)
|
||||
}
|
||||
|
||||
if len(ecsCredURI) > 0 {
|
||||
return ecsCredProvider(cfg, handlers, ecsCredURI)
|
||||
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 {
|
||||
u := fmt.Sprintf("http://169.254.170.2%s", uri)
|
||||
return httpCredProvider(cfg, handlers, u)
|
||||
}
|
||||
|
||||
return ec2RoleProvider(cfg, handlers)
|
||||
}
|
||||
|
||||
func ecsCredProvider(cfg aws.Config, handlers request.Handlers, uri string) credentials.Provider {
|
||||
const host = `169.254.170.2`
|
||||
func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
|
||||
var errMsg string
|
||||
|
||||
return endpointcreds.NewProviderClient(cfg, handlers,
|
||||
fmt.Sprintf("http://%s%s", host, uri),
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
errMsg = fmt.Sprintf("invalid URL, %v", err)
|
||||
} else if host := aws.URLHostname(parsed); !(host == "localhost" || host == "127.0.0.1") {
|
||||
errMsg = fmt.Sprintf("invalid host address, %q, only localhost and 127.0.0.1 are valid.", host)
|
||||
}
|
||||
|
||||
if len(errMsg) > 0 {
|
||||
if cfg.Logger != nil {
|
||||
cfg.Logger.Log("Ignoring, HTTP credential provider", errMsg, err)
|
||||
}
|
||||
return credentials.ErrorProvider{
|
||||
Err: awserr.New("CredentialsEndpointError", errMsg, err),
|
||||
ProviderName: endpointcreds.ProviderName,
|
||||
}
|
||||
}
|
||||
|
||||
return httpCredProvider(cfg, handlers, u)
|
||||
}
|
||||
|
||||
func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
|
||||
return endpointcreds.NewProviderClient(cfg, handlers, u,
|
||||
func(p *endpointcreds.Provider) {
|
||||
p.ExpiryWindow = 5 * time.Minute
|
||||
},
|
||||
|
||||
+64
-20
@@ -6,39 +6,83 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHTTPCredProvider(t *testing.T) {
|
||||
cases := []struct {
|
||||
Host string
|
||||
Fail bool
|
||||
}{
|
||||
{"localhost", false}, {"127.0.0.1", false},
|
||||
{"www.example.com", true}, {"169.254.170.2", true},
|
||||
}
|
||||
|
||||
defer os.Clearenv()
|
||||
|
||||
for i, c := range cases {
|
||||
u := fmt.Sprintf("http://%s/abc/123", c.Host)
|
||||
os.Setenv(httpProviderEnvVar, u)
|
||||
|
||||
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
|
||||
if provider == nil {
|
||||
t.Fatalf("%d, expect provider not to be nil, but was", i)
|
||||
}
|
||||
|
||||
if c.Fail {
|
||||
creds, err := provider.Retrieve()
|
||||
if err == nil {
|
||||
t.Fatalf("%d, expect error but got none", i)
|
||||
} else {
|
||||
aerr := err.(awserr.Error)
|
||||
if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
|
||||
t.Errorf("%d, expect %s error code, got %s", i, e, a)
|
||||
}
|
||||
}
|
||||
if e, a := endpointcreds.ProviderName, creds.ProviderName; e != a {
|
||||
t.Errorf("%d, expect %s provider name got %s", i, e, a)
|
||||
}
|
||||
} else {
|
||||
httpProvider := provider.(*endpointcreds.Provider)
|
||||
if e, a := u, httpProvider.Client.Endpoint; e != a {
|
||||
t.Errorf("%d, expect %q endpoint, got %q", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestECSCredProvider(t *testing.T) {
|
||||
defer os.Clearenv()
|
||||
os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/abc/123")
|
||||
os.Setenv(ecsCredsProviderEnvVar, "/abc/123")
|
||||
|
||||
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
|
||||
if provider == nil {
|
||||
t.Fatalf("expect provider not to be nil, but was")
|
||||
}
|
||||
|
||||
assert.NotNil(t, provider)
|
||||
|
||||
ecsProvider, ok := provider.(*endpointcreds.Provider)
|
||||
assert.NotNil(t, ecsProvider)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, fmt.Sprintf("http://169.254.170.2/abc/123"),
|
||||
ecsProvider.Client.Endpoint)
|
||||
httpProvider := provider.(*endpointcreds.Provider)
|
||||
if httpProvider == nil {
|
||||
t.Fatalf("expect provider not to be nil, but was")
|
||||
}
|
||||
if e, a := "http://169.254.170.2/abc/123", httpProvider.Client.Endpoint; e != a {
|
||||
t.Errorf("expect %q endpoint, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultEC2RoleProvider(t *testing.T) {
|
||||
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
|
||||
if provider == nil {
|
||||
t.Fatalf("expect provider not to be nil, but was")
|
||||
}
|
||||
|
||||
assert.NotNil(t, provider)
|
||||
|
||||
ec2Provider, ok := provider.(*ec2rolecreds.EC2RoleProvider)
|
||||
assert.NotNil(t, ec2Provider)
|
||||
assert.True(t, ok)
|
||||
|
||||
fmt.Println(ec2Provider.Client.Endpoint)
|
||||
|
||||
assert.Equal(t, fmt.Sprintf("http://169.254.169.254/latest"),
|
||||
ec2Provider.Client.Endpoint)
|
||||
ec2Provider := provider.(*ec2rolecreds.EC2RoleProvider)
|
||||
if ec2Provider == nil {
|
||||
t.Fatalf("expect provider not to be nil, but was")
|
||||
}
|
||||
if e, a := "http://169.254.169.254/latest", ec2Provider.Client.Endpoint; e != a {
|
||||
t.Errorf("expect %q endpoint, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
package defaults
|
||||
|
||||
import (
|
||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||
)
|
||||
|
||||
// SharedCredentialsFilename returns the SDK's default file path
|
||||
// for the shared credentials file.
|
||||
//
|
||||
// Builds the shared config file path based on the OS's platform.
|
||||
//
|
||||
// - Linux/Unix: $HOME/.aws/credentials
|
||||
// - Windows: %USERPROFILE%\.aws\credentials
|
||||
func SharedCredentialsFilename() string {
|
||||
return shareddefaults.SharedCredentialsFilename()
|
||||
}
|
||||
|
||||
// SharedConfigFilename returns the SDK's default file path for
|
||||
// the shared config file.
|
||||
//
|
||||
// Builds the shared config file path based on the OS's platform.
|
||||
//
|
||||
// - Linux/Unix: $HOME/.aws/config
|
||||
// - Windows: %USERPROFILE%\.aws\config
|
||||
func SharedConfigFilename() string {
|
||||
return shareddefaults.SharedConfigFilename()
|
||||
}
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
// Package aws provides the core SDK's utilities and shared types. Use this package's
|
||||
// utilities to simplify setting and reading API operations parameters.
|
||||
//
|
||||
// Value and Pointer Conversion Utilities
|
||||
//
|
||||
// This package includes a helper conversion utility for each scalar type the SDK's
|
||||
// API use. These utilities make getting a pointer of the scalar, and dereferencing
|
||||
// a pointer easier.
|
||||
//
|
||||
// Each conversion utility comes in two forms. Value to Pointer and Pointer to Value.
|
||||
// The Pointer to value will safely dereference the pointer and return its value.
|
||||
// If the pointer was nil, the scalar's zero value will be returned.
|
||||
//
|
||||
// The value to pointer functions will be named after the scalar type. So get a
|
||||
// *string from a string value use the "String" function. This makes it easy to
|
||||
// to get pointer of a literal string value, because getting the address of a
|
||||
// literal requires assigning the value to a variable first.
|
||||
//
|
||||
// var strPtr *string
|
||||
//
|
||||
// // Without the SDK's conversion functions
|
||||
// str := "my string"
|
||||
// strPtr = &str
|
||||
//
|
||||
// // With the SDK's conversion functions
|
||||
// strPtr = aws.String("my string")
|
||||
//
|
||||
// // Convert *string to string value
|
||||
// str = aws.StringValue(strPtr)
|
||||
//
|
||||
// In addition to scalars the aws package also includes conversion utilities for
|
||||
// map and slice for commonly types used in API parameters. The map and slice
|
||||
// conversion functions use similar naming pattern as the scalar conversion
|
||||
// functions.
|
||||
//
|
||||
// var strPtrs []*string
|
||||
// var strs []string = []string{"Go", "Gophers", "Go"}
|
||||
//
|
||||
// // Convert []string to []*string
|
||||
// strPtrs = aws.StringSlice(strs)
|
||||
//
|
||||
// // Convert []*string to []string
|
||||
// strs = aws.StringValueSlice(strPtrs)
|
||||
//
|
||||
// SDK Default HTTP Client
|
||||
//
|
||||
// The SDK will use the http.DefaultClient if a HTTP client is not provided to
|
||||
// the SDK's Session, or service client constructor. This means that if the
|
||||
// http.DefaultClient is modified by other components of your application the
|
||||
// modifications will be picked up by the SDK as well.
|
||||
//
|
||||
// In some cases this might be intended, but it is a better practice to create
|
||||
// a custom HTTP Client to share explicitly through your application. You can
|
||||
// configure the SDK to use the custom HTTP Client by setting the HTTPClient
|
||||
// value of the SDK's Config type when creating a Session or service client.
|
||||
package aws
|
||||
+365
-34
@@ -1,4 +1,4 @@
|
||||
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
|
||||
// Code generated by aws/endpoints/v3model_codegen.go. DO NOT EDIT.
|
||||
|
||||
package endpoints
|
||||
|
||||
@@ -47,6 +47,7 @@ const (
|
||||
ApigatewayServiceID = "apigateway" // Apigateway.
|
||||
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
|
||||
Appstream2ServiceID = "appstream2" // Appstream2.
|
||||
AthenaServiceID = "athena" // Athena.
|
||||
AutoscalingServiceID = "autoscaling" // Autoscaling.
|
||||
BatchServiceID = "batch" // Batch.
|
||||
BudgetsServiceID = "budgets" // Budgets.
|
||||
@@ -54,12 +55,14 @@ const (
|
||||
CloudformationServiceID = "cloudformation" // Cloudformation.
|
||||
CloudfrontServiceID = "cloudfront" // Cloudfront.
|
||||
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
|
||||
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
|
||||
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
|
||||
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
|
||||
CodebuildServiceID = "codebuild" // Codebuild.
|
||||
CodecommitServiceID = "codecommit" // Codecommit.
|
||||
CodedeployServiceID = "codedeploy" // Codedeploy.
|
||||
CodepipelineServiceID = "codepipeline" // Codepipeline.
|
||||
CodestarServiceID = "codestar" // Codestar.
|
||||
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
|
||||
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
|
||||
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
|
||||
@@ -83,11 +86,14 @@ const (
|
||||
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
|
||||
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
|
||||
EmailServiceID = "email" // Email.
|
||||
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
|
||||
EsServiceID = "es" // Es.
|
||||
EventsServiceID = "events" // Events.
|
||||
FirehoseServiceID = "firehose" // Firehose.
|
||||
GameliftServiceID = "gamelift" // Gamelift.
|
||||
GlacierServiceID = "glacier" // Glacier.
|
||||
GlueServiceID = "glue" // Glue.
|
||||
GreengrassServiceID = "greengrass" // Greengrass.
|
||||
HealthServiceID = "health" // Health.
|
||||
IamServiceID = "iam" // Iam.
|
||||
ImportexportServiceID = "importexport" // Importexport.
|
||||
@@ -102,7 +108,9 @@ const (
|
||||
MachinelearningServiceID = "machinelearning" // Machinelearning.
|
||||
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
|
||||
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
|
||||
MghServiceID = "mgh" // Mgh.
|
||||
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
|
||||
ModelsLexServiceID = "models.lex" // ModelsLex.
|
||||
MonitoringServiceID = "monitoring" // Monitoring.
|
||||
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
|
||||
OpsworksServiceID = "opsworks" // Opsworks.
|
||||
@@ -131,6 +139,7 @@ const (
|
||||
StsServiceID = "sts" // Sts.
|
||||
SupportServiceID = "support" // Support.
|
||||
SwfServiceID = "swf" // Swf.
|
||||
TaggingServiceID = "tagging" // Tagging.
|
||||
WafServiceID = "waf" // Waf.
|
||||
WafRegionalServiceID = "waf-regional" // WafRegional.
|
||||
WorkdocsServiceID = "workdocs" // Workdocs.
|
||||
@@ -141,17 +150,20 @@ const (
|
||||
// DefaultResolver returns an Endpoint resolver that will be able
|
||||
// to resolve endpoints for: AWS Standard, AWS China, and AWS GovCloud (US).
|
||||
//
|
||||
// Casting the return value of this func to a EnumPartitions will
|
||||
// allow you to get a list of the partitions in the order the endpoints
|
||||
// will be resolved in.
|
||||
// Use DefaultPartitions() to get the list of the default partitions.
|
||||
func DefaultResolver() Resolver {
|
||||
return defaultPartitions
|
||||
}
|
||||
|
||||
// DefaultPartitions returns a list of the partitions the SDK is bundled
|
||||
// with. The available partitions are: AWS Standard, AWS China, and AWS GovCloud (US).
|
||||
//
|
||||
// resolver := endpoints.DefaultResolver()
|
||||
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
|
||||
// partitions := endpoints.DefaultPartitions
|
||||
// for _, p := range partitions {
|
||||
// // ... inspect partitions
|
||||
// }
|
||||
func DefaultResolver() Resolver {
|
||||
return defaultPartitions
|
||||
func DefaultPartitions() []Partition {
|
||||
return defaultPartitions.Partitions()
|
||||
}
|
||||
|
||||
var defaultPartitions = partitions{
|
||||
@@ -249,11 +261,14 @@ var awsPartition = partition{
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
@@ -299,6 +314,17 @@ var awsPartition = partition{
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"athena": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"autoscaling": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "https"},
|
||||
@@ -323,7 +349,15 @@ var awsPartition = partition{
|
||||
"batch": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-east-1": endpoint{},
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"budgets": service{
|
||||
@@ -345,6 +379,7 @@ var awsPartition = partition{
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
@@ -398,6 +433,15 @@ var awsPartition = partition{
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"cloudhsmv2": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"cloudsearch": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -435,19 +479,36 @@ var awsPartition = partition{
|
||||
"codebuild": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"codecommit": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"codedeploy": service{
|
||||
@@ -473,13 +534,32 @@ var awsPartition = partition{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"codestar": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
@@ -488,6 +568,8 @@ var awsPartition = partition{
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
@@ -502,6 +584,8 @@ var awsPartition = partition{
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
@@ -516,6 +600,8 @@ var awsPartition = partition{
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
@@ -614,11 +700,16 @@ var awsPartition = partition{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
@@ -755,15 +846,17 @@ var awsPartition = partition{
|
||||
"elasticfilesystem": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"elasticloadbalancing": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "https"},
|
||||
Protocols: []string{"https"},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
@@ -829,6 +922,16 @@ var awsPartition = partition{
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"entitlement.marketplace": service{
|
||||
Defaults: endpoint{
|
||||
CredentialScope: credentialScope{
|
||||
Service: "aws-marketplace",
|
||||
},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"us-east-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"es": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -837,8 +940,10 @@ var awsPartition = partition{
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
@@ -868,9 +973,12 @@ var awsPartition = partition{
|
||||
"firehose": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
"ap-northeast-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"gamelift": service{
|
||||
@@ -880,10 +988,15 @@ var awsPartition = partition{
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
@@ -895,6 +1008,7 @@ var awsPartition = partition{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
@@ -906,6 +1020,27 @@ var awsPartition = partition{
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"glue": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"greengrass": service{
|
||||
IsRegionalized: boxedTrue,
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"https"},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"health": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -949,6 +1084,7 @@ var awsPartition = partition{
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
@@ -1022,11 +1158,14 @@ var awsPartition = partition{
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
@@ -1036,7 +1175,16 @@ var awsPartition = partition{
|
||||
"lightsail": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-east-1": endpoint{},
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"logs": service{
|
||||
@@ -1094,12 +1242,28 @@ var awsPartition = partition{
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"mgh": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"mobileanalytics": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-east-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"models.lex": service{
|
||||
Defaults: endpoint{
|
||||
CredentialScope: credentialScope{
|
||||
Service: "lex",
|
||||
},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"us-east-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"monitoring": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "https"},
|
||||
@@ -1346,6 +1510,7 @@ var awsPartition = partition{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
@@ -1370,20 +1535,27 @@ var awsPartition = partition{
|
||||
"sms": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"snowball": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
@@ -1440,10 +1612,13 @@ var awsPartition = partition{
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
@@ -1455,8 +1630,10 @@ var awsPartition = partition{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
@@ -1467,6 +1644,7 @@ var awsPartition = partition{
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
@@ -1482,7 +1660,7 @@ var awsPartition = partition{
|
||||
},
|
||||
"streams.dynamodb": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "http", "https", "https"},
|
||||
Protocols: []string{"http", "https"},
|
||||
CredentialScope: credentialScope{
|
||||
Service: "dynamodb",
|
||||
},
|
||||
@@ -1537,9 +1715,33 @@ var awsPartition = partition{
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
"us-east-1-fips": endpoint{
|
||||
Hostname: "sts-fips.us-east-1.amazonaws.com",
|
||||
CredentialScope: credentialScope{
|
||||
Region: "us-east-1",
|
||||
},
|
||||
},
|
||||
"us-east-2": endpoint{},
|
||||
"us-east-2-fips": endpoint{
|
||||
Hostname: "sts-fips.us-east-2.amazonaws.com",
|
||||
CredentialScope: credentialScope{
|
||||
Region: "us-east-2",
|
||||
},
|
||||
},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-1-fips": endpoint{
|
||||
Hostname: "sts-fips.us-west-1.amazonaws.com",
|
||||
CredentialScope: credentialScope{
|
||||
Region: "us-west-1",
|
||||
},
|
||||
},
|
||||
"us-west-2": endpoint{},
|
||||
"us-west-2-fips": endpoint{
|
||||
Hostname: "sts-fips.us-west-2.amazonaws.com",
|
||||
CredentialScope: credentialScope{
|
||||
Region: "us-west-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"support": service{
|
||||
@@ -1567,6 +1769,25 @@ var awsPartition = partition{
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"tagging": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"ap-northeast-2": endpoint{},
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
"waf": service{
|
||||
PartitionEndpoint: "aws-global",
|
||||
IsRegionalized: boxedFalse,
|
||||
@@ -1586,6 +1807,7 @@ var awsPartition = partition{
|
||||
"ap-northeast-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-west-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
},
|
||||
@@ -1608,6 +1830,7 @@ var awsPartition = partition{
|
||||
"ap-southeast-2": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-west-2": endpoint{},
|
||||
},
|
||||
@@ -1620,8 +1843,10 @@ var awsPartition = partition{
|
||||
"ap-south-1": endpoint{},
|
||||
"ap-southeast-1": endpoint{},
|
||||
"ap-southeast-2": endpoint{},
|
||||
"ca-central-1": endpoint{},
|
||||
"eu-central-1": endpoint{},
|
||||
"eu-west-1": endpoint{},
|
||||
"eu-west-2": endpoint{},
|
||||
"sa-east-1": endpoint{},
|
||||
"us-east-1": endpoint{},
|
||||
"us-east-2": endpoint{},
|
||||
@@ -1658,6 +1883,18 @@ var awscnPartition = partition{
|
||||
},
|
||||
},
|
||||
Services: services{
|
||||
"application-autoscaling": service{
|
||||
Defaults: endpoint{
|
||||
Hostname: "autoscaling.{region}.amazonaws.com",
|
||||
Protocols: []string{"http", "https"},
|
||||
CredentialScope: credentialScope{
|
||||
Service: "application-autoscaling",
|
||||
},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"autoscaling": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "https"},
|
||||
@@ -1678,6 +1915,12 @@ var awscnPartition = partition{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"codedeploy": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"config": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -1717,6 +1960,18 @@ var awscnPartition = partition{
|
||||
},
|
||||
},
|
||||
},
|
||||
"ecr": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"ecs": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"elasticache": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -1731,7 +1986,7 @@ var awscnPartition = partition{
|
||||
},
|
||||
"elasticloadbalancing": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "https"},
|
||||
Protocols: []string{"https"},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
@@ -1772,6 +2027,16 @@ var awscnPartition = partition{
|
||||
},
|
||||
},
|
||||
},
|
||||
"iot": service{
|
||||
Defaults: endpoint{
|
||||
CredentialScope: credentialScope{
|
||||
Service: "execute-api",
|
||||
},
|
||||
},
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"kinesis": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -1813,6 +2078,12 @@ var awscnPartition = partition{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"snowball": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"sns": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "https"},
|
||||
@@ -1830,6 +2101,12 @@ var awscnPartition = partition{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"ssm": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"storagegateway": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -1838,7 +2115,7 @@ var awscnPartition = partition{
|
||||
},
|
||||
"streams.dynamodb": service{
|
||||
Defaults: endpoint{
|
||||
Protocols: []string{"http", "http", "https", "https"},
|
||||
Protocols: []string{"http", "https"},
|
||||
CredentialScope: credentialScope{
|
||||
Service: "dynamodb",
|
||||
},
|
||||
@@ -1855,6 +2132,12 @@ var awscnPartition = partition{
|
||||
},
|
||||
"swf": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"tagging": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"cn-north-1": endpoint{},
|
||||
},
|
||||
@@ -1888,6 +2171,18 @@ var awsusgovPartition = partition{
|
||||
},
|
||||
},
|
||||
Services: services{
|
||||
"acm": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"apigateway": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"autoscaling": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -1914,6 +2209,12 @@ var awsusgovPartition = partition{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"codedeploy": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"config": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -1971,6 +2272,12 @@ var awsusgovPartition = partition{
|
||||
},
|
||||
},
|
||||
},
|
||||
"events": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"glacier": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -2004,6 +2311,12 @@ var awsusgovPartition = partition{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"lambda": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"logs": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -2028,6 +2341,12 @@ var awsusgovPartition = partition{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"rekognition": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"s3": service{
|
||||
Defaults: endpoint{
|
||||
SignatureVersions: []string{"s3", "s3v4"},
|
||||
@@ -2045,6 +2364,12 @@ var awsusgovPartition = partition{
|
||||
},
|
||||
},
|
||||
},
|
||||
"sms": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"snowball": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
@@ -2068,6 +2393,12 @@ var awsusgovPartition = partition{
|
||||
},
|
||||
},
|
||||
},
|
||||
"ssm": service{
|
||||
|
||||
Endpoints: endpoints{
|
||||
"us-gov-west-1": endpoint{},
|
||||
},
|
||||
},
|
||||
"streams.dynamodb": service{
|
||||
Defaults: endpoint{
|
||||
CredentialScope: credentialScope{
|
||||
|
||||
+2
-2
@@ -21,12 +21,12 @@
|
||||
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
|
||||
//
|
||||
// for _, p := range partitions {
|
||||
// fmt.Println("Regions for", p.Name)
|
||||
// fmt.Println("Regions for", p.ID())
|
||||
// for id, _ := range p.Regions() {
|
||||
// fmt.Println("*", id)
|
||||
// }
|
||||
//
|
||||
// fmt.Println("Services for", p.Name)
|
||||
// fmt.Println("Services for", p.ID())
|
||||
// for id, _ := range p.Services() {
|
||||
// fmt.Println("*", id)
|
||||
// }
|
||||
|
||||
+77
-35
@@ -124,6 +124,49 @@ type EnumPartitions interface {
|
||||
Partitions() []Partition
|
||||
}
|
||||
|
||||
// RegionsForService returns a map of regions for the partition and service.
|
||||
// If either the partition or service does not exist false will be returned
|
||||
// as the second parameter.
|
||||
//
|
||||
// This example shows how to get the regions for DynamoDB in the AWS partition.
|
||||
// rs, exists := endpoints.RegionsForService(endpoints.DefaultPartitions(), endpoints.AwsPartitionID, endpoints.DynamodbServiceID)
|
||||
//
|
||||
// This is equivalent to using the partition directly.
|
||||
// rs := endpoints.AwsPartition().Services()[endpoints.DynamodbServiceID].Regions()
|
||||
func RegionsForService(ps []Partition, partitionID, serviceID string) (map[string]Region, bool) {
|
||||
for _, p := range ps {
|
||||
if p.ID() != partitionID {
|
||||
continue
|
||||
}
|
||||
if _, ok := p.p.Services[serviceID]; !ok {
|
||||
break
|
||||
}
|
||||
|
||||
s := Service{
|
||||
id: serviceID,
|
||||
p: p.p,
|
||||
}
|
||||
return s.Regions(), true
|
||||
}
|
||||
|
||||
return map[string]Region{}, false
|
||||
}
|
||||
|
||||
// PartitionForRegion returns the first partition which includes the region
|
||||
// passed in. This includes both known regions and regions which match
|
||||
// a pattern supported by the partition which may include regions that are
|
||||
// not explicitly known by the partition. Use the Regions method of the
|
||||
// returned Partition if explicit support is needed.
|
||||
func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
|
||||
for _, p := range ps {
|
||||
if _, ok := p.p.Regions[regionID]; ok || p.p.RegionRegex.MatchString(regionID) {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
|
||||
return Partition{}, false
|
||||
}
|
||||
|
||||
// A Partition provides the ability to enumerate the partition's regions
|
||||
// and services.
|
||||
type Partition struct {
|
||||
@@ -132,7 +175,7 @@ type Partition struct {
|
||||
}
|
||||
|
||||
// ID returns the identifier of the partition.
|
||||
func (p *Partition) ID() string { return p.id }
|
||||
func (p Partition) ID() string { return p.id }
|
||||
|
||||
// EndpointFor attempts to resolve the endpoint based on service and region.
|
||||
// See Options for information on configuring how the endpoint is resolved.
|
||||
@@ -155,13 +198,13 @@ func (p *Partition) ID() string { return p.id }
|
||||
// Errors that can be returned.
|
||||
// * UnknownServiceError
|
||||
// * UnknownEndpointError
|
||||
func (p *Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
return p.p.EndpointFor(service, region, opts...)
|
||||
}
|
||||
|
||||
// Regions returns a map of Regions indexed by their ID. This is useful for
|
||||
// enumerating over the regions in a partition.
|
||||
func (p *Partition) Regions() map[string]Region {
|
||||
func (p Partition) Regions() map[string]Region {
|
||||
rs := map[string]Region{}
|
||||
for id := range p.p.Regions {
|
||||
rs[id] = Region{
|
||||
@@ -175,7 +218,7 @@ func (p *Partition) Regions() map[string]Region {
|
||||
|
||||
// Services returns a map of Service indexed by their ID. This is useful for
|
||||
// enumerating over the services in a partition.
|
||||
func (p *Partition) Services() map[string]Service {
|
||||
func (p Partition) Services() map[string]Service {
|
||||
ss := map[string]Service{}
|
||||
for id := range p.p.Services {
|
||||
ss[id] = Service{
|
||||
@@ -195,16 +238,16 @@ type Region struct {
|
||||
}
|
||||
|
||||
// ID returns the region's identifier.
|
||||
func (r *Region) ID() string { return r.id }
|
||||
func (r Region) ID() string { return r.id }
|
||||
|
||||
// ResolveEndpoint resolves an endpoint from the context of the region given
|
||||
// a service. See Partition.EndpointFor for usage and errors that can be returned.
|
||||
func (r *Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
func (r Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
return r.p.EndpointFor(service, r.id, opts...)
|
||||
}
|
||||
|
||||
// Services returns a list of all services that are known to be in this region.
|
||||
func (r *Region) Services() map[string]Service {
|
||||
func (r Region) Services() map[string]Service {
|
||||
ss := map[string]Service{}
|
||||
for id, s := range r.p.Services {
|
||||
if _, ok := s.Endpoints[r.id]; ok {
|
||||
@@ -226,17 +269,38 @@ type Service struct {
|
||||
}
|
||||
|
||||
// ID returns the identifier for the service.
|
||||
func (s *Service) ID() string { return s.id }
|
||||
func (s Service) ID() string { return s.id }
|
||||
|
||||
// ResolveEndpoint resolves an endpoint from the context of a service given
|
||||
// a region. See Partition.EndpointFor for usage and errors that can be returned.
|
||||
func (s *Service) ResolveEndpoint(region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
func (s Service) ResolveEndpoint(region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
return s.p.EndpointFor(s.id, region, opts...)
|
||||
}
|
||||
|
||||
// Regions returns a map of Regions that the service is present in.
|
||||
//
|
||||
// A region is the AWS region the service exists in. Whereas a Endpoint is
|
||||
// an URL that can be resolved to a instance of a service.
|
||||
func (s Service) Regions() map[string]Region {
|
||||
rs := map[string]Region{}
|
||||
for id := range s.p.Services[s.id].Endpoints {
|
||||
if _, ok := s.p.Regions[id]; ok {
|
||||
rs[id] = Region{
|
||||
id: id,
|
||||
p: s.p,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
// Endpoints returns a map of Endpoints indexed by their ID for all known
|
||||
// endpoints for a service.
|
||||
func (s *Service) Endpoints() map[string]Endpoint {
|
||||
//
|
||||
// A region is the AWS region the service exists in. Whereas a Endpoint is
|
||||
// an URL that can be resolved to a instance of a service.
|
||||
func (s Service) Endpoints() map[string]Endpoint {
|
||||
es := map[string]Endpoint{}
|
||||
for id := range s.p.Services[s.id].Endpoints {
|
||||
es[id] = Endpoint{
|
||||
@@ -259,15 +323,15 @@ type Endpoint struct {
|
||||
}
|
||||
|
||||
// ID returns the identifier for an endpoint.
|
||||
func (e *Endpoint) ID() string { return e.id }
|
||||
func (e Endpoint) ID() string { return e.id }
|
||||
|
||||
// ServiceID returns the identifier the endpoint belongs to.
|
||||
func (e *Endpoint) ServiceID() string { return e.serviceID }
|
||||
func (e Endpoint) ServiceID() string { return e.serviceID }
|
||||
|
||||
// ResolveEndpoint resolves an endpoint from the context of a service and
|
||||
// region the endpoint represents. See Partition.EndpointFor for usage and
|
||||
// errors that can be returned.
|
||||
func (e *Endpoint) ResolveEndpoint(opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
func (e Endpoint) ResolveEndpoint(opts ...func(*Options)) (ResolvedEndpoint, error) {
|
||||
return e.p.EndpointFor(e.serviceID, e.id, opts...)
|
||||
}
|
||||
|
||||
@@ -300,28 +364,6 @@ type EndpointNotFoundError struct {
|
||||
Region string
|
||||
}
|
||||
|
||||
//// NewEndpointNotFoundError builds and returns NewEndpointNotFoundError.
|
||||
//func NewEndpointNotFoundError(p, s, r string) EndpointNotFoundError {
|
||||
// return EndpointNotFoundError{
|
||||
// awsError: awserr.New("EndpointNotFoundError", "unable to find endpoint", nil),
|
||||
// Partition: p,
|
||||
// Service: s,
|
||||
// Region: r,
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// Error returns string representation of the error.
|
||||
//func (e EndpointNotFoundError) Error() string {
|
||||
// extra := fmt.Sprintf("partition: %q, service: %q, region: %q",
|
||||
// e.Partition, e.Service, e.Region)
|
||||
// return awserr.SprintError(e.Code(), e.Message(), extra, e.OrigErr())
|
||||
//}
|
||||
//
|
||||
//// String returns the string representation of the error.
|
||||
//func (e EndpointNotFoundError) String() string {
|
||||
// return e.Error()
|
||||
//}
|
||||
|
||||
// A UnknownServiceError is returned when the service does not resolve to an
|
||||
// endpoint. Includes a list of all known services for the partition. Returned
|
||||
// when a partition does not support the service.
|
||||
|
||||
+91
@@ -84,6 +84,22 @@ func TestEnumRegionServices(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumServiceRegions(t *testing.T) {
|
||||
p := testPartitions[0].Partition()
|
||||
|
||||
rs := p.Services()["service1"].Regions()
|
||||
if e, a := 2, len(rs); e != a {
|
||||
t.Errorf("expect %d regions, got %d", e, a)
|
||||
}
|
||||
|
||||
if _, ok := rs["us-east-1"]; !ok {
|
||||
t.Errorf("expect region to be found")
|
||||
}
|
||||
if _, ok := rs["us-west-2"]; !ok {
|
||||
t.Errorf("expect region to be found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumServicesEndpoints(t *testing.T) {
|
||||
p := testPartitions[0].Partition()
|
||||
|
||||
@@ -242,3 +258,78 @@ func TestOptionsSet(t *testing.T) {
|
||||
t.Errorf("expect %v options got %v", expect, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegionsForService(t *testing.T) {
|
||||
ps := DefaultPartitions()
|
||||
|
||||
var expect map[string]Region
|
||||
var serviceID string
|
||||
for _, s := range ps[0].Services() {
|
||||
expect = s.Regions()
|
||||
serviceID = s.ID()
|
||||
if len(expect) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
actual, ok := RegionsForService(ps, ps[0].ID(), serviceID)
|
||||
if !ok {
|
||||
t.Fatalf("expect regions to be found, was not")
|
||||
}
|
||||
|
||||
if len(actual) == 0 {
|
||||
t.Fatalf("expect service %s to have regions", serviceID)
|
||||
}
|
||||
if e, a := len(expect), len(actual); e != a {
|
||||
t.Fatalf("expect %d regions, got %d", e, a)
|
||||
}
|
||||
|
||||
for id, r := range actual {
|
||||
if e, a := id, r.ID(); e != a {
|
||||
t.Errorf("expect %s region id, got %s", e, a)
|
||||
}
|
||||
if _, ok := expect[id]; !ok {
|
||||
t.Errorf("expect %s region to be found", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegionsForService_NotFound(t *testing.T) {
|
||||
ps := testPartitions.Partitions()
|
||||
|
||||
actual, ok := RegionsForService(ps, ps[0].ID(), "service-not-exists")
|
||||
if ok {
|
||||
t.Fatalf("expect no regions to be found, but were")
|
||||
}
|
||||
if len(actual) != 0 {
|
||||
t.Errorf("expect no regions, got %v", actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionForRegion(t *testing.T) {
|
||||
ps := DefaultPartitions()
|
||||
expect := ps[len(ps)%2]
|
||||
|
||||
var regionID string
|
||||
for id := range expect.Regions() {
|
||||
regionID = id
|
||||
break
|
||||
}
|
||||
|
||||
actual, ok := PartitionForRegion(ps, regionID)
|
||||
if !ok {
|
||||
t.Fatalf("expect partition to be found")
|
||||
}
|
||||
if e, a := expect.ID(), actual.ID(); e != a {
|
||||
t.Errorf("expect %s partition, got %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionForRegion_NotFound(t *testing.T) {
|
||||
ps := DefaultPartitions()
|
||||
|
||||
actual, ok := PartitionForRegion(ps, "regionNotExists")
|
||||
if ok {
|
||||
t.Errorf("expect no partition to be found, got %v", actual)
|
||||
}
|
||||
}
|
||||
|
||||
+11
-8
@@ -158,7 +158,7 @@ var funcMap = template.FuncMap{
|
||||
|
||||
const v3Tmpl = `
|
||||
{{ define "defaults" -}}
|
||||
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
|
||||
// Code generated by aws/endpoints/v3model_codegen.go. DO NOT EDIT.
|
||||
|
||||
package endpoints
|
||||
|
||||
@@ -209,17 +209,20 @@ import (
|
||||
// DefaultResolver returns an Endpoint resolver that will be able
|
||||
// to resolve endpoints for: {{ ListPartitionNames . }}.
|
||||
//
|
||||
// Casting the return value of this func to a EnumPartitions will
|
||||
// allow you to get a list of the partitions in the order the endpoints
|
||||
// will be resolved in.
|
||||
// Use DefaultPartitions() to get the list of the default partitions.
|
||||
func DefaultResolver() Resolver {
|
||||
return defaultPartitions
|
||||
}
|
||||
|
||||
// DefaultPartitions returns a list of the partitions the SDK is bundled
|
||||
// with. The available partitions are: {{ ListPartitionNames . }}.
|
||||
//
|
||||
// resolver := endpoints.DefaultResolver()
|
||||
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
|
||||
// partitions := endpoints.DefaultPartitions
|
||||
// for _, p := range partitions {
|
||||
// // ... inspect partitions
|
||||
// }
|
||||
func DefaultResolver() Resolver {
|
||||
return defaultPartitions
|
||||
func DefaultPartitions() []Partition {
|
||||
return defaultPartitions.Partitions()
|
||||
}
|
||||
|
||||
var defaultPartitions = partitions{
|
||||
|
||||
+2
-1
@@ -4,7 +4,8 @@ package aws
|
||||
// into a json string. This type can be used just like any other map.
|
||||
//
|
||||
// Example:
|
||||
// values := JSONValue{
|
||||
//
|
||||
// values := aws.JSONValue{
|
||||
// "Foo": "Bar",
|
||||
// }
|
||||
// values["Baz"] = "Qux"
|
||||
|
||||
+2
-2
@@ -26,14 +26,14 @@ func (l *LogLevelType) Value() LogLevelType {
|
||||
|
||||
// Matches returns true if the v LogLevel is enabled by this LogLevel. Should be
|
||||
// used with logging sub levels. Is safe to use on nil value LogLevelTypes. If
|
||||
// LogLevel is nill, will default to LogOff comparison.
|
||||
// LogLevel is nil, will default to LogOff comparison.
|
||||
func (l *LogLevelType) Matches(v LogLevelType) bool {
|
||||
c := l.Value()
|
||||
return c&v == v
|
||||
}
|
||||
|
||||
// AtLeast returns true if this LogLevel is at least high enough to satisfies v.
|
||||
// Is safe to use on nil value LogLevelTypes. If LogLevel is nill, will default
|
||||
// Is safe to use on nil value LogLevelTypes. If LogLevel is nil, will default
|
||||
// to LogOff comparison.
|
||||
func (l *LogLevelType) AtLeast(v LogLevelType) bool {
|
||||
c := l.Value()
|
||||
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
// +build !appengine,!plan9
|
||||
|
||||
package request
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func isErrConnectionReset(err error) bool {
|
||||
if opErr, ok := err.(*net.OpError); ok {
|
||||
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
|
||||
return sysErr.Err == syscall.ECONNRESET
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
+11
@@ -0,0 +1,11 @@
|
||||
// +build appengine plan9
|
||||
|
||||
package request
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isErrConnectionReset(err error) bool {
|
||||
return strings.Contains(err.Error(), "connection reset")
|
||||
}
|
||||
Generated
Vendored
+9
@@ -0,0 +1,9 @@
|
||||
// +build appengine plan9
|
||||
|
||||
package request_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var stubConnectionResetError = errors.New("connection reset")
|
||||
+11
@@ -0,0 +1,11 @@
|
||||
// +build !appengine,!plan9
|
||||
|
||||
package request_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var stubConnectionResetError = &net.OpError{Err: &os.SyscallError{Syscall: "read", Err: syscall.ECONNRESET}}
|
||||
+69
-13
@@ -88,13 +88,17 @@ func (l *HandlerList) copy() HandlerList {
|
||||
n := HandlerList{
|
||||
AfterEachFn: l.AfterEachFn,
|
||||
}
|
||||
if len(l.list) == 0 {
|
||||
return n
|
||||
}
|
||||
|
||||
n.list = append(make([]NamedHandler, 0, len(l.list)), l.list...)
|
||||
return n
|
||||
}
|
||||
|
||||
// Clear clears the handler list.
|
||||
func (l *HandlerList) Clear() {
|
||||
l.list = []NamedHandler{}
|
||||
l.list = l.list[0:0]
|
||||
}
|
||||
|
||||
// Len returns the number of handlers in the list.
|
||||
@@ -104,33 +108,85 @@ func (l *HandlerList) Len() int {
|
||||
|
||||
// PushBack pushes handler f to the back of the handler list.
|
||||
func (l *HandlerList) PushBack(f func(*Request)) {
|
||||
l.list = append(l.list, NamedHandler{"__anonymous", f})
|
||||
}
|
||||
|
||||
// PushFront pushes handler f to the front of the handler list.
|
||||
func (l *HandlerList) PushFront(f func(*Request)) {
|
||||
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...)
|
||||
l.PushBackNamed(NamedHandler{"__anonymous", f})
|
||||
}
|
||||
|
||||
// PushBackNamed pushes named handler f to the back of the handler list.
|
||||
func (l *HandlerList) PushBackNamed(n NamedHandler) {
|
||||
if cap(l.list) == 0 {
|
||||
l.list = make([]NamedHandler, 0, 5)
|
||||
}
|
||||
l.list = append(l.list, n)
|
||||
}
|
||||
|
||||
// PushFront pushes handler f to the front of the handler list.
|
||||
func (l *HandlerList) PushFront(f func(*Request)) {
|
||||
l.PushFrontNamed(NamedHandler{"__anonymous", f})
|
||||
}
|
||||
|
||||
// PushFrontNamed pushes named handler f to the front of the handler list.
|
||||
func (l *HandlerList) PushFrontNamed(n NamedHandler) {
|
||||
l.list = append([]NamedHandler{n}, l.list...)
|
||||
if cap(l.list) == len(l.list) {
|
||||
// Allocating new list required
|
||||
l.list = append([]NamedHandler{n}, l.list...)
|
||||
} else {
|
||||
// Enough room to prepend into list.
|
||||
l.list = append(l.list, NamedHandler{})
|
||||
copy(l.list[1:], l.list)
|
||||
l.list[0] = n
|
||||
}
|
||||
}
|
||||
|
||||
// Remove removes a NamedHandler n
|
||||
func (l *HandlerList) Remove(n NamedHandler) {
|
||||
newlist := []NamedHandler{}
|
||||
for _, m := range l.list {
|
||||
if m.Name != n.Name {
|
||||
newlist = append(newlist, m)
|
||||
l.RemoveByName(n.Name)
|
||||
}
|
||||
|
||||
// RemoveByName removes a NamedHandler by name.
|
||||
func (l *HandlerList) RemoveByName(name string) {
|
||||
for i := 0; i < len(l.list); i++ {
|
||||
m := l.list[i]
|
||||
if m.Name == name {
|
||||
// Shift array preventing creating new arrays
|
||||
copy(l.list[i:], l.list[i+1:])
|
||||
l.list[len(l.list)-1] = NamedHandler{}
|
||||
l.list = l.list[:len(l.list)-1]
|
||||
|
||||
// decrement list so next check to length is correct
|
||||
i--
|
||||
}
|
||||
}
|
||||
l.list = newlist
|
||||
}
|
||||
|
||||
// SwapNamed will swap out any existing handlers with the same name as the
|
||||
// passed in NamedHandler returning true if handlers were swapped. False is
|
||||
// returned otherwise.
|
||||
func (l *HandlerList) SwapNamed(n NamedHandler) (swapped bool) {
|
||||
for i := 0; i < len(l.list); i++ {
|
||||
if l.list[i].Name == n.Name {
|
||||
l.list[i].Fn = n.Fn
|
||||
swapped = true
|
||||
}
|
||||
}
|
||||
|
||||
return swapped
|
||||
}
|
||||
|
||||
// SetBackNamed will replace the named handler if it exists in the handler list.
|
||||
// If the handler does not exist the handler will be added to the end of the list.
|
||||
func (l *HandlerList) SetBackNamed(n NamedHandler) {
|
||||
if !l.SwapNamed(n) {
|
||||
l.PushBackNamed(n)
|
||||
}
|
||||
}
|
||||
|
||||
// SetFrontNamed will replace the named handler if it exists in the handler list.
|
||||
// If the handler does not exist the handler will be added to the beginning of
|
||||
// the list.
|
||||
func (l *HandlerList) SetFrontNamed(n NamedHandler) {
|
||||
if !l.SwapNamed(n) {
|
||||
l.PushFrontNamed(n)
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes all handlers in the list with a given request object.
|
||||
|
||||
+188
-9
@@ -1,12 +1,13 @@
|
||||
package request_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
func TestHandlerList(t *testing.T) {
|
||||
@@ -18,8 +19,12 @@ func TestHandlerList(t *testing.T) {
|
||||
r.Data = s
|
||||
})
|
||||
l.Run(r)
|
||||
assert.Equal(t, "a", s)
|
||||
assert.Equal(t, "a", r.Data)
|
||||
if e, a := "a", s; e != a {
|
||||
t.Errorf("expect %q update got %q", e, a)
|
||||
}
|
||||
if e, a := "a", r.Data.(string); e != a {
|
||||
t.Errorf("expect %q data update got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleHandlers(t *testing.T) {
|
||||
@@ -41,9 +46,110 @@ func TestNamedHandlers(t *testing.T) {
|
||||
l.PushBackNamed(named)
|
||||
l.PushBackNamed(named2)
|
||||
l.PushBack(func(r *request.Request) {})
|
||||
assert.Equal(t, 4, l.Len())
|
||||
if e, a := 4, l.Len(); e != a {
|
||||
t.Errorf("expect %d list length, got %d", e, a)
|
||||
}
|
||||
l.Remove(named)
|
||||
assert.Equal(t, 2, l.Len())
|
||||
if e, a := 2, l.Len(); e != a {
|
||||
t.Errorf("expect %d list length, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwapHandlers(t *testing.T) {
|
||||
firstHandlerCalled := 0
|
||||
swappedOutHandlerCalled := 0
|
||||
swappedInHandlerCalled := 0
|
||||
|
||||
l := request.HandlerList{}
|
||||
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
|
||||
firstHandlerCalled++
|
||||
}}
|
||||
named2 := request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
|
||||
swappedOutHandlerCalled++
|
||||
}}
|
||||
l.PushBackNamed(named)
|
||||
l.PushBackNamed(named2)
|
||||
l.PushBackNamed(named)
|
||||
|
||||
l.SwapNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
|
||||
swappedInHandlerCalled++
|
||||
}})
|
||||
|
||||
l.Run(&request.Request{})
|
||||
|
||||
if e, a := 2, firstHandlerCalled; e != a {
|
||||
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
|
||||
}
|
||||
if n := swappedOutHandlerCalled; n != 0 {
|
||||
t.Errorf("expect swapped out handler to not be called, was called %d times", n)
|
||||
}
|
||||
if e, a := 1, swappedInHandlerCalled; e != a {
|
||||
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetBackNamed_Exists(t *testing.T) {
|
||||
firstHandlerCalled := 0
|
||||
swappedOutHandlerCalled := 0
|
||||
swappedInHandlerCalled := 0
|
||||
|
||||
l := request.HandlerList{}
|
||||
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
|
||||
firstHandlerCalled++
|
||||
}}
|
||||
named2 := request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
|
||||
swappedOutHandlerCalled++
|
||||
}}
|
||||
l.PushBackNamed(named)
|
||||
l.PushBackNamed(named2)
|
||||
|
||||
l.SetBackNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
|
||||
swappedInHandlerCalled++
|
||||
}})
|
||||
|
||||
l.Run(&request.Request{})
|
||||
|
||||
if e, a := 1, firstHandlerCalled; e != a {
|
||||
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
|
||||
}
|
||||
if n := swappedOutHandlerCalled; n != 0 {
|
||||
t.Errorf("expect swapped out handler to not be called, was called %d times", n)
|
||||
}
|
||||
if e, a := 1, swappedInHandlerCalled; e != a {
|
||||
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetBackNamed_NotExists(t *testing.T) {
|
||||
firstHandlerCalled := 0
|
||||
secondHandlerCalled := 0
|
||||
swappedInHandlerCalled := 0
|
||||
|
||||
l := request.HandlerList{}
|
||||
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
|
||||
firstHandlerCalled++
|
||||
}}
|
||||
named2 := request.NamedHandler{Name: "OtherName", Fn: func(r *request.Request) {
|
||||
secondHandlerCalled++
|
||||
}}
|
||||
l.PushBackNamed(named)
|
||||
l.PushBackNamed(named2)
|
||||
|
||||
l.SetBackNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
|
||||
swappedInHandlerCalled++
|
||||
}})
|
||||
|
||||
l.Run(&request.Request{})
|
||||
|
||||
if e, a := 1, firstHandlerCalled; e != a {
|
||||
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
|
||||
}
|
||||
if e, a := 1, secondHandlerCalled; e != a {
|
||||
t.Errorf("expect second handler to be called %d, was called %d times", e, a)
|
||||
}
|
||||
if e, a := 1, swappedInHandlerCalled; e != a {
|
||||
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggedHandlers(t *testing.T) {
|
||||
@@ -61,7 +167,10 @@ func TestLoggedHandlers(t *testing.T) {
|
||||
l.PushBackNamed(named2)
|
||||
l.Run(&request.Request{Config: cfg})
|
||||
|
||||
assert.Equal(t, expectedHandlers, loggedHandlers)
|
||||
if !reflect.DeepEqual(expectedHandlers, loggedHandlers) {
|
||||
t.Errorf("expect handlers executed %v to match logged handlers, %v",
|
||||
expectedHandlers, loggedHandlers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopHandlers(t *testing.T) {
|
||||
@@ -79,9 +188,79 @@ func TestStopHandlers(t *testing.T) {
|
||||
called++
|
||||
}})
|
||||
l.PushBackNamed(request.NamedHandler{Name: "name3", Fn: func(r *request.Request) {
|
||||
assert.Fail(t, "third handler should not be called")
|
||||
t.Fatalf("third handler should not be called")
|
||||
}})
|
||||
l.Run(&request.Request{})
|
||||
|
||||
assert.Equal(t, 2, called, "Expect only two handlers to be called")
|
||||
if e, a := 2, called; e != a {
|
||||
t.Errorf("expect %d handlers called, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewRequest(b *testing.B) {
|
||||
svc := s3.New(unit.Session)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, _ := svc.GetObjectRequest(nil)
|
||||
if r == nil {
|
||||
b.Fatal("r should not be nil")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandlersCopy(b *testing.B) {
|
||||
handlers := request.Handlers{}
|
||||
|
||||
handlers.Validate.PushBack(func(r *request.Request) {})
|
||||
handlers.Validate.PushBack(func(r *request.Request) {})
|
||||
handlers.Build.PushBack(func(r *request.Request) {})
|
||||
handlers.Build.PushBack(func(r *request.Request) {})
|
||||
handlers.Send.PushBack(func(r *request.Request) {})
|
||||
handlers.Send.PushBack(func(r *request.Request) {})
|
||||
handlers.Unmarshal.PushBack(func(r *request.Request) {})
|
||||
handlers.Unmarshal.PushBack(func(r *request.Request) {})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := handlers.Copy()
|
||||
if e, a := handlers.Validate.Len(), h.Validate.Len(); e != a {
|
||||
b.Fatalf("expected %d handlers got %d", e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandlersPushBack(b *testing.B) {
|
||||
handlers := request.Handlers{}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := handlers.Copy()
|
||||
h.Validate.PushBack(func(r *request.Request) {})
|
||||
h.Validate.PushBack(func(r *request.Request) {})
|
||||
h.Validate.PushBack(func(r *request.Request) {})
|
||||
h.Validate.PushBack(func(r *request.Request) {})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandlersPushFront(b *testing.B) {
|
||||
handlers := request.Handlers{}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := handlers.Copy()
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandlersClear(b *testing.B) {
|
||||
handlers := request.Handlers{}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := handlers.Copy()
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
h.Validate.PushFront(func(r *request.Request) {})
|
||||
h.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
+92
-38
@@ -16,10 +16,23 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
)
|
||||
|
||||
// CanceledErrorCode is the error code that will be returned by an
|
||||
// API request that was canceled. Requests given a aws.Context may
|
||||
// return this error when canceled.
|
||||
const CanceledErrorCode = "RequestCanceled"
|
||||
const (
|
||||
// ErrCodeSerialization is the serialization error code that is received
|
||||
// during protocol unmarshaling.
|
||||
ErrCodeSerialization = "SerializationError"
|
||||
|
||||
// ErrCodeRead is an error that is returned during HTTP reads.
|
||||
ErrCodeRead = "ReadError"
|
||||
|
||||
// ErrCodeResponseTimeout is the connection timeout error that is received
|
||||
// during body reads.
|
||||
ErrCodeResponseTimeout = "ResponseTimeout"
|
||||
|
||||
// CanceledErrorCode is the error code that will be returned by an
|
||||
// API request that was canceled. Requests given a aws.Context may
|
||||
// return this error when canceled.
|
||||
CanceledErrorCode = "RequestCanceled"
|
||||
)
|
||||
|
||||
// A Request is the service request to be made.
|
||||
type Request struct {
|
||||
@@ -28,23 +41,24 @@ type Request struct {
|
||||
Handlers Handlers
|
||||
|
||||
Retryer
|
||||
Time time.Time
|
||||
ExpireTime time.Duration
|
||||
Operation *Operation
|
||||
HTTPRequest *http.Request
|
||||
HTTPResponse *http.Response
|
||||
Body io.ReadSeeker
|
||||
BodyStart int64 // offset from beginning of Body that the request body starts
|
||||
Params interface{}
|
||||
Error error
|
||||
Data interface{}
|
||||
RequestID string
|
||||
RetryCount int
|
||||
Retryable *bool
|
||||
RetryDelay time.Duration
|
||||
NotHoist bool
|
||||
SignedHeaderVals http.Header
|
||||
LastSignedAt time.Time
|
||||
Time time.Time
|
||||
ExpireTime time.Duration
|
||||
Operation *Operation
|
||||
HTTPRequest *http.Request
|
||||
HTTPResponse *http.Response
|
||||
Body io.ReadSeeker
|
||||
BodyStart int64 // offset from beginning of Body that the request body starts
|
||||
Params interface{}
|
||||
Error error
|
||||
Data interface{}
|
||||
RequestID string
|
||||
RetryCount int
|
||||
Retryable *bool
|
||||
RetryDelay time.Duration
|
||||
NotHoist bool
|
||||
SignedHeaderVals http.Header
|
||||
LastSignedAt time.Time
|
||||
DisableFollowRedirects bool
|
||||
|
||||
context aws.Context
|
||||
|
||||
@@ -114,6 +128,40 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
|
||||
// using a WithContext API operation method.
|
||||
type Option func(*Request)
|
||||
|
||||
// WithGetResponseHeader builds a request Option which will retrieve a single
|
||||
// header value from the HTTP Response. If there are multiple values for the
|
||||
// header key use WithGetResponseHeaders instead to access the http.Header
|
||||
// map directly. The passed in val pointer must be non-nil.
|
||||
//
|
||||
// This Option can be used multiple times with a single API operation.
|
||||
//
|
||||
// var id2, versionID string
|
||||
// svc.PutObjectWithContext(ctx, params,
|
||||
// request.WithGetResponseHeader("x-amz-id-2", &id2),
|
||||
// request.WithGetResponseHeader("x-amz-version-id", &versionID),
|
||||
// )
|
||||
func WithGetResponseHeader(key string, val *string) Option {
|
||||
return func(r *Request) {
|
||||
r.Handlers.Complete.PushBack(func(req *Request) {
|
||||
*val = req.HTTPResponse.Header.Get(key)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithGetResponseHeaders builds a request Option which will retrieve the
|
||||
// headers from the HTTP response and assign them to the passed in headers
|
||||
// variable. The passed in headers pointer must be non-nil.
|
||||
//
|
||||
// var headers http.Header
|
||||
// svc.PutObjectWithContext(ctx, params, request.WithGetResponseHeaders(&headers))
|
||||
func WithGetResponseHeaders(headers *http.Header) Option {
|
||||
return func(r *Request) {
|
||||
r.Handlers.Complete.PushBack(func(req *Request) {
|
||||
*headers = req.HTTPResponse.Header
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogLevel is a request option that will set the request to use a specific
|
||||
// log level when the request is made.
|
||||
//
|
||||
@@ -221,11 +269,17 @@ func (r *Request) Presign(expireTime time.Duration) (string, error) {
|
||||
return r.HTTPRequest.URL.String(), nil
|
||||
}
|
||||
|
||||
// PresignRequest behaves just like presign, but hoists all headers and signs them.
|
||||
// Also returns the signed hash back to the user
|
||||
// PresignRequest behaves just like presign, with the addition of returning a
|
||||
// set of headers that were signed.
|
||||
//
|
||||
// Returns the URL string for the API operation with signature in the query string,
|
||||
// and the HTTP headers that were included in the signature. These headers must
|
||||
// be included in any HTTP request made with the presigned URL.
|
||||
//
|
||||
// To prevent hoisting any headers to the query string set NotHoist to true on
|
||||
// this Request value prior to calling PresignRequest.
|
||||
func (r *Request) PresignRequest(expireTime time.Duration) (string, http.Header, error) {
|
||||
r.ExpireTime = expireTime
|
||||
r.NotHoist = true
|
||||
r.Sign()
|
||||
if r.Error != nil {
|
||||
return "", nil, r.Error
|
||||
@@ -290,10 +344,7 @@ func (r *Request) Sign() error {
|
||||
return r.Error
|
||||
}
|
||||
|
||||
// ResetBody rewinds the request body backto its starting position, and
|
||||
// set's the HTTP Request body reference. When the body is read prior
|
||||
// to being sent in the HTTP request it will need to be rewound.
|
||||
func (r *Request) ResetBody() {
|
||||
func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
|
||||
if r.safeBody != nil {
|
||||
r.safeBody.Close()
|
||||
}
|
||||
@@ -315,14 +366,14 @@ func (r *Request) ResetBody() {
|
||||
// Related golang/go#18257
|
||||
l, err := computeBodyLength(r.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to compute request body size", err)
|
||||
return
|
||||
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
|
||||
}
|
||||
|
||||
var body io.ReadCloser
|
||||
if l == 0 {
|
||||
r.HTTPRequest.Body = noBodyReader
|
||||
body = NoBody
|
||||
} else if l > 0 {
|
||||
r.HTTPRequest.Body = r.safeBody
|
||||
body = r.safeBody
|
||||
} else {
|
||||
// Hack to prevent sending bodies for methods where the body
|
||||
// should be ignored by the server. Sending bodies on these
|
||||
@@ -334,11 +385,13 @@ func (r *Request) ResetBody() {
|
||||
// a io.Reader that was not also an io.Seeker.
|
||||
switch r.Operation.HTTPMethod {
|
||||
case "GET", "HEAD", "DELETE":
|
||||
r.HTTPRequest.Body = noBodyReader
|
||||
body = NoBody
|
||||
default:
|
||||
r.HTTPRequest.Body = r.safeBody
|
||||
body = r.safeBody
|
||||
}
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Attempts to compute the length of the body of the reader using the
|
||||
@@ -440,7 +493,7 @@ func (r *Request) Send() error {
|
||||
r.Handlers.Retry.Run(r)
|
||||
r.Handlers.AfterRetry.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Send Request", false, r.Error)
|
||||
debugLogReqError(r, "Send Request", false, err)
|
||||
return r.Error
|
||||
}
|
||||
debugLogReqError(r, "Send Request", true, err)
|
||||
@@ -449,12 +502,13 @@ func (r *Request) Send() error {
|
||||
r.Handlers.UnmarshalMeta.Run(r)
|
||||
r.Handlers.ValidateResponse.Run(r)
|
||||
if r.Error != nil {
|
||||
err := r.Error
|
||||
r.Handlers.UnmarshalError.Run(r)
|
||||
err := r.Error
|
||||
|
||||
r.Handlers.Retry.Run(r)
|
||||
r.Handlers.AfterRetry.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Validate Response", false, r.Error)
|
||||
debugLogReqError(r, "Validate Response", false, err)
|
||||
return r.Error
|
||||
}
|
||||
debugLogReqError(r, "Validate Response", true, err)
|
||||
@@ -467,7 +521,7 @@ func (r *Request) Send() error {
|
||||
r.Handlers.Retry.Run(r)
|
||||
r.Handlers.AfterRetry.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Unmarshal Response", false, r.Error)
|
||||
debugLogReqError(r, "Unmarshal Response", false, err)
|
||||
return r.Error
|
||||
}
|
||||
debugLogReqError(r, "Unmarshal Response", true, err)
|
||||
|
||||
+20
-2
@@ -16,6 +16,24 @@ func (noBody) Read([]byte) (int, error) { return 0, io.EOF }
|
||||
func (noBody) Close() error { return nil }
|
||||
func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil }
|
||||
|
||||
// Is an empty reader that will trigger the Go HTTP client to not include
|
||||
// NoBody is an empty reader that will trigger the Go HTTP client to not include
|
||||
// and body in the HTTP request.
|
||||
var noBodyReader = noBody{}
|
||||
var NoBody = noBody{}
|
||||
|
||||
// ResetBody rewinds the request body back to its starting position, and
|
||||
// set's the HTTP Request body reference. When the body is read prior
|
||||
// to being sent in the HTTP request it will need to be rewound.
|
||||
//
|
||||
// ResetBody will automatically be called by the SDK's build handler, but if
|
||||
// the request is being used directly ResetBody must be called before the request
|
||||
// is Sent. SetStringBody, SetBufferBody, and SetReaderBody will automatically
|
||||
// call ResetBody.
|
||||
func (r *Request) ResetBody() {
|
||||
body, err := r.getNextRequestBody()
|
||||
if err != nil {
|
||||
r.Error = err
|
||||
return
|
||||
}
|
||||
|
||||
r.HTTPRequest.Body = body
|
||||
}
|
||||
|
||||
+27
-3
@@ -2,8 +2,32 @@
|
||||
|
||||
package request
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Is a http.NoBody reader instructing Go HTTP client to not include
|
||||
// NoBody is a http.NoBody reader instructing Go HTTP client to not include
|
||||
// and body in the HTTP request.
|
||||
var noBodyReader = http.NoBody
|
||||
var NoBody = http.NoBody
|
||||
|
||||
// ResetBody rewinds the request body back to its starting position, and
|
||||
// set's the HTTP Request body reference. When the body is read prior
|
||||
// to being sent in the HTTP request it will need to be rewound.
|
||||
//
|
||||
// ResetBody will automatically be called by the SDK's build handler, but if
|
||||
// the request is being used directly ResetBody must be called before the request
|
||||
// is Sent. SetStringBody, SetBufferBody, and SetReaderBody will automatically
|
||||
// call ResetBody.
|
||||
//
|
||||
// Will also set the Go 1.8's http.Request.GetBody member to allow retrying
|
||||
// PUT/POST redirects.
|
||||
func (r *Request) ResetBody() {
|
||||
body, err := r.getNextRequestBody()
|
||||
if err != nil {
|
||||
r.Error = err
|
||||
return
|
||||
}
|
||||
|
||||
r.HTTPRequest.Body = body
|
||||
r.HTTPRequest.GetBody = r.getNextRequestBody
|
||||
}
|
||||
|
||||
+62
-2
@@ -1,15 +1,23 @@
|
||||
// +build go1.8
|
||||
|
||||
package request
|
||||
package request_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
)
|
||||
|
||||
func TestResetBody_WithEmptyBody(t *testing.T) {
|
||||
r := Request{
|
||||
r := request.Request{
|
||||
HTTPRequest: &http.Request{},
|
||||
}
|
||||
|
||||
@@ -23,3 +31,55 @@ func TestResetBody_WithEmptyBody(t *testing.T) {
|
||||
r.HTTPRequest.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequest_FollowPUTRedirects(t *testing.T) {
|
||||
const bodySize = 1024
|
||||
|
||||
redirectHit := 0
|
||||
endpointHit := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/redirect-me":
|
||||
u := *r.URL
|
||||
u.Path = "/endpoint"
|
||||
w.Header().Set("Location", u.String())
|
||||
w.WriteHeader(307)
|
||||
redirectHit++
|
||||
case "/endpoint":
|
||||
b := bytes.Buffer{}
|
||||
io.Copy(&b, r.Body)
|
||||
r.Body.Close()
|
||||
if e, a := bodySize, b.Len(); e != a {
|
||||
t.Fatalf("expect %d body size, got %d", e, a)
|
||||
}
|
||||
endpointHit++
|
||||
default:
|
||||
t.Fatalf("unexpected endpoint used, %q", r.URL.String())
|
||||
}
|
||||
}))
|
||||
|
||||
svc := awstesting.NewClient(&aws.Config{
|
||||
Region: unit.Session.Config.Region,
|
||||
DisableSSL: aws.Bool(true),
|
||||
Endpoint: aws.String(server.URL),
|
||||
})
|
||||
|
||||
req := svc.NewRequest(&request.Operation{
|
||||
Name: "Operation",
|
||||
HTTPMethod: "PUT",
|
||||
HTTPPath: "/redirect-me",
|
||||
}, &struct{}{}, &struct{}{})
|
||||
req.SetReaderBody(bytes.NewReader(make([]byte, bodySize)))
|
||||
|
||||
err := req.Send()
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := 1, redirectHit; e != a {
|
||||
t.Errorf("expect %d redirect hits, got %d", e, a)
|
||||
}
|
||||
if e, a := 1, endpointHit; e != a {
|
||||
t.Errorf("expect %d endpoint hits, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+11
-4
@@ -2,8 +2,6 @@ package request
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
@@ -15,6 +13,15 @@ func TestCopy(t *testing.T) {
|
||||
req.Handlers = handlers
|
||||
|
||||
r := req.copy()
|
||||
assert.NotEqual(t, req, r)
|
||||
assert.Equal(t, req.Operation.HTTPMethod, r.Operation.HTTPMethod)
|
||||
|
||||
if r == req {
|
||||
t.Fatal("expect request pointer copy to be different")
|
||||
}
|
||||
if r.Operation == req.Operation {
|
||||
t.Errorf("expect request operation pointer to be different")
|
||||
}
|
||||
|
||||
if e, a := req.Operation.HTTPMethod, r.Operation.HTTPMethod; e != a {
|
||||
t.Errorf("expect %q http method, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+73
@@ -381,6 +381,79 @@ func TestPaginationNilToken(t *testing.T) {
|
||||
assert.Equal(t, []string{"first.example.com.", "second.example.com.", "third.example.com."}, results)
|
||||
}
|
||||
|
||||
func TestPaginationNilInput(t *testing.T) {
|
||||
// Code generation doesn't have a great way to verify the code is correct
|
||||
// other than being run via unit tests in the SDK. This should be fixed
|
||||
// So code generation can be validated independently.
|
||||
|
||||
client := s3.New(unit.Session)
|
||||
client.Handlers.Validate.Clear()
|
||||
client.Handlers.Send.Clear() // mock sending
|
||||
client.Handlers.Unmarshal.Clear()
|
||||
client.Handlers.UnmarshalMeta.Clear()
|
||||
client.Handlers.ValidateResponse.Clear()
|
||||
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
|
||||
r.Data = &s3.ListObjectsOutput{}
|
||||
})
|
||||
|
||||
gotToEnd := false
|
||||
numPages := 0
|
||||
err := client.ListObjectsPages(nil, func(p *s3.ListObjectsOutput, last bool) bool {
|
||||
numPages++
|
||||
if last {
|
||||
gotToEnd = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
if e, a := 1, numPages; e != a {
|
||||
t.Errorf("expect %d number pages but got %d", e, a)
|
||||
}
|
||||
if !gotToEnd {
|
||||
t.Errorf("expect to of gotten to end, did not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginationWithContextNilInput(t *testing.T) {
|
||||
// Code generation doesn't have a great way to verify the code is correct
|
||||
// other than being run via unit tests in the SDK. This should be fixed
|
||||
// So code generation can be validated independently.
|
||||
|
||||
client := s3.New(unit.Session)
|
||||
client.Handlers.Validate.Clear()
|
||||
client.Handlers.Send.Clear() // mock sending
|
||||
client.Handlers.Unmarshal.Clear()
|
||||
client.Handlers.UnmarshalMeta.Clear()
|
||||
client.Handlers.ValidateResponse.Clear()
|
||||
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
|
||||
r.Data = &s3.ListObjectsOutput{}
|
||||
})
|
||||
|
||||
gotToEnd := false
|
||||
numPages := 0
|
||||
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
|
||||
err := client.ListObjectsPagesWithContext(ctx, nil, func(p *s3.ListObjectsOutput, last bool) bool {
|
||||
numPages++
|
||||
if last {
|
||||
gotToEnd = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
if e, a := 1, numPages; e != a {
|
||||
t.Errorf("expect %d number pages but got %d", e, a)
|
||||
}
|
||||
if !gotToEnd {
|
||||
t.Errorf("expect to of gotten to end, did not")
|
||||
}
|
||||
}
|
||||
|
||||
type testPageInput struct {
|
||||
NextToken string
|
||||
}
|
||||
|
||||
+1
-1
@@ -50,7 +50,7 @@ func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) {
|
||||
|
||||
r.SetReaderBody(reader)
|
||||
|
||||
if a, e := r.HTTPRequest.Body == noBodyReader, c.IsNoBody; a != e {
|
||||
if a, e := r.HTTPRequest.Body == NoBody, c.IsNoBody; a != e {
|
||||
t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+456
-52
@@ -8,18 +8,23 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
)
|
||||
|
||||
@@ -91,9 +96,15 @@ func TestRequestRecoverRetry5xx(t *testing.T) {
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, int(r.RetryCount))
|
||||
assert.Equal(t, "valid", out.Data)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
if e, a := 2, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
if e, a := "valid", out.Data; e != a {
|
||||
t.Errorf("expect %q output got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
// test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry`
|
||||
@@ -117,9 +128,15 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, int(r.RetryCount))
|
||||
assert.Equal(t, "valid", out.Data)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
if e, a := 2, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
if e, a := "valid", out.Data; e != a {
|
||||
t.Errorf("expect %q output got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
// test that retries don't occur for 4xx status codes with a response type that can't be retried
|
||||
@@ -135,15 +152,22 @@ func TestRequest4xxUnretryable(t *testing.T) {
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
assert.NotNil(t, err)
|
||||
if e, ok := err.(awserr.RequestFailure); ok {
|
||||
assert.Equal(t, 401, e.StatusCode())
|
||||
} else {
|
||||
assert.Fail(t, "Expected error to be a service failure")
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, but did not get one")
|
||||
}
|
||||
aerr := err.(awserr.RequestFailure)
|
||||
if e, a := 401, aerr.StatusCode(); e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
if e, a := "SignatureDoesNotMatch", aerr.Code(); e != a {
|
||||
t.Errorf("expect %q error code, got %q", e, a)
|
||||
}
|
||||
if e, a := "Signature does not match.", aerr.Message(); e != a {
|
||||
t.Errorf("expect %q error message, got %q", e, a)
|
||||
}
|
||||
if e, a := 0, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
assert.Equal(t, "SignatureDoesNotMatch", err.(awserr.Error).Code())
|
||||
assert.Equal(t, "Signature does not match.", err.(awserr.Error).Message())
|
||||
assert.Equal(t, 0, int(r.RetryCount))
|
||||
}
|
||||
|
||||
func TestRequestExhaustRetries(t *testing.T) {
|
||||
@@ -171,22 +195,31 @@ func TestRequestExhaustRetries(t *testing.T) {
|
||||
})
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
|
||||
err := r.Send()
|
||||
assert.NotNil(t, err)
|
||||
if e, ok := err.(awserr.RequestFailure); ok {
|
||||
assert.Equal(t, 500, e.StatusCode())
|
||||
} else {
|
||||
assert.Fail(t, "Expected error to be a service failure")
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, but did not get one")
|
||||
}
|
||||
aerr := err.(awserr.RequestFailure)
|
||||
if e, a := 500, aerr.StatusCode(); e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
if e, a := "UnknownError", aerr.Code(); e != a {
|
||||
t.Errorf("expect %q error code, got %q", e, a)
|
||||
}
|
||||
if e, a := "An error occurred.", aerr.Message(); e != a {
|
||||
t.Errorf("expect %q error message, got %q", e, a)
|
||||
}
|
||||
if e, a := 3, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
assert.Equal(t, "UnknownError", err.(awserr.Error).Code())
|
||||
assert.Equal(t, "An error occurred.", err.(awserr.Error).Message())
|
||||
assert.Equal(t, 3, int(r.RetryCount))
|
||||
|
||||
expectDelays := []struct{ min, max time.Duration }{{30, 59}, {60, 118}, {120, 236}}
|
||||
for i, v := range delays {
|
||||
min := expectDelays[i].min * time.Millisecond
|
||||
max := expectDelays[i].max * time.Millisecond
|
||||
assert.True(t, min <= v && v <= max,
|
||||
"Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", i, v, min, max)
|
||||
if !(min <= v && v <= max) {
|
||||
t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s",
|
||||
i, v, min, max)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,14 +255,26 @@ func TestRequestRecoverExpiredCreds(t *testing.T) {
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
assert.Nil(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check")
|
||||
assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check")
|
||||
assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery")
|
||||
if credExpiredBeforeRetry {
|
||||
t.Errorf("Expect valid creds before retry check")
|
||||
}
|
||||
if !credExpiredAfterRetry {
|
||||
t.Errorf("Expect expired creds after retry check")
|
||||
}
|
||||
if s.Config.Credentials.IsExpired() {
|
||||
t.Errorf("Expect valid creds after cred expired recovery")
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, int(r.RetryCount))
|
||||
assert.Equal(t, "valid", out.Data)
|
||||
if e, a := 1, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
if e, a := "valid", out.Data; e != a {
|
||||
t.Errorf("expect %q output got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeAddtoUserAgentHandler(t *testing.T) {
|
||||
@@ -238,7 +283,9 @@ func TestMakeAddtoUserAgentHandler(t *testing.T) {
|
||||
r.HTTPRequest.Header.Set("User-Agent", "foo/bar")
|
||||
fn(r)
|
||||
|
||||
assert.Equal(t, "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"))
|
||||
if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); e != a {
|
||||
t.Errorf("expect %q user agent, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeAddtoUserAgentFreeFormHandler(t *testing.T) {
|
||||
@@ -247,7 +294,9 @@ func TestMakeAddtoUserAgentFreeFormHandler(t *testing.T) {
|
||||
r.HTTPRequest.Header.Set("User-Agent", "foo/bar")
|
||||
fn(r)
|
||||
|
||||
assert.Equal(t, "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"))
|
||||
if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); e != a {
|
||||
t.Errorf("expect %q user agent, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestUserAgent(t *testing.T) {
|
||||
@@ -256,11 +305,15 @@ func TestRequestUserAgent(t *testing.T) {
|
||||
|
||||
req := s.NewRequest(&request.Operation{Name: "Operation"}, nil, &testData{})
|
||||
req.HTTPRequest.Header.Set("User-Agent", "foo/bar")
|
||||
assert.NoError(t, req.Build())
|
||||
if err := req.Build(); err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
expectUA := fmt.Sprintf("foo/bar %s/%s (%s; %s; %s)",
|
||||
aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH)
|
||||
assert.Equal(t, expectUA, req.HTTPRequest.Header.Get("User-Agent"))
|
||||
if e, a := expectUA, req.HTTPRequest.Header.Get("User-Agent"); e != a {
|
||||
t.Errorf("expect %q user agent, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestThrottleRetries(t *testing.T) {
|
||||
@@ -288,22 +341,31 @@ func TestRequestThrottleRetries(t *testing.T) {
|
||||
})
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
|
||||
err := r.Send()
|
||||
assert.NotNil(t, err)
|
||||
if e, ok := err.(awserr.RequestFailure); ok {
|
||||
assert.Equal(t, 500, e.StatusCode())
|
||||
} else {
|
||||
assert.Fail(t, "Expected error to be a service failure")
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, but did not get one")
|
||||
}
|
||||
aerr := err.(awserr.RequestFailure)
|
||||
if e, a := 500, aerr.StatusCode(); e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
if e, a := "Throttling", aerr.Code(); e != a {
|
||||
t.Errorf("expect %q error code, got %q", e, a)
|
||||
}
|
||||
if e, a := "An error occurred.", aerr.Message(); e != a {
|
||||
t.Errorf("expect %q error message, got %q", e, a)
|
||||
}
|
||||
if e, a := 3, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
assert.Equal(t, "Throttling", err.(awserr.Error).Code())
|
||||
assert.Equal(t, "An error occurred.", err.(awserr.Error).Message())
|
||||
assert.Equal(t, 3, int(r.RetryCount))
|
||||
|
||||
expectDelays := []struct{ min, max time.Duration }{{500, 999}, {1000, 1998}, {2000, 3996}}
|
||||
for i, v := range delays {
|
||||
min := expectDelays[i].min * time.Millisecond
|
||||
max := expectDelays[i].max * time.Millisecond
|
||||
assert.True(t, min <= v && v <= max,
|
||||
"Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", i, v, min, max)
|
||||
if !(min <= v && v <= max) {
|
||||
t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s",
|
||||
i, v, min, max)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,9 +401,15 @@ func TestRequestRecoverTimeoutWithNilBody(t *testing.T) {
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, int(r.RetryCount))
|
||||
assert.Equal(t, "valid", out.Data)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
if e, a := 1, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
if e, a := "valid", out.Data; e != a {
|
||||
t.Errorf("expect %q output got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestRecoverTimeoutWithNilResponse(t *testing.T) {
|
||||
@@ -376,9 +444,15 @@ func TestRequestRecoverTimeoutWithNilResponse(t *testing.T) {
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, int(r.RetryCount))
|
||||
assert.Equal(t, "valid", out.Data)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
if e, a := 1, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
if e, a := "valid", out.Data; e != a {
|
||||
t.Errorf("expect %q output got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequest_NoBody(t *testing.T) {
|
||||
@@ -438,3 +512,333 @@ func TestRequest_NoBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSerializationErrorRetryable(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
err: awserr.New(request.ErrCodeSerialization, "foo error", nil),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
err: awserr.New("ErrFoo", "foo error", nil),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
err: awserr.New(request.ErrCodeSerialization, "foo error", stubConnectionResetError),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range testCases {
|
||||
r := &request.Request{
|
||||
Error: c.err,
|
||||
}
|
||||
if r.IsErrorRetryable() != c.expected {
|
||||
t.Errorf("Case %d: Expected %v, but received %v", i+1, c.expected, !c.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogLevel(t *testing.T) {
|
||||
r := &request.Request{}
|
||||
|
||||
opt := request.WithLogLevel(aws.LogDebugWithHTTPBody)
|
||||
r.ApplyOptions(opt)
|
||||
|
||||
if !r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) {
|
||||
t.Errorf("expect log level to be set, but was not, %v",
|
||||
r.Config.LogLevel.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithGetResponseHeader(t *testing.T) {
|
||||
r := &request.Request{}
|
||||
|
||||
var val, val2 string
|
||||
r.ApplyOptions(
|
||||
request.WithGetResponseHeader("x-a-header", &val),
|
||||
request.WithGetResponseHeader("x-second-header", &val2),
|
||||
)
|
||||
|
||||
r.HTTPResponse = &http.Response{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("x-a-header", "first")
|
||||
h.Set("x-second-header", "second")
|
||||
return h
|
||||
}(),
|
||||
}
|
||||
r.Handlers.Complete.Run(r)
|
||||
|
||||
if e, a := "first", val; e != a {
|
||||
t.Errorf("expect %q header value got %q", e, a)
|
||||
}
|
||||
if e, a := "second", val2; e != a {
|
||||
t.Errorf("expect %q header value got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithGetResponseHeaders(t *testing.T) {
|
||||
r := &request.Request{}
|
||||
|
||||
var headers http.Header
|
||||
opt := request.WithGetResponseHeaders(&headers)
|
||||
|
||||
r.ApplyOptions(opt)
|
||||
|
||||
r.HTTPResponse = &http.Response{
|
||||
Header: func() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("x-a-header", "headerValue")
|
||||
return h
|
||||
}(),
|
||||
}
|
||||
r.Handlers.Complete.Run(r)
|
||||
|
||||
if e, a := "headerValue", headers.Get("x-a-header"); e != a {
|
||||
t.Errorf("expect %q header value got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
type connResetCloser struct {
|
||||
}
|
||||
|
||||
func (rc *connResetCloser) Read(b []byte) (int, error) {
|
||||
return 0, stubConnectionResetError
|
||||
}
|
||||
|
||||
func (rc *connResetCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSerializationErrConnectionReset(t *testing.T) {
|
||||
count := 0
|
||||
handlers := request.Handlers{}
|
||||
handlers.Send.PushBack(func(r *request.Request) {
|
||||
count++
|
||||
r.HTTPResponse = &http.Response{}
|
||||
r.HTTPResponse.Body = &connResetCloser{}
|
||||
})
|
||||
|
||||
handlers.Sign.PushBackNamed(v4.SignRequestHandler)
|
||||
handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
|
||||
handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
|
||||
handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
|
||||
handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
|
||||
handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
|
||||
|
||||
op := &request.Operation{
|
||||
Name: "op",
|
||||
HTTPMethod: "POST",
|
||||
HTTPPath: "/",
|
||||
}
|
||||
|
||||
meta := metadata.ClientInfo{
|
||||
ServiceName: "fooService",
|
||||
SigningName: "foo",
|
||||
SigningRegion: "foo",
|
||||
Endpoint: "localhost",
|
||||
APIVersion: "2001-01-01",
|
||||
JSONVersion: "1.1",
|
||||
TargetPrefix: "Foo",
|
||||
}
|
||||
cfg := unit.Session.Config.Copy()
|
||||
cfg.MaxRetries = aws.Int(5)
|
||||
|
||||
req := request.New(
|
||||
*cfg,
|
||||
meta,
|
||||
handlers,
|
||||
client.DefaultRetryer{NumMaxRetries: 5},
|
||||
op,
|
||||
&struct {
|
||||
}{},
|
||||
&struct {
|
||||
}{},
|
||||
)
|
||||
|
||||
osErr := stubConnectionResetError
|
||||
req.ApplyOptions(request.WithResponseReadTimeout(time.Second))
|
||||
err := req.Send()
|
||||
if err == nil {
|
||||
t.Error("Expected rror 'SerializationError', but received nil")
|
||||
}
|
||||
if aerr, ok := err.(awserr.Error); ok && aerr.Code() != "SerializationError" {
|
||||
t.Errorf("Expected 'SerializationError', but received %q", aerr.Code())
|
||||
} else if !ok {
|
||||
t.Errorf("Expected 'awserr.Error', but received %v", reflect.TypeOf(err))
|
||||
} else if aerr.OrigErr().Error() != osErr.Error() {
|
||||
t.Errorf("Expected %q, but received %q", osErr.Error(), aerr.OrigErr().Error())
|
||||
}
|
||||
|
||||
if count != 6 {
|
||||
t.Errorf("Expected '6', but received %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
type testRetryer struct {
|
||||
shouldRetry bool
|
||||
}
|
||||
|
||||
func (d *testRetryer) MaxRetries() int {
|
||||
return 3
|
||||
}
|
||||
|
||||
// RetryRules returns the delay duration before retrying this request again
|
||||
func (d *testRetryer) RetryRules(r *request.Request) time.Duration {
|
||||
return time.Duration(time.Millisecond)
|
||||
}
|
||||
|
||||
func (d *testRetryer) ShouldRetry(r *request.Request) bool {
|
||||
d.shouldRetry = true
|
||||
if r.Retryable != nil {
|
||||
return *r.Retryable
|
||||
}
|
||||
|
||||
if r.HTTPResponse.StatusCode >= 500 {
|
||||
return true
|
||||
}
|
||||
return r.IsErrorRetryable()
|
||||
}
|
||||
|
||||
func TestEnforceShouldRetryCheck(t *testing.T) {
|
||||
tp := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
ResponseHeaderTimeout: 1 * time.Millisecond,
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: tp}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// This server should wait forever. Requests will timeout and the SDK should
|
||||
// attempt to retry.
|
||||
select {}
|
||||
}))
|
||||
|
||||
retryer := &testRetryer{}
|
||||
s := awstesting.NewClient(&aws.Config{
|
||||
Region: aws.String("mock-region"),
|
||||
MaxRetries: aws.Int(0),
|
||||
Endpoint: aws.String(server.URL),
|
||||
DisableSSL: aws.Bool(true),
|
||||
Retryer: retryer,
|
||||
HTTPClient: client,
|
||||
EnforceShouldRetryCheck: aws.Bool(true),
|
||||
})
|
||||
|
||||
s.Handlers.Validate.Clear()
|
||||
s.Handlers.Unmarshal.PushBack(unmarshal)
|
||||
s.Handlers.UnmarshalError.PushBack(unmarshalError)
|
||||
|
||||
out := &testData{}
|
||||
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
|
||||
err := r.Send()
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, but got nil")
|
||||
}
|
||||
if e, a := 3, int(r.RetryCount); e != a {
|
||||
t.Errorf("expect %d retry count, got %d", e, a)
|
||||
}
|
||||
if !retryer.shouldRetry {
|
||||
t.Errorf("expect 'true' for ShouldRetry, but got %v", retryer.shouldRetry)
|
||||
}
|
||||
}
|
||||
|
||||
type errReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (reader *errReader) Read(b []byte) (int, error) {
|
||||
return 0, reader.err
|
||||
}
|
||||
|
||||
func (reader *errReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIsNoBodyReader(t *testing.T) {
|
||||
cases := []struct {
|
||||
reader io.ReadCloser
|
||||
expect bool
|
||||
}{
|
||||
{ioutil.NopCloser(bytes.NewReader([]byte("abc"))), false},
|
||||
{ioutil.NopCloser(bytes.NewReader(nil)), false},
|
||||
{nil, false},
|
||||
{request.NoBody, true},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
if e, a := c.expect, request.NoBody == c.reader; e != a {
|
||||
t.Errorf("%d, expect %t match, but was %t", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequest_TemporaryRetry(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "1024")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
w.Write(make([]byte, 100))
|
||||
|
||||
f := w.(http.Flusher)
|
||||
f.Flush()
|
||||
|
||||
<-done
|
||||
}))
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
svc := awstesting.NewClient(&aws.Config{
|
||||
Region: unit.Session.Config.Region,
|
||||
MaxRetries: aws.Int(1),
|
||||
HTTPClient: client,
|
||||
DisableSSL: aws.Bool(true),
|
||||
Endpoint: aws.String(server.URL),
|
||||
})
|
||||
|
||||
req := svc.NewRequest(&request.Operation{
|
||||
Name: "name", HTTPMethod: "GET", HTTPPath: "/path",
|
||||
}, &struct{}{}, &struct{}{})
|
||||
|
||||
req.Handlers.Unmarshal.PushBack(func(r *request.Request) {
|
||||
defer req.HTTPResponse.Body.Close()
|
||||
_, err := io.Copy(ioutil.Discard, req.HTTPResponse.Body)
|
||||
r.Error = awserr.New(request.ErrCodeSerialization, "error", err)
|
||||
})
|
||||
|
||||
err := req.Send()
|
||||
if err == nil {
|
||||
t.Errorf("expect error, got none")
|
||||
}
|
||||
close(done)
|
||||
|
||||
aerr := err.(awserr.Error)
|
||||
if e, a := request.ErrCodeSerialization, aerr.Code(); e != a {
|
||||
t.Errorf("expect %q error code, got %q", e, a)
|
||||
}
|
||||
|
||||
if e, a := 1, req.RetryCount; e != a {
|
||||
t.Errorf("expect %d retries, got %d", e, a)
|
||||
}
|
||||
|
||||
type temporary interface {
|
||||
Temporary() bool
|
||||
}
|
||||
|
||||
terr := aerr.OrigErr().(temporary)
|
||||
if !terr.Temporary() {
|
||||
t.Errorf("expect temporary error, was not")
|
||||
}
|
||||
}
|
||||
|
||||
+79
-20
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// Retryer is an interface to control retry logic for a given service.
|
||||
// The default implementation used by most services is the service.DefaultRetryer
|
||||
// The default implementation used by most services is the client.DefaultRetryer
|
||||
// structure, which contains basic retry logic using exponential backoff.
|
||||
type Retryer interface {
|
||||
RetryRules(*Request) time.Duration
|
||||
@@ -26,8 +26,10 @@ func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
|
||||
// retryableCodes is a collection of service response codes which are retry-able
|
||||
// without any further action.
|
||||
var retryableCodes = map[string]struct{}{
|
||||
"RequestError": {},
|
||||
"RequestTimeout": {},
|
||||
"RequestError": {},
|
||||
"RequestTimeout": {},
|
||||
ErrCodeResponseTimeout: {},
|
||||
"RequestTimeoutException": {}, // Glacier's flavor of RequestTimeout
|
||||
}
|
||||
|
||||
var throttleCodes = map[string]struct{}{
|
||||
@@ -36,7 +38,6 @@ var throttleCodes = map[string]struct{}{
|
||||
"ThrottlingException": {},
|
||||
"RequestLimitExceeded": {},
|
||||
"RequestThrottled": {},
|
||||
"LimitExceededException": {}, // Deleting 10+ DynamoDb tables at once
|
||||
"TooManyRequestsException": {}, // Lambda functions
|
||||
"PriorRequestNotComplete": {}, // Route53
|
||||
}
|
||||
@@ -68,35 +69,93 @@ func isCodeExpiredCreds(code string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
var validParentCodes = map[string]struct{}{
|
||||
ErrCodeSerialization: {},
|
||||
ErrCodeRead: {},
|
||||
}
|
||||
|
||||
type temporaryError interface {
|
||||
Temporary() bool
|
||||
}
|
||||
|
||||
func isNestedErrorRetryable(parentErr awserr.Error) bool {
|
||||
if parentErr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := validParentCodes[parentErr.Code()]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
err := parentErr.OrigErr()
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
return isCodeRetryable(aerr.Code())
|
||||
}
|
||||
|
||||
if t, ok := err.(temporaryError); ok {
|
||||
return t.Temporary()
|
||||
}
|
||||
|
||||
return isErrConnectionReset(err)
|
||||
}
|
||||
|
||||
// IsErrorRetryable returns whether the error is retryable, based on its Code.
|
||||
// Returns false if the request has no Error set.
|
||||
func (r *Request) IsErrorRetryable() bool {
|
||||
if r.Error != nil {
|
||||
if err, ok := r.Error.(awserr.Error); ok {
|
||||
return isCodeRetryable(err.Code())
|
||||
// Returns false if error is nil.
|
||||
func IsErrorRetryable(err error) bool {
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorThrottle returns whether the error is to be throttled based on its code.
|
||||
// Returns false if the request has no Error set
|
||||
func (r *Request) IsErrorThrottle() bool {
|
||||
if r.Error != nil {
|
||||
if err, ok := r.Error.(awserr.Error); ok {
|
||||
return isCodeThrottle(err.Code())
|
||||
// Returns false if error is nil.
|
||||
func IsErrorThrottle(err error) bool {
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
return isCodeThrottle(aerr.Code())
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorExpired returns whether the error code is a credential expiry error.
|
||||
// Returns false if the request has no Error set.
|
||||
func (r *Request) IsErrorExpired() bool {
|
||||
if r.Error != nil {
|
||||
if err, ok := r.Error.(awserr.Error); ok {
|
||||
return isCodeExpiredCreds(err.Code())
|
||||
// IsErrorExpiredCreds returns whether the error code is a credential expiry error.
|
||||
// Returns false if error is nil.
|
||||
func IsErrorExpiredCreds(err error) bool {
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
return isCodeExpiredCreds(aerr.Code())
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorRetryable returns whether the error is retryable, based on its Code.
|
||||
// Returns false if the request has no Error set.
|
||||
//
|
||||
// Alias for the utility function IsErrorRetryable
|
||||
func (r *Request) IsErrorRetryable() bool {
|
||||
return IsErrorRetryable(r.Error)
|
||||
}
|
||||
|
||||
// IsErrorThrottle returns whether the error is to be throttled based on its code.
|
||||
// Returns false if the request has no Error set
|
||||
//
|
||||
// Alias for the utility function IsErrorThrottle
|
||||
func (r *Request) IsErrorThrottle() bool {
|
||||
return IsErrorThrottle(r.Error)
|
||||
}
|
||||
|
||||
// IsErrorExpired returns whether the error code is a credential expiry error.
|
||||
// Returns false if the request has no Error set.
|
||||
//
|
||||
// Alias for the utility function IsErrorExpiredCreds
|
||||
func (r *Request) IsErrorExpired() bool {
|
||||
return IsErrorExpiredCreds(r.Error)
|
||||
}
|
||||
|
||||
+49
-3
@@ -1,10 +1,10 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
@@ -12,5 +12,51 @@ func TestRequestThrottling(t *testing.T) {
|
||||
req := Request{}
|
||||
|
||||
req.Error = awserr.New("Throttling", "", nil)
|
||||
assert.True(t, req.IsErrorThrottle())
|
||||
if e, a := true, req.IsErrorThrottle(); e != a {
|
||||
t.Errorf("expect %t to be throttled, was %t", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
type mockTempError bool
|
||||
|
||||
func (e mockTempError) Error() string {
|
||||
return fmt.Sprintf("mock temporary error: %t", e.Temporary())
|
||||
}
|
||||
func (e mockTempError) Temporary() bool {
|
||||
return bool(e)
|
||||
}
|
||||
|
||||
func TestIsErrorRetryable(t *testing.T) {
|
||||
cases := []struct {
|
||||
Err error
|
||||
IsTemp bool
|
||||
}{
|
||||
{
|
||||
Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(true)),
|
||||
IsTemp: true,
|
||||
},
|
||||
{
|
||||
Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(false)),
|
||||
IsTemp: false,
|
||||
},
|
||||
{
|
||||
Err: awserr.New(ErrCodeSerialization, "some error", errors.New("blah")),
|
||||
IsTemp: false,
|
||||
},
|
||||
{
|
||||
Err: awserr.New("SomeError", "some error", nil),
|
||||
IsTemp: false,
|
||||
},
|
||||
{
|
||||
Err: awserr.New("RequestError", "some error", nil),
|
||||
IsTemp: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
retryable := IsErrorRetryable(c.Err)
|
||||
if e, a := c.IsTemp, retryable; e != a {
|
||||
t.Errorf("%d, expect %t temporary error, got %t", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+94
@@ -0,0 +1,94 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
var timeoutErr = awserr.New(
|
||||
ErrCodeResponseTimeout,
|
||||
"read on body has reached the timeout limit",
|
||||
nil,
|
||||
)
|
||||
|
||||
type readResult struct {
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
// timeoutReadCloser will handle body reads that take too long.
|
||||
// We will return a ErrReadTimeout error if a timeout occurs.
|
||||
type timeoutReadCloser struct {
|
||||
reader io.ReadCloser
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
// Read will spin off a goroutine to call the reader's Read method. We will
|
||||
// select on the timer's channel or the read's channel. Whoever completes first
|
||||
// will be returned.
|
||||
func (r *timeoutReadCloser) Read(b []byte) (int, error) {
|
||||
timer := time.NewTimer(r.duration)
|
||||
c := make(chan readResult, 1)
|
||||
|
||||
go func() {
|
||||
n, err := r.reader.Read(b)
|
||||
timer.Stop()
|
||||
c <- readResult{n: n, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case data := <-c:
|
||||
return data.n, data.err
|
||||
case <-timer.C:
|
||||
return 0, timeoutErr
|
||||
}
|
||||
}
|
||||
|
||||
func (r *timeoutReadCloser) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
|
||||
const (
|
||||
// HandlerResponseTimeout is what we use to signify the name of the
|
||||
// response timeout handler.
|
||||
HandlerResponseTimeout = "ResponseTimeoutHandler"
|
||||
)
|
||||
|
||||
// adaptToResponseTimeoutError is a handler that will replace any top level error
|
||||
// to a ErrCodeResponseTimeout, if its child is that.
|
||||
func adaptToResponseTimeoutError(req *Request) {
|
||||
if err, ok := req.Error.(awserr.Error); ok {
|
||||
aerr, ok := err.OrigErr().(awserr.Error)
|
||||
if ok && aerr.Code() == ErrCodeResponseTimeout {
|
||||
req.Error = aerr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithResponseReadTimeout is a request option that will wrap the body in a timeout read closer.
|
||||
// This will allow for per read timeouts. If a timeout occurred, we will return the
|
||||
// ErrCodeResponseTimeout.
|
||||
//
|
||||
// svc.PutObjectWithContext(ctx, params, request.WithTimeoutReadCloser(30 * time.Second)
|
||||
func WithResponseReadTimeout(duration time.Duration) Option {
|
||||
return func(r *Request) {
|
||||
|
||||
var timeoutHandler = NamedHandler{
|
||||
HandlerResponseTimeout,
|
||||
func(req *Request) {
|
||||
req.HTTPResponse.Body = &timeoutReadCloser{
|
||||
reader: req.HTTPResponse.Body,
|
||||
duration: duration,
|
||||
}
|
||||
}}
|
||||
|
||||
// remove the handler so we are not stomping over any new durations.
|
||||
r.Handlers.Send.RemoveByName(HandlerResponseTimeout)
|
||||
r.Handlers.Send.PushBackNamed(timeoutHandler)
|
||||
|
||||
r.Handlers.Unmarshal.PushBack(adaptToResponseTimeoutError)
|
||||
r.Handlers.UnmarshalError.PushBack(adaptToResponseTimeoutError)
|
||||
}
|
||||
}
|
||||
Generated
Vendored
+76
@@ -0,0 +1,76 @@
|
||||
package request_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
|
||||
)
|
||||
|
||||
func BenchmarkTimeoutReadCloser(b *testing.B) {
|
||||
resp := `
|
||||
{
|
||||
"Bar": "qux"
|
||||
}
|
||||
`
|
||||
|
||||
handlers := request.Handlers{}
|
||||
|
||||
handlers.Send.PushBack(func(r *request.Request) {
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewBuffer([]byte(resp))),
|
||||
}
|
||||
})
|
||||
handlers.Sign.PushBackNamed(v4.SignRequestHandler)
|
||||
handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
|
||||
handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
|
||||
handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
|
||||
handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
|
||||
|
||||
op := &request.Operation{
|
||||
Name: "op",
|
||||
HTTPMethod: "POST",
|
||||
HTTPPath: "/",
|
||||
}
|
||||
|
||||
meta := metadata.ClientInfo{
|
||||
ServiceName: "fooService",
|
||||
SigningName: "foo",
|
||||
SigningRegion: "foo",
|
||||
Endpoint: "localhost",
|
||||
APIVersion: "2001-01-01",
|
||||
JSONVersion: "1.1",
|
||||
TargetPrefix: "Foo",
|
||||
}
|
||||
|
||||
req := request.New(
|
||||
*unit.Session.Config,
|
||||
meta,
|
||||
handlers,
|
||||
client.DefaultRetryer{NumMaxRetries: 5},
|
||||
op,
|
||||
&struct {
|
||||
Foo *string
|
||||
}{},
|
||||
&struct {
|
||||
Bar *string
|
||||
}{},
|
||||
)
|
||||
|
||||
req.ApplyOptions(request.WithResponseReadTimeout(15 * time.Second))
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := req.Send()
|
||||
if err != nil {
|
||||
b.Errorf("Expected no error, but received %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
+118
@@ -0,0 +1,118 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
type testReader struct {
|
||||
duration time.Duration
|
||||
count int
|
||||
}
|
||||
|
||||
func (r *testReader) Read(b []byte) (int, error) {
|
||||
if r.count > 0 {
|
||||
r.count--
|
||||
return len(b), nil
|
||||
}
|
||||
time.Sleep(r.duration)
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (r *testReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTimeoutReadCloser(t *testing.T) {
|
||||
reader := timeoutReadCloser{
|
||||
reader: &testReader{
|
||||
duration: time.Second,
|
||||
count: 5,
|
||||
},
|
||||
duration: time.Millisecond,
|
||||
}
|
||||
b := make([]byte, 100)
|
||||
_, err := reader.Read(b)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutReadCloserSameDuration(t *testing.T) {
|
||||
reader := timeoutReadCloser{
|
||||
reader: &testReader{
|
||||
duration: time.Millisecond,
|
||||
count: 5,
|
||||
},
|
||||
duration: time.Millisecond,
|
||||
}
|
||||
b := make([]byte, 100)
|
||||
_, err := reader.Read(b)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithResponseReadTimeout(t *testing.T) {
|
||||
r := Request{
|
||||
HTTPResponse: &http.Response{
|
||||
Body: ioutil.NopCloser(bytes.NewReader(nil)),
|
||||
},
|
||||
}
|
||||
r.ApplyOptions(WithResponseReadTimeout(time.Second))
|
||||
err := r.Send()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
v, ok := r.HTTPResponse.Body.(*timeoutReadCloser)
|
||||
if !ok {
|
||||
t.Error("Expected the body to be a timeoutReadCloser")
|
||||
}
|
||||
if v.duration != time.Second {
|
||||
t.Errorf("Expected %v, but receive %v\n", time.Second, v.duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdaptToResponseTimeout(t *testing.T) {
|
||||
testCases := []struct {
|
||||
childErr error
|
||||
r Request
|
||||
expectedRootCode string
|
||||
}{
|
||||
{
|
||||
childErr: awserr.New(ErrCodeResponseTimeout, "timeout!", nil),
|
||||
r: Request{
|
||||
Error: awserr.New("ErrTest", "FooBar", awserr.New(ErrCodeResponseTimeout, "timeout!", nil)),
|
||||
},
|
||||
expectedRootCode: ErrCodeResponseTimeout,
|
||||
},
|
||||
{
|
||||
childErr: awserr.New(ErrCodeResponseTimeout+"1", "timeout!", nil),
|
||||
r: Request{
|
||||
Error: awserr.New("ErrTest", "FooBar", awserr.New(ErrCodeResponseTimeout+"1", "timeout!", nil)),
|
||||
},
|
||||
expectedRootCode: "ErrTest",
|
||||
},
|
||||
{
|
||||
r: Request{
|
||||
Error: awserr.New("ErrTest", "FooBar", nil),
|
||||
},
|
||||
expectedRootCode: "ErrTest",
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range testCases {
|
||||
adaptToResponseTimeoutError(&c.r)
|
||||
if aerr, ok := c.r.Error.(awserr.Error); !ok {
|
||||
t.Errorf("Case %d: Expected 'awserr', but received %v", i+1, c.r.Error)
|
||||
} else if aerr.Code() != c.expectedRootCode {
|
||||
t.Errorf("Case %d: Expected %q, but received %s", i+1, c.expectedRootCode, aerr.Code())
|
||||
}
|
||||
}
|
||||
}
|
||||
+1
-1
@@ -220,7 +220,7 @@ type ErrParamMinLen struct {
|
||||
func NewErrParamMinLen(field string, min int) *ErrParamMinLen {
|
||||
return &ErrParamMinLen{
|
||||
errInvalidParam: errInvalidParam{
|
||||
code: ParamMinValueErrCode,
|
||||
code: ParamMinLenErrCode,
|
||||
field: field,
|
||||
msg: fmt.Sprintf("minimum field size of %v", min),
|
||||
},
|
||||
|
||||
+17
-15
@@ -66,8 +66,8 @@ func WithWaiterRequestOptions(opts ...Option) WaiterOption {
|
||||
}
|
||||
}
|
||||
|
||||
// A Waiter provides the functionality to performing blocking call which will
|
||||
// wait for an resource state to be satisfied a service.
|
||||
// A Waiter provides the functionality to perform a blocking call which will
|
||||
// wait for a resource state to be satisfied by a service.
|
||||
//
|
||||
// This type should not be used directly. The API operations provided in the
|
||||
// service packages prefixed with "WaitUntil" should be used instead.
|
||||
@@ -79,8 +79,9 @@ type Waiter struct {
|
||||
MaxAttempts int
|
||||
Delay WaiterDelay
|
||||
|
||||
RequestOptions []Option
|
||||
NewRequest func([]Option) (*Request, error)
|
||||
RequestOptions []Option
|
||||
NewRequest func([]Option) (*Request, error)
|
||||
SleepWithContext func(aws.Context, time.Duration) error
|
||||
}
|
||||
|
||||
// ApplyOptions updates the waiter with the list of waiter options provided.
|
||||
@@ -178,14 +179,8 @@ func (w Waiter) WaitWithContext(ctx aws.Context) error {
|
||||
|
||||
// See if any of the acceptors match the request's response, or error
|
||||
for _, a := range w.Acceptors {
|
||||
var matched bool
|
||||
matched, err = a.match(w.Name, w.Logger, req, err)
|
||||
if err != nil {
|
||||
// Error occurred during current waiter call
|
||||
return err
|
||||
} else if matched {
|
||||
// Match was found can stop here and return
|
||||
return nil
|
||||
if matched, matchErr := a.match(w.Name, w.Logger, req, err); matched {
|
||||
return matchErr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,8 +196,15 @@ func (w Waiter) WaitWithContext(ctx aws.Context) error {
|
||||
if sleepFn := req.Config.SleepDelay; sleepFn != nil {
|
||||
// Support SleepDelay for backwards compatibility and testing
|
||||
sleepFn(delay)
|
||||
} else if err := aws.SleepWithContext(ctx, delay); err != nil {
|
||||
return awserr.New(CanceledErrorCode, "waiter context canceled", err)
|
||||
} else {
|
||||
sleepCtxFn := w.SleepWithContext
|
||||
if sleepCtxFn == nil {
|
||||
sleepCtxFn = aws.SleepWithContext
|
||||
}
|
||||
|
||||
if err := sleepCtxFn(ctx, delay); err != nil {
|
||||
return awserr.New(CanceledErrorCode, "waiter context canceled", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,7 +276,7 @@ func (a *WaiterAcceptor) match(name string, l aws.Logger, req *Request, err erro
|
||||
return true, nil
|
||||
case FailureWaiterState:
|
||||
// Waiter failure state triggered
|
||||
return false, awserr.New("ResourceNotReady",
|
||||
return true, awserr.New(WaiterResourceNotReadyErrorCode,
|
||||
"failed waiting for successful resource state", err)
|
||||
case RetryWaiterState:
|
||||
// clear the error and retry the operation
|
||||
|
||||
+122
-27
@@ -15,6 +15,8 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
@@ -98,8 +100,9 @@ func TestWaiterPathAll(t *testing.T) {
|
||||
})
|
||||
|
||||
w := request.Waiter{
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -162,8 +165,9 @@ func TestWaiterPath(t *testing.T) {
|
||||
})
|
||||
|
||||
w := request.Waiter{
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -226,8 +230,9 @@ func TestWaiterFailure(t *testing.T) {
|
||||
})
|
||||
|
||||
w := request.Waiter{
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -271,7 +276,9 @@ func TestWaiterError(t *testing.T) {
|
||||
{State: aws.String("pending")},
|
||||
},
|
||||
},
|
||||
{ // Request 2, error case
|
||||
{ // Request 1, error case retry
|
||||
},
|
||||
{ // Request 2, error case failure
|
||||
},
|
||||
{ // Request 3
|
||||
States: []*MockState{
|
||||
@@ -280,6 +287,9 @@ func TestWaiterError(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
reqErrs := make([]error, len(resps))
|
||||
reqErrs[1] = awserr.New("MockException", "mock exception message", nil)
|
||||
reqErrs[2] = awserr.New("FailureException", "mock failure exception message", nil)
|
||||
|
||||
numBuiltReq := 0
|
||||
svc.Handlers.Build.PushBack(func(r *request.Request) {
|
||||
@@ -305,17 +315,18 @@ func TestWaiterError(t *testing.T) {
|
||||
reqNum++
|
||||
})
|
||||
svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) {
|
||||
if reqNum == 1 {
|
||||
r.Error = awserr.New("MockException", "mock exception message", nil)
|
||||
// If there was an error unmarshal error will be called instead of unmarshal
|
||||
// need to increment count here also
|
||||
// If there was an error unmarshal error will be called instead of unmarshal
|
||||
// need to increment count here also
|
||||
if err := reqErrs[reqNum]; err != nil {
|
||||
r.Error = err
|
||||
reqNum++
|
||||
}
|
||||
})
|
||||
|
||||
w := request.Waiter{
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -329,14 +340,30 @@ func TestWaiterError(t *testing.T) {
|
||||
Argument: "",
|
||||
Expected: "MockException",
|
||||
},
|
||||
{
|
||||
State: request.FailureWaiterState,
|
||||
Matcher: request.ErrorWaiterMatch,
|
||||
Argument: "",
|
||||
Expected: "FailureException",
|
||||
},
|
||||
},
|
||||
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
|
||||
}
|
||||
|
||||
err := w.WaitWithContext(aws.BackgroundContext())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, numBuiltReq)
|
||||
assert.Equal(t, 3, reqNum)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, but did not get one")
|
||||
}
|
||||
aerr := err.(awserr.Error)
|
||||
if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
|
||||
t.Errorf("expect %q error code, got %q", e, a)
|
||||
}
|
||||
if e, a := 3, numBuiltReq; e != a {
|
||||
t.Errorf("expect %d built requests got %d", e, a)
|
||||
}
|
||||
if e, a := 3, reqNum; e != a {
|
||||
t.Errorf("expect %d reqNum got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaiterStatus(t *testing.T) {
|
||||
@@ -366,8 +393,9 @@ func TestWaiterStatus(t *testing.T) {
|
||||
})
|
||||
|
||||
w := request.Waiter{
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(0),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -424,9 +452,10 @@ func TestWaiter_WithContextCanceled(t *testing.T) {
|
||||
reqCount := 0
|
||||
|
||||
w := request.Waiter{
|
||||
Name: "TestWaiter",
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
|
||||
Name: "TestWaiter",
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -452,6 +481,16 @@ func TestWaiter_WithContextCanceled(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
w.SleepWithContext = func(c aws.Context, delay time.Duration) error {
|
||||
context := c.(*awstesting.FakeContext)
|
||||
select {
|
||||
case <-context.DoneCh:
|
||||
return context.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
err := w.WaitWithContext(ctx)
|
||||
|
||||
if err == nil {
|
||||
@@ -475,9 +514,10 @@ func TestWaiter_WithContext(t *testing.T) {
|
||||
statuses := []int{http.StatusNotFound, http.StatusOK}
|
||||
|
||||
w := request.Waiter{
|
||||
Name: "TestWaiter",
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
|
||||
Name: "TestWaiter",
|
||||
MaxAttempts: 10,
|
||||
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -520,9 +560,10 @@ func TestWaiter_AttemptsExpires(t *testing.T) {
|
||||
reqCount := 0
|
||||
|
||||
w := request.Waiter{
|
||||
Name: "TestWaiter",
|
||||
MaxAttempts: 2,
|
||||
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
|
||||
Name: "TestWaiter",
|
||||
MaxAttempts: 2,
|
||||
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
|
||||
SleepWithContext: aws.SleepWithContext,
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{
|
||||
State: request.SuccessWaiterState,
|
||||
@@ -557,3 +598,57 @@ func TestWaiter_AttemptsExpires(t *testing.T) {
|
||||
t.Errorf("expect %d requests, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaiterNilInput(t *testing.T) {
|
||||
// Code generation doesn't have a great way to verify the code is correct
|
||||
// other than being run via unit tests in the SDK. This should be fixed
|
||||
// So code generation can be validated independently.
|
||||
|
||||
client := s3.New(unit.Session)
|
||||
client.Handlers.Validate.Clear()
|
||||
client.Handlers.Send.Clear() // mock sending
|
||||
client.Handlers.Send.PushBack(func(r *request.Request) {
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
})
|
||||
client.Handlers.Unmarshal.Clear()
|
||||
client.Handlers.UnmarshalMeta.Clear()
|
||||
client.Handlers.ValidateResponse.Clear()
|
||||
client.Config.SleepDelay = func(dur time.Duration) {}
|
||||
|
||||
// Ensure waiters do not panic on nil input. It doesn't make sense to
|
||||
// call a waiter without an input, Validation will
|
||||
err := client.WaitUntilBucketExists(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaiterWithContextNilInput(t *testing.T) {
|
||||
// Code generation doesn't have a great way to verify the code is correct
|
||||
// other than being run via unit tests in the SDK. This should be fixed
|
||||
// So code generation can be validated independently.
|
||||
|
||||
client := s3.New(unit.Session)
|
||||
client.Handlers.Validate.Clear()
|
||||
client.Handlers.Send.Clear() // mock sending
|
||||
client.Handlers.Send.PushBack(func(r *request.Request) {
|
||||
r.HTTPResponse = &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
})
|
||||
client.Handlers.Unmarshal.Clear()
|
||||
client.Handlers.UnmarshalMeta.Clear()
|
||||
client.Handlers.ValidateResponse.Clear()
|
||||
|
||||
// Ensure waiters do not panic on nil input
|
||||
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
|
||||
err := client.WaitUntilBucketExistsWithContext(ctx, nil,
|
||||
request.WithWaiterDelay(request.ConstantWaiterDelay(0)),
|
||||
request.WithWaiterMaxAttempts(1),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+118
-194
@@ -2,158 +2,166 @@ package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
)
|
||||
|
||||
func createTLSServer(cert, key []byte, done <-chan struct{}) (*httptest.Server, error) {
|
||||
c, err := tls.X509KeyPair(cert, key)
|
||||
var TLSBundleCertFile string
|
||||
var TLSBundleKeyFile string
|
||||
var TLSBundleCAFile string
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var err error
|
||||
|
||||
TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile, err = awstesting.CreateTLSBundleFiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
s.TLS = &tls.Config{
|
||||
Certificates: []tls.Certificate{c},
|
||||
}
|
||||
s.TLS.BuildNameToCertificate()
|
||||
s.StartTLS()
|
||||
fmt.Println("TestMain", TLSBundleCertFile, TLSBundleKeyFile)
|
||||
|
||||
go func() {
|
||||
<-done
|
||||
s.Close()
|
||||
}()
|
||||
code := m.Run()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func setupTestCAFile(b []byte) (string, error) {
|
||||
bundleFile, err := ioutil.TempFile(os.TempDir(), "aws-sdk-go-session-test")
|
||||
err = awstesting.CleanupTLSBundleFiles(TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = bundleFile.Write(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer bundleFile.Close()
|
||||
return bundleFile.Name(), nil
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestNewSession_WithCustomCABundle_Env(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
done := make(chan struct{})
|
||||
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
|
||||
assert.NoError(t, err)
|
||||
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
// Write bundle to file
|
||||
caFilename, err := setupTestCAFile(testTLSBundleCA)
|
||||
defer func() {
|
||||
os.Remove(caFilename)
|
||||
}()
|
||||
assert.NoError(t, err)
|
||||
|
||||
os.Setenv("AWS_CA_BUNDLE", caFilename)
|
||||
os.Setenv("AWS_CA_BUNDLE", TLSBundleCAFile)
|
||||
|
||||
s, err := NewSession(&aws.Config{
|
||||
HTTPClient: &http.Client{},
|
||||
Endpoint: aws.String(server.URL),
|
||||
Endpoint: aws.String(endpoint),
|
||||
Region: aws.String("mock-region"),
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatalf("expect session to be created, got none")
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
|
||||
resp, err := s.Config.HTTPClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSession_WithCustomCABundle_EnvNotExists(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
|
||||
|
||||
s, err := NewSession()
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "LoadCustomCABundleError", err.(awserr.Error).Code())
|
||||
assert.Nil(t, s)
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a {
|
||||
t.Errorf("expect %s error code, got %s", e, a)
|
||||
}
|
||||
if s != nil {
|
||||
t.Errorf("expect nil session, got %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSession_WithCustomCABundle_Option(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
done := make(chan struct{})
|
||||
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
|
||||
assert.NoError(t, err)
|
||||
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
s, err := NewSessionWithOptions(Options{
|
||||
Config: aws.Config{
|
||||
HTTPClient: &http.Client{},
|
||||
Endpoint: aws.String(server.URL),
|
||||
Endpoint: aws.String(endpoint),
|
||||
Region: aws.String("mock-region"),
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
},
|
||||
CustomCABundle: bytes.NewReader(testTLSBundleCA),
|
||||
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatalf("expect session to be created, got none")
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
|
||||
resp, err := s.Config.HTTPClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSession_WithCustomCABundle_OptionPriority(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
done := make(chan struct{})
|
||||
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
|
||||
assert.NoError(t, err)
|
||||
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
|
||||
|
||||
s, err := NewSessionWithOptions(Options{
|
||||
Config: aws.Config{
|
||||
HTTPClient: &http.Client{},
|
||||
Endpoint: aws.String(server.URL),
|
||||
Endpoint: aws.String(endpoint),
|
||||
Region: aws.String("mock-region"),
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
},
|
||||
CustomCABundle: bytes.NewReader(testTLSBundleCA),
|
||||
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatalf("expect session to be created, got none")
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
|
||||
resp, err := s.Config.HTTPClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
type mockRoundTripper struct{}
|
||||
@@ -164,7 +172,7 @@ func (m *mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
|
||||
func TestNewSession_WithCustomCABundle_UnsupportedTransport(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
s, err := NewSessionWithOptions(Options{
|
||||
Config: aws.Config{
|
||||
@@ -172,25 +180,35 @@ func TestNewSession_WithCustomCABundle_UnsupportedTransport(t *testing.T) {
|
||||
Transport: &mockRoundTripper{},
|
||||
},
|
||||
},
|
||||
CustomCABundle: bytes.NewReader(testTLSBundleCA),
|
||||
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "LoadCustomCABundleError", err.(awserr.Error).Code())
|
||||
assert.Contains(t, err.(awserr.Error).Message(), "transport unsupported type")
|
||||
assert.Nil(t, s)
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a {
|
||||
t.Errorf("expect %s error code, got %s", e, a)
|
||||
}
|
||||
if s != nil {
|
||||
t.Errorf("expect nil session, got %v", s)
|
||||
}
|
||||
aerrMsg := err.(awserr.Error).Message()
|
||||
if e, a := "transport unsupported type", aerrMsg; !strings.Contains(a, e) {
|
||||
t.Errorf("expect %s to be in %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSession_WithCustomCABundle_TransportSet(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
done := make(chan struct{})
|
||||
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
|
||||
assert.NoError(t, err)
|
||||
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
s, err := NewSessionWithOptions(Options{
|
||||
Config: aws.Config{
|
||||
Endpoint: aws.String(server.URL),
|
||||
Endpoint: aws.String(endpoint),
|
||||
Region: aws.String("mock-region"),
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
HTTPClient: &http.Client{
|
||||
@@ -205,115 +223,21 @@ func TestNewSession_WithCustomCABundle_TransportSet(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
CustomCABundle: bytes.NewReader(testTLSBundleCA),
|
||||
CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatalf("expect session to be created, got none")
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
|
||||
resp, err := s.Config.HTTPClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %d status code, got %d", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
/* Cert generation steps
|
||||
# Create the CA key
|
||||
openssl genrsa -des3 -out ca.key 1024
|
||||
|
||||
# Create the CA Cert
|
||||
openssl req -new -sha256 -x509 -days 3650 \
|
||||
-subj "/C=GO/ST=Gopher/O=Testing ROOT CA" \
|
||||
-key ca.key -out ca.crt
|
||||
|
||||
# Create config
|
||||
cat > csr_details.txt <<-EOF
|
||||
|
||||
[req]
|
||||
default_bits = 1024
|
||||
prompt = no
|
||||
default_md = sha256
|
||||
req_extensions = SAN
|
||||
distinguished_name = dn
|
||||
|
||||
[ dn ]
|
||||
C=GO
|
||||
ST=Gopher
|
||||
O=Testing Certificate
|
||||
OU=Testing IP
|
||||
|
||||
[SAN]
|
||||
subjectAltName = IP:127.0.0.1
|
||||
EOF
|
||||
|
||||
# Create certificate signing request
|
||||
openssl req -new -sha256 -nodes -newkey rsa:1024 \
|
||||
-config <( cat csr_details.txt ) \
|
||||
-keyout ia.key -out ia.csr
|
||||
|
||||
# Create a signed certificate
|
||||
openssl x509 -req -days 3650 \
|
||||
-CAcreateserial \
|
||||
-extfile <( cat csr_details.txt ) \
|
||||
-extensions SAN \
|
||||
-CA ca.crt -CAkey ca.key -in ia.csr -out ia.crt
|
||||
|
||||
# Verify
|
||||
openssl req -noout -text -in ia.csr
|
||||
openssl x509 -noout -text -in ia.crt
|
||||
*/
|
||||
var (
|
||||
// ca.crt
|
||||
testTLSBundleCA = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICiTCCAfKgAwIBAgIJAJ5X1olt05XjMA0GCSqGSIb3DQEBCwUAMDgxCzAJBgNV
|
||||
BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD
|
||||
QTAeFw0xNzAzMDkwMDAyMDZaFw0yNzAzMDcwMDAyMDZaMDgxCzAJBgNVBAYTAkdP
|
||||
MQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBDQTCBnzAN
|
||||
BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAw/8DN+t9XQR60jx42rsQ2WE2Dx85rb3n
|
||||
GQxnKZZLNddsT8rDyxJNP18aFalbRbFlyln5fxWxZIblu9Xkm/HRhOpbSimSqo1y
|
||||
uDx21NVZ1YsOvXpHby71jx3gPrrhSc/t/zikhi++6D/C6m1CiIGuiJ0GBiJxtrub
|
||||
UBMXT0QtI2ECAwEAAaOBmjCBlzAdBgNVHQ4EFgQU8XG3X/YHBA6T04kdEkq6+4GV
|
||||
YykwaAYDVR0jBGEwX4AU8XG3X/YHBA6T04kdEkq6+4GVYymhPKQ6MDgxCzAJBgNV
|
||||
BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD
|
||||
QYIJAJ5X1olt05XjMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADgYEAeILv
|
||||
z49+uxmPcfOZzonuOloRcpdvyjiXblYxbzz6ch8GsE7Q886FTZbvwbgLhzdwSVgG
|
||||
G8WHkodDUsymVepdqAamS3f8PdCUk8xIk9mop8LgaB9Ns0/TssxDvMr3sOD2Grb3
|
||||
xyWymTWMcj6uCiEBKtnUp4rPiefcvCRYZ17/hLE=
|
||||
-----END CERTIFICATE-----
|
||||
`)
|
||||
|
||||
// ai.crt
|
||||
testTLSBundleCert = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICGjCCAYOgAwIBAgIJAIIu+NOoxxM0MA0GCSqGSIb3DQEBBQUAMDgxCzAJBgNV
|
||||
BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD
|
||||
QTAeFw0xNzAzMDkwMDAzMTRaFw0yNzAzMDcwMDAzMTRaMFExCzAJBgNVBAYTAkdP
|
||||
MQ8wDQYDVQQIDAZHb3BoZXIxHDAaBgNVBAoME1Rlc3RpbmcgQ2VydGlmaWNhdGUx
|
||||
EzARBgNVBAsMClRlc3RpbmcgSVAwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGB
|
||||
AN1hWHeioo/nASvbrjwCQzXCiWiEzGkw353NxsAB54/NqDL3LXNATtiSJu8kJBrm
|
||||
Ah12IFLtWLGXjGjjYlHbQWnOR6awveeXnQZukJyRWh7m/Qlt9Ho0CgZE1U+832ac
|
||||
5GWVldNxW1Lz4I+W9/ehzqe8I80RS6eLEKfUFXGiW+9RAgMBAAGjEzARMA8GA1Ud
|
||||
EQQIMAaHBH8AAAEwDQYJKoZIhvcNAQEFBQADgYEAdF4WQHfVdPCbgv9sxgJjcR1H
|
||||
Hgw9rZ47gO1IiIhzglnLXQ6QuemRiHeYFg4kjcYBk1DJguxzDTGnUwhUXOibAB+S
|
||||
zssmrkdYYvn9aUhjc3XK3tjAoDpsPpeBeTBamuUKDHoH/dNRXxerZ8vu6uPR3Pgs
|
||||
5v/KCV6IAEcvNyOXMPo=
|
||||
-----END CERTIFICATE-----
|
||||
`)
|
||||
|
||||
// ai.key
|
||||
testTLSBundleKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQDdYVh3oqKP5wEr2648AkM1wolohMxpMN+dzcbAAeePzagy9y1z
|
||||
QE7YkibvJCQa5gIddiBS7Vixl4xo42JR20FpzkemsL3nl50GbpCckVoe5v0JbfR6
|
||||
NAoGRNVPvN9mnORllZXTcVtS8+CPlvf3oc6nvCPNEUunixCn1BVxolvvUQIDAQAB
|
||||
AoGBAMISrcirddGrlLZLLrKC1ULS2T0cdkqdQtwHYn4+7S5+/z42vMx1iumHLsSk
|
||||
rVY7X41OWkX4trFxhvEIrc/O48bo2zw78P7flTxHy14uxXnllU8cLThE29SlUU7j
|
||||
AVBNxJZMsXMlS/DowwD4CjFe+x4Pu9wZcReF2Z9ntzMpySABAkEA+iWoJCPE2JpS
|
||||
y78q3HYYgpNY3gF3JqQ0SI/zTNkb3YyEIUffEYq0Y9pK13HjKtdsSuX4osTIhQkS
|
||||
+UgRp6tCAQJBAOKPYTfQ2FX8ijgUpHZRuEAVaxASAS0UATiLgzXxLvOh/VC2at5x
|
||||
wjOX6sD65pPz/0D8Qj52Cq6Q1TQ+377SDVECQAIy0od+yPweXxvrUjUd1JlRMjbB
|
||||
TIrKZqs8mKbUQapw0bh5KTy+O1elU4MRPS3jNtBxtP25PQnuSnxmZcFTgAECQFzg
|
||||
DiiFcsn9FuRagfkHExMiNJuH5feGxeFaP9WzI144v9GAllrOI6Bm3JNzx2ZLlg4b
|
||||
20Qju8lIEj6yr6JYFaECQHM1VSojGRKpOl9Ox/R4yYSA9RV5Gyn00/aJNxVYyPD5
|
||||
i3acL2joQm2kLD/LO8paJ4+iQdRXCOMMIpjxSNjGQjQ=
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`)
|
||||
)
|
||||
|
||||
+4
-5
@@ -23,7 +23,7 @@ additional config if the AWS_SDK_LOAD_CONFIG environment variable is set.
|
||||
Alternatively you can explicitly create a Session with shared config enabled.
|
||||
To do this you can use NewSessionWithOptions to configure how the Session will
|
||||
be created. Using the NewSessionWithOptions with SharedConfigState set to
|
||||
SharedConfigEnabled will create the session as if the AWS_SDK_LOAD_CONFIG
|
||||
SharedConfigEnable will create the session as if the AWS_SDK_LOAD_CONFIG
|
||||
environment variable was set.
|
||||
|
||||
Creating Sessions
|
||||
@@ -84,7 +84,7 @@ override the shared config state (AWS_SDK_LOAD_CONFIG).
|
||||
|
||||
// Force enable Shared Config support
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
SharedConfigState: SharedConfigEnable,
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}))
|
||||
|
||||
Adding Handlers
|
||||
@@ -124,9 +124,8 @@ file (~/.aws/config) and shared credentials file (~/.aws/credentials). Both
|
||||
files have the same format.
|
||||
|
||||
If both config files are present the configuration from both files will be
|
||||
read. The Session will be created from configuration values from the shared
|
||||
credentials file (~/.aws/credentials) over those in the shared credentials
|
||||
file (~/.aws/config).
|
||||
read. The Session will be created from configuration values from the shared
|
||||
credentials file (~/.aws/credentials) over those in the shared config file (~/.aws/config).
|
||||
|
||||
Credentials are the values the SDK should use for authenticating requests with
|
||||
AWS Services. They arfrom a configuration file will need to include both
|
||||
|
||||
+13
-30
@@ -2,12 +2,14 @@ package session
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
)
|
||||
|
||||
// EnvProviderName provides a name of the provider when config is loaded from environment.
|
||||
const EnvProviderName = "EnvConfigCredentials"
|
||||
|
||||
// envConfig is a collection of environment values the SDK will read
|
||||
// setup config from. All environment values are optional. But some values
|
||||
// such as credentials require multiple values to be complete or the values
|
||||
@@ -77,7 +79,7 @@ type envConfig struct {
|
||||
SharedConfigFile string
|
||||
|
||||
// Sets the path to a custom Credentials Authroity (CA) Bundle PEM file
|
||||
// that the SDK will use instead of the the system's root CA bundle.
|
||||
// that the SDK will use instead of the system's root CA bundle.
|
||||
// Only use this if you want to configure the SDK to use a custom set
|
||||
// of CAs.
|
||||
//
|
||||
@@ -116,6 +118,12 @@ var (
|
||||
"AWS_PROFILE",
|
||||
"AWS_DEFAULT_PROFILE", // Only read if AWS_SDK_LOAD_CONFIG is also set
|
||||
}
|
||||
sharedCredsFileEnvKey = []string{
|
||||
"AWS_SHARED_CREDENTIALS_FILE",
|
||||
}
|
||||
sharedConfigFileEnvKey = []string{
|
||||
"AWS_CONFIG_FILE",
|
||||
}
|
||||
)
|
||||
|
||||
// loadEnvConfig retrieves the SDK's environment configuration.
|
||||
@@ -152,7 +160,7 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
|
||||
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
|
||||
cfg.Creds = credentials.Value{}
|
||||
} else {
|
||||
cfg.Creds.ProviderName = "EnvConfigCredentials"
|
||||
cfg.Creds.ProviderName = EnvProviderName
|
||||
}
|
||||
|
||||
regionKeys := regionEnvKeys
|
||||
@@ -165,8 +173,8 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
|
||||
setFromEnvVal(&cfg.Region, regionKeys)
|
||||
setFromEnvVal(&cfg.Profile, profileKeys)
|
||||
|
||||
cfg.SharedCredentialsFile = sharedCredentialsFilename()
|
||||
cfg.SharedConfigFile = sharedConfigFilename()
|
||||
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
|
||||
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
|
||||
|
||||
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
|
||||
|
||||
@@ -181,28 +189,3 @@ func setFromEnvVal(dst *string, keys []string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sharedCredentialsFilename() string {
|
||||
if name := os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); len(name) > 0 {
|
||||
return name
|
||||
}
|
||||
|
||||
return filepath.Join(userHomeDir(), ".aws", "credentials")
|
||||
}
|
||||
|
||||
func sharedConfigFilename() string {
|
||||
if name := os.Getenv("AWS_CONFIG_FILE"); len(name) > 0 {
|
||||
return name
|
||||
}
|
||||
|
||||
return filepath.Join(userHomeDir(), ".aws", "config")
|
||||
}
|
||||
|
||||
func userHomeDir() string {
|
||||
homeDir := os.Getenv("HOME") // *nix
|
||||
if len(homeDir) == 0 { // windows
|
||||
homeDir = os.Getenv("USERPROFILE")
|
||||
}
|
||||
|
||||
return homeDir
|
||||
}
|
||||
|
||||
+70
-82
@@ -2,17 +2,16 @@ package session
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
)
|
||||
|
||||
func TestLoadEnvConfig_Creds(t *testing.T) {
|
||||
env := stashEnv()
|
||||
defer popEnv(env)
|
||||
env := awstesting.StashEnv()
|
||||
defer awstesting.PopEnv(env)
|
||||
|
||||
cases := []struct {
|
||||
Env map[string]string
|
||||
@@ -83,26 +82,30 @@ func TestLoadEnvConfig_Creds(t *testing.T) {
|
||||
}
|
||||
|
||||
cfg := loadEnvConfig()
|
||||
assert.Equal(t, c.Val, cfg.Creds)
|
||||
if !reflect.DeepEqual(c.Val, cfg.Creds) {
|
||||
t.Errorf("expect credentials to match.\n%s",
|
||||
awstesting.SprintExpectActual(c.Val, cfg.Creds))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvConfig(t *testing.T) {
|
||||
env := stashEnv()
|
||||
defer popEnv(env)
|
||||
env := awstesting.StashEnv()
|
||||
defer awstesting.PopEnv(env)
|
||||
|
||||
cases := []struct {
|
||||
Env map[string]string
|
||||
Region, Profile string
|
||||
CustomCABundle string
|
||||
UseSharedConfigCall bool
|
||||
Config envConfig
|
||||
}{
|
||||
{
|
||||
Env: map[string]string{
|
||||
"AWS_REGION": "region",
|
||||
"AWS_PROFILE": "profile",
|
||||
},
|
||||
Region: "region", Profile: "profile",
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
},
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
@@ -111,7 +114,9 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_PROFILE": "profile",
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
},
|
||||
Region: "region", Profile: "profile",
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
},
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
@@ -121,7 +126,10 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
"AWS_SDK_LOAD_CONFIG": "1",
|
||||
},
|
||||
Region: "region", Profile: "profile",
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
@@ -135,14 +143,20 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
"AWS_SDK_LOAD_CONFIG": "1",
|
||||
},
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
Config: envConfig{
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
"AWS_REGION": "region",
|
||||
"AWS_PROFILE": "profile",
|
||||
},
|
||||
Region: "region", Profile: "profile",
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
{
|
||||
@@ -152,7 +166,10 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_PROFILE": "profile",
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
},
|
||||
Region: "region", Profile: "profile",
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
{
|
||||
@@ -163,7 +180,10 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
"AWS_SDK_LOAD_CONFIG": "1",
|
||||
},
|
||||
Region: "region", Profile: "profile",
|
||||
Config: envConfig{
|
||||
Region: "region", Profile: "profile",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
{
|
||||
@@ -171,7 +191,10 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_DEFAULT_REGION": "default_region",
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
},
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
Config: envConfig{
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
{
|
||||
@@ -180,22 +203,40 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
"AWS_DEFAULT_PROFILE": "default_profile",
|
||||
"AWS_SDK_LOAD_CONFIG": "1",
|
||||
},
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
Config: envConfig{
|
||||
Region: "default_region", Profile: "default_profile",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
"AWS_CA_BUNDLE": "custom_ca_bundle",
|
||||
},
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
Config: envConfig{
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
},
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
"AWS_CA_BUNDLE": "custom_ca_bundle",
|
||||
},
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
Config: envConfig{
|
||||
CustomCABundle: "custom_ca_bundle",
|
||||
EnableSharedConfig: true,
|
||||
},
|
||||
UseSharedConfigCall: true,
|
||||
},
|
||||
{
|
||||
Env: map[string]string{
|
||||
"AWS_SHARED_CREDENTIALS_FILE": "/path/to/credentials/file",
|
||||
"AWS_CONFIG_FILE": "/path/to/config/file",
|
||||
},
|
||||
Config: envConfig{
|
||||
SharedCredentialsFile: "/path/to/credentials/file",
|
||||
SharedConfigFile: "/path/to/config/file",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
@@ -212,55 +253,16 @@ func TestLoadEnvConfig(t *testing.T) {
|
||||
cfg = loadEnvConfig()
|
||||
}
|
||||
|
||||
assert.Equal(t, c.Region, cfg.Region)
|
||||
assert.Equal(t, c.Profile, cfg.Profile)
|
||||
assert.Equal(t, c.CustomCABundle, cfg.CustomCABundle)
|
||||
if !reflect.DeepEqual(c.Config, cfg) {
|
||||
t.Errorf("expect config to match.\n%s",
|
||||
awstesting.SprintExpectActual(c.Config, cfg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedCredsFilename(t *testing.T) {
|
||||
env := stashEnv()
|
||||
defer popEnv(env)
|
||||
|
||||
os.Setenv("USERPROFILE", "profile_dir")
|
||||
expect := filepath.Join("profile_dir", ".aws", "credentials")
|
||||
name := sharedCredentialsFilename()
|
||||
assert.Equal(t, expect, name)
|
||||
|
||||
os.Setenv("HOME", "home_dir")
|
||||
expect = filepath.Join("home_dir", ".aws", "credentials")
|
||||
name = sharedCredentialsFilename()
|
||||
assert.Equal(t, expect, name)
|
||||
|
||||
expect = filepath.Join("path/to/credentials/file")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", expect)
|
||||
name = sharedCredentialsFilename()
|
||||
assert.Equal(t, expect, name)
|
||||
}
|
||||
|
||||
func TestSharedConfigFilename(t *testing.T) {
|
||||
env := stashEnv()
|
||||
defer popEnv(env)
|
||||
|
||||
os.Setenv("USERPROFILE", "profile_dir")
|
||||
expect := filepath.Join("profile_dir", ".aws", "config")
|
||||
name := sharedConfigFilename()
|
||||
assert.Equal(t, expect, name)
|
||||
|
||||
os.Setenv("HOME", "home_dir")
|
||||
expect = filepath.Join("home_dir", ".aws", "config")
|
||||
name = sharedConfigFilename()
|
||||
assert.Equal(t, expect, name)
|
||||
|
||||
expect = filepath.Join("path/to/config/file")
|
||||
os.Setenv("AWS_CONFIG_FILE", expect)
|
||||
name = sharedConfigFilename()
|
||||
assert.Equal(t, expect, name)
|
||||
}
|
||||
|
||||
func TestSetEnvValue(t *testing.T) {
|
||||
env := stashEnv()
|
||||
defer popEnv(env)
|
||||
env := awstesting.StashEnv()
|
||||
defer awstesting.PopEnv(env)
|
||||
|
||||
os.Setenv("empty_key", "")
|
||||
os.Setenv("second_key", "2")
|
||||
@@ -271,21 +273,7 @@ func TestSetEnvValue(t *testing.T) {
|
||||
"empty_key", "first_key", "second_key", "third_key",
|
||||
})
|
||||
|
||||
assert.Equal(t, "2", dst)
|
||||
}
|
||||
|
||||
func stashEnv() []string {
|
||||
env := os.Environ()
|
||||
os.Clearenv()
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
func popEnv(env []string) {
|
||||
os.Clearenv()
|
||||
|
||||
for _, e := range env {
|
||||
p := strings.SplitN(e, "=", 2)
|
||||
os.Setenv(p[0], p[1])
|
||||
if e, a := "2", dst; e != a {
|
||||
t.Errorf("expect %s value from environment, got %s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+23
-7
@@ -155,6 +155,10 @@ type Options struct {
|
||||
// and enable or disable the shared config functionality.
|
||||
SharedConfigState SharedConfigState
|
||||
|
||||
// Ordered list of files the session will load configuration from.
|
||||
// It will override environment variable AWS_SHARED_CREDENTIALS_FILE, AWS_CONFIG_FILE.
|
||||
SharedConfigFiles []string
|
||||
|
||||
// When the SDK's shared config is configured to assume a role with MFA
|
||||
// this option is required in order to provide the mechanism that will
|
||||
// retrieve the MFA token. There is no default value for this field. If
|
||||
@@ -218,7 +222,7 @@ type Options struct {
|
||||
//
|
||||
// // Force enable Shared Config support
|
||||
// sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
// SharedConfigState: SharedConfigEnable,
|
||||
// SharedConfigState: session.SharedConfigEnable,
|
||||
// }))
|
||||
func NewSessionWithOptions(opts Options) (*Session, error) {
|
||||
var envCfg envConfig
|
||||
@@ -239,6 +243,13 @@ func NewSessionWithOptions(opts Options) (*Session, error) {
|
||||
envCfg.EnableSharedConfig = true
|
||||
}
|
||||
|
||||
if len(envCfg.SharedCredentialsFile) == 0 {
|
||||
envCfg.SharedCredentialsFile = defaults.SharedCredentialsFilename()
|
||||
}
|
||||
if len(envCfg.SharedConfigFile) == 0 {
|
||||
envCfg.SharedConfigFile = defaults.SharedConfigFilename()
|
||||
}
|
||||
|
||||
// Only use AWS_CA_BUNDLE if session option is not provided.
|
||||
if len(envCfg.CustomCABundle) != 0 && opts.CustomCABundle == nil {
|
||||
f, err := os.Open(envCfg.CustomCABundle)
|
||||
@@ -304,13 +315,18 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
|
||||
userCfg := &aws.Config{}
|
||||
userCfg.MergeIn(cfgs...)
|
||||
|
||||
// Order config files will be loaded in with later files overwriting
|
||||
// Ordered config files will be loaded in with later files overwriting
|
||||
// previous config file values.
|
||||
cfgFiles := []string{envCfg.SharedConfigFile, envCfg.SharedCredentialsFile}
|
||||
if !envCfg.EnableSharedConfig {
|
||||
// The shared config file (~/.aws/config) is only loaded if instructed
|
||||
// to load via the envConfig.EnableSharedConfig (AWS_SDK_LOAD_CONFIG).
|
||||
cfgFiles = cfgFiles[1:]
|
||||
var cfgFiles []string
|
||||
if opts.SharedConfigFiles != nil {
|
||||
cfgFiles = opts.SharedConfigFiles
|
||||
} else {
|
||||
cfgFiles = []string{envCfg.SharedConfigFile, envCfg.SharedCredentialsFile}
|
||||
if !envCfg.EnableSharedConfig {
|
||||
// The shared config file (~/.aws/config) is only loaded if instructed
|
||||
// to load via the envConfig.EnableSharedConfig (AWS_SDK_LOAD_CONFIG).
|
||||
cfgFiles = cfgFiles[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Load additional config from file(s)
|
||||
|
||||
+39
-15
@@ -14,12 +14,13 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
func TestNewDefaultSession(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
s := New(&aws.Config{Region: aws.String("region")})
|
||||
|
||||
@@ -31,7 +32,7 @@ func TestNewDefaultSession(t *testing.T) {
|
||||
|
||||
func TestNew_WithCustomCreds(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
customCreds := credentials.NewStaticCredentials("AKID", "SECRET", "TOKEN")
|
||||
s := New(&aws.Config{Credentials: customCreds})
|
||||
@@ -49,7 +50,7 @@ func (w mockLogger) Log(args ...interface{}) {
|
||||
|
||||
func TestNew_WithSessionLoadError(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
|
||||
@@ -72,7 +73,7 @@ func TestNew_WithSessionLoadError(t *testing.T) {
|
||||
|
||||
func TestSessionCopy(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_REGION", "orig_region")
|
||||
|
||||
@@ -100,7 +101,7 @@ func TestSessionClientConfig(t *testing.T) {
|
||||
|
||||
func TestNewSession_NoCredentials(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
s, err := NewSession()
|
||||
assert.NoError(t, err)
|
||||
@@ -111,7 +112,7 @@ func TestNewSession_NoCredentials(t *testing.T) {
|
||||
|
||||
func TestNewSessionWithOptions_OverrideProfile(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
|
||||
@@ -134,7 +135,7 @@ func TestNewSessionWithOptions_OverrideProfile(t *testing.T) {
|
||||
|
||||
func TestNewSessionWithOptions_OverrideSharedConfigEnable(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "0")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
|
||||
@@ -157,7 +158,7 @@ func TestNewSessionWithOptions_OverrideSharedConfigEnable(t *testing.T) {
|
||||
|
||||
func TestNewSessionWithOptions_OverrideSharedConfigDisable(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
|
||||
@@ -178,6 +179,29 @@ func TestNewSessionWithOptions_OverrideSharedConfigDisable(t *testing.T) {
|
||||
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
|
||||
}
|
||||
|
||||
func TestNewSessionWithOptions_OverrideSharedConfigFiles(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
|
||||
os.Setenv("AWS_PROFILE", "config_file_load_order")
|
||||
|
||||
s, err := NewSessionWithOptions(Options{
|
||||
SharedConfigFiles: []string{testConfigOtherFilename},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "shared_config_other_region", *s.Config.Region)
|
||||
|
||||
creds, err := s.Config.Credentials.Get()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "shared_config_other_akid", creds.AccessKeyID)
|
||||
assert.Equal(t, "shared_config_other_secret", creds.SecretAccessKey)
|
||||
assert.Empty(t, creds.SessionToken)
|
||||
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
|
||||
}
|
||||
|
||||
func TestNewSessionWithOptions_Overrides(t *testing.T) {
|
||||
cases := []struct {
|
||||
InEnvs map[string]string
|
||||
@@ -235,7 +259,7 @@ func TestNewSessionWithOptions_Overrides(t *testing.T) {
|
||||
|
||||
for _, c := range cases {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
for k, v := range c.InEnvs {
|
||||
os.Setenv(k, v)
|
||||
@@ -279,7 +303,7 @@ const assumeRoleRespMsg = `
|
||||
|
||||
func TestSesisonAssumeRole(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_REGION", "us-east-1")
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
@@ -302,7 +326,7 @@ func TestSesisonAssumeRole(t *testing.T) {
|
||||
|
||||
func TestSessionAssumeRole_WithMFA(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_REGION", "us-east-1")
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
@@ -345,7 +369,7 @@ func TestSessionAssumeRole_WithMFA(t *testing.T) {
|
||||
|
||||
func TestSessionAssumeRole_WithMFA_NoTokenProvider(t *testing.T) {
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_REGION", "us-east-1")
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
@@ -363,7 +387,7 @@ func TestSessionAssumeRole_DisableSharedConfig(t *testing.T) {
|
||||
// Backwards compatibility with Shared config disabled
|
||||
// assume role should not be built into the config.
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "0")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
|
||||
@@ -383,7 +407,7 @@ func TestSessionAssumeRole_InvalidSourceProfile(t *testing.T) {
|
||||
// Backwards compatibility with Shared config disabled
|
||||
// assume role should not be built into the config.
|
||||
oldEnv := initSessionTestEnv()
|
||||
defer popEnv(oldEnv)
|
||||
defer awstesting.PopEnv(oldEnv)
|
||||
|
||||
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
|
||||
@@ -396,7 +420,7 @@ func TestSessionAssumeRole_InvalidSourceProfile(t *testing.T) {
|
||||
}
|
||||
|
||||
func initSessionTestEnv() (oldEnv []string) {
|
||||
oldEnv = stashEnv()
|
||||
oldEnv = awstesting.StashEnv()
|
||||
os.Setenv("AWS_CONFIG_FILE", "file_not_exists")
|
||||
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists")
|
||||
|
||||
|
||||
+1
-1
@@ -113,7 +113,7 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
|
||||
|
||||
f, err := ini.Load(b)
|
||||
if err != nil {
|
||||
return nil, SharedConfigLoadError{Filename: filename}
|
||||
return nil, SharedConfigLoadError{Filename: filename, Err: err}
|
||||
}
|
||||
|
||||
files = append(files, sharedConfigFile{
|
||||
|
||||
+73
-27
@@ -3,6 +3,8 @@ package v4_test
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -10,7 +12,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var standaloneSignCases = []struct {
|
||||
@@ -40,24 +41,43 @@ func TestPresignHandler(t *testing.T) {
|
||||
req.Time = time.Unix(0, 0)
|
||||
urlstr, err := req.Presign(5 * time.Minute)
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
expectedHost := "bucket.s3.mock-region.amazonaws.com"
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedHeaders := "content-disposition;host;x-amz-acl"
|
||||
expectedSig := "2d76a414208c0eac2a23ef9c834db9635ecd5a0fbb447a00ad191f82d854f55b"
|
||||
expectedSig := "a46583256431b09eb45ba4af2e6286d96a9835ed13721023dc8076dfdcb90fcb"
|
||||
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
|
||||
|
||||
u, _ := url.Parse(urlstr)
|
||||
urlQ := u.Query()
|
||||
assert.Equal(t, expectedHost, u.Host)
|
||||
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
|
||||
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
|
||||
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
|
||||
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
|
||||
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
|
||||
if e, a := expectedHost, u.Host; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedSig, urlQ.Get("X-Amz-Signature"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedCred, urlQ.Get("X-Amz-Credential"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, urlQ.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "300", urlQ.Get("X-Amz-Expires"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "UNSIGNED-PAYLOAD", urlQ.Get("X-Amz-Content-Sha256"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
assert.NotContains(t, urlstr, "+") // + encoded as %20
|
||||
if e, a := "+", urlstr; strings.Contains(a, e) { // + encoded as %20
|
||||
t.Errorf("expect %v not to be in %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresignRequest(t *testing.T) {
|
||||
@@ -71,30 +91,50 @@ func TestPresignRequest(t *testing.T) {
|
||||
req.Time = time.Unix(0, 0)
|
||||
urlstr, headers, err := req.PresignRequest(5 * time.Minute)
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
expectedHost := "bucket.s3.mock-region.amazonaws.com"
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedHeaders := "content-disposition;host;x-amz-acl;x-amz-content-sha256"
|
||||
expectedSig := "a5b2b500dfbf2eab5b4f55bec3e3752e04536ea1d5c047aa93bc9f1130a72cd2"
|
||||
expectedHeaders := "content-disposition;host;x-amz-acl"
|
||||
expectedSig := "a46583256431b09eb45ba4af2e6286d96a9835ed13721023dc8076dfdcb90fcb"
|
||||
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
|
||||
expectedHeaderMap := http.Header{
|
||||
"x-amz-acl": []string{"public-read"},
|
||||
"content-disposition": []string{"a+b c$d"},
|
||||
"x-amz-content-sha256": []string{"UNSIGNED-PAYLOAD"},
|
||||
"x-amz-acl": []string{"public-read"},
|
||||
"content-disposition": []string{"a+b c$d"},
|
||||
}
|
||||
|
||||
u, _ := url.Parse(urlstr)
|
||||
urlQ := u.Query()
|
||||
assert.Equal(t, expectedHost, u.Host)
|
||||
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
|
||||
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
|
||||
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
|
||||
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
|
||||
assert.Equal(t, expectedHeaderMap, headers)
|
||||
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
|
||||
if e, a := expectedHost, u.Host; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedSig, urlQ.Get("X-Amz-Signature"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedCred, urlQ.Get("X-Amz-Credential"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, urlQ.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedHeaderMap, headers; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "300", urlQ.Get("X-Amz-Expires"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := "UNSIGNED-PAYLOAD", urlQ.Get("X-Amz-Content-Sha256"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
assert.NotContains(t, urlstr, "+") // + encoded as %20
|
||||
if e, a := "+", urlstr; strings.Contains(a, e) { // + encoded as %20
|
||||
t.Errorf("expect %v not to be in %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandaloneSign_CustomURIEscape(t *testing.T) {
|
||||
@@ -107,14 +147,20 @@ func TestStandaloneSign_CustomURIEscape(t *testing.T) {
|
||||
|
||||
host := "https://subdomain.us-east-1.es.amazonaws.com"
|
||||
req, err := http.NewRequest("GET", host, nil)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
req.URL.Path = `/log-*/_search`
|
||||
req.URL.Opaque = "//subdomain.us-east-1.es.amazonaws.com/log-%2A/_search"
|
||||
|
||||
_, err = signer.Sign(req, nil, "es", "us-east-1", time.Unix(0, 0))
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
actual := req.Header.Get("Authorization")
|
||||
assert.Equal(t, expectSig, actual)
|
||||
if e, a := expectSig, actual; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
+43
-42
@@ -45,7 +45,7 @@
|
||||
// If signing a request intended for HTTP2 server, and you're using Go 1.6.2
|
||||
// through 1.7.4 you should use the URL.RawPath as the pre-escaped form of the
|
||||
// request URL. https://github.com/golang/go/issues/16847 points to a bug in
|
||||
// Go pre 1.8 that failes to make HTTP2 requests using absolute URL in the HTTP
|
||||
// Go pre 1.8 that fails to make HTTP2 requests using absolute URL in the HTTP
|
||||
// message. URL.Opaque generally will force Go to make requests with absolute URL.
|
||||
// URL.RawPath does not do this, but RawPath must be a valid escaping of Path
|
||||
// or url.EscapedPath will ignore the RawPath escaping.
|
||||
@@ -55,7 +55,6 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -402,7 +401,7 @@ var SignRequestHandler = request.NamedHandler{
|
||||
}
|
||||
|
||||
// SignSDKRequest signs an AWS request with the V4 signature. This
|
||||
// request handler is bested used only with the SDK's built in service client's
|
||||
// request handler should only be used with the SDK's built in service client's
|
||||
// API operation requests.
|
||||
//
|
||||
// This function should not be used on its on its own, but in conjunction with
|
||||
@@ -503,6 +502,8 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) {
|
||||
ctx.buildTime() // no depends
|
||||
ctx.buildCredentialString() // no depends
|
||||
|
||||
ctx.buildBodyDigest()
|
||||
|
||||
unsignedHeaders := ctx.Request.Header
|
||||
if ctx.isPresign {
|
||||
if !disableHeaderHoisting {
|
||||
@@ -514,7 +515,6 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) {
|
||||
}
|
||||
}
|
||||
|
||||
ctx.buildBodyDigest()
|
||||
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
|
||||
ctx.buildCanonicalString() // depends on canon headers / signed headers
|
||||
ctx.buildStringToSign() // depends on canon string
|
||||
@@ -604,14 +604,18 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
|
||||
headerValues := make([]string, len(headers))
|
||||
for i, k := range headers {
|
||||
if k == "host" {
|
||||
headerValues[i] = "host:" + ctx.Request.URL.Host
|
||||
if ctx.Request.Host != "" {
|
||||
headerValues[i] = "host:" + ctx.Request.Host
|
||||
} else {
|
||||
headerValues[i] = "host:" + ctx.Request.URL.Host
|
||||
}
|
||||
} else {
|
||||
headerValues[i] = k + ":" +
|
||||
strings.Join(ctx.SignedHeaderVals[k], ",")
|
||||
}
|
||||
}
|
||||
|
||||
ctx.canonicalHeaders = strings.Join(stripExcessSpaces(headerValues), "\n")
|
||||
stripExcessSpaces(headerValues)
|
||||
ctx.canonicalHeaders = strings.Join(headerValues, "\n")
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildCanonicalString() {
|
||||
@@ -713,49 +717,46 @@ func makeSha256Reader(reader io.ReadSeeker) []byte {
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
const doubleSpaces = " "
|
||||
const doubleSpace = " "
|
||||
|
||||
var doubleSpaceBytes = []byte(doubleSpaces)
|
||||
// stripExcessSpaces will rewrite the passed in slice's string values to not
|
||||
// contain muliple side-by-side spaces.
|
||||
func stripExcessSpaces(vals []string) {
|
||||
var j, k, l, m, spaces int
|
||||
for i, str := range vals {
|
||||
// Trim trailing spaces
|
||||
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
|
||||
}
|
||||
|
||||
func stripExcessSpaces(headerVals []string) []string {
|
||||
vals := make([]string, len(headerVals))
|
||||
for i, str := range headerVals {
|
||||
// Trim leading and trailing spaces
|
||||
trimmed := strings.TrimSpace(str)
|
||||
// Trim leading spaces
|
||||
for k = 0; k < j && str[k] == ' '; k++ {
|
||||
}
|
||||
str = str[k : j+1]
|
||||
|
||||
idx := strings.Index(trimmed, doubleSpaces)
|
||||
var buf []byte
|
||||
for idx > -1 {
|
||||
// Multiple adjacent spaces found
|
||||
if buf == nil {
|
||||
// first time create the buffer
|
||||
buf = []byte(trimmed)
|
||||
}
|
||||
// Strip multiple spaces.
|
||||
j = strings.Index(str, doubleSpace)
|
||||
if j < 0 {
|
||||
vals[i] = str
|
||||
continue
|
||||
}
|
||||
|
||||
stripToIdx := -1
|
||||
for j := idx + 1; j < len(buf); j++ {
|
||||
if buf[j] != ' ' {
|
||||
buf = append(buf[:idx+1], buf[j:]...)
|
||||
stripToIdx = j
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if stripToIdx >= 0 {
|
||||
idx = bytes.Index(buf[stripToIdx:], doubleSpaceBytes)
|
||||
if idx >= 0 {
|
||||
idx += stripToIdx
|
||||
buf := []byte(str)
|
||||
for k, m, l = j, j, len(buf); k < l; k++ {
|
||||
if buf[k] == ' ' {
|
||||
if spaces == 0 {
|
||||
// First space.
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
spaces++
|
||||
} else {
|
||||
idx = -1
|
||||
// End of multiple spaces.
|
||||
spaces = 0
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
}
|
||||
|
||||
if buf != nil {
|
||||
vals[i] = string(buf)
|
||||
} else {
|
||||
vals[i] = trimmed
|
||||
}
|
||||
vals[i] = string(buf[:m])
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
+211
-64
@@ -6,12 +6,11 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
@@ -20,30 +19,44 @@ import (
|
||||
|
||||
func TestStripExcessHeaders(t *testing.T) {
|
||||
vals := []string{
|
||||
"",
|
||||
"123",
|
||||
"1 2 3",
|
||||
"1 2 3 ",
|
||||
" 1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2 ",
|
||||
" 1 2 ",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
"",
|
||||
"123",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2",
|
||||
"1 2",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
newVals := stripExcessSpaces(vals)
|
||||
for i := 0; i < len(newVals); i++ {
|
||||
assert.Equal(t, expected[i], newVals[i], "test: %d", i)
|
||||
stripExcessSpaces(vals)
|
||||
for i := 0; i < len(vals); i++ {
|
||||
if e, a := expected[i], vals[i]; e != a {
|
||||
t.Errorf("%d, expect %v, got %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,12 +106,24 @@ func TestPresignRequest(t *testing.T) {
|
||||
expectedTarget := "prefix.Operation"
|
||||
|
||||
q := req.URL.Query()
|
||||
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
|
||||
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
|
||||
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
|
||||
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
|
||||
assert.Empty(t, q.Get("X-Amz-Meta-Other-Header"))
|
||||
assert.Equal(t, expectedTarget, q.Get("X-Amz-Target"))
|
||||
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty", a)
|
||||
}
|
||||
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresignBodyWithArrayRequest(t *testing.T) {
|
||||
@@ -115,12 +140,24 @@ func TestPresignBodyWithArrayRequest(t *testing.T) {
|
||||
expectedTarget := "prefix.Operation"
|
||||
|
||||
q := req.URL.Query()
|
||||
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
|
||||
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
|
||||
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
|
||||
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
|
||||
assert.Empty(t, q.Get("X-Amz-Meta-Other-Header"))
|
||||
assert.Equal(t, expectedTarget, q.Get("X-Amz-Target"))
|
||||
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignRequest(t *testing.T) {
|
||||
@@ -132,8 +169,12 @@ func TestSignRequest(t *testing.T) {
|
||||
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=ea766cabd2ec977d955a3c2bae1ae54f4515d70752f2207618396f20aa85bd21"
|
||||
|
||||
q := req.Header
|
||||
assert.Equal(t, expectedSig, q.Get("Authorization"))
|
||||
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
|
||||
if e, a := expectedSig, q.Get("Authorization"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignBodyS3(t *testing.T) {
|
||||
@@ -141,7 +182,9 @@ func TestSignBodyS3(t *testing.T) {
|
||||
signer := buildSigner()
|
||||
signer.Sign(req, body, "s3", "us-east-1", time.Now())
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
|
||||
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignBodyGlacier(t *testing.T) {
|
||||
@@ -149,7 +192,9 @@ func TestSignBodyGlacier(t *testing.T) {
|
||||
signer := buildSigner()
|
||||
signer.Sign(req, body, "glacier", "us-east-1", time.Now())
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
|
||||
if e, a := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresignEmptyBodyS3(t *testing.T) {
|
||||
@@ -157,7 +202,9 @@ func TestPresignEmptyBodyS3(t *testing.T) {
|
||||
signer := buildSigner()
|
||||
signer.Presign(req, body, "s3", "us-east-1", 5*time.Minute, time.Now())
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
|
||||
if e, a := "UNSIGNED-PAYLOAD", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignPrecomputedBodyChecksum(t *testing.T) {
|
||||
@@ -166,7 +213,9 @@ func TestSignPrecomputedBodyChecksum(t *testing.T) {
|
||||
signer := buildSigner()
|
||||
signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
assert.Equal(t, "PRECOMPUTED", hash)
|
||||
if e, a := "PRECOMPUTED", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymousCredentials(t *testing.T) {
|
||||
@@ -183,14 +232,26 @@ func TestAnonymousCredentials(t *testing.T) {
|
||||
SignSDKRequest(r)
|
||||
|
||||
urlQ := r.HTTPRequest.URL.Query()
|
||||
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
|
||||
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
|
||||
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
|
||||
assert.Empty(t, urlQ.Get("X-Amz-Date"))
|
||||
if a := urlQ.Get("X-Amz-Signature"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
if a := urlQ.Get("X-Amz-Credential"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
if a := urlQ.Get("X-Amz-SignedHeaders"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
if a := urlQ.Get("X-Amz-Date"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
|
||||
hQ := r.HTTPRequest.Header
|
||||
assert.Empty(t, hQ.Get("Authorization"))
|
||||
assert.Empty(t, hQ.Get("X-Amz-Date"))
|
||||
if a := hQ.Get("Authorization"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
if a := hQ.Get("X-Amz-Date"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
|
||||
@@ -216,7 +277,9 @@ func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
|
||||
// when it is resigned.
|
||||
return time.Now().Add(1 * time.Second)
|
||||
})
|
||||
assert.NotEqual(t, sig, r.HTTPRequest.Header.Get("Authorization"))
|
||||
if e, a := sig, r.HTTPRequest.Header.Get("Authorization"); e == a {
|
||||
t.Errorf("expect %v to be %v, but was not", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
|
||||
@@ -243,7 +306,9 @@ func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
|
||||
// when it is resigned.
|
||||
return time.Now().Add(1 * time.Second)
|
||||
})
|
||||
assert.NotEqual(t, sig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
|
||||
if e, a := sig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"); e == a {
|
||||
t.Errorf("expect %v to be %v, but was not", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResignRequestExpiredCreds(t *testing.T) {
|
||||
@@ -267,8 +332,12 @@ func TestResignRequestExpiredCreds(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, origSignedHeaders)
|
||||
assert.NotContains(t, origSignedHeaders, "authorization")
|
||||
if a := origSignedHeaders; len(a) == 0 {
|
||||
t.Errorf("expect not to be empty, but was")
|
||||
}
|
||||
if e, a := origSignedHeaders, "authorization"; strings.Contains(a, e) {
|
||||
t.Errorf("expect %v to not be in %v, but was", e, a)
|
||||
}
|
||||
origSignedAt := r.LastSignedAt
|
||||
|
||||
creds.Expire()
|
||||
@@ -279,7 +348,9 @@ func TestResignRequestExpiredCreds(t *testing.T) {
|
||||
return time.Now().Add(1 * time.Second)
|
||||
})
|
||||
updatedQuerySig := r.HTTPRequest.Header.Get("Authorization")
|
||||
assert.NotEqual(t, querySig, updatedQuerySig)
|
||||
if e, a := querySig, updatedQuerySig; e == a {
|
||||
t.Errorf("expect %v to be %v, was not", e, a)
|
||||
}
|
||||
|
||||
var updatedSignedHeaders string
|
||||
for _, p := range strings.Split(updatedQuerySig, ", ") {
|
||||
@@ -288,9 +359,15 @@ func TestResignRequestExpiredCreds(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, updatedSignedHeaders)
|
||||
assert.NotContains(t, updatedQuerySig, "authorization")
|
||||
assert.NotEqual(t, origSignedAt, r.LastSignedAt)
|
||||
if a := updatedSignedHeaders; len(a) == 0 {
|
||||
t.Errorf("expect not to be empty, but was")
|
||||
}
|
||||
if e, a := updatedQuerySig, "authorization"; strings.Contains(a, e) {
|
||||
t.Errorf("expect %v to not be in %v, but was", e, a)
|
||||
}
|
||||
if e, a := origSignedAt, r.LastSignedAt; e == a {
|
||||
t.Errorf("expect %v to be %v, was not", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreResignRequestExpiredCreds(t *testing.T) {
|
||||
@@ -315,7 +392,9 @@ func TestPreResignRequestExpiredCreds(t *testing.T) {
|
||||
SignSDKRequest(r)
|
||||
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
|
||||
signedHeaders := r.HTTPRequest.URL.Query().Get("X-Amz-SignedHeaders")
|
||||
assert.NotEmpty(t, signedHeaders)
|
||||
if a := signedHeaders; len(a) == 0 {
|
||||
t.Errorf("expect not to be empty, but was")
|
||||
}
|
||||
origSignedAt := r.LastSignedAt
|
||||
|
||||
creds.Expire()
|
||||
@@ -324,11 +403,19 @@ func TestPreResignRequestExpiredCreds(t *testing.T) {
|
||||
// Simulate the request occurred 15 minutes in the past
|
||||
return time.Now().Add(-48 * time.Hour)
|
||||
})
|
||||
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
|
||||
if e, a := querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"); e == a {
|
||||
t.Errorf("expect %v to be %v, was not", e, a)
|
||||
}
|
||||
resignedHeaders := r.HTTPRequest.URL.Query().Get("X-Amz-SignedHeaders")
|
||||
assert.Equal(t, signedHeaders, resignedHeaders)
|
||||
assert.NotContains(t, signedHeaders, "x-amz-signedHeaders")
|
||||
assert.NotEqual(t, origSignedAt, r.LastSignedAt)
|
||||
if e, a := signedHeaders, resignedHeaders; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signedHeaders, "x-amz-signedHeaders"; strings.Contains(a, e) {
|
||||
t.Errorf("expect %v to not be in %v, but was", e, a)
|
||||
}
|
||||
if e, a := origSignedAt, r.LastSignedAt; e == a {
|
||||
t.Errorf("expect %v to be %v, was not", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResignRequestExpiredRequest(t *testing.T) {
|
||||
@@ -352,8 +439,12 @@ func TestResignRequestExpiredRequest(t *testing.T) {
|
||||
// Simulate the request occurred 15 minutes in the past
|
||||
return time.Now().Add(15 * time.Minute)
|
||||
})
|
||||
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
|
||||
assert.NotEqual(t, origSignedAt, r.LastSignedAt)
|
||||
if e, a := querySig, r.HTTPRequest.Header.Get("Authorization"); e == a {
|
||||
t.Errorf("expect %v to be %v, was not", e, a)
|
||||
}
|
||||
if e, a := origSignedAt, r.LastSignedAt; e == a {
|
||||
t.Errorf("expect %v to be %v, was not", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignWithRequestBody(t *testing.T) {
|
||||
@@ -365,19 +456,29 @@ func TestSignWithRequestBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectBody, b)
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := expectBody, b; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL, nil)
|
||||
|
||||
_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignWithRequestBody_Overwrite(t *testing.T) {
|
||||
@@ -389,8 +490,12 @@ func TestSignWithRequestBody_Overwrite(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(expectBody), len(b))
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
if e, a := len(expectBody), len(b); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
@@ -399,11 +504,17 @@ func TestSignWithRequestBody_Overwrite(t *testing.T) {
|
||||
_, err = signer.Sign(req, nil, "service", "region", time.Now())
|
||||
req.ContentLength = 0
|
||||
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCanonicalRequest(t *testing.T) {
|
||||
@@ -421,7 +532,9 @@ func TestBuildCanonicalRequest(t *testing.T) {
|
||||
|
||||
ctx.buildCanonicalString()
|
||||
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
assert.Equal(t, expected, ctx.Request.URL.String())
|
||||
if e, a := expected, ctx.Request.URL.String(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
|
||||
@@ -463,7 +576,27 @@ func TestSignWithBody_NoReplaceRequestBody(t *testing.T) {
|
||||
}
|
||||
|
||||
if req.Body != origBody {
|
||||
t.Errorf("expeect request body to not be chagned")
|
||||
t.Errorf("expect request body to not be chagned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHost(t *testing.T) {
|
||||
req, body := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
req.Host = "myhost"
|
||||
ctx := &signingCtx{
|
||||
ServiceName: "dynamodb",
|
||||
Region: "us-east-1",
|
||||
Request: req,
|
||||
Body: body,
|
||||
Query: req.URL.Query(),
|
||||
Time: time.Now(),
|
||||
ExpireTime: 5 * time.Second,
|
||||
}
|
||||
|
||||
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
|
||||
if !strings.Contains(ctx.canonicalHeaders, "host:"+req.Host) {
|
||||
t.Errorf("canonical host header invalid")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -483,15 +616,29 @@ func BenchmarkSignRequest(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStripExcessSpaces(b *testing.B) {
|
||||
vals := []string{
|
||||
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
|
||||
`123 321 123 321`,
|
||||
` 123 321 123 321 `,
|
||||
}
|
||||
var stripExcessSpaceCases = []string{
|
||||
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
|
||||
`123 321 123 321`,
|
||||
` 123 321 123 321 `,
|
||||
` 123 321 123 321 `,
|
||||
"123",
|
||||
"1 2 3",
|
||||
" 1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2 ",
|
||||
" 1 2 ",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
func BenchmarkStripExcessSpaces(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
stripExcessSpaces(vals)
|
||||
// Make sure to start with a copy of the cases
|
||||
cases := append([]string{}, stripExcessSpaceCases...)
|
||||
stripExcessSpaces(cases)
|
||||
}
|
||||
}
|
||||
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
// +build go1.8
|
||||
|
||||
package aws
|
||||
|
||||
import "net/url"
|
||||
|
||||
// URLHostname will extract the Hostname without port from the URL value.
|
||||
//
|
||||
// Wrapper of net/url#URL.Hostname for backwards Go version compatibility.
|
||||
func URLHostname(url *url.URL) string {
|
||||
return url.Hostname()
|
||||
}
|
||||
+29
@@ -0,0 +1,29 @@
|
||||
// +build !go1.8
|
||||
|
||||
package aws
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// URLHostname will extract the Hostname without port from the URL value.
|
||||
//
|
||||
// Copy of Go 1.8's net/url#URL.Hostname functionality.
|
||||
func URLHostname(url *url.URL) string {
|
||||
return stripPort(url.Host)
|
||||
|
||||
}
|
||||
|
||||
// stripPort is copy of Go 1.8 url#URL.Hostname functionality.
|
||||
// https://golang.org/src/net/url/url.go
|
||||
func stripPort(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return hostport
|
||||
}
|
||||
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||
return strings.TrimPrefix(hostport[:i], "[")
|
||||
}
|
||||
return hostport[:colon]
|
||||
}
|
||||
+1
-1
@@ -5,4 +5,4 @@ package aws
|
||||
const SDKName = "aws-sdk-go"
|
||||
|
||||
// SDKVersion is the version of this SDK
|
||||
const SDKVersion = "1.8.0"
|
||||
const SDKVersion = "1.12.1"
|
||||
|
||||
Reference in New Issue
Block a user