Update Go AWS SDK to the latest version

This commit is contained in:
Andrey Smirnov
2019-07-13 00:03:55 +03:00
committed by Andrey Smirnov
parent d08be990ef
commit 94a72b23ff
2183 changed files with 885887 additions and 228114 deletions
+127
View File
@@ -0,0 +1,127 @@
package rdsutils
import (
"fmt"
"net/url"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)
// ConnectionFormat is the type of connection that will be
// used to connect to the database
type ConnectionFormat string
// ConnectionFormat enums
const (
NoConnectionFormat ConnectionFormat = ""
TCPFormat ConnectionFormat = "tcp"
)
// ErrNoConnectionFormat will be returned during build if no format had been
// specified
var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection format was specified", nil)
// ConnectionStringBuilder is a builder that will construct a connection
// string with the provided parameters. params field is required to have
// a tls specification and allowCleartextPasswords must be set to true.
type ConnectionStringBuilder struct {
dbName string
endpoint string
region string
user string
creds *credentials.Credentials
connectFormat ConnectionFormat
params url.Values
}
// NewConnectionStringBuilder will return an ConnectionStringBuilder
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, creds *credentials.Credentials) ConnectionStringBuilder {
return ConnectionStringBuilder{
dbName: dbName,
endpoint: endpoint,
region: region,
user: dbUser,
creds: creds,
}
}
// WithEndpoint will return a builder with the given endpoint
func (b ConnectionStringBuilder) WithEndpoint(endpoint string) ConnectionStringBuilder {
b.endpoint = endpoint
return b
}
// WithRegion will return a builder with the given region
func (b ConnectionStringBuilder) WithRegion(region string) ConnectionStringBuilder {
b.region = region
return b
}
// WithUser will return a builder with the given user
func (b ConnectionStringBuilder) WithUser(user string) ConnectionStringBuilder {
b.user = user
return b
}
// WithDBName will return a builder with the given database name
func (b ConnectionStringBuilder) WithDBName(dbName string) ConnectionStringBuilder {
b.dbName = dbName
return b
}
// WithParams will return a builder with the given params. The parameters
// will be included in the connection query string
//
// Example:
// v := url.Values{}
// v.Add("tls", "rds")
// b := rdsutils.NewConnectionBuilder(endpoint, region, user, dbname, creds)
// connectStr, err := b.WithParams(v).WithTCPFormat().Build()
func (b ConnectionStringBuilder) WithParams(params url.Values) ConnectionStringBuilder {
b.params = params
return b
}
// WithFormat will return a builder with the given connection format
func (b ConnectionStringBuilder) WithFormat(f ConnectionFormat) ConnectionStringBuilder {
b.connectFormat = f
return b
}
// WithTCPFormat will set the format to TCP and return the modified builder
func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder {
return b.WithFormat(TCPFormat)
}
// Build will return a new connection string that can be used to open a connection
// to the desired database.
//
// Example:
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
// connectStr, err := b.WithTCPFormat().Build()
// if err != nil {
// panic(err)
// }
// const dbType = "mysql"
// db, err := sql.Open(dbType, connectStr)
func (b ConnectionStringBuilder) Build() (string, error) {
if b.connectFormat == NoConnectionFormat {
return "", ErrNoConnectionFormat
}
authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.creds)
if err != nil {
return "", err
}
connectionStr := fmt.Sprintf("%s:%s@%s(%s)/%s",
b.user, authToken, string(b.connectFormat), b.endpoint, b.dbName,
)
if len(b.params) > 0 {
connectionStr = fmt.Sprintf("%s?%s", connectionStr, b.params.Encode())
}
return connectionStr, nil
}
+58
View File
@@ -0,0 +1,58 @@
package rdsutils_test
import (
"net/url"
"regexp"
"testing"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
)
func TestConnectionStringBuilder(t *testing.T) {
cases := []struct {
user string
endpoint string
region string
dbName string
values url.Values
format rdsutils.ConnectionFormat
creds *credentials.Credentials
expectedErr error
expectedConnectRegex string
}{
{
user: "foo",
endpoint: "foo.bar",
region: "region",
dbName: "name",
format: rdsutils.NoConnectionFormat,
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
expectedErr: rdsutils.ErrNoConnectionFormat,
expectedConnectRegex: "",
},
{
user: "foo",
endpoint: "foo.bar",
region: "region",
dbName: "name",
format: rdsutils.TCPFormat,
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
expectedConnectRegex: `^foo:foo.bar\?Action=connect\&DBUser=foo.*\@tcp\(foo.bar\)/name`,
},
}
for _, c := range cases {
b := rdsutils.NewConnectionStringBuilder(c.endpoint, c.region, c.user, c.dbName, c.creds)
connectStr, err := b.WithFormat(c.format).Build()
if e, a := c.expectedErr, err; e != a {
t.Errorf("expected %v error, but received %v", e, a)
}
if re, a := regexp.MustCompile(c.expectedConnectRegex), connectStr; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
}
}
+8 -11
View File
@@ -9,16 +9,13 @@ import (
"github.com/aws/aws-sdk-go/aws/signer/v4"
)
// BuildAuthToken will return a authentication token for the database's connect
// based on the RDS database endpoint, AWS region, IAM user or role, and AWS credentials.
// BuildAuthToken will return an authorization token used as the password for a DB
// connection.
//
// Endpoint consists of the hostname and port, IE hostname:port, of the RDS database.
// Region is the AWS region the RDS database is in and where the authentication token
// will be generated for. DbUser is the IAM user or role the request will be authenticated
// for. The creds is the AWS credentials the authentication token is signed with.
//
// An error is returned if the authentication token is unable to be signed with
// the credentials, or the endpoint is not a valid URL.
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
// * region - Region is the location of where the DB is
// * dbUser - User account within the database to sign in with
// * creds - Credentials to be signed with
//
// The following example shows how to use BuildAuthToken to create an authentication
// token for connecting to a MySQL database in RDS.
@@ -27,12 +24,12 @@ import (
//
// // Create the MySQL DNS string for the DB connection
// // user:password@protocol(endpoint)/dbname?<params>
// dnsStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=true",
// connectStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?allowCleartextPasswords=true&tls=rds",
// dbUser, authToken, dbEndpoint, dbName,
// )
//
// // Use db to perform SQL operations on database
// db, err := sql.Open("mysql", dnsStr)
// db, err := sql.Open("mysql", connectStr)
//
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
// for more information on using IAM database authentication with RDS.
+18
View File
@@ -0,0 +1,18 @@
// Package rdsutils is used to generate authentication tokens used to
// connect to a givent Amazon Relational Database Service (RDS) database.
//
// Before using the authentication please visit the docs here to ensure
// the database has the proper policies to allow for IAM token authentication.
// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html#UsingWithRDS.IAMDBAuth.Availability
//
// When building the connection string, there are two required parameters that are needed to be set on the query.
// * tls
// * allowCleartextPasswords must be set to true
//
// Example creating a basic auth token with the builder:
// v := url.Values{}
// v.Add("tls", "tls_profile_name")
// v.Add("allowCleartextPasswords", "true")
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
// connectStr, err := b.WithTCPFormat().WithParams(v).Build()
package rdsutils
+119
View File
@@ -0,0 +1,119 @@
// +build example,exclude
package rdsutils_test
import (
"crypto/tls"
"crypto/x509"
"database/sql"
"flag"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"github.com/go-sql-driver/mysql"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
)
// ExampleConnectionStringBuilder contains usage of assuming a role and using
// that to build the auth token.
// Usage:
// ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306
func ExampleConnectionStringBuilder() {
userPtr := flag.String("user", "", "user of the credentials")
regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds")
roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds")
endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to")
portPtr := flag.Int("port", 3306, "DB port to be connected to")
tablePtr := flag.String("table", "test_table", "DB table to query against")
dbNamePtr := flag.String("dbname", "", "DB name to query against")
flag.Parse()
// Check required flags. Will exit with status code 1 if
// required field isn't set.
if err := requiredFlags(
userPtr,
regionPtr,
roleArnPtr,
endpointPtr,
portPtr,
dbNamePtr,
); err != nil {
fmt.Printf("Error: %v\n\n", err)
flag.PrintDefaults()
os.Exit(1)
}
err := registerRDSMysqlCerts(http.DefaultClient)
if err != nil {
panic(err)
}
sess := session.Must(session.NewSession())
creds := stscreds.NewCredentials(sess, *roleArnPtr)
v := url.Values{}
// required fields for DB connection
v.Add("tls", "rds")
v.Add("allowCleartextPasswords", "true")
endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr)
b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, creds)
connectStr, err := b.WithTCPFormat().WithParams(v).Build()
const dbType = "mysql"
db, err := sql.Open(dbType, connectStr)
// if an error is encountered here, then most likely security groups are incorrect
// in the database.
if err != nil {
panic(fmt.Errorf("failed to open connection to the database"))
}
rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr))
if err != nil {
panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err))
}
for rows.Next() {
columns, err := rows.Columns()
if err != nil {
panic(fmt.Errorf("failed to read columns from row: %v", err))
}
fmt.Printf("rows colums:\n%d\n", len(columns))
}
}
func requiredFlags(flags ...interface{}) error {
for _, f := range flags {
switch f.(type) {
case nil:
return fmt.Errorf("one or more required flags were not set")
}
}
return nil
}
func registerRDSMysqlCerts(c *http.Client) error {
resp, err := c.Get("https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem")
if err != nil {
return err
}
pem, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
rootCertPool := x509.NewCertPool()
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append cert to cert pool!")
}
return mysql.RegisterTLSConfig("rds", &tls.Config{RootCAs: rootCertPool, InsecureSkipVerify: true})
}