Update Go AWS SDK to the latest version

This commit is contained in:
Andrey Smirnov
2019-07-13 00:03:55 +03:00
committed by Andrey Smirnov
parent d08be990ef
commit 94a72b23ff
2183 changed files with 885887 additions and 228114 deletions
+1 -1
View File
@@ -10,5 +10,5 @@ Please fill out the sections below to help us address your issue.
### Steps to reproduce
If you have have an runnable example, please include it.
If you have an runnable example, please include it.
@@ -0,0 +1,21 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: feature-request
assignees: ''
---
### Is this related to a problem?
A clear and concise description of the issue, e.g. I'm always frustrated when...
### Feature description
Describe what you want to happen.
### Describe alternatives you've considered
Any alternative solutions or features you've considered.
### Additional context
Add any other context or screenshots about the feature request here.
@@ -0,0 +1,24 @@
---
name: General issue
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
Please fill out the sections below to help us address your issue.
### Version of AWS SDK for Go?
### Version of Go (`go version`)?
### What issue did you see?
### Steps to reproduce
If you have an runnable example, please include it.
+1 -1
View File
@@ -3,7 +3,7 @@
"Pattern": "/sdk-for-go/api/",
"StripPrefix": "/sdk-for-go/api",
"Include": ["/src/github.com/aws/aws-sdk-go/aws", "/src/github.com/aws/aws-sdk-go/service"],
"Exclude": ["/src/cmd", "/src/github.com/aws/aws-sdk-go/awstesting", "/src/github.com/aws/aws-sdk-go/awsmigrate"],
"Exclude": ["/src/cmd", "/src/github.com/aws/aws-sdk-go/awstesting", "/src/github.com/aws/aws-sdk-go/awsmigrate", "/src/github.com/aws/aws-sdk-go/private"],
"IgnoredSuffixes": ["iface"]
},
"Github": {
+46 -19
View File
@@ -2,28 +2,55 @@ language: go
sudo: required
os:
- linux
- osx
go:
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- tip
# Use Go 1.5's vendoring experiment for 1.5 tests.
env:
- GO15VENDOREXPERIMENT=1
install:
- make get-deps
script:
- make unit-with-race-cover
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- 1.12.x
- tip
matrix:
allow_failures:
- go: tip
allow_failures:
- go: tip
- os: windows
exclude:
# OSX 1.6.4 is not present in travis.
# https://github.com/travis-ci/travis-ci/issues/10309
- go: 1.6.x
os: osx
include:
- os: windows
go: 1.12.x
- os: linux
go: 1.5.x
# Use Go 1.5's vendoring experiment for 1.5 tests.
env: GO15VENDOREXPERIMENT=1
before_install:
- if [ "$TRAVIS_OS_NAME" = "windows" ]; then choco install make; fi
script:
- if [ "$TRAVIS_OS_NAME" = "windows" ]; then
make get-deps;
make unit-no-verify;
else
if [ $TRAVIS_GO_VERSION == "1.10.x" ] ||
[ $TRAVIS_GO_VERSION == "1.11.x" ] ||
[ $TRAVIS_GO_VERSION == "1.12.x" ] ||
[ $TRAVIS_GO_VERSION == "tip" ]; then
make get-deps;
make ci-test;
else
make get-deps-tests;
make unit-old-go-race-cover;
fi
fi
branches:
only:
+3522 -10
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -1,4 +1,4 @@
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
+4 -4
View File
@@ -56,7 +56,7 @@ Please be aware of the following notes prior to opening a pull request:
3. Wherever possible, pull requests should contain tests as appropriate.
Bugfixes should contain tests that exercise the corrected behavior (i.e., the
test should fail without the bugfix and pass with it), and new features
test should fail without the bugfix and pass with it), and new features
should be accompanied by tests exercising the feature.
4. Pull requests that contain failing tests will not be merged until the test
@@ -71,7 +71,7 @@ Please be aware of the following notes prior to opening a pull request:
### Testing
To run the tests locally, running the `make unit` command will `go get` the
To run the tests locally, running the `make unit` command will `go get` the
SDK's testing dependencies, and run vet, link and unit tests for the SDK.
```
@@ -88,8 +88,8 @@ go test -tags codegen ./private/...
See the `Makefile` for additional testing tags that can be used in testing.
To test on multiple platform the SDK includes several DockerFiles under the
`awstesting/sandbox` folder, and associated make recipes to to execute
To test on multiple platform the SDK includes several DockerFiles under the
`awstesting/sandbox` folder, and associated make recipes to execute
unit testing within environments configured for specific Go versions.
```
+4 -8
View File
@@ -2,19 +2,15 @@
[[projects]]
name = "github.com/go-ini/ini"
packages = ["."]
revision = "300e940a926eb277d3901b20bdfcc54928ad3642"
version = "v1.25.4"
[[projects]]
digest = "1:13fe471d0ed891e8544eddfeeb0471fd3c9f2015609a1c000aefdedf52a19d40"
name = "github.com/jmespath/go-jmespath"
packages = ["."]
revision = "0b12d6b5"
pruneopts = ""
revision = "c2b33e84"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "51a86a867df617990082dec6b868e4efe2fdb2ed0e02a3daa93cd30f962b5085"
input-imports = ["github.com/jmespath/go-jmespath"]
solver-name = "gps-cdcl"
solver-version = 1
+1 -6
View File
@@ -37,12 +37,7 @@ ignored = [
"golang.org/x/tools/go/loader",
]
[[constraint]]
name = "github.com/go-ini/ini"
version = "1.25.4"
[[constraint]]
name = "github.com/jmespath/go-jmespath"
revision = "0b12d6b5"
revision = "c2b33e84"
#version = "0.2.2"
+145 -95
View File
@@ -1,123 +1,173 @@
LINTIGNOREDOT='awstesting/integration.+should not use dot imports'
LINTIGNOREDOC='service/[^/]+/(api|service|waiters)\.go:.+(comment on exported|should have comment or be unexported)'
LINTIGNORECONST='service/[^/]+/(api|service|waiters)\.go:.+(type|struct field|const|func) ([^ ]+) should be ([^ ]+)'
LINTIGNORESTUTTER='service/[^/]+/(api|service)\.go:.+(and that stutters)'
LINTIGNOREINFLECT='service/[^/]+/(api|errors|service)\.go:.+(method|const) .+ should be '
LINTIGNOREINFLECTS3UPLOAD='service/s3/s3manager/upload\.go:.+struct field SSEKMSKeyId should be '
LINTIGNOREENDPOINTS='aws/endpoints/(defaults|dep_service_ids).go:.+(method|const) .+ should be '
LINTIGNOREDEPS='vendor/.+\.go'
LINTIGNOREPKGCOMMENT='service/[^/]+/doc_custom.go:.+package comment should be of the form'
UNIT_TEST_TAGS="example codegen awsinclude"
SDK_WITH_VENDOR_PKGS=$(shell go list -tags ${UNIT_TEST_TAGS} ./... | grep -v "/vendor/src")
SDK_ONLY_PKGS=$(shell go list ./... | grep -v "/vendor/")
SDK_UNIT_TEST_ONLY_PKGS=$(shell go list -tags ${UNIT_TEST_TAGS} ./... | grep -v "/vendor/")
SDK_GO_1_4=$(shell go version | grep "go1.4")
SDK_GO_1_5=$(shell go version | grep "go1.5")
SDK_GO_VERSION=$(shell go version | awk '''{print $$3}''' | tr -d '''\n''')
# SDK's Core and client packages that are compatable with Go 1.5+.
SDK_CORE_PKGS=./aws/... ./private/... ./internal/...
SDK_CLIENT_PKGS=./service/...
SDK_COMPA_PKGS=${SDK_CORE_PKGS} ${SDK_CLIENT_PKGS}
all: get-deps generate unit
# SDK additional packages that are used for development of the SDK.
SDK_EXAMPLES_PKGS=./example/...
SDK_TESTING_PKGS=./awstesting/...
SDK_MODELS_PKGS=./models/...
SDK_ALL_PKGS=${SDK_COMPA_PKGS} ${SDK_TESTING_PKGS} ${SDK_EXAMPLES_PKGS} ${SDK_MODELS_PKGS}
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " api_info to print a list of services and versions"
@echo " docs to build SDK documentation"
@echo " build to go build the SDK"
@echo " unit to run unit tests"
@echo " integration to run integration tests"
@echo " performance to run performance tests"
@echo " verify to verify tests"
@echo " lint to lint the SDK"
@echo " vet to vet the SDK"
@echo " generate to go generate and make services"
@echo " gen-test to generate protocol tests"
@echo " gen-services to generate services"
@echo " get-deps to go get the SDK dependencies"
@echo " get-deps-tests to get the SDK's test dependencies"
@echo " get-deps-verify to get the SDK's verification dependencies"
all: generate unit
generate: gen-test gen-endpoints gen-services
###################
# Code Generation #
###################
generate: cleanup-models gen-test gen-endpoints gen-services
gen-test: gen-protocol-test
gen-test: gen-protocol-test gen-codegen-test
gen-codegen-test:
@echo "Generating SDK API tests"
go generate ./private/model/api/codegentest/service
gen-services:
@echo "Generating SDK clients"
go generate ./service
gen-protocol-test:
@echo "Generating SDK protocol tests"
go generate ./private/protocol/...
gen-endpoints:
go generate ./models/endpoints/
@echo "Generating SDK endpoints"
go generate ./models/endpoints
build:
@echo "go build SDK and vendor packages"
@go build ${SDK_ONLY_PKGS}
cleanup-models:
@echo "Cleaning up stale model versions"
go run -tags codegen ./private/model/cli/cleanup-models/* "./models/apis/*/*/api-2.json"
unit: get-deps-tests build verify
###################
# Unit/CI Testing #
###################
unit-no-verify:
@echo "go test SDK and vendor packages with no linting"
go test -count=1 -tags ${UNIT_TEST_TAGS} ${SDK_ALL_PKGS}
unit: verify
@echo "go test SDK and vendor packages"
@go test -tags ${UNIT_TEST_TAGS} $(SDK_UNIT_TEST_ONLY_PKGS)
go test -count=1 -tags ${UNIT_TEST_TAGS} ${SDK_ALL_PKGS}
unit-with-race-cover: get-deps-tests build verify
unit-with-race-cover: verify
@echo "go test SDK and vendor packages"
@go test -tags ${UNIT_TEST_TAGS} -race -cpu=1,2,4 $(SDK_UNIT_TEST_ONLY_PKGS)
go test -count=1 -tags ${UNIT_TEST_TAGS} -race -cpu=1,2,4 ${SDK_ALL_PKGS}
integration: get-deps-tests integ-custom smoke-tests performance
unit-old-go-race-cover:
@echo "go test SDK only packages for old Go versions"
go test -count=1 -race -cpu=1,2,4 ${SDK_COMPA_PKGS}
integ-custom:
go test -tags "integration" ./awstesting/integration/customizations/...
ci-test: generate unit-with-race-cover ci-test-generate-validate
cleanup-integ:
ci-test-generate-validate:
@echo "CI test validate no generated code changes"
git update-index --assume-unchanged go.mod go.sum
git add . -A
gitstatus=`git diff --cached --ignore-space-change`; \
git update-index --no-assume-unchanged go.mod go.sum
echo "$$gitstatus"; \
if [ "$$gitstatus" != "" ] && [ "$$gitstatus" != "skipping validation" ]; then echo "$$gitstatus"; exit 1; fi
#######################
# Integration Testing #
#######################
integration: core-integ client-integ
core-integ:
@echo "Integration Testing SDK core"
AWS_REGION="" go test -count=1 -tags "integration" -v -run '^TestInteg_' ${SDK_CORE_PKGS} ./awstesting/...
client-integ:
@echo "Integration Testing SDK clients"
AWS_REGION="" go test -count=1 -tags "integration" -v -run '^TestInteg_' ./service/...
s3crypto-integ:
@echo "Integration Testing S3 Cyrpto utility"
AWS_REGION="" go test -count=1 -tags "s3crypto_integ integration" -v -run '^TestInteg_' ./service/s3/s3crypto
cleanup-integ-buckets:
@echo "Cleaning up SDK integraiton resources"
go run -tags "integration" ./awstesting/cmd/bucket_cleanup/main.go "aws-sdk-go-integration"
smoke-tests: get-deps-tests
gucumber -go-tags "integration" ./awstesting/integration/smoke
###################
# Sandbox Testing #
###################
sandbox-tests: sandbox-test-go1.5 sandbox-test-go1.6 sandbox-test-go1.7 sandbox-test-go1.8 sandbox-test-go1.9 sandbox-test-go1.10 sandbox-test-go1.11 sandbox-test-go1.12 sandbox-test-gotip
performance: get-deps-tests
AWS_TESTING_LOG_RESULTS=${log-detailed} AWS_TESTING_REGION=$(region) AWS_TESTING_DB_TABLE=$(table) gucumber -go-tags "integration" ./awstesting/performance
sandbox-tests: sandbox-test-go15 sandbox-test-go15-novendorexp sandbox-test-go16 sandbox-test-go17 sandbox-test-go18 sandbox-test-go19 sandbox-test-gotip
sandbox-build-go15:
sandbox-build-go1.5:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.5 -t "aws-sdk-go-1.5" .
sandbox-go15: sandbox-build-go15
sandbox-go1.5: sandbox-build-go1.5
docker run -i -t aws-sdk-go-1.5 bash
sandbox-test-go15: sandbox-build-go15
sandbox-test-go1.5: sandbox-build-go1.5
docker run -t aws-sdk-go-1.5
sandbox-build-go15-novendorexp:
sandbox-build-go1.5-novendorexp:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.5-novendorexp -t "aws-sdk-go-1.5-novendorexp" .
sandbox-go15-novendorexp: sandbox-build-go15-novendorexp
sandbox-go1.5-novendorexp: sandbox-build-go1.5-novendorexp
docker run -i -t aws-sdk-go-1.5-novendorexp bash
sandbox-test-go15-novendorexp: sandbox-build-go15-novendorexp
sandbox-test-go1.5-novendorexp: sandbox-build-go1.5-novendorexp
docker run -t aws-sdk-go-1.5-novendorexp
sandbox-build-go16:
sandbox-build-go1.6:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.6 -t "aws-sdk-go-1.6" .
sandbox-go16: sandbox-build-go16
sandbox-go1.6: sandbox-build-go1.6
docker run -i -t aws-sdk-go-1.6 bash
sandbox-test-go16: sandbox-build-go16
sandbox-test-go1.6: sandbox-build-go1.6
docker run -t aws-sdk-go-1.6
sandbox-build-go17:
sandbox-build-go1.7:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.7 -t "aws-sdk-go-1.7" .
sandbox-go17: sandbox-build-go17
sandbox-go1.7: sandbox-build-go17
docker run -i -t aws-sdk-go-1.7 bash
sandbox-test-go17: sandbox-build-go17
sandbox-test-go1.7: sandbox-build-go17
docker run -t aws-sdk-go-1.7
sandbox-build-go18:
sandbox-build-go1.8:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.8 -t "aws-sdk-go-1.8" .
sandbox-go18: sandbox-build-go18
sandbox-go1.8: sandbox-build-go1.8
docker run -i -t aws-sdk-go-1.8 bash
sandbox-test-go18: sandbox-build-go18
sandbox-test-go1.8: sandbox-build-go1.8
docker run -t aws-sdk-go-1.8
sandbox-build-go19:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.8 -t "aws-sdk-go-1.9" .
sandbox-go19: sandbox-build-go19
sandbox-build-go1.9:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.9 -t "aws-sdk-go-1.9" .
sandbox-go1.9: sandbox-build-go1.9
docker run -i -t aws-sdk-go-1.9 bash
sandbox-test-go19: sandbox-build-go19
sandbox-test-go1.9: sandbox-build-go1.9
docker run -t aws-sdk-go-1.9
sandbox-build-go1.10:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.10 -t "aws-sdk-go-1.10" .
sandbox-go1.10: sandbox-build-go1.10
docker run -i -t aws-sdk-go-1.10 bash
sandbox-test-go1.10: sandbox-build-go1.10
docker run -t aws-sdk-go-1.10
sandbox-build-go1.11:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.11 -t "aws-sdk-go-1.11" .
sandbox-go1.11: sandbox-build-go1.11
docker run -i -t aws-sdk-go-1.11 bash
sandbox-test-go1.11: sandbox-build-go1.11
docker run -t aws-sdk-go-1.11
sandbox-build-go1.12:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.12 -t "aws-sdk-go-1.12" .
sandbox-go1.12: sandbox-build-go1.12
docker run -i -t aws-sdk-go-1.12 bash
sandbox-test-go1.12: sandbox-build-go1.12
docker run -t aws-sdk-go-1.12
sandbox-build-gotip:
@echo "Run make update-aws-golang-tip, if this test fails because missing aws-golang:tip container"
docker build -f ./awstesting/sandbox/Dockerfile.test.gotip -t "aws-sdk-go-tip" .
@@ -129,59 +179,59 @@ sandbox-test-gotip: sandbox-build-gotip
update-aws-golang-tip:
docker build --no-cache=true -f ./awstesting/sandbox/Dockerfile.golang-tip -t "aws-golang:tip" .
verify: get-deps-verify lint vet
##################
# Linting/Verify #
##################
verify: lint vet
lint:
@echo "go lint SDK and vendor packages"
@lint=`if [ \( -z "${SDK_GO_1_4}" \) -a \( -z "${SDK_GO_1_5}" \) ]; then golint ./...; else echo "skipping golint"; fi`; \
lint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOT} -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT}`; \
echo "$$lint"; \
if [ "$$lint" != "" ] && [ "$$lint" != "skipping golint" ]; then exit 1; fi
SDK_BASE_FOLDERS=$(shell ls -d */ | grep -v vendor | grep -v awsmigrate)
ifneq (,$(findstring go1.4, ${SDK_GO_VERSION}))
GO_VET_CMD=echo skipping go vet, ${SDK_GO_VERSION}
else ifneq (,$(findstring go1.6, ${SDK_GO_VERSION}))
GO_VET_CMD=go tool vet --all -shadow -example=false
else
GO_VET_CMD=go tool vet --all -shadow
endif
@lint=`golint ./...`; \
dolint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT} -e ${LINTIGNOREENDPOINTS}`; \
echo "$$dolint"; \
if [ "$$dolint" != "" ]; then exit 1; fi
vet:
${GO_VET_CMD} ${SDK_BASE_FOLDERS}
go vet -tags "example codegen awsinclude integration" --all ${SDK_ALL_PKGS}
get-deps: get-deps-tests get-deps-verify
@echo "go get SDK dependencies"
@go get -v $(SDK_ONLY_PKGS)
################
# Dependencies #
################
get-deps: get-deps-tests get-deps-x-tests get-deps-codegen get-deps-verify
get-deps-tests:
@echo "go get SDK testing dependencies"
go get github.com/gucumber/gucumber/cmd/gucumber
go get github.com/stretchr/testify
go get github.com/smartystreets/goconvey
go get golang.org/x/net/html
get-deps-x-tests:
@echo "go get SDK testing golang.org/x dependencies"
go get golang.org/x/net/http2
get-deps-codegen: get-deps-x-tests
@echo "go get SDK codegen dependencies"
go get golang.org/x/net/html
get-deps-verify:
@echo "go get SDK verification utilities"
@if [ \( -z "${SDK_GO_1_4}" \) -a \( -z "${SDK_GO_1_5}" \) ]; then go get github.com/golang/lint/golint; else echo "skipped getting golint"; fi
go get golang.org/x/lint/golint
##############
# Benchmarks #
##############
bench:
@echo "go bench SDK packages"
@go test -run NONE -bench . -benchmem -tags 'bench' $(SDK_ONLY_PKGS)
go test -count=1 -run NONE -bench . -benchmem -tags 'bench' ${SDK_ALL_PKGS}
bench-protocol:
@echo "go bench SDK protocol marshallers"
@go test -run NONE -bench . -benchmem -tags 'bench' ./private/protocol/...
go test -count=1 -run NONE -bench . -benchmem -tags 'bench' ./private/protocol/...
#############
# Utilities #
#############
docs:
@echo "generate SDK docs"
@# This env variable, DOCS, is for internal use
@if [ -z ${AWS_DOC_GEN_TOOL} ]; then\
rm -rf doc && bundle install && bundle exec yard;\
else\
$(AWS_DOC_GEN_TOOL) `pwd`;\
fi
$(AWS_DOC_GEN_TOOL) `pwd`
api_info:
@go run private/model/cli/api-info/api-info.go
+1 -1
View File
@@ -1,3 +1,3 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.
+74 -22
View File
@@ -1,35 +1,61 @@
[![API Reference](http://img.shields.io/badge/api-reference-blue.svg)](http://docs.aws.amazon.com/sdk-for-go/api) [![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://img.shields.io/travis/aws/aws-sdk-go.svg)](https://travis-ci.org/aws/aws-sdk-go) [![Apache V2 License](http://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/master/LICENSE.txt)
[![API Reference](https://img.shields.io/badge/api-reference-blue.svg)](https://docs.aws.amazon.com/sdk-for-go/api) [![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://img.shields.io/travis/aws/aws-sdk-go.svg)](https://travis-ci.org/aws/aws-sdk-go) [![Apache V2 License](https://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/master/LICENSE.txt)
# AWS SDK for Go
aws-sdk-go is the official AWS SDK for the Go programming language.
Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for information about the latest bug fixes, updates, and features added to the SDK.
Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for
information about the latest bug fixes, updates, and features added to the SDK.
We [announced](https://aws.amazon.com/blogs/developer/aws-sdk-for-go-2-0-developer-preview/) the Developer Preview for the [v2 AWS SDK for Go](https://github.com/aws/aws-sdk-go-v2). The v2 SDK is available at https://github.com/aws/aws-sdk-go-v2, and `go get github.com/aws/aws-sdk-go-v2` via `go get`. Check out the v2 SDK's [changes and updates](https://github.com/aws/aws-sdk-go-v2/blob/master/CHANGELOG.md), and let us know what you think. We want your feedback.
We [announced](https://aws.amazon.com/blogs/developer/aws-sdk-for-go-2-0-developer-preview/) the Developer Preview for the [v2 AWS SDK for Go](https://github.com/aws/aws-sdk-go-v2). The v2 SDK source is available at https://github.com/aws/aws-sdk-go-v2, and add it to your project with `go get github.com/aws/aws-sdk-go-v2`. Check out the v2 SDK's [changes and updates](https://github.com/aws/aws-sdk-go-v2/blob/master/CHANGELOG.md), and let us know what you think. We want your feedback.
We have a pilot redesign of the [AWS SDK for Go API reference documentation](https://docs.aws.amazon.com/sdk-for-go/v1/api/gosdk-apiref.html). Let us know what you think.
## Installing
If you are using Go 1.5 with the `GO15VENDOREXPERIMENT=1` vendoring flag, or 1.6 and higher you can use the following command to retrieve the SDK. The SDK's non-testing dependencies will be included and are vendored in the `vendor` folder.
Use `go get` to retrieve the SDK to add it to your `GOPATH` workspace, or
project's Go module dependencies.
go get -u github.com/aws/aws-sdk-go
go get github.com/aws/aws-sdk-go
Otherwise if your Go environment does not have vendoring support enabled, or you do not want to include the vendored SDK's dependencies you can use the following command to retrieve the SDK and its non-testing dependencies using `go get`.
To update the SDK use `go get -u` to retrieve the latest version of the SDK.
go get -u github.com/aws/aws-sdk-go/aws/...
go get -u github.com/aws/aws-sdk-go/service/...
go get -u github.com/aws/aws-sdk-go
If you're looking to retrieve just the SDK without any dependencies use the following command.
### Dependencies
go get -d github.com/aws/aws-sdk-go/
The SDK includes a `vendor` folder containing the runtime dependencies of the
SDK. The metadata of the SDK's dependencies can be found in the Go module file
`go.mod` or Dep file `Gopkg.toml`.
These two processes will still include the `vendor` folder and it should be deleted if its not going to be used by your environment.
### Go Modules
If you are using Go modules, your `go get` will default to the latest tagged
release version of the SDK. To get a specific release version of the SDK use
`@<tag>` in your `go get` command.
go get github.com/aws/aws-sdk-go@v1.15.77
To get the latest SDK repository change use `@latest`.
go get github.com/aws/aws-sdk-go@latest
### Go 1.5
If you are using Go 1.5 without vendoring enabled, (`GO15VENDOREXPERIMENT=1`),
you will need to use `...` when retrieving the SDK to get its dependencies.
go get github.com/aws/aws-sdk-go/...
This will still include the `vendor` folder. The `vendor` folder can be deleted
if not used by your environment.
rm -rf $GOPATH/src/github.com/aws/aws-sdk-go/vendor
## Getting Help
Please use these community resources for getting help. We use the GitHub issues for tracking bugs and feature requests.
Please use these community resources for getting help. We use the GitHub issues
for tracking bugs and feature requests.
* Ask a question on [StackOverflow](http://stackoverflow.com/) and tag it with the [`aws-sdk-go`](http://stackoverflow.com/questions/tagged/aws-sdk-go) tag.
* Come join the AWS SDK for Go community chat on [gitter](https://gitter.im/aws/aws-sdk-go).
@@ -38,19 +64,44 @@ Please use these community resources for getting help. We use the GitHub issues
## Opening Issues
If you encounter a bug with the AWS SDK for Go we would like to hear about it. Search the [existing issues](https://github.com/aws/aws-sdk-go/issues) and see if others are also experiencing the issue before opening a new issue. Please include the version of AWS SDK for Go, Go language, and OS youre using. Please also include repro case when appropriate.
If you encounter a bug with the AWS SDK for Go we would like to hear about it.
Search the [existing issues](https://github.com/aws/aws-sdk-go/issues) and see
if others are also experiencing the issue before opening a new issue. Please
include the version of AWS SDK for Go, Go language, and OS youre using. Please
also include reproduction case when appropriate.
The GitHub issues are intended for bug reports and feature requests. For help and questions with using AWS SDK for GO please make use of the resources listed in the [Getting Help](https://github.com/aws/aws-sdk-go#getting-help) section. Keeping the list of open issues lean will help us respond in a timely manner.
The GitHub issues are intended for bug reports and feature requests. For help
and questions with using AWS SDK for GO please make use of the resources listed
in the [Getting Help](https://github.com/aws/aws-sdk-go#getting-help) section.
Keeping the list of open issues lean will help us respond in a timely manner.
## Reference Documentation
[`Getting Started Guide`](https://aws.amazon.com/sdk-for-go/) - This document is a general introduction how to configure and make requests with the SDK. If this is your first time using the SDK, this documentation and the API documentation will help you get started. This document focuses on the syntax and behavior of the SDK. The [Service Developer Guide](https://aws.amazon.com/documentation/) will help you get started using specific AWS services.
[`Getting Started Guide`](https://aws.amazon.com/sdk-for-go/) - This document
is a general introduction on how to configure and make requests with the SDK.
If this is your first time using the SDK, this documentation and the API
documentation will help you get started. This document focuses on the syntax
and behavior of the SDK. The [Service Developer
Guide](https://aws.amazon.com/documentation/) will help you get started using
specific AWS services.
[`SDK API Reference Documentation`](https://docs.aws.amazon.com/sdk-for-go/api/) - Use this document to look up all API operation input and output parameters for AWS services supported by the SDK. The API reference also includes documentation of the SDK, and examples how to using the SDK, service client API operations, and API operation require parameters.
[`SDK API Reference
Documentation`](https://docs.aws.amazon.com/sdk-for-go/api/) - Use this
document to look up all API operation input and output parameters for AWS
services supported by the SDK. The API reference also includes documentation of
the SDK, and examples how to using the SDK, service client API operations, and
API operation require parameters.
[`Service Developer Guide`](https://aws.amazon.com/documentation/) - Use this documentation to learn how to interface with an AWS service. These are great guides both, if you're getting started with a service, or looking for more information on a service. You should not need this document for coding, though in some cases, services may supply helpful samples that you might want to look out for.
[`Service Developer Guide`](https://aws.amazon.com/documentation/) - Use this
documentation to learn how to interface with AWS services. These are great
guides both, if you're getting started with a service, or looking for more
information on a service. You should not need this document for coding, though
in some cases, services may supply helpful samples that you might want to look
out for.
[`SDK Examples`](https://github.com/aws/aws-sdk-go/tree/master/example) - Included in the SDK's repo are a several hand crafted examples using the SDK features and AWS services.
[`SDK Examples`](https://github.com/aws/aws-sdk-go/tree/master/example) -
Included in the SDK's repo are several hand crafted examples using the SDK
features and AWS services.
## Overview of SDK's Packages
@@ -94,8 +145,7 @@ package under the service folder at the root of the SDK.
The SDK includes the Go types and utilities you can use to make requests to
AWS service APIs. Within the service folder at the root of the SDK you'll find
a package for each AWS service the SDK supports. All service clients follows
a common pattern of creation and usage.
a package for each AWS service the SDK supports. All service clients follow common pattern of creation and usage.
When creating a client for an AWS service you'll first need to have a Session
value constructed. The Session provides shared configuration that can be shared
@@ -334,7 +384,7 @@ take a callback function that will be called for each page of the API's response
```
Waiter helper methods provide the functionality to wait for an AWS resource
state. These methods abstract the logic needed to to check the state of an
state. These methods abstract the logic needed to check the state of an
AWS resource, and wait until that resource is in a desired state. The waiter
will block until the resource is in the state that is desired, an error occurs,
or the waiter times out. If a resource times out the error code returned will
@@ -420,7 +470,9 @@ response.
}
// Ensure the context is canceled to prevent leaking.
// See context package for more information, https://golang.org/pkg/context/
defer cancelFn()
if cancelFn {
defer cancelFn()
}
// Uploads the object to S3. The Context will interrupt the request if the
// timeout expires.
+21 -2
View File
@@ -138,8 +138,27 @@ type RequestFailure interface {
RequestID() string
}
// NewRequestFailure returns a new request error wrapper for the given Error
// provided.
// NewRequestFailure returns a wrapped error with additional information for
// request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID)
}
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
type UnmarshalError interface {
awsError
Bytes() []byte
}
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
// the bytes that fail to unmarshal to the error.
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
return &unmarshalError{
awsError: New("UnmarshalError", msg, err),
bytes: bytes,
}
}
+28 -1
View File
@@ -1,6 +1,9 @@
package awserr
import "fmt"
import (
"encoding/hex"
"fmt"
)
// SprintError returns a string of the formatted error code.
//
@@ -119,6 +122,7 @@ type requestError struct {
awsError
statusCode int
requestID string
bytes []byte
}
// newRequestError returns a wrapped error with additional information for
@@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error {
return []error{r.OrigErr()}
}
type unmarshalError struct {
awsError
bytes []byte
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (e unmarshalError) Error() string {
extra := hex.Dump(e.bytes)
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (e unmarshalError) String() string {
return e.Error()
}
// Bytes returns the bytes that failed to unmarshal.
func (e unmarshalError) Bytes() []byte {
return e.bytes
}
// An error list that satisfies the golang interface
type errorList []error
+1 -1
View File
@@ -15,7 +15,7 @@ func DeepEqual(a, b interface{}) bool {
rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type the are equal
// If the elements are both nil, and of the same type they are equal
// If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid {
+12 -13
View File
@@ -23,28 +23,27 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
case reflect.Struct:
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
ft := v.Type().Field(i)
fv := v.Field(i)
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
stringValue(val, indent+2, buf)
buf.WriteString(ft.Name + ": ")
if i < len(names)-1 {
buf.WriteString(",\n")
if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString("<sensitive>")
} else {
stringValue(fv, indent+2, buf)
}
buf.WriteString(",\n")
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
+51
View File
@@ -0,0 +1,51 @@
// +build go1.7
package awsutil
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
)
type testStruct struct {
Field1 string
Field2 *string
Field3 []byte `sensitive:"true"`
Value []*string
}
func TestStringValue(t *testing.T) {
cases := map[string]struct {
Value interface{}
Expect string
}{
"general": {
Value: testStruct{
Field1: "abc123",
Field2: aws.String("abc123"),
Field3: []byte("don't show me"),
Value: []*string{
aws.String("first"),
aws.String("second"),
},
},
Expect: `{
Field1: "abc123",
Field2: "abc123",
Field3: <sensitive>,
Value: ["first","second"],
}`,
},
}
for d, c := range cases {
t.Run(d, func(t *testing.T) {
actual := StringValue(c.Value)
if e, a := c.Expect, actual; e != a {
t.Errorf("expect:\n%v\nactual:\n%v\n", e, a)
}
})
}
}
+3 -3
View File
@@ -18,7 +18,7 @@ type Config struct {
// States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors
// to determine if the signin name can be overriden based on metadata the
// to determine if the signin name can be overridden based on metadata the
// service has.
SigningNameDerived bool
}
@@ -91,6 +91,6 @@ func (c *Client) AddDebugHandlers() {
return
}
c.Handlers.Send.PushFrontNamed(request.NamedHandler{Name: "awssdk.client.LogRequest", Fn: logRequest})
c.Handlers.Send.PushBackNamed(request.NamedHandler{Name: "awssdk.client.LogResponse", Fn: logResponse})
c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler)
c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler)
}
+93 -15
View File
@@ -44,12 +44,22 @@ func (reader *teeReaderCloser) Close() error {
return reader.Source.Close()
}
// LogHTTPRequestHandler is a SDK request handler to log the HTTP request sent
// to a service. Will include the HTTP request body if the LogLevel of the
// request matches LogDebugWithHTTPBody.
var LogHTTPRequestHandler = request.NamedHandler{
Name: "awssdk.client.LogRequest",
Fn: logRequest,
}
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
bodySeekable := aws.IsReaderSeekable(r.Body)
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
b, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
@@ -63,7 +73,28 @@ func logRequest(r *request.Request) {
r.ResetBody()
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}
// LogHTTPRequestHeaderHandler is a SDK request handler to log the HTTP request sent
// to a service. Will only log the HTTP request's headers. The request payload
// will not be read.
var LogHTTPRequestHeaderHandler = request.NamedHandler{
Name: "awssdk.client.LogRequestHeader",
Fn: logRequestHeader,
}
func logRequestHeader(r *request.Request) {
b, err := httputil.DumpRequestOut(r.HTTPRequest, false)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
@@ -76,27 +107,50 @@ const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
%s
-----------------------------------------------------`
// LogHTTPResponseHandler is a SDK request handler to log the HTTP response
// received from a service. Will include the HTTP response body if the LogLevel
// of the request matches LogDebugWithHTTPBody.
var LogHTTPResponseHandler = request.NamedHandler{
Name: "awssdk.client.LogResponse",
Fn: logResponse,
}
func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
r.HTTPResponse.Body = &teeReaderCloser{
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
Source: r.HTTPResponse.Body,
if r.HTTPResponse == nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
if logBody {
r.HTTPResponse.Body = &teeReaderCloser{
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
Source: r.HTTPResponse.Body,
}
}
handlerFn := func(req *request.Request) {
body, err := httputil.DumpResponse(req.HTTPResponse, false)
b, err := httputil.DumpResponse(req.HTTPResponse, false)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
b, err := ioutil.ReadAll(lw.buf)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(fmt.Sprintf(logRespMsg, req.ClientInfo.ServiceName, req.Operation.Name, string(body)))
if req.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) {
lw.Logger.Log(fmt.Sprintf(logRespMsg,
req.ClientInfo.ServiceName, req.Operation.Name, string(b)))
if logBody {
b, err := ioutil.ReadAll(lw.buf)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(string(b))
}
}
@@ -110,3 +164,27 @@ func logResponse(r *request.Request) {
Name: handlerName, Fn: handlerFn,
})
}
// LogHTTPResponseHeaderHandler is a SDK request handler to log the HTTP
// response received from a service. Will only log the HTTP response's headers.
// The response payload will not be read.
var LogHTTPResponseHeaderHandler = request.NamedHandler{
Name: "awssdk.client.LogResponseHeader",
Fn: logResponseHeader,
}
func logResponseHeader(r *request.Request) {
if r.Config.Logger == nil {
return
}
b, err := httputil.DumpResponse(r.HTTPResponse, false)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}
+78
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"testing"
@@ -127,6 +128,83 @@ func TestLogRequest(t *testing.T) {
}
}
func TestLogResponse(t *testing.T) {
cases := []struct {
Body *bytes.Buffer
ExpectBody []byte
ReadBody bool
LogLevel aws.LogLevelType
}{
{
Body: bytes.NewBuffer([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebug,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ReadBody: true,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
var logW bytes.Buffer
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: &logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.HTTPResponse = &http.Response{
StatusCode: 200,
Status: "OK",
Header: http.Header{
"ABC": []string{"123"},
},
Body: ioutil.NopCloser(c.Body),
}
logResponse(req)
req.Handlers.Unmarshal.Run(req)
if c.ReadBody {
if e, a := len(c.ExpectBody), c.Body.Len(); e != a {
t.Errorf("%d, expect original body not to of been read", i)
}
}
if logW.Len() == 0 {
t.Errorf("%d, expect HTTP Response headers to be logged", i)
}
b, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !bytes.Equal(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
type bufLogger struct {
w *bytes.Buffer
}
+1
View File
@@ -3,6 +3,7 @@ package metadata
// ClientInfo wraps immutable data from the client.Client structure.
type ClientInfo struct {
ServiceName string
ServiceID string
APIVersion string
Endpoint string
SigningName string
+52 -8
View File
@@ -18,7 +18,7 @@ const UseServiceDefaultRetries = -1
type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig tructure.
// all clients will use the defaults.DefaultConfig structure.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
@@ -45,8 +45,8 @@ type Config struct {
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
// Note: You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The resolver to use for looking up endpoints for AWS service clients
@@ -65,8 +65,8 @@ type Config struct {
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
// Regions and Endpoints.
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
@@ -120,9 +120,10 @@ type Config struct {
// will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
// Note: This configuration option is specific to the Amazon S3 service.
//
// See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// for Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
@@ -223,6 +224,28 @@ type Config struct {
// Key: aws.String("//foo//bar//moo"),
// })
DisableRestProtocolURICleaning *bool
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
// have the definition in its model. By default, endpoint discovery is off.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// EnableEndpointDiscovery: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("/foo/bar/moo"),
// })
EnableEndpointDiscovery *bool
// DisableEndpointHostPrefix will disable the SDK's behavior of prefixing
// request endpoint hosts with modeled information.
//
// Disabling this feature is useful when you want to use local endpoints
// for testing that do not support the modeled host prefix pattern.
DisableEndpointHostPrefix *bool
}
// NewConfig returns a new Config pointer that can be chained with builder
@@ -377,6 +400,19 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
return c
}
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
func (c *Config) WithEndpointDiscovery(t bool) *Config {
c.EnableEndpointDiscovery = &t
return c
}
// WithDisableEndpointHostPrefix will set whether or not to use modeled host prefix
// when making requests.
func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
c.DisableEndpointHostPrefix = &t
return c
}
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
@@ -476,6 +512,14 @@ func mergeInConfig(dst *Config, other *Config) {
if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
}
if other.EnableEndpointDiscovery != nil {
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
}
if other.DisableEndpointHostPrefix != nil {
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
}
}
// Copy will return a shallow copy of the Config object. If any additional
@@ -1,8 +1,8 @@
// +build !go1.9
package aws
import (
"time"
)
import "time"
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
@@ -35,37 +35,3 @@ type Context interface {
// functions.
Value(key interface{}) interface{}
}
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}
-9
View File
@@ -1,9 +0,0 @@
// +build go1.7
package aws
import "context"
var (
backgroundCtx = context.Background()
)
+11
View File
@@ -0,0 +1,11 @@
// +build go1.9
package aws
import "context"
// Context is an alias of the Go stdlib's context.Context interface.
// It can be used within the SDK's API operation "WithContext" methods.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context = context.Context
@@ -39,3 +39,18 @@ func (e *emptyCtx) String() string {
var (
backgroundCtx = new(emptyCtx)
)
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
+20
View File
@@ -0,0 +1,20 @@
// +build go1.7
package aws
import "context"
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return context.Background()
}
+24
View File
@@ -0,0 +1,24 @@
package aws
import (
"time"
)
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}
+1 -1
View File
@@ -26,7 +26,7 @@ func TestSleepWithContext_Canceled(t *testing.T) {
ctx.Error = expectErr
close(ctx.DoneCh)
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
err := aws.SleepWithContext(ctx, 10*time.Second)
if err == nil {
t.Fatalf("expect error, did not get one")
}
+2 -2
View File
@@ -72,9 +72,9 @@ var ValidateReqSigHandler = request.NamedHandler{
signedTime = r.LastSignedAt
}
// 10 minutes to allow for some clock skew/delays in transmission.
// 5 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) {
if signedTime.Add(5 * time.Minute).After(time.Now()) {
return
}
@@ -1,4 +1,4 @@
// +build go1.8
// +build go1.10
package corehandlers_test
+8 -6
View File
@@ -6,7 +6,6 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
@@ -18,12 +17,13 @@ import (
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/internal/sdktesting"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
svc := awstesting.NewClient(aws.NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
@@ -37,8 +37,8 @@ func TestValidateEndpointHandler(t *testing.T) {
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
svc := awstesting.NewClient()
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
@@ -69,7 +69,9 @@ func (m *mockCredsProvider) IsExpired() bool {
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
credProvider := &mockCredsProvider{}
svc := awstesting.NewClient(&aws.Config{
+2 -2
View File
@@ -2,8 +2,8 @@ package corehandlers_test
import (
"fmt"
"testing"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@@ -256,7 +256,7 @@ func TestValidateFieldMinParameter(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
if e, a := c.err, req.Error; !reflect.DeepEqual(e,a) {
if e, a := c.err, req.Error; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
+1 -1
View File
@@ -17,7 +17,7 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
}
const execEnvVar = `AWS_EXECUTION_ENV`
const execEnvUAKey = `exec_env`
const execEnvUAKey = `exec-env`
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
// execution environment to the user agent.
+4 -3
View File
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdktesting"
)
func TestAddHostExecEnvUserAgentHander(t *testing.T) {
@@ -13,13 +14,13 @@ func TestAddHostExecEnvUserAgentHander(t *testing.T) {
ExecEnv string
Expect string
}{
{ExecEnv: "Lambda", Expect: "exec_env/Lambda"},
{ExecEnv: "Lambda", Expect: execEnvUAKey + "/Lambda"},
{ExecEnv: "", Expect: ""},
{ExecEnv: "someThingCool", Expect: "exec_env/someThingCool"},
{ExecEnv: "someThingCool", Expect: execEnvUAKey + "/someThingCool"},
}
for i, c := range cases {
os.Clearenv()
sdktesting.StashEnv()
os.Setenv(execEnvVar, c.ExecEnv)
req := &request.Request{
+1 -3
View File
@@ -9,9 +9,7 @@ var (
// providers in the ChainProvider.
//
// This has been deprecated. For verbose error messaging set
// aws.Config.CredentialsChainVerboseErrors to true
//
// @readonly
// aws.Config.CredentialsChainVerboseErrors to true.
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
`no valid providers in chain. Deprecated.
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,
+62 -30
View File
@@ -1,10 +1,10 @@
package credentials
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type secondStubProvider struct {
@@ -45,13 +45,23 @@ func TestChainProviderWithNames(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "secondStubProvider", creds.ProviderName, "Expect provider name to match")
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "secondStubProvider", creds.ProviderName; e != a {
t.Errorf("Expect provider name to match, %v got, %v", e, a)
}
// Also check credentials
assert.Equal(t, "AKIF", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "NOSECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
if e, a := "AKIF", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
@@ -71,10 +81,18 @@ func TestChainProviderGet(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestChainProviderIsExpired(t *testing.T) {
@@ -85,16 +103,26 @@ func TestChainProviderIsExpired(t *testing.T) {
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
if !p.IsExpired() {
t.Errorf("Expect expired to be true before any Retrieve")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
if !p.IsExpired() {
t.Errorf("Expect return of expired provider")
}
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
}
func TestChainProviderWithNoProvider(t *testing.T) {
@@ -102,12 +130,13 @@ func TestChainProviderWithNoProvider(t *testing.T) {
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
if e, a := ErrNoValidProvidersFoundInChain, err; e != a {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
@@ -122,13 +151,14 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
if e, a := ErrNoValidProvidersFoundInChain, err; e != a {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}
func TestChainProviderWithNoValidProviderWithVerboseEnabled(t *testing.T) {
@@ -144,11 +174,13 @@ func TestChainProviderWithNoValidProviderWithVerboseEnabled(t *testing.T) {
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
assert.Equal(t,
awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs),
err,
"Expect no providers error returned")
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}
+56 -9
View File
@@ -49,8 +49,11 @@
package credentials
import (
"fmt"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// AnonymousCredentials is an empty Credential object that can be used as
@@ -64,8 +67,6 @@ import (
// Credentials: credentials.AnonymousCredentials,
// })))
// // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields.
@@ -99,6 +100,14 @@ type Provider interface {
IsExpired() bool
}
// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
type Expirer interface {
// The time at which the credentials are no longer valid
ExpiresAt() time.Time
}
// An ErrorProvider is a stub credentials provider that always returns an error
// this is used by the SDK when construction a known provider is not possible
// due to an error.
@@ -158,13 +167,19 @@ func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {
// IsExpired returns if the credentials are expired.
func (e *Expiry) IsExpired() bool {
if e.CurrentTime == nil {
e.CurrentTime = time.Now
curTime := e.CurrentTime
if curTime == nil {
curTime = time.Now
}
return e.expiration.Before(e.CurrentTime())
return e.expiration.Before(curTime())
}
// A Credentials provides synchronous safe retrieval of AWS credentials Value.
// ExpiresAt returns the expiration time of the credential
func (e *Expiry) ExpiresAt() time.Time {
return e.expiration
}
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
//
@@ -178,7 +193,8 @@ func (e *Expiry) IsExpired() bool {
type Credentials struct {
creds Value
forceRefresh bool
m sync.Mutex
m sync.RWMutex
provider Provider
}
@@ -201,6 +217,17 @@ func NewCredentials(provider Provider) *Credentials {
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) {
// Check the cached credentials first with just the read lock.
c.m.RLock()
if !c.isExpired() {
creds := c.creds
c.m.RUnlock()
return creds, nil
}
c.m.RUnlock()
// Credentials are expired need to retrieve the credentials taking the full
// lock.
c.m.Lock()
defer c.m.Unlock()
@@ -234,8 +261,8 @@ func (c *Credentials) Expire() {
// If the Credentials were forced to be expired with Expire() this will
// reflect that override.
func (c *Credentials) IsExpired() bool {
c.m.Lock()
defer c.m.Unlock()
c.m.RLock()
defer c.m.RUnlock()
return c.isExpired()
}
@@ -244,3 +271,23 @@ func (c *Credentials) IsExpired() bool {
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
}
// ExpiresAt provides access to the functionality of the Expirer interface of
// the underlying Provider, if it supports that interface. Otherwise, it returns
// an error.
func (c *Credentials) ExpiresAt() (time.Time, error) {
c.m.RLock()
defer c.m.RUnlock()
expirer, ok := c.provider.(Expirer)
if !ok {
return time.Time{}, awserr.New("ProviderNotExpirer",
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
nil)
}
if c.forceRefresh {
// set expiration time to the distant past
return time.Time{}, nil
}
return expirer.ExpiresAt(), nil
}
@@ -0,0 +1,90 @@
// +build go1.9
package credentials
import (
"fmt"
"strconv"
"sync"
"testing"
"time"
)
func BenchmarkCredentials_Get(b *testing.B) {
stub := &stubProvider{}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, c := range cases {
b.Run(strconv.Itoa(c), func(b *testing.B) {
creds := NewCredentials(stub)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func() {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
}
wg.Done()
}()
}
b.ResetTimer()
wg.Wait()
})
}
}
func BenchmarkCredentials_Get_Expire(b *testing.B) {
p := &blockProvider{}
expRates := []int{10000, 1000, 100}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, expRate := range expRates {
for _, c := range cases {
b.Run(fmt.Sprintf("%d-%d", expRate, c), func(b *testing.B) {
creds := NewCredentials(p)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func(id int) {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
// periodically expire creds to cause rwlock
if id == 0 && j%expRate == 0 {
creds.Expire()
}
}
wg.Done()
}(i)
}
b.ResetTimer()
wg.Wait()
})
}
}
}
type blockProvider struct {
creds Value
expired bool
err error
}
func (s *blockProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "blockProvider"
time.Sleep(time.Millisecond)
return s.creds, s.err
}
func (s *blockProvider) IsExpired() bool {
return s.expired
}
+108 -12
View File
@@ -1,10 +1,11 @@
package credentials
import (
"math/rand"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type stubProvider struct {
@@ -33,17 +34,27 @@ func TestCredentialsGet(t *testing.T) {
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
func TestCredentialsExpire(t *testing.T) {
@@ -51,15 +62,31 @@ func TestCredentialsExpire(t *testing.T) {
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
if !c.IsExpired() {
t.Errorf("Expected to start out expired")
}
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
if !c.IsExpired() {
t.Errorf("Expected to be expired")
}
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
if c.IsExpired() {
t.Errorf("Expected not to be expired")
}
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
if !c.IsExpired() {
t.Errorf("Expected to be expired")
}
}
type MockProvider struct {
Expiry
}
func (*MockProvider) Retrieve() (Value, error) {
return Value{}, nil
}
func TestCredentialsGetWithProviderName(t *testing.T) {
@@ -68,6 +95,75 @@ func TestCredentialsGetWithProviderName(t *testing.T) {
c := NewCredentials(stub)
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, creds.ProviderName, "stubProvider", "Expected provider name to match")
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := creds.ProviderName, "stubProvider"; e != a {
t.Errorf("Expected provider name to match, %v got %v", e, a)
}
}
func TestCredentialsIsExpired_Race(t *testing.T) {
creds := NewChainCredentials([]Provider{&MockProvider{}})
starter := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
<-starter
for {
creds.IsExpired()
}
}()
}
close(starter)
time.Sleep(10 * time.Second)
}
func TestCredentialsExpiresAt_NoExpirer(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
_, err := c.ExpiresAt()
if e, a := "ProviderNotExpirer", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
type stubProviderExpirer struct {
stubProvider
expiration time.Time
}
func (s *stubProviderExpirer) ExpiresAt() time.Time {
return s.expiration
}
func TestCredentialsExpiresAt_HasExpirer(t *testing.T) {
stub := &stubProviderExpirer{}
c := NewCredentials(stub)
// fetch initial credentials so that forceRefresh is set false
_, err := c.Get()
if err != nil {
t.Errorf("Unexpecte error: %v", err)
}
stub.expiration = time.Unix(rand.Int63(), 0)
expiration, err := c.ExpiresAt()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if stub.expiration != expiration {
t.Errorf("Expected matching expiration, %v got %v", stub.expiration, expiration)
}
c.Expire()
expiration, err = c.ExpiresAt()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !expiration.IsZero() {
t.Errorf("Expected distant past expiration, got %v", expiration)
}
}
@@ -4,7 +4,6 @@ import (
"bufio"
"encoding/json"
"fmt"
"path"
"strings"
"time"
@@ -12,6 +11,8 @@ import (
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
// ProviderName provides a name of EC2Role provider
@@ -125,7 +126,7 @@ type ec2RoleCredRespBody struct {
Message string
}
const iamSecurityCredsPath = "/iam/security-credentials"
const iamSecurityCredsPath = "iam/security-credentials/"
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
@@ -142,7 +143,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
}
if err := s.Err(); err != nil {
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err)
return nil, awserr.New(request.ErrCodeSerialization,
"failed to read EC2 instance role from metadata service", err)
}
return credsList, nil
@@ -153,7 +155,7 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName))
resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName))
if err != nil {
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",
@@ -164,7 +166,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{},
awserr.New("SerializationError",
awserr.New(request.ErrCodeSerialization,
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
err)
}
@@ -7,8 +7,6 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
@@ -34,7 +32,7 @@ const credsFailRespTmpl = `{
func initTestServer(expireOn string, failAssume bool) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials" {
if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if failAssume {
@@ -59,11 +57,19 @@ func TestEC2RoleProvider(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("Expect session token to match, %v got %v", e, a)
}
}
func TestEC2RoleProviderFailAssume(t *testing.T) {
@@ -75,16 +81,30 @@ func TestEC2RoleProviderFailAssume(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Error(t, err, "Expect error")
if err == nil {
t.Errorf("Expect error")
}
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "ErrorMsg", e.Message())
assert.Nil(t, e.OrigErr())
if e, a := "ErrorCode", e.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ErrorMsg", e.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := e.OrigErr(); v != nil {
t.Errorf("expect nil, got %v", v)
}
assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "", creds.SessionToken, "Expect session token to match")
if e, a := "", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if e, a := "", creds.SessionToken; e != a {
t.Errorf("Expect session token to match, %v got %v", e, a)
}
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
@@ -98,18 +118,26 @@ func TestEC2RoleProviderIsExpired(t *testing.T) {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve.")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
if v := err; v != nil {
t.Errorf("Expect no error, %v", err)
}
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
if p.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve.")
}
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired.")
}
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
@@ -124,18 +152,26 @@ func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve.")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
if v := err; v != nil {
t.Errorf("Expect no error, %v", err)
}
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
if p.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve.")
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired.")
}
}
func BenchmarkEC3RoleProvider(b *testing.B) {
+17 -5
View File
@@ -39,6 +39,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
)
// ProviderName is the name of the credentials provider.
@@ -65,6 +66,10 @@ type Provider struct {
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
AuthorizationToken string
}
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
@@ -152,6 +157,9 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}
return out, req.Send()
}
@@ -167,7 +175,7 @@ func unmarshalHandler(r *request.Request) {
out := r.Data.(*getCredentialsOutput)
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
r.Error = awserr.New("SerializationError",
r.Error = awserr.New(request.ErrCodeSerialization,
"failed to decode endpoint credentials",
err,
)
@@ -178,11 +186,15 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
var errOut errorOutput
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
r.Error = awserr.New("SerializationError",
"failed to decode endpoint credentials",
err,
err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to decode error message", err),
r.HTTPResponse.StatusCode,
r.RequestID,
)
return
}
// Response body format is not consistent between metadata endpoints.
@@ -11,14 +11,19 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestRetrieveRefreshableCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/path/to/endpoint", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Accept"))
assert.Equal(t, "else", r.URL.Query().Get("something"))
if e, a := "/path/to/endpoint", r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "else", r.URL.Query().Get("something"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
@@ -39,18 +44,30 @@ func TestRetrieveRefreshableCredentials(t *testing.T) {
)
creds, err := client.Retrieve()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "TOKEN", creds.SessionToken)
assert.False(t, client.IsExpired())
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
assert.True(t, client.IsExpired())
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}
func TestRetrieveStaticCredentials(t *testing.T) {
@@ -69,12 +86,22 @@ func TestRetrieveStaticCredentials(t *testing.T) {
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.False(t, client.IsExpired())
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no SessionToken, got %#v", v)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
}
func TestFailedRetrieveCredentials(t *testing.T) {
@@ -94,18 +121,98 @@ func TestFailedRetrieveCredentials(t *testing.T) {
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.Error(t, err)
if err == nil {
t.Errorf("expect error, got none")
}
aerr := err.(awserr.Error)
assert.Equal(t, "CredentialsEndpointError", aerr.Code())
assert.Equal(t, "failed to load credentials", aerr.Message())
if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "failed to load credentials", aerr.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
aerr = aerr.OrigErr().(awserr.Error)
assert.Equal(t, "Error", aerr.Code())
assert.Equal(t, "Message", aerr.Message())
if e, a := "Error", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "Message", aerr.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.True(t, client.IsExpired())
if v := creds.AccessKeyID; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if v := creds.SecretAccessKey; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}
func TestAuthorizationToken(t *testing.T) {
const expectAuthToken = "Basic abc123"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if e, a := "/path/to/endpoint", r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectAuthToken, r.Header.Get("Authorization"); e != a {
t.Fatalf("expect %v, got %v", e, a)
}
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+"/path/to/endpoint?something=else",
func(p *endpointcreds.Provider) {
p.AuthorizationToken = expectAuthToken
},
)
creds, err := client.Retrieve()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}
-4
View File
@@ -12,14 +12,10 @@ const EnvProviderName = "EnvProvider"
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
)
+57 -21
View File
@@ -1,70 +1,106 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
"github.com/aws/aws-sdk-go/internal/sdktesting"
)
func TestEnvProviderRetrieve(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "access", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEnvProviderIsExpired(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
if !e.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve.")
}
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
if e.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve.")
}
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
_, err := e.Retrieve()
if e, a := ErrAccessKeyIDNotFound, err; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
_, err := e.Retrieve()
if e, a := ErrSecretAccessKeyNotFound, err; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEnvProviderAlternateNames(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_ACCESS_KEY", "access")
os.Setenv("AWS_SECRET_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
if e, a := "access", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expected no token, %v", v)
}
}
+1 -1
View File
@@ -24,7 +24,7 @@
//
// func() (RetrieveFn func() (key, secret, token string, err error), IsExpiredFn func() bool)
//
// Plugin Implementation Exmaple
// Plugin Implementation Example
//
// The following is an example implementation of a SDK credential provider using
// the plugin provider in this package. See the SDK's example/aws/credential/plugincreds/plugin
@@ -0,0 +1,425 @@
/*
Package processcreds is a credential Provider to retrieve `credential_process`
credentials.
WARNING: The following describes a method of sourcing credentials from an external
process. This can potentially be dangerous, so proceed with caution. Other
credential providers should be preferred if at all possible. If using this
option, you should make sure that the config file is as locked down as possible
using security best practices for your operating system.
You can use credentials from a `credential_process` in a variety of ways.
One way is to setup your shared config file, located in the default
location, with the `credential_process` key and the command you want to be
called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable
(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file.
[default]
credential_process = /command/to/call
Creating a new session will use the credential process to retrieve credentials.
NOTE: If there are credentials in the profile you are using, the credential
process will not be used.
// Initialize a session to load credentials.
sess, _ := session.NewSession(&aws.Config{
Region: aws.String("us-east-1")},
)
// Create S3 service client to use the credentials.
svc := s3.New(sess)
Another way to use the `credential_process` method is by using
`credentials.NewCredentials()` and providing a command to be executed to
retrieve credentials:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentials("/path/to/command")
// Create service client value configured for credentials.
svc := s3.New(sess, &aws.Config{Credentials: creds})
You can set a non-default timeout for the `credential_process` with another
constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To
set a one minute timeout:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentialsTimeout(
"/path/to/command",
time.Duration(500) * time.Millisecond)
If you need more control, you can set any configurable options in the
credentials using one or more option functions. For example, you can set a two
minute timeout, a credential duration of 60 minutes, and a maximum stdout
buffer size of 2k.
creds := processcreds.NewCredentials(
"/path/to/command",
func(opt *ProcessProvider) {
opt.Timeout = time.Duration(2) * time.Minute
opt.Duration = time.Duration(60) * time.Minute
opt.MaxBufSize = 2048
})
You can also use your own `exec.Cmd`:
// Create an exec.Cmd
myCommand := exec.Command("/path/to/command")
// Create credentials using your exec.Cmd and custom timeout
creds := processcreds.NewCredentialsCommand(
myCommand,
func(opt *processcreds.ProcessProvider) {
opt.Timeout = time.Duration(1) * time.Second
})
*/
package processcreds
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)
const (
// ProviderName is the name this credentials provider will label any
// returned credentials Value with.
ProviderName = `ProcessProvider`
// ErrCodeProcessProviderParse error parsing process output
ErrCodeProcessProviderParse = "ProcessProviderParseError"
// ErrCodeProcessProviderVersion version error in output
ErrCodeProcessProviderVersion = "ProcessProviderVersionError"
// ErrCodeProcessProviderRequired required attribute missing in output
ErrCodeProcessProviderRequired = "ProcessProviderRequiredError"
// ErrCodeProcessProviderExecution execution of command failed
ErrCodeProcessProviderExecution = "ProcessProviderExecutionError"
// errMsgProcessProviderTimeout process took longer than allowed
errMsgProcessProviderTimeout = "credential process timed out"
// errMsgProcessProviderProcess process error
errMsgProcessProviderProcess = "error in credential_process"
// errMsgProcessProviderParse problem parsing output
errMsgProcessProviderParse = "parse failed of credential_process output"
// errMsgProcessProviderVersion version error in output
errMsgProcessProviderVersion = "wrong version in process output (not 1)"
// errMsgProcessProviderMissKey missing access key id in output
errMsgProcessProviderMissKey = "missing AccessKeyId in process output"
// errMsgProcessProviderMissSecret missing secret acess key in output
errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output"
// errMsgProcessProviderPrepareCmd prepare of command failed
errMsgProcessProviderPrepareCmd = "failed to prepare command"
// errMsgProcessProviderEmptyCmd command must not be empty
errMsgProcessProviderEmptyCmd = "command must not be empty"
// errMsgProcessProviderPipe failed to initialize pipe
errMsgProcessProviderPipe = "failed to initialize pipe"
// DefaultDuration is the default amount of time in minutes that the
// credentials will be valid for.
DefaultDuration = time.Duration(15) * time.Minute
// DefaultBufSize limits buffer size from growing to an enormous
// amount due to a faulty process.
DefaultBufSize = 1024
// DefaultTimeout default limit on time a process can run.
DefaultTimeout = time.Duration(1) * time.Minute
)
// ProcessProvider satisfies the credentials.Provider interface, and is a
// client to retrieve credentials from a process.
type ProcessProvider struct {
staticCreds bool
credentials.Expiry
originalCommand []string
// Expiry duration of the credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// A string representing an os command that should return a JSON with
// credential information.
command *exec.Cmd
// MaxBufSize limits memory usage from growing to an enormous
// amount due to a faulty process.
MaxBufSize int
// Timeout limits the time a process can run.
Timeout time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// ProcessProvider. The credentials will expire every 15 minutes by default.
func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: exec.Command(command),
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsTimeout returns a pointer to a new Credentials object with
// the specified command and timeout, and default duration and max buffer size.
func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials {
p := NewCredentials(command, func(opt *ProcessProvider) {
opt.Timeout = timeout
})
return p
}
// NewCredentialsCommand returns a pointer to a new Credentials object with
// the specified command, and default timeout, duration and max buffer size.
func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: command,
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
type credentialProcessResponse struct {
Version int
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
SessionToken string
Expiration *time.Time
}
// Retrieve executes the 'credential_process' and returns the credentials.
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
out, err := p.executeCredentialProcess()
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
// Serialize and validate response
resp := &credentialProcessResponse{}
if err = json.Unmarshal(out, resp); err != nil {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderParse,
fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)),
err)
}
if resp.Version != 1 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderVersion,
errMsgProcessProviderVersion,
nil)
}
if len(resp.AccessKeyID) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissKey,
nil)
}
if len(resp.SecretAccessKey) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissSecret,
nil)
}
// Handle expiration
p.staticCreds = resp.Expiration == nil
if resp.Expiration != nil {
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
}
return credentials.Value{
ProviderName: ProviderName,
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.SessionToken,
}, nil
}
// IsExpired returns true if the credentials retrieved are expired, or not yet
// retrieved.
func (p *ProcessProvider) IsExpired() bool {
if p.staticCreds {
return false
}
return p.Expiry.IsExpired()
}
// prepareCommand prepares the command to be executed.
func (p *ProcessProvider) prepareCommand() error {
var cmdArgs []string
if runtime.GOOS == "windows" {
cmdArgs = []string{"cmd.exe", "/C"}
} else {
cmdArgs = []string{"sh", "-c"}
}
if len(p.originalCommand) == 0 {
p.originalCommand = make([]string, len(p.command.Args))
copy(p.originalCommand, p.command.Args)
// check for empty command because it succeeds
if len(strings.TrimSpace(p.originalCommand[0])) < 1 {
return awserr.New(
ErrCodeProcessProviderExecution,
fmt.Sprintf(
"%s: %s",
errMsgProcessProviderPrepareCmd,
errMsgProcessProviderEmptyCmd),
nil)
}
}
cmdArgs = append(cmdArgs, p.originalCommand...)
p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...)
p.command.Env = os.Environ()
return nil
}
// executeCredentialProcess starts the credential process on the OS and
// returns the results or an error.
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
if err := p.prepareCommand(); err != nil {
return nil, err
}
// Setup the pipes
outReadPipe, outWritePipe, err := os.Pipe()
if err != nil {
return nil, awserr.New(
ErrCodeProcessProviderExecution,
errMsgProcessProviderPipe,
err)
}
p.command.Stderr = os.Stderr // display stderr on console for MFA
p.command.Stdout = outWritePipe // get creds json on process's stdout
p.command.Stdin = os.Stdin // enable stdin for MFA
output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize))
stdoutCh := make(chan error, 1)
go readInput(
io.LimitReader(outReadPipe, int64(p.MaxBufSize)),
output,
stdoutCh)
execCh := make(chan error, 1)
go executeCommand(*p.command, execCh)
finished := false
var errors []error
for !finished {
select {
case readError := <-stdoutCh:
errors = appendError(errors, readError)
finished = true
case execError := <-execCh:
err := outWritePipe.Close()
errors = appendError(errors, err)
errors = appendError(errors, execError)
if errors != nil {
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderProcess,
errors)
}
case <-time.After(p.Timeout):
finished = true
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderTimeout,
errors) // errors can be nil
}
}
out := output.Bytes()
if runtime.GOOS == "windows" {
// windows adds slashes to quotes
out = []byte(strings.Replace(string(out), `\"`, `"`, -1))
}
return out, nil
}
// appendError conveniently checks for nil before appending slice
func appendError(errors []error, err error) []error {
if err != nil {
return append(errors, err)
}
return errors
}
func executeCommand(cmd exec.Cmd, exec chan error) {
// Start the command
err := cmd.Start()
if err == nil {
err = cmd.Wait()
}
exec <- err
}
func readInput(r io.Reader, w io.Writer, read chan error) {
tee := io.TeeReader(r, w)
_, err := ioutil.ReadAll(tee)
if err == io.EOF {
err = nil
}
read <- err // will only arrive here when write end of pipe is closed
}
@@ -0,0 +1,555 @@
package processcreds_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"os/exec"
"runtime"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/internal/sdktesting"
)
func TestProcessProviderFromSessionCfg(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
if runtime.GOOS == "windows" {
os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
} else {
os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "tokenDefault", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderFromSessionWithProfileCfg(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_PROFILE", "non_expire")
if runtime.GOOS == "windows" {
os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
} else {
os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "nonDefaultToken", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderNotFromCredProcCfg(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_PROFILE", "not_alone")
if runtime.GOOS == "windows" {
os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
} else {
os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderFromSessionCrd(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
if runtime.GOOS == "windows" {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
} else {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "tokenDefault", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderFromSessionWithProfileCrd(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_PROFILE", "non_expire")
if runtime.GOOS == "windows" {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
} else {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "nonDefaultToken", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderNotFromCredProcCrd(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_PROFILE", "not_alone")
if runtime.GOOS == "windows" {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
} else {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderBadCommand(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
creds := processcreds.NewCredentials("/bad/process")
_, err := creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
}
}
func TestProcessProviderMoreEmptyCommands(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
creds := processcreds.NewCredentials("")
_, err := creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
}
}
func TestProcessProviderExpectErrors(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "malformed.json"},
string(os.PathSeparator))))
_, err := creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderParse {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderParse, err)
}
creds = processcreds.NewCredentials(
fmt.Sprintf("%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "wrongversion.json"},
string(os.PathSeparator))))
_, err = creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderVersion {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderVersion, err)
}
creds = processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "missingkey.json"},
string(os.PathSeparator))))
_, err = creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
}
creds = processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "missingsecret.json"},
string(os.PathSeparator))))
_, err = creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
}
}
func TestProcessProviderTimeout(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
command := "/bin/sleep 2"
if runtime.GOOS == "windows" {
// "timeout" command does not work due to pipe redirection
command = "ping -n 2 127.0.0.1>nul"
}
creds := processcreds.NewCredentialsTimeout(
command,
time.Duration(1)*time.Second)
if _, err := creds.Get(); err == nil || err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution || err.(awserr.Error).Message() != "credential process timed out" {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
}
}
func TestProcessProviderWithLongSessionToken(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "longsessiontoken.json"},
string(os.PathSeparator))))
v, err := creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
// Text string same length as session token returned by AWS for AssumeRoleWithWebIdentity
e := "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
if a := v.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
type credentialTest struct {
Version int
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
Expiration string
}
func TestProcessProviderStatic(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
// static
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "static.json"},
string(os.PathSeparator))))
_, err := creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
}
}
func TestProcessProviderNotExpired(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
// non-static, not expired
exp := &credentialTest{}
exp.Version = 1
exp.AccessKeyID = "accesskey"
exp.SecretAccessKey = "secretkey"
exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
b, err := json.Marshal(exp)
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
tmpFile := strings.Join(
[]string{"testdata", "tmp_expiring.json"},
string(os.PathSeparator))
if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
defer func() {
if err = os.Remove(tmpFile); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
}()
creds := processcreds.NewCredentials(
fmt.Sprintf("%s %s", getOSCat(), tmpFile))
_, err = creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "not expired", "expired")
}
}
func TestProcessProviderExpired(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
// non-static, expired
exp := &credentialTest{}
exp.Version = 1
exp.AccessKeyID = "accesskey"
exp.SecretAccessKey = "secretkey"
exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339)
b, err := json.Marshal(exp)
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
tmpFile := strings.Join(
[]string{"testdata", "tmp_expired.json"},
string(os.PathSeparator))
if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
defer func() {
if err = os.Remove(tmpFile); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
}()
creds := processcreds.NewCredentials(
fmt.Sprintf("%s %s", getOSCat(), tmpFile))
_, err = creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if !creds.IsExpired() {
t.Errorf("expected %v, got %v", "expired", "not expired")
}
}
func TestProcessProviderForceExpire(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
// non-static, not expired
// setup test credentials file
exp := &credentialTest{}
exp.Version = 1
exp.AccessKeyID = "accesskey"
exp.SecretAccessKey = "secretkey"
exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
b, err := json.Marshal(exp)
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
tmpFile := strings.Join(
[]string{"testdata", "tmp_force_expire.json"},
string(os.PathSeparator))
if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
defer func() {
if err = os.Remove(tmpFile); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
}()
// get credentials from file
creds := processcreds.NewCredentials(
fmt.Sprintf("%s %s", getOSCat(), tmpFile))
if _, err = creds.Get(); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "not expired", "expired")
}
// force expire creds
creds.Expire()
if !creds.IsExpired() {
t.Errorf("expected %v, got %v", "expired", "not expired")
}
// renew creds
if _, err = creds.Get(); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "not expired", "expired")
}
}
func TestProcessProviderAltConstruct(t *testing.T) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
// constructing with exec.Cmd instead of string
myCommand := exec.Command(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "static.json"},
string(os.PathSeparator))))
creds := processcreds.NewCredentialsCommand(myCommand, func(opt *processcreds.ProcessProvider) {
opt.Timeout = time.Duration(1) * time.Second
})
_, err := creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
}
}
func BenchmarkProcessProvider(b *testing.B) {
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "static.json"},
string(os.PathSeparator))))
_, err := creds.Get()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := creds.Get()
if err != nil {
b.Fatal(err)
}
}
}
func getOSCat() string {
if runtime.GOOS == "windows" {
return "type"
}
return "cat"
}
@@ -0,0 +1,7 @@
{
"Version": 1,
"AccessKeyId": "accessKey",
"SecretAccessKey": "secret",
"SessionToken": "tokenDefault",
"Expiration": "2000-01-01T00:00:00-00:00"
}
@@ -0,0 +1,8 @@
{
"Version": 1,
"AccessKeyId": "ASIAXXXXXXXXXXXXXXXX",
"Expiration": "2199-01-01T00:00:00Z",
"SecretAccessKey": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
"SessionToken":
"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
}
@@ -0,0 +1,2 @@
{
"Version": 1
@@ -0,0 +1,4 @@
{
"Version": 1,
"AccessKeyId": "accesskey"
}
@@ -0,0 +1,4 @@
{
"Version": 1,
"SecretAccessKey": "secretkey"
}
@@ -0,0 +1,6 @@
{
"Version": 1,
"AccessKeyId": "accessKey",
"SecretAccessKey": "secret",
"SessionToken": "nonDefaultToken"
}
@@ -0,0 +1,10 @@
[default]
credential_process = cat ./testdata/expired.json
[profile non_expire]
credential_process = cat ./testdata/nonexpire.json
[profile not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = cat ./testdata/verybad.json
@@ -0,0 +1,10 @@
[default]
credential_process = type .\testdata\expired.json
[profile non_expire]
credential_process = type .\testdata\nonexpire.json
[profile not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = type .\testdata\verybad.json
@@ -0,0 +1,10 @@
[default]
credential_process = cat ./testdata/expired.json
[non_expire]
credential_process = cat ./testdata/nonexpire.json
[not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = cat ./testdata/verybad.json
@@ -0,0 +1,10 @@
[default]
credential_process = type .\testdata\expired.json
[non_expire]
credential_process = type .\testdata\nonexpire.json
[not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = type .\testdata\verybad.json
@@ -0,0 +1,5 @@
{
"Version":1,
"AccessKeyId":"accesskey",
"SecretAccessKey":"secretkey"
}
@@ -0,0 +1,5 @@
{
"Version":1,
"AccessKeyId":"veryBadAccessKeyID",
"SecretAccessKey":"veryBadSecretAccessKey"
}
@@ -0,0 +1,3 @@
{
"Version": 2
}
@@ -4,9 +4,8 @@ import (
"fmt"
"os"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/ini"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
@@ -77,36 +76,37 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.Load(filename)
config, err := ini.OpenFile(filename)
if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
iniProfile, err := config.GetSection(profile)
if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err)
iniProfile, ok := config.GetSection(profile)
if !ok {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
}
id, err := iniProfile.GetKey("aws_access_key_id")
if err != nil {
id := iniProfile.String("aws_access_key_id")
if len(id) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
err)
nil)
}
secret, err := iniProfile.GetKey("aws_secret_access_key")
if err != nil {
secret := iniProfile.String("aws_secret_access_key")
if len(secret) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil)
}
// Default to empty string if not found
token := iniProfile.Key("aws_session_token")
token := iniProfile.String("aws_session_token")
return Value{
AccessKeyID: id.String(),
SecretAccessKey: secret.String(),
SessionToken: token.String(),
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
ProviderName: SharedCredsProviderName,
}, nil
}
@@ -5,101 +5,169 @@ import (
"path/filepath"
"testing"
"github.com/aws/aws-sdk-go/internal/sdktesting"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/stretchr/testify/assert"
)
func TestSharedCredentialsProvider(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
if p.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve")
}
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "example.ini")
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILEAbsPath(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
wd, err := os.Getwd()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(wd, "example.ini"))
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_PROFILE", "no_token")
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no token, %v", v)
}
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no token, %v", v)
}
}
func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "with_colon"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no token, %v", v)
}
}
func TestSharedCredentialsProvider_DefaultFilename(t *testing.T) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("USERPROFILE", "profile_dir")
os.Setenv("HOME", "home_dir")
@@ -118,7 +186,8 @@ func TestSharedCredentialsProvider_DefaultFilename(t *testing.T) {
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
_, err := p.Retrieve()
-2
View File
@@ -9,8 +9,6 @@ const StaticProviderName = "StaticProvider"
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
)
+15 -6
View File
@@ -1,7 +1,6 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
@@ -15,10 +14,18 @@ func TestStaticProviderGet(t *testing.T) {
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no session token, %v", v)
}
}
func TestStaticProviderIsExpired(t *testing.T) {
@@ -30,5 +37,7 @@ func TestStaticProviderIsExpired(t *testing.T) {
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
if s.IsExpired() {
t.Errorf("Expect static credentials to never expire")
}
}
@@ -80,16 +80,18 @@ package stscreds
import (
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts"
)
// StdinTokenProvider will prompt on stdout and read from stdin for a string value.
// StdinTokenProvider will prompt on stderr and read from stdin for a string value.
// An error is returned if reading from stdin fails.
//
// Use this function go read MFA tokens from stdin. The function makes no attempt
@@ -102,7 +104,7 @@ import (
// Will wait forever until something is provided on the stdin.
func StdinTokenProvider() (string, error) {
var v string
fmt.Printf("Assume Role MFA token code: ")
fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
_, err := fmt.Scanln(&v)
return v, err
@@ -193,6 +195,18 @@ type AssumeRoleProvider struct {
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// MaxJitterFrac reduces the effective Duration of each credential requested
// by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must
// have a value between 0 and 1. Any other value may lead to expected behavior.
// With a MaxJitterFrac value of 0, default) will no jitter will be used.
//
// For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the
// AssumeRole call will be made with an arbitrary Duration between 27m and
// 30m.
//
// MaxJitterFrac should not be negative.
MaxJitterFrac float64
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
@@ -244,7 +258,6 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
@@ -254,8 +267,9 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Expire as often as AWS permits.
p.Duration = DefaultDuration
}
jitter := time.Duration(sdkrand.SeededRand.Float64() * p.MaxJitterFrac * float64(p.Duration))
input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
DurationSeconds: aws.Int64(int64((p.Duration - jitter) / time.Second)),
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,
@@ -7,7 +7,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
)
type stubSTS struct {
@@ -38,18 +37,30 @@ func TestAssumeRoleProvider(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
if e, a := "0123456789", *in.SerialNumber; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "code", *in.TokenCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
},
}
p := &AssumeRoleProvider{
@@ -60,18 +71,30 @@ func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
if e, a := "0123456789", *in.SerialNumber; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "code", *in.TokenCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
},
}
p := &AssumeRoleProvider{
@@ -84,17 +107,25 @@ func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
t.Errorf("API request should not of been called")
},
}
p := &AssumeRoleProvider{
@@ -107,17 +138,25 @@ func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Error(t, err)
if err == nil {
t.Errorf("expect error")
}
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
if v := creds.AccessKeyID; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SecretAccessKey; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
}
func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
t.Errorf("API request should not of been called")
},
}
p := &AssumeRoleProvider{
@@ -127,11 +166,19 @@ func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Error(t, err)
if err == nil {
t.Errorf("expect error")
}
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
if v := creds.AccessKeyID; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SecretAccessKey; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
+119
View File
@@ -0,0 +1,119 @@
package crr
import (
"sync/atomic"
)
// EndpointCache is an LRU cache that holds a series of endpoints
// based on some key. The datastructure makes use of a read write
// mutex to enable asynchronous use.
type EndpointCache struct {
endpoints syncMap
endpointLimit int64
// size is used to count the number elements in the cache.
// The atomic package is used to ensure this size is accurate when
// using multiple goroutines.
size int64
}
// NewEndpointCache will return a newly initialized cache with a limit
// of endpointLimit entries.
func NewEndpointCache(endpointLimit int64) *EndpointCache {
return &EndpointCache{
endpointLimit: endpointLimit,
endpoints: newSyncMap(),
}
}
// get is a concurrent safe get operation that will retrieve an endpoint
// based on endpointKey. A boolean will also be returned to illustrate whether
// or not the endpoint had been found.
func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
endpoint, ok := c.endpoints.Load(endpointKey)
if !ok {
return Endpoint{}, false
}
c.endpoints.Store(endpointKey, endpoint)
return endpoint.(Endpoint), true
}
// Has returns if the enpoint cache contains a valid entry for the endpoint key
// provided.
func (c *EndpointCache) Has(endpointKey string) bool {
endpoint, ok := c.get(endpointKey)
_, found := endpoint.GetValidAddress()
return ok && found
}
// Get will retrieve a weighted address based off of the endpoint key. If an endpoint
// should be retrieved, due to not existing or the current endpoint has expired
// the Discoverer object that was passed in will attempt to discover a new endpoint
// and add that to the cache.
func (c *EndpointCache) Get(d Discoverer, endpointKey string, required bool) (WeightedAddress, error) {
var err error
endpoint, ok := c.get(endpointKey)
weighted, found := endpoint.GetValidAddress()
shouldGet := !ok || !found
if required && shouldGet {
if endpoint, err = c.discover(d, endpointKey); err != nil {
return WeightedAddress{}, err
}
weighted, _ = endpoint.GetValidAddress()
} else if shouldGet {
go c.discover(d, endpointKey)
}
return weighted, nil
}
// Add is a concurrent safe operation that will allow new endpoints to be added
// to the cache. If the cache is full, the number of endpoints equal endpointLimit,
// then this will remove the oldest entry before adding the new endpoint.
func (c *EndpointCache) Add(endpoint Endpoint) {
// de-dups multiple adds of an endpoint with a pre-existing key
if iface, ok := c.endpoints.Load(endpoint.Key); ok {
e := iface.(Endpoint)
if e.Len() > 0 {
return
}
}
c.endpoints.Store(endpoint.Key, endpoint)
size := atomic.AddInt64(&c.size, 1)
if size > 0 && size > c.endpointLimit {
c.deleteRandomKey()
}
}
// deleteRandomKey will delete a random key from the cache. If
// no key was deleted false will be returned.
func (c *EndpointCache) deleteRandomKey() bool {
atomic.AddInt64(&c.size, -1)
found := false
c.endpoints.Range(func(key, value interface{}) bool {
found = true
c.endpoints.Delete(key)
return false
})
return found
}
// discover will get and store and endpoint using the Discoverer.
func (c *EndpointCache) discover(d Discoverer, endpointKey string) (Endpoint, error) {
endpoint, err := d.Discover()
if err != nil {
return Endpoint{}, err
}
endpoint.Key = endpointKey
c.Add(endpoint)
return endpoint, nil
}
+452
View File
@@ -0,0 +1,452 @@
package crr
import (
"net/url"
"reflect"
"testing"
)
func urlParse(uri string) *url.URL {
u, _ := url.Parse(uri)
return u
}
func TestCacheAdd(t *testing.T) {
cases := []struct {
limit int64
endpoints []Endpoint
validKeys map[string]Endpoint
expectedSize int
}{
{
limit: 5,
endpoints: []Endpoint{
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
expectedSize: 5,
},
{
limit: 2,
endpoints: []Endpoint{
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
expectedSize: 2,
},
}
for _, c := range cases {
cache := NewEndpointCache(c.limit)
for _, endpoint := range c.endpoints {
cache.Add(endpoint)
}
count := 0
endpoints := map[string]Endpoint{}
cache.endpoints.Range(func(key, value interface{}) bool {
count++
endpoints[key.(string)] = value.(Endpoint)
return true
})
if e, a := c.expectedSize, cache.size; int64(e) != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := c.expectedSize, count; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
for k, ep := range endpoints {
endpoint, ok := c.validKeys[k]
if !ok {
t.Errorf("unrecognized key %q in cache", k)
}
if e, a := endpoint, ep; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
}
}
func TestCacheGet(t *testing.T) {
cases := []struct {
addEndpoints []Endpoint
validKeys map[string]Endpoint
limit int64
}{
{
limit: 5,
addEndpoints: []Endpoint{
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
},
{
limit: 2,
addEndpoints: []Endpoint{
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
},
}
for _, c := range cases {
cache := NewEndpointCache(c.limit)
for _, endpoint := range c.addEndpoints {
cache.Add(endpoint)
}
keys := []string{}
cache.endpoints.Range(func(key, value interface{}) bool {
a := value.(Endpoint)
e, ok := c.validKeys[key.(string)]
if !ok {
t.Errorf("unrecognized key %q in cache", key.(string))
}
if !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
keys = append(keys, key.(string))
return true
})
for _, key := range keys {
a, ok := cache.get(key)
if !ok {
t.Errorf("expected key to be present: %q", key)
}
e := c.validKeys[key]
if !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
}
}
+99
View File
@@ -0,0 +1,99 @@
package crr
import (
"net/url"
"sort"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
)
// Endpoint represents an endpoint used in endpoint discovery.
type Endpoint struct {
Key string
Addresses WeightedAddresses
}
// WeightedAddresses represents a list of WeightedAddress.
type WeightedAddresses []WeightedAddress
// WeightedAddress represents an address with a given weight.
type WeightedAddress struct {
URL *url.URL
Expired time.Time
}
// HasExpired will return whether or not the endpoint has expired with
// the exception of a zero expiry meaning does not expire.
func (e WeightedAddress) HasExpired() bool {
return e.Expired.Before(time.Now())
}
// Add will add a given WeightedAddress to the address list of Endpoint.
func (e *Endpoint) Add(addr WeightedAddress) {
e.Addresses = append(e.Addresses, addr)
}
// Len returns the number of valid endpoints where valid means the endpoint
// has not expired.
func (e *Endpoint) Len() int {
validEndpoints := 0
for _, endpoint := range e.Addresses {
if endpoint.HasExpired() {
continue
}
validEndpoints++
}
return validEndpoints
}
// GetValidAddress will return a non-expired weight endpoint
func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) {
for i := 0; i < len(e.Addresses); i++ {
we := e.Addresses[i]
if we.HasExpired() {
e.Addresses = append(e.Addresses[:i], e.Addresses[i+1:]...)
i--
continue
}
return we, true
}
return WeightedAddress{}, false
}
// Discoverer is an interface used to discovery which endpoint hit. This
// allows for specifics about what parameters need to be used to be contained
// in the Discoverer implementor.
type Discoverer interface {
Discover() (Endpoint, error)
}
// BuildEndpointKey will sort the keys in alphabetical order and then retrieve
// the values in that order. Those values are then concatenated together to form
// the endpoint key.
func BuildEndpointKey(params map[string]*string) string {
keys := make([]string, len(params))
i := 0
for k := range params {
keys[i] = k
i++
}
sort.Strings(keys)
values := make([]string, len(params))
for i, k := range keys {
if params[k] == nil {
continue
}
values[i] = aws.StringValue(params[k])
}
return strings.Join(values, ".")
}
+29
View File
@@ -0,0 +1,29 @@
// +build go1.9
package crr
import (
"sync"
)
type syncMap sync.Map
func newSyncMap() syncMap {
return syncMap{}
}
func (m *syncMap) Load(key interface{}) (interface{}, bool) {
return (*sync.Map)(m).Load(key)
}
func (m *syncMap) Store(key interface{}, value interface{}) {
(*sync.Map)(m).Store(key, value)
}
func (m *syncMap) Delete(key interface{}) {
(*sync.Map)(m).Delete(key)
}
func (m *syncMap) Range(f func(interface{}, interface{}) bool) {
(*sync.Map)(m).Range(f)
}
+48
View File
@@ -0,0 +1,48 @@
// +build !go1.9
package crr
import (
"sync"
)
type syncMap struct {
container map[interface{}]interface{}
lock sync.RWMutex
}
func newSyncMap() syncMap {
return syncMap{
container: map[interface{}]interface{}{},
}
}
func (m *syncMap) Load(key interface{}) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
v, ok := m.container[key]
return v, ok
}
func (m *syncMap) Store(key interface{}, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
m.container[key] = value
}
func (m *syncMap) Delete(key interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.container, key)
}
func (m *syncMap) Range(f func(interface{}, interface{}) bool) {
for k, v := range m.container {
if !f(k, v) {
return
}
}
}
+110
View File
@@ -0,0 +1,110 @@
package crr
import (
"reflect"
"testing"
)
func TestRangeDelete(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
m.Delete(key)
return true
})
expectedMap := map[interface{}]interface{}{}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestRangeStore(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
v := value.(int)
m.Store(key, v+1)
return true
})
expectedMap := map[interface{}]interface{}{
0: 1,
1: 11,
2: 21,
3: 31,
4: 41,
5: 51,
6: 61,
7: 71,
8: 81,
9: 91,
}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestRangeGet(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
m.Load(key)
return true
})
expectedMap := map[interface{}]interface{}{
0: 0,
1: 10,
2: 20,
3: 30,
4: 40,
5: 50,
6: 60,
7: 70,
8: 80,
9: 90,
}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
+40
View File
@@ -0,0 +1,40 @@
// +build go1.7
package csm
import "testing"
func TestAddressWithDefaults(t *testing.T) {
cases := map[string]struct {
Host, Port string
Expect string
}{
"ip": {
Host: "127.0.0.2", Port: "", Expect: "127.0.0.2:31000",
},
"localhost": {
Host: "localhost", Port: "", Expect: "127.0.0.1:31000",
},
"uppercase localhost": {
Host: "LOCALHOST", Port: "", Expect: "127.0.0.1:31000",
},
"port": {
Host: "localhost", Port: "32000", Expect: "127.0.0.1:32000",
},
"ip6": {
Host: "::1", Port: "", Expect: "[::1]:31000",
},
"unset": {
Host: "", Port: "", Expect: "127.0.0.1:31000",
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
actual := AddressWithDefaults(c.Host, c.Port)
if e, a := c.Expect, actual; e != a {
t.Errorf("expect %v, got %v", e, a)
}
})
}
}
+69
View File
@@ -0,0 +1,69 @@
// Package csm provides the Client Side Monitoring (CSM) client which enables
// sending metrics via UDP connection to the CSM agent. This package provides
// control options, and configuration for the CSM client. The client can be
// controlled manually, or automatically via the SDK's Session configuration.
//
// Enabling CSM client via SDK's Session configuration
//
// The CSM client can be enabled automatically via SDK's Session configuration.
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
// environment variable is set to a non-empty value.
//
// The configuration options for the CSM client via the SDK's session
// configuration are:
//
// * AWS_CSM_PORT=<port number>
// The port number the CSM agent will receive metrics on.
//
// * AWS_CSM_HOST=<hostname or ip>
// The hostname, or IP address the CSM agent will receive metrics on.
// Without port number.
//
// Manually enabling the CSM client
//
// The CSM client can be started, paused, and resumed manually. The Start
// function will enable the CSM client to publish metrics to the CSM agent. It
// is safe to call Start concurrently, but if Start is called additional times
// with different ClientID or address it will panic.
//
// r, err := csm.Start("clientID", ":31000")
// if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err))
// }
//
// When controlling the CSM client manually, you must also inject its request
// handlers into the SDK's Session configuration for the SDK's API clients to
// publish metrics.
//
// sess, err := session.NewSession(&aws.Config{})
// if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err))
// }
//
// // Add CSM client's metric publishing request handlers to the SDK's
// // Session Configuration.
// r.InjectHandlers(&sess.Handlers)
//
// Controlling CSM client
//
// Once the CSM client has been enabled the Get function will return a Reporter
// value that you can use to pause and resume the metrics published to the CSM
// agent. If Get function is called before the reporter is enabled with the
// Start function or via SDK's Session configuration nil will be returned.
//
// The Pause method can be called to stop the CSM client publishing metrics to
// the CSM agent. The Continue method will resume metric publishing.
//
// // Get the CSM client Reporter.
// r := csm.Get()
//
// // Will pause monitoring
// r.Pause()
// resp, err = client.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
//
// // Resume monitoring
// r.Continue()
package csm
+89
View File
@@ -0,0 +1,89 @@
package csm
import (
"fmt"
"strings"
"sync"
)
var (
lock sync.Mutex
)
const (
// DefaultPort is used when no port is specified.
DefaultPort = "31000"
// DefaultHost is the host that will be used when none is specified.
DefaultHost = "127.0.0.1"
)
// AddressWithDefaults returns a CSM address built from the host and port
// values. If the host or port is not set, default values will be used
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
func AddressWithDefaults(host, port string) string {
if len(host) == 0 || strings.EqualFold(host, "localhost") {
host = DefaultHost
}
if len(port) == 0 {
port = DefaultPort
}
// Only IP6 host can contain a colon
if strings.Contains(host, ":") {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// Start will start a long running go routine to capture
// client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different
// client ID or port is passed in.
//
// r, err := csm.Start("clientID", "127.0.0.1:31000")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }
// sess := session.NewSession()
// r.InjectHandlers(sess.Handlers)
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
func Start(clientID string, url string) (*Reporter, error) {
lock.Lock()
defer lock.Unlock()
if sender == nil {
sender = newReporter(clientID, url)
} else {
if sender.clientID != clientID {
panic(fmt.Errorf("inconsistent client IDs. %q was expected, but received %q", sender.clientID, clientID))
}
if sender.url != url {
panic(fmt.Errorf("inconsistent URLs. %q was expected, but received %q", sender.url, url))
}
}
if err := connect(url); err != nil {
sender = nil
return nil, err
}
return sender, nil
}
// Get will return a reporter if one exists, if one does not exist, nil will
// be returned.
func Get() *Reporter {
lock.Lock()
defer lock.Unlock()
return sender
}
+74
View File
@@ -0,0 +1,74 @@
package csm
import (
"encoding/json"
"fmt"
"net"
"testing"
)
func startUDPServer(done chan struct{}, fn func([]byte)) (string, error) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
return "", err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return "", err
}
buf := make([]byte, 1024)
go func() {
defer conn.Close()
for {
select {
case <-done:
return
default:
}
n, _, err := conn.ReadFromUDP(buf)
fn(buf[:n])
if err != nil {
panic(err)
}
}
}()
return conn.LocalAddr().String(), nil
}
func TestDifferentParams(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic with different parameters")
}
}()
Start("clientID2", ":0")
}
var MetricsCh = make(chan map[string]interface{}, 1)
var Done = make(chan struct{})
func init() {
url, err := startUDPServer(Done, func(b []byte) {
m := map[string]interface{}{}
if err := json.Unmarshal(b, &m); err != nil {
panic(fmt.Sprintf("expected no error, but received %v", err))
}
MetricsCh <- m
})
if err != nil {
panic(err)
}
_, err = Start("clientID", url)
if err != nil {
panic(err)
}
}
+40
View File
@@ -0,0 +1,40 @@
package csm_test
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
func ExampleStart() {
r, err := csm.Start("clientID", ":31000")
if err != nil {
panic(fmt.Errorf("failed starting CSM: %v", err))
}
sess, err := session.NewSession(&aws.Config{})
if err != nil {
panic(fmt.Errorf("failed loading session: %v", err))
}
r.InjectHandlers(&sess.Handlers)
client := s3.New(sess)
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Pauses monitoring
r.Pause()
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Resume monitoring
r.Continue()
}
+109
View File
@@ -0,0 +1,109 @@
package csm
import (
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
)
type metricTime time.Time
func (t metricTime) MarshalJSON() ([]byte, error) {
ns := time.Duration(time.Time(t).UnixNano())
return []byte(strconv.FormatInt(int64(ns/time.Millisecond), 10)), nil
}
type metric struct {
ClientID *string `json:"ClientId,omitempty"`
API *string `json:"Api,omitempty"`
Service *string `json:"Service,omitempty"`
Timestamp *metricTime `json:"Timestamp,omitempty"`
Type *string `json:"Type,omitempty"`
Version *int `json:"Version,omitempty"`
AttemptCount *int `json:"AttemptCount,omitempty"`
Latency *int `json:"Latency,omitempty"`
Fqdn *string `json:"Fqdn,omitempty"`
UserAgent *string `json:"UserAgent,omitempty"`
AttemptLatency *int `json:"AttemptLatency,omitempty"`
SessionToken *string `json:"SessionToken,omitempty"`
Region *string `json:"Region,omitempty"`
AccessKey *string `json:"AccessKey,omitempty"`
HTTPStatusCode *int `json:"HttpStatusCode,omitempty"`
XAmzID2 *string `json:"XAmzId2,omitempty"`
XAmzRequestID *string `json:"XAmznRequestId,omitempty"`
AWSException *string `json:"AwsException,omitempty"`
AWSExceptionMessage *string `json:"AwsExceptionMessage,omitempty"`
SDKException *string `json:"SdkException,omitempty"`
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
FinalHTTPStatusCode *int `json:"FinalHttpStatusCode,omitempty"`
FinalAWSException *string `json:"FinalAwsException,omitempty"`
FinalAWSExceptionMessage *string `json:"FinalAwsExceptionMessage,omitempty"`
FinalSDKException *string `json:"FinalSdkException,omitempty"`
FinalSDKExceptionMessage *string `json:"FinalSdkExceptionMessage,omitempty"`
DestinationIP *string `json:"DestinationIp,omitempty"`
ConnectionReused *int `json:"ConnectionReused,omitempty"`
AcquireConnectionLatency *int `json:"AcquireConnectionLatency,omitempty"`
ConnectLatency *int `json:"ConnectLatency,omitempty"`
RequestLatency *int `json:"RequestLatency,omitempty"`
DNSLatency *int `json:"DnsLatency,omitempty"`
TCPLatency *int `json:"TcpLatency,omitempty"`
SSLLatency *int `json:"SslLatency,omitempty"`
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
}
func (m *metric) TruncateFields() {
m.ClientID = truncateString(m.ClientID, 255)
m.UserAgent = truncateString(m.UserAgent, 256)
m.AWSException = truncateString(m.AWSException, 128)
m.AWSExceptionMessage = truncateString(m.AWSExceptionMessage, 512)
m.SDKException = truncateString(m.SDKException, 128)
m.SDKExceptionMessage = truncateString(m.SDKExceptionMessage, 512)
m.FinalAWSException = truncateString(m.FinalAWSException, 128)
m.FinalAWSExceptionMessage = truncateString(m.FinalAWSExceptionMessage, 512)
m.FinalSDKException = truncateString(m.FinalSDKException, 128)
m.FinalSDKExceptionMessage = truncateString(m.FinalSDKExceptionMessage, 512)
}
func truncateString(v *string, l int) *string {
if v != nil && len(*v) > l {
nv := (*v)[:l]
return &nv
}
return v
}
func (m *metric) SetException(e metricException) {
switch te := e.(type) {
case awsException:
m.AWSException = aws.String(te.exception)
m.AWSExceptionMessage = aws.String(te.message)
case sdkException:
m.SDKException = aws.String(te.exception)
m.SDKExceptionMessage = aws.String(te.message)
}
}
func (m *metric) SetFinalException(e metricException) {
switch te := e.(type) {
case awsException:
m.FinalAWSException = aws.String(te.exception)
m.FinalAWSExceptionMessage = aws.String(te.message)
case sdkException:
m.FinalSDKException = aws.String(te.exception)
m.FinalSDKExceptionMessage = aws.String(te.message)
}
}
+54
View File
@@ -0,0 +1,54 @@
package csm
import (
"sync/atomic"
)
const (
runningEnum = iota
pausedEnum
)
var (
// MetricsChannelSize of metrics to hold in the channel
MetricsChannelSize = 100
)
type metricChan struct {
ch chan metric
paused int64
}
func newMetricChan(size int) metricChan {
return metricChan{
ch: make(chan metric, size),
}
}
func (ch *metricChan) Pause() {
atomic.StoreInt64(&ch.paused, pausedEnum)
}
func (ch *metricChan) Continue() {
atomic.StoreInt64(&ch.paused, runningEnum)
}
func (ch *metricChan) IsPaused() bool {
v := atomic.LoadInt64(&ch.paused)
return v == pausedEnum
}
// Push will push metrics to the metric channel if the channel
// is not paused
func (ch *metricChan) Push(m metric) bool {
if ch.IsPaused() {
return false
}
select {
case ch.ch <- m:
return true
default:
return false
}
}
+72
View File
@@ -0,0 +1,72 @@
package csm
import (
"testing"
)
func TestMetricChanPush(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPauseContinue(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
if !ch.IsPaused() {
t.Errorf("expected to be paused, but did not pause properly")
}
ch.Continue()
if ch.IsPaused() {
t.Errorf("expected to be not paused, but did not continue properly")
}
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPushWhenPaused(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to not be pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanNonBlocking(t *testing.T) {
ch := newMetricChan(0)
defer close(ch.ch)
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to be not pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
+26
View File
@@ -0,0 +1,26 @@
package csm
type metricException interface {
Exception() string
Message() string
}
type requestException struct {
exception string
message string
}
func (e requestException) Exception() string {
return e.exception
}
func (e requestException) Message() string {
return e.message
}
type awsException struct {
requestException
}
type sdkException struct {
requestException
}
+106
View File
@@ -0,0 +1,106 @@
// +build go1.7
package csm
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
)
func TestTruncateString(t *testing.T) {
cases := map[string]struct {
Val string
Len int
Expect string
}{
"no change": {
Val: "123456789", Len: 10,
Expect: "123456789",
},
"max len": {
Val: "1234567890", Len: 10,
Expect: "1234567890",
},
"too long": {
Val: "12345678901", Len: 10,
Expect: "1234567890",
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
v := c.Val
actual := truncateString(&v, c.Len)
if e, a := c.Val, v; e != a {
t.Errorf("expect input value not to change, %v, %v", e, a)
}
if e, a := c.Expect, *actual; e != a {
t.Errorf("expect %v, got %v", e, a)
}
})
}
}
func TestMetric_SetException(t *testing.T) {
cases := map[string]struct {
Exc metricException
Expect metric
Final bool
}{
"aws exc": {
Exc: awsException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
AWSException: aws.String("abc"),
AWSExceptionMessage: aws.String("123"),
},
},
"sdk exc": {
Exc: sdkException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
SDKException: aws.String("abc"),
SDKExceptionMessage: aws.String("123"),
},
},
"final aws exc": {
Exc: awsException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
FinalAWSException: aws.String("abc"),
FinalAWSExceptionMessage: aws.String("123"),
},
Final: true,
},
"final sdk exc": {
Exc: sdkException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
FinalSDKException: aws.String("abc"),
FinalSDKExceptionMessage: aws.String("123"),
},
Final: true,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
var m metric
if c.Final {
m.SetFinalException(c.Exc)
} else {
m.SetException(c.Exc)
}
if e, a := c.Expect, m; !reflect.DeepEqual(e, a) {
t.Errorf("expect:\n%#v\nactual:\n%#v\n", e, a)
}
})
}
}
+265
View File
@@ -0,0 +1,265 @@
package csm
import (
"encoding/json"
"net"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint.
type Reporter struct {
clientID string
url string
conn net.Conn
metricsCh metricChan
done chan struct{}
}
var (
sender *Reporter
)
func connect(url string) error {
const network = "udp"
if err := sender.connect(network, url); err != nil {
return err
}
if sender.done == nil {
sender.done = make(chan struct{})
go sender.start()
}
return nil
}
func newReporter(clientID, url string) *Reporter {
return &Reporter{
clientID: clientID,
url: url,
metricsCh: newMetricChan(MetricsChannelSize),
}
}
func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
if rep == nil {
return
}
now := time.Now()
creds, _ := r.Config.Credentials.Get()
m := metric{
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Region: r.Config.Region,
Type: aws.String("ApiCallAttempt"),
Version: aws.Int(1),
XAmzRequestID: aws.String(r.RequestID),
AttemptCount: aws.Int(r.RetryCount + 1),
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
AccessKey: aws.String(creds.AccessKeyID),
}
if r.HTTPResponse != nil {
m.HTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
m.SetException(getMetricException(awserr))
}
}
m.TruncateFields()
rep.metricsCh.Push(m)
}
func getMetricException(err awserr.Error) metricException {
msg := err.Error()
code := err.Code()
switch code {
case "RequestError",
request.ErrCodeSerialization,
request.CanceledErrorCode:
return sdkException{
requestException{exception: code, message: msg},
}
default:
return awsException{
requestException{exception: code, message: msg},
}
}
}
func (rep *Reporter) sendAPICallMetric(r *request.Request) {
if rep == nil {
return
}
now := time.Now()
m := metric{
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1),
Region: r.Config.Region,
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
}
if r.HTTPResponse != nil {
m.FinalHTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
m.SetFinalException(getMetricException(awserr))
}
}
m.TruncateFields()
// TODO: Probably want to figure something out for logging dropped
// metrics
rep.metricsCh.Push(m)
}
func (rep *Reporter) connect(network, url string) error {
if rep.conn != nil {
rep.conn.Close()
}
conn, err := net.Dial(network, url)
if err != nil {
return awserr.New("UDPError", "Could not connect", err)
}
rep.conn = conn
return nil
}
func (rep *Reporter) close() {
if rep.done != nil {
close(rep.done)
}
rep.metricsCh.Pause()
}
func (rep *Reporter) start() {
defer func() {
rep.metricsCh.Pause()
}()
for {
select {
case <-rep.done:
rep.done = nil
return
case m := <-rep.metricsCh.ch:
// TODO: What to do with this error? Probably should just log
b, err := json.Marshal(m)
if err != nil {
continue
}
rep.conn.Write(b)
}
}
}
// Pause will pause the metric channel preventing any new metrics from being
// added. It is safe to call concurrently with other calls to Pause, but if
// called concurently with Continue can lead to unexpected state.
func (rep *Reporter) Pause() {
lock.Lock()
defer lock.Unlock()
if rep == nil {
return
}
rep.close()
}
// Continue will reopen the metric channel and allow for monitoring to be
// resumed. It is safe to call concurrently with other calls to Continue, but
// if called concurently with Pause can lead to unexpected state.
func (rep *Reporter) Continue() {
lock.Lock()
defer lock.Unlock()
if rep == nil {
return
}
if !rep.metricsCh.IsPaused() {
return
}
rep.metricsCh.Continue()
}
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent.
//
// InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
//
// // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }
//
// sess := session.NewSession()
// r.InjectHandlers(&sess.Handlers)
//
// // create a new service client with our client side metric session
// svc := s3.New(sess)
func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
if rep == nil {
return
}
handlers.Complete.PushFrontNamed(request.NamedHandler{
Name: APICallMetricHandlerName,
Fn: rep.sendAPICallMetric,
})
handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{
Name: APICallAttemptMetricHandlerName,
Fn: rep.sendAPICallAttemptMetric,
})
}
// boolIntValue return 1 for true and 0 for false.
func boolIntValue(b bool) int {
if b {
return 1
}
return 0
}
+72
View File
@@ -0,0 +1,72 @@
package csm
import (
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestMaxRetriesExceeded(t *testing.T) {
md := metadata.ClientInfo{
Endpoint: "http://127.0.0.1",
}
cfg := aws.Config{
Region: aws.String("foo"),
Credentials: credentials.NewStaticCredentials("", "", ""),
}
op := &request.Operation{}
cases := []struct {
name string
httpStatusCode int
expectedMaxRetriesValue int
expectedMetrics int
}{
{
name: "max retry reached",
httpStatusCode: http.StatusBadGateway,
expectedMaxRetriesValue: 1,
},
{
name: "status ok",
httpStatusCode: http.StatusOK,
expectedMaxRetriesValue: 0,
},
}
for _, c := range cases {
r := request.New(cfg, md, defaults.Handlers(), client.DefaultRetryer{NumMaxRetries: 2}, op, nil, nil)
reporter := newReporter("", "")
r.Handlers.Send.Clear()
reporter.InjectHandlers(&r.Handlers)
r.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: c.httpStatusCode,
}
})
r.Send()
for {
m := <-reporter.metricsCh.ch
if *m.Type != "ApiCall" {
// ignore non-ApiCall metrics since MaxRetriesExceeded is only on ApiCall events
continue
}
if val := *m.MaxRetriesExceeded; val != c.expectedMaxRetriesValue {
t.Errorf("%s: expected %d, but received %d", c.name, c.expectedMaxRetriesValue, val)
}
break
}
}
}
+414
View File
@@ -0,0 +1,414 @@
// +build go1.7
package csm_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"sort"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
)
func TestReportingMetrics(t *testing.T) {
sess := unit.Session.Copy(&aws.Config{
SleepDelay: func(time.Duration) {},
})
sess.Handlers.Validate.Clear()
sess.Handlers.Sign.Clear()
sess.Handlers.Send.Clear()
reporter := csm.Get()
if reporter == nil {
t.Errorf("expected non-nil reporter")
}
reporter.InjectHandlers(&sess.Handlers)
cases := map[string]struct {
Request *request.Request
ExpectMetrics []map[string]interface{}
}{
"successful request": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 3}, op, nil, nil)
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{},
}
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(200),
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(200),
},
},
},
"failed request, no retry": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 3}, op, nil, nil)
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = &http.Response{
StatusCode: 400,
Header: http.Header{},
}
req.Retryable = aws.Bool(false)
req.Error = awserr.New("Error", "Message", nil)
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(400),
"AwsException": "Error",
"AwsExceptionMessage": "Error: Message",
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(400),
"FinalAwsException": "Error",
"FinalAwsExceptionMessage": "Error: Message",
"AttemptCount": float64(1),
},
},
},
"failed request, with retry": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 1}, op, nil, nil)
resps := []*http.Response{
{
StatusCode: 500,
Header: http.Header{},
},
{
StatusCode: 500,
Header: http.Header{},
},
}
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = resps[0]
resps = resps[1:]
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(500),
"AwsException": "UnknownError",
"AwsExceptionMessage": "UnknownError: unknown error",
},
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(500),
"AwsException": "UnknownError",
"AwsExceptionMessage": "UnknownError: unknown error",
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(500),
"FinalAwsException": "UnknownError",
"FinalAwsExceptionMessage": "UnknownError: unknown error",
"AttemptCount": float64(2),
},
},
},
"success request, with retry": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 3}, op, nil, nil)
errs := []error{
awserr.New("AWSError", "aws error", nil),
awserr.New("RequestError", "sdk error", nil),
nil,
}
resps := []*http.Response{
{
StatusCode: 500,
Header: http.Header{},
},
{
StatusCode: 500,
Header: http.Header{},
},
{
StatusCode: 200,
Header: http.Header{},
},
}
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = resps[0]
resps = resps[1:]
req.Error = errs[0]
errs = errs[1:]
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"AwsException": "AWSError",
"AwsExceptionMessage": "AWSError: aws error",
"HttpStatusCode": float64(500),
},
{
"Type": "ApiCallAttempt",
"SdkException": "RequestError",
"SdkExceptionMessage": "RequestError: sdk error",
"HttpStatusCode": float64(500),
},
{
"Type": "ApiCallAttempt",
"AwsException": nil,
"AwsExceptionMessage": nil,
"SdkException": nil,
"SdkExceptionMessage": nil,
"HttpStatusCode": float64(200),
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(200),
"FinalAwsException": nil,
"FinalAwsExceptionMessage": nil,
"FinalSdkException": nil,
"FinalSdkExceptionMessage": nil,
"AttemptCount": float64(3),
},
},
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
c.Request.Send()
for i := 0; i < len(c.ExpectMetrics); i++ {
select {
case m := <-csm.MetricsCh:
for ek, ev := range c.ExpectMetrics[i] {
if ev == nil {
// must not be set
if _, ok := m[ek]; ok {
t.Errorf("%d, expect %v metric member, not to be set, %v", i, ek, m[ek])
}
continue
}
if _, ok := m[ek]; !ok {
t.Errorf("%d, expect %v metric member, keys: %v", i, ek, keys(m))
}
if e, a := ev, m[ek]; e != a {
t.Errorf("%d, expect %v:%v(%T), metric value, got %v(%T)", i, ek, e, e, a, a)
}
}
case <-ctx.Done():
t.Errorf("timeout waiting for metrics")
return
}
}
var extraMetrics []map[string]interface{}
Loop:
for {
select {
case m := <-csm.MetricsCh:
extraMetrics = append(extraMetrics, m)
default:
break Loop
}
}
if len(extraMetrics) != 0 {
t.Fatalf("unexpected metrics, %#v", extraMetrics)
}
})
}
}
type mockService struct {
*client.Client
}
type input struct{}
type output struct{}
func (s *mockService) Request(i input) *request.Request {
op := &request.Operation{
Name: "foo",
HTTPMethod: "POST",
HTTPPath: "/",
}
o := output{}
req := s.NewRequest(op, &i, &o)
return req
}
func BenchmarkWithCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := unit.Session.Copy(&cfg)
r := csm.Get()
r.InjectHandlers(&sess.Handlers)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithCSMNoUDPConnection(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := unit.Session.Copy(&cfg)
r := csm.Get()
r.Pause()
r.InjectHandlers(&sess.Handlers)
defer r.Pause()
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithoutCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := unit.Session.Copy(&cfg)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func keys(m map[string]interface{}) []string {
ks := make([]string, 0, len(m))
for k := range m {
ks = append(ks, k)
}
sort.Strings(ks)
return ks
}
+22 -9
View File
@@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
// A Defaults provides a collection of default values for SDK clients.
@@ -92,17 +93,28 @@ func Handlers() request.Handlers {
func CredChain(cfg *aws.Config, handlers request.Handlers) *credentials.Credentials {
return credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
RemoteCredProvider(*cfg, handlers),
},
Providers: CredProviders(cfg, handlers),
})
}
// CredProviders returns the slice of providers used in
// the default credential chain.
//
// For applications that need to use some other provider (for example use
// different environment variables for legacy reasons) but still fall back
// on the default chain of providers. This allows that default chaint to be
// automatically updated
func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Provider {
return []credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
RemoteCredProvider(*cfg, handlers),
}
}
const (
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
)
// RemoteCredProvider returns a credentials provider for the default remote
@@ -112,8 +124,8 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
return localHTTPCredProvider(cfg, handlers, u)
}
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("http://169.254.170.2%s", uri)
if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
return httpCredProvider(cfg, handlers, u)
}
@@ -176,6 +188,7 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
return endpointcreds.NewProviderClient(cfg, handlers, u,
func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
},
)
}
+22 -12
View File
@@ -10,6 +10,8 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdktesting"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func TestHTTPCredProvider(t *testing.T) {
@@ -37,23 +39,27 @@ func TestHTTPCredProvider(t *testing.T) {
}
cases := []struct {
Host string
Fail bool
Host string
AuthToken string
Fail bool
}{
{"localhost", false},
{"actuallylocal", false},
{"127.0.0.1", false},
{"127.1.1.1", false},
{"[::1]", false},
{"www.example.com", true},
{"169.254.170.2", true},
{Host: "localhost", Fail: false},
{Host: "actuallylocal", Fail: false},
{Host: "127.0.0.1", Fail: false},
{Host: "127.1.1.1", Fail: false},
{Host: "[::1]", Fail: false},
{Host: "www.example.com", Fail: true},
{Host: "169.254.170.2", Fail: true},
{Host: "localhost", Fail: false, AuthToken: "Basic abc123"},
}
defer os.Clearenv()
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
for i, c := range cases {
u := fmt.Sprintf("http://%s/abc/123", c.Host)
os.Setenv(httpProviderEnvVar, u)
os.Setenv(httpProviderAuthorizationEnvVar, c.AuthToken)
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
@@ -78,13 +84,17 @@ func TestHTTPCredProvider(t *testing.T) {
if e, a := u, httpProvider.Client.Endpoint; e != a {
t.Errorf("%d, expect %q endpoint, got %q", i, e, a)
}
if e, a := c.AuthToken, httpProvider.AuthorizationToken; e != a {
t.Errorf("%d, expect %q auth token, got %q", i, e, a)
}
}
}
}
func TestECSCredProvider(t *testing.T) {
defer os.Clearenv()
os.Setenv(ecsCredsProviderEnvVar, "/abc/123")
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv(shareddefaults.ECSCredsProviderEnvVar, "/abc/123")
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
+16 -9
View File
@@ -4,12 +4,12 @@ import (
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
// GetMetadata uses the path provided to request information from the EC2
@@ -19,13 +19,14 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", p),
HTTPPath: sdkuri.PathJoin("/meta-data", p),
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
err := req.Send()
return output.Content, req.Send()
return output.Content, err
}
// GetUserData returns the userdata that was configured for the service. If
@@ -35,7 +36,7 @@ func (c *EC2Metadata) GetUserData() (string, error) {
op := &request.Operation{
Name: "GetUserData",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "user-data"),
HTTPPath: "/user-data",
}
output := &metadataOutput{}
@@ -45,8 +46,9 @@ func (c *EC2Metadata) GetUserData() (string, error) {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
err := req.Send()
return output.Content, req.Send()
return output.Content, err
}
// GetDynamicData uses the path provided to request information from the EC2
@@ -56,13 +58,14 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
op := &request.Operation{
Name: "GetDynamicData",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "dynamic", p),
HTTPPath: sdkuri.PathJoin("/dynamic", p),
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
err := req.Send()
return output.Content, req.Send()
return output.Content, err
}
// GetInstanceIdentityDocument retrieves an identity document describing an
@@ -79,7 +82,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
doc := EC2InstanceIdentityDocument{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
return EC2InstanceIdentityDocument{},
awserr.New("SerializationError",
awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 instance identity document", err)
}
@@ -98,7 +101,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
info := EC2IAMInfo{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
return EC2IAMInfo{},
awserr.New("SerializationError",
awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 IAM info", err)
}
@@ -118,6 +121,10 @@ func (c *EC2Metadata) Region() (string, error) {
return "", err
}
if len(resp) == 0 {
return "", awserr.New("EC2MetadataError", "invalid Region response", nil)
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
}
+17
View File
@@ -167,6 +167,23 @@ func TestGetRegion(t *testing.T) {
}
}
func TestGetRegion_invalidResponse(t *testing.T) {
server := initTestServer(
"/latest/meta-data/placement/availability-zone",
"", // no data in response
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
region, err := c.Region()
if err == nil {
t.Errorf("expected error, got %v", err)
}
if e, a := "", region; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataAvailable(t *testing.T) {
server := initTestServer(
"/latest/meta-data/instance-id",
+7 -3
View File
@@ -4,7 +4,7 @@
// This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive).
// be used while the environment variable is set to true, (case insensitive).
package ec2metadata
import (
@@ -72,6 +72,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
ServiceID: ServiceName,
Endpoint: endpoint,
APIVersion: "latest",
},
@@ -91,6 +92,9 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
Header: http.Header{},
}
r.Error = awserr.New(
request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
@@ -119,7 +123,7 @@ func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata respose", err)
return
}
@@ -132,7 +136,7 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error respose", err)
return
}
+6 -4
View File
@@ -13,8 +13,8 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/internal/sdktesting"
)
func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
@@ -80,12 +80,14 @@ func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
}
func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
restoreEnvFn := sdktesting.StashEnv()
defer restoreEnvFn()
os.Setenv("AWS_EC2_METADATA_DISABLED", "true")
svc := ec2metadata.New(unit.Session)
svc := ec2metadata.New(unit.Session, &aws.Config{
LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody),
})
resp, err := svc.Region()
if err == nil {
t.Fatalf("expect error, got none")
+57 -2
View File
@@ -84,6 +84,8 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
custAddEC2Metadata(p)
custAddS3DualStack(p)
custRmIotDataService(p)
custFixAppAutoscalingChina(p)
custFixAppAutoscalingUsGov(p)
}
return ps, nil
@@ -94,7 +96,12 @@ func custAddS3DualStack(p *partition) {
return
}
s, ok := p.Services["s3"]
custAddDualstack(p, "s3")
custAddDualstack(p, "s3-control")
}
func custAddDualstack(p *partition, svcName string) {
s, ok := p.Services[svcName]
if !ok {
return
}
@@ -102,7 +109,7 @@ func custAddS3DualStack(p *partition) {
s.Defaults.HasDualStack = boxedTrue
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
p.Services["s3"] = s
p.Services[svcName] = s
}
func custAddEC2Metadata(p *partition) {
@@ -122,6 +129,54 @@ func custRmIotDataService(p *partition) {
delete(p.Services, "data.iot")
}
func custFixAppAutoscalingChina(p *partition) {
if p.ID != "aws-cn" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
const expectHostname = `autoscaling.{region}.amazonaws.com`
if e, a := s.Defaults.Hostname, expectHostname; e != a {
fmt.Printf("custFixAppAutoscalingChina: ignoring customization, expected %s, got %s\n", e, a)
return
}
s.Defaults.Hostname = expectHostname + ".cn"
p.Services[serviceName] = s
}
func custFixAppAutoscalingUsGov(p *partition) {
if p.ID != "aws-us-gov" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
if a := s.Defaults.CredentialScope.Service; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
return
}
if a := s.Defaults.Hostname; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
return
}
s.Defaults.CredentialScope.Service = "application-autoscaling"
s.Defaults.Hostname = "autoscaling.{region}.amazonaws.com"
p.Services[serviceName] = s
}
type decodeModelError struct {
awsError
}
+107
View File
@@ -115,3 +115,110 @@ func TestDecodeModelOptionsSet(t *testing.T) {
t.Errorf("expect %v options got %v", expect, actual)
}
}
func TestCustFixAppAutoscalingChina(t *testing.T) {
const doc = `
{
"version": 3,
"partitions": [{
"defaults" : {
"hostname" : "{service}.{region}.{dnsSuffix}",
"protocols" : [ "https" ],
"signatureVersions" : [ "v4" ]
},
"dnsSuffix" : "amazonaws.com.cn",
"partition" : "aws-cn",
"partitionName" : "AWS China",
"regionRegex" : "^cn\\-\\w+\\-\\d+$",
"regions" : {
"cn-north-1" : {
"description" : "China (Beijing)"
},
"cn-northwest-1" : {
"description" : "China (Ningxia)"
}
},
"services" : {
"application-autoscaling" : {
"defaults" : {
"credentialScope" : {
"service" : "application-autoscaling"
},
"hostname" : "autoscaling.{region}.amazonaws.com",
"protocols" : [ "http", "https" ]
},
"endpoints" : {
"cn-north-1" : { },
"cn-northwest-1" : { }
}
}
}
}]
}`
resolver, err := DecodeModel(strings.NewReader(doc))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
endpoint, err := resolver.EndpointFor(
"application-autoscaling", "cn-northwest-1",
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := `https://autoscaling.cn-northwest-1.amazonaws.com.cn`, endpoint.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestCustFixAppAutoscalingUsGov(t *testing.T) {
const doc = `
{
"version": 3,
"partitions": [{
"defaults" : {
"hostname" : "{service}.{region}.{dnsSuffix}",
"protocols" : [ "https" ],
"signatureVersions" : [ "v4" ]
},
"dnsSuffix" : "amazonaws.com",
"partition" : "aws-us-gov",
"partitionName" : "AWS GovCloud (US)",
"regionRegex" : "^us\\-gov\\-\\w+\\-\\d+$",
"regions" : {
"us-gov-east-1" : {
"description" : "AWS GovCloud (US-East)"
},
"us-gov-west-1" : {
"description" : "AWS GovCloud (US)"
}
},
"services" : {
"application-autoscaling" : {
"endpoints" : {
"us-gov-east-1" : { },
"us-gov-west-1" : { }
}
}
}
}]
}`
resolver, err := DecodeModel(strings.NewReader(doc))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
endpoint, err := resolver.EndpointFor(
"application-autoscaling", "us-gov-west-1",
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := `https://autoscaling.us-gov-west-1.amazonaws.com`, endpoint.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
File diff suppressed because it is too large Load Diff
+141
View File
@@ -0,0 +1,141 @@
package endpoints
// Service identifiers
//
// Deprecated: Use client package's EndpointsID value instead of these
// ServiceIDs. These IDs are not maintained, and are out of date.
const (
A4bServiceID = "a4b" // A4b.
AcmServiceID = "acm" // Acm.
AcmPcaServiceID = "acm-pca" // AcmPca.
ApiMediatailorServiceID = "api.mediatailor" // ApiMediatailor.
ApiPricingServiceID = "api.pricing" // ApiPricing.
ApiSagemakerServiceID = "api.sagemaker" // ApiSagemaker.
ApigatewayServiceID = "apigateway" // Apigateway.
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
Appstream2ServiceID = "appstream2" // Appstream2.
AppsyncServiceID = "appsync" // Appsync.
AthenaServiceID = "athena" // Athena.
AutoscalingServiceID = "autoscaling" // Autoscaling.
AutoscalingPlansServiceID = "autoscaling-plans" // AutoscalingPlans.
BatchServiceID = "batch" // Batch.
BudgetsServiceID = "budgets" // Budgets.
CeServiceID = "ce" // Ce.
ChimeServiceID = "chime" // Chime.
Cloud9ServiceID = "cloud9" // Cloud9.
ClouddirectoryServiceID = "clouddirectory" // Clouddirectory.
CloudformationServiceID = "cloudformation" // Cloudformation.
CloudfrontServiceID = "cloudfront" // Cloudfront.
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
CodebuildServiceID = "codebuild" // Codebuild.
CodecommitServiceID = "codecommit" // Codecommit.
CodedeployServiceID = "codedeploy" // Codedeploy.
CodepipelineServiceID = "codepipeline" // Codepipeline.
CodestarServiceID = "codestar" // Codestar.
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
ComprehendServiceID = "comprehend" // Comprehend.
ConfigServiceID = "config" // Config.
CurServiceID = "cur" // Cur.
DatapipelineServiceID = "datapipeline" // Datapipeline.
DaxServiceID = "dax" // Dax.
DevicefarmServiceID = "devicefarm" // Devicefarm.
DirectconnectServiceID = "directconnect" // Directconnect.
DiscoveryServiceID = "discovery" // Discovery.
DmsServiceID = "dms" // Dms.
DsServiceID = "ds" // Ds.
DynamodbServiceID = "dynamodb" // Dynamodb.
Ec2ServiceID = "ec2" // Ec2.
Ec2metadataServiceID = "ec2metadata" // Ec2metadata.
EcrServiceID = "ecr" // Ecr.
EcsServiceID = "ecs" // Ecs.
ElasticacheServiceID = "elasticache" // Elasticache.
ElasticbeanstalkServiceID = "elasticbeanstalk" // Elasticbeanstalk.
ElasticfilesystemServiceID = "elasticfilesystem" // Elasticfilesystem.
ElasticloadbalancingServiceID = "elasticloadbalancing" // Elasticloadbalancing.
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
EmailServiceID = "email" // Email.
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
EsServiceID = "es" // Es.
EventsServiceID = "events" // Events.
FirehoseServiceID = "firehose" // Firehose.
FmsServiceID = "fms" // Fms.
GameliftServiceID = "gamelift" // Gamelift.
GlacierServiceID = "glacier" // Glacier.
GlueServiceID = "glue" // Glue.
GreengrassServiceID = "greengrass" // Greengrass.
GuarddutyServiceID = "guardduty" // Guardduty.
HealthServiceID = "health" // Health.
IamServiceID = "iam" // Iam.
ImportexportServiceID = "importexport" // Importexport.
InspectorServiceID = "inspector" // Inspector.
IotServiceID = "iot" // Iot.
IotanalyticsServiceID = "iotanalytics" // Iotanalytics.
KinesisServiceID = "kinesis" // Kinesis.
KinesisanalyticsServiceID = "kinesisanalytics" // Kinesisanalytics.
KinesisvideoServiceID = "kinesisvideo" // Kinesisvideo.
KmsServiceID = "kms" // Kms.
LambdaServiceID = "lambda" // Lambda.
LightsailServiceID = "lightsail" // Lightsail.
LogsServiceID = "logs" // Logs.
MachinelearningServiceID = "machinelearning" // Machinelearning.
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
MediaconvertServiceID = "mediaconvert" // Mediaconvert.
MedialiveServiceID = "medialive" // Medialive.
MediapackageServiceID = "mediapackage" // Mediapackage.
MediastoreServiceID = "mediastore" // Mediastore.
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
MghServiceID = "mgh" // Mgh.
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
ModelsLexServiceID = "models.lex" // ModelsLex.
MonitoringServiceID = "monitoring" // Monitoring.
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
NeptuneServiceID = "neptune" // Neptune.
OpsworksServiceID = "opsworks" // Opsworks.
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
OrganizationsServiceID = "organizations" // Organizations.
PinpointServiceID = "pinpoint" // Pinpoint.
PollyServiceID = "polly" // Polly.
RdsServiceID = "rds" // Rds.
RedshiftServiceID = "redshift" // Redshift.
RekognitionServiceID = "rekognition" // Rekognition.
ResourceGroupsServiceID = "resource-groups" // ResourceGroups.
Route53ServiceID = "route53" // Route53.
Route53domainsServiceID = "route53domains" // Route53domains.
RuntimeLexServiceID = "runtime.lex" // RuntimeLex.
RuntimeSagemakerServiceID = "runtime.sagemaker" // RuntimeSagemaker.
S3ServiceID = "s3" // S3.
S3ControlServiceID = "s3-control" // S3Control.
SagemakerServiceID = "api.sagemaker" // Sagemaker.
SdbServiceID = "sdb" // Sdb.
SecretsmanagerServiceID = "secretsmanager" // Secretsmanager.
ServerlessrepoServiceID = "serverlessrepo" // Serverlessrepo.
ServicecatalogServiceID = "servicecatalog" // Servicecatalog.
ServicediscoveryServiceID = "servicediscovery" // Servicediscovery.
ShieldServiceID = "shield" // Shield.
SmsServiceID = "sms" // Sms.
SnowballServiceID = "snowball" // Snowball.
SnsServiceID = "sns" // Sns.
SqsServiceID = "sqs" // Sqs.
SsmServiceID = "ssm" // Ssm.
StatesServiceID = "states" // States.
StoragegatewayServiceID = "storagegateway" // Storagegateway.
StreamsDynamodbServiceID = "streams.dynamodb" // StreamsDynamodb.
StsServiceID = "sts" // Sts.
SupportServiceID = "support" // Support.
SwfServiceID = "swf" // Swf.
TaggingServiceID = "tagging" // Tagging.
TransferServiceID = "transfer" // Transfer.
TranslateServiceID = "translate" // Translate.
WafServiceID = "waf" // Waf.
WafRegionalServiceID = "waf-regional" // WafRegional.
WorkdocsServiceID = "workdocs" // Workdocs.
WorkmailServiceID = "workmail" // Workmail.
WorkspacesServiceID = "workspaces" // Workspaces.
XrayServiceID = "xray" // Xray.
)
+13 -7
View File
@@ -35,7 +35,7 @@ type Options struct {
//
// If resolving an endpoint on the partition list the provided region will
// be used to determine which partition's domain name pattern to the service
// endpoint ID with. If both the service and region are unkonwn and resolving
// endpoint ID with. If both the service and region are unknown and resolving
// the endpoint on partition list an UnknownEndpointError error will be returned.
//
// If resolving and endpoint on a partition specific resolver that partition's
@@ -206,10 +206,11 @@ func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (
// enumerating over the regions in a partition.
func (p Partition) Regions() map[string]Region {
rs := map[string]Region{}
for id := range p.p.Regions {
for id, r := range p.p.Regions {
rs[id] = Region{
id: id,
p: p.p,
id: id,
desc: r.Description,
p: p.p,
}
}
@@ -240,6 +241,10 @@ type Region struct {
// ID returns the region's identifier.
func (r Region) ID() string { return r.id }
// Description returns the region's description. The region description
// is free text, it can be empty, and it may change between SDK releases.
func (r Region) Description() string { return r.desc }
// ResolveEndpoint resolves an endpoint from the context of the region given
// a service. See Partition.EndpointFor for usage and errors that can be returned.
func (r Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
@@ -284,10 +289,11 @@ func (s Service) ResolveEndpoint(region string, opts ...func(*Options)) (Resolve
func (s Service) Regions() map[string]Region {
rs := map[string]Region{}
for id := range s.p.Services[s.id].Endpoints {
if _, ok := s.p.Regions[id]; ok {
if r, ok := s.p.Regions[id]; ok {
rs[id] = Region{
id: id,
p: s.p,
id: id,
desc: r.Description,
p: s.p,
}
}
}
+7
View File
@@ -65,6 +65,10 @@ func TestEnumRegionServices(t *testing.T) {
t.Errorf("expect %q region ID, got %q", e, a)
}
if a, e := r.Description(), "region description"; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
ss := r.Services()
if a, e := len(ss), 1; a != e {
t.Errorf("expect %d services for us-east-1, got %d", e, a)
@@ -291,6 +295,9 @@ func TestRegionsForService(t *testing.T) {
if _, ok := expect[id]; !ok {
t.Errorf("expect %s region to be found", id)
}
if a, e := r.Description(), expect[id].desc; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
}
}

Some files were not shown because too many files have changed in this diff Show More