mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-06-02 04:50:49 +00:00
Update Go AWS SDK to the latest version
This commit is contained in:
committed by
Andrey Smirnov
parent
d08be990ef
commit
94a72b23ff
+127
@@ -0,0 +1,127 @@
|
||||
package rdsutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
)
|
||||
|
||||
// ConnectionFormat is the type of connection that will be
|
||||
// used to connect to the database
|
||||
type ConnectionFormat string
|
||||
|
||||
// ConnectionFormat enums
|
||||
const (
|
||||
NoConnectionFormat ConnectionFormat = ""
|
||||
TCPFormat ConnectionFormat = "tcp"
|
||||
)
|
||||
|
||||
// ErrNoConnectionFormat will be returned during build if no format had been
|
||||
// specified
|
||||
var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection format was specified", nil)
|
||||
|
||||
// ConnectionStringBuilder is a builder that will construct a connection
|
||||
// string with the provided parameters. params field is required to have
|
||||
// a tls specification and allowCleartextPasswords must be set to true.
|
||||
type ConnectionStringBuilder struct {
|
||||
dbName string
|
||||
endpoint string
|
||||
region string
|
||||
user string
|
||||
creds *credentials.Credentials
|
||||
|
||||
connectFormat ConnectionFormat
|
||||
params url.Values
|
||||
}
|
||||
|
||||
// NewConnectionStringBuilder will return an ConnectionStringBuilder
|
||||
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, creds *credentials.Credentials) ConnectionStringBuilder {
|
||||
return ConnectionStringBuilder{
|
||||
dbName: dbName,
|
||||
endpoint: endpoint,
|
||||
region: region,
|
||||
user: dbUser,
|
||||
creds: creds,
|
||||
}
|
||||
}
|
||||
|
||||
// WithEndpoint will return a builder with the given endpoint
|
||||
func (b ConnectionStringBuilder) WithEndpoint(endpoint string) ConnectionStringBuilder {
|
||||
b.endpoint = endpoint
|
||||
return b
|
||||
}
|
||||
|
||||
// WithRegion will return a builder with the given region
|
||||
func (b ConnectionStringBuilder) WithRegion(region string) ConnectionStringBuilder {
|
||||
b.region = region
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUser will return a builder with the given user
|
||||
func (b ConnectionStringBuilder) WithUser(user string) ConnectionStringBuilder {
|
||||
b.user = user
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDBName will return a builder with the given database name
|
||||
func (b ConnectionStringBuilder) WithDBName(dbName string) ConnectionStringBuilder {
|
||||
b.dbName = dbName
|
||||
return b
|
||||
}
|
||||
|
||||
// WithParams will return a builder with the given params. The parameters
|
||||
// will be included in the connection query string
|
||||
//
|
||||
// Example:
|
||||
// v := url.Values{}
|
||||
// v.Add("tls", "rds")
|
||||
// b := rdsutils.NewConnectionBuilder(endpoint, region, user, dbname, creds)
|
||||
// connectStr, err := b.WithParams(v).WithTCPFormat().Build()
|
||||
func (b ConnectionStringBuilder) WithParams(params url.Values) ConnectionStringBuilder {
|
||||
b.params = params
|
||||
return b
|
||||
}
|
||||
|
||||
// WithFormat will return a builder with the given connection format
|
||||
func (b ConnectionStringBuilder) WithFormat(f ConnectionFormat) ConnectionStringBuilder {
|
||||
b.connectFormat = f
|
||||
return b
|
||||
}
|
||||
|
||||
// WithTCPFormat will set the format to TCP and return the modified builder
|
||||
func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder {
|
||||
return b.WithFormat(TCPFormat)
|
||||
}
|
||||
|
||||
// Build will return a new connection string that can be used to open a connection
|
||||
// to the desired database.
|
||||
//
|
||||
// Example:
|
||||
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
|
||||
// connectStr, err := b.WithTCPFormat().Build()
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// const dbType = "mysql"
|
||||
// db, err := sql.Open(dbType, connectStr)
|
||||
func (b ConnectionStringBuilder) Build() (string, error) {
|
||||
if b.connectFormat == NoConnectionFormat {
|
||||
return "", ErrNoConnectionFormat
|
||||
}
|
||||
|
||||
authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.creds)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
connectionStr := fmt.Sprintf("%s:%s@%s(%s)/%s",
|
||||
b.user, authToken, string(b.connectFormat), b.endpoint, b.dbName,
|
||||
)
|
||||
|
||||
if len(b.params) > 0 {
|
||||
connectionStr = fmt.Sprintf("%s?%s", connectionStr, b.params.Encode())
|
||||
}
|
||||
return connectionStr, nil
|
||||
}
|
||||
+58
@@ -0,0 +1,58 @@
|
||||
package rdsutils_test
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
|
||||
)
|
||||
|
||||
func TestConnectionStringBuilder(t *testing.T) {
|
||||
cases := []struct {
|
||||
user string
|
||||
endpoint string
|
||||
region string
|
||||
dbName string
|
||||
values url.Values
|
||||
format rdsutils.ConnectionFormat
|
||||
creds *credentials.Credentials
|
||||
|
||||
expectedErr error
|
||||
expectedConnectRegex string
|
||||
}{
|
||||
{
|
||||
user: "foo",
|
||||
endpoint: "foo.bar",
|
||||
region: "region",
|
||||
dbName: "name",
|
||||
format: rdsutils.NoConnectionFormat,
|
||||
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
|
||||
expectedErr: rdsutils.ErrNoConnectionFormat,
|
||||
expectedConnectRegex: "",
|
||||
},
|
||||
{
|
||||
user: "foo",
|
||||
endpoint: "foo.bar",
|
||||
region: "region",
|
||||
dbName: "name",
|
||||
format: rdsutils.TCPFormat,
|
||||
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
|
||||
expectedConnectRegex: `^foo:foo.bar\?Action=connect\&DBUser=foo.*\@tcp\(foo.bar\)/name`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
b := rdsutils.NewConnectionStringBuilder(c.endpoint, c.region, c.user, c.dbName, c.creds)
|
||||
connectStr, err := b.WithFormat(c.format).Build()
|
||||
|
||||
if e, a := c.expectedErr, err; e != a {
|
||||
t.Errorf("expected %v error, but received %v", e, a)
|
||||
}
|
||||
|
||||
if re, a := regexp.MustCompile(c.expectedConnectRegex), connectStr; !re.MatchString(a) {
|
||||
t.Errorf("expect %s to match %s", re, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
+8
-11
@@ -9,16 +9,13 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
)
|
||||
|
||||
// BuildAuthToken will return a authentication token for the database's connect
|
||||
// based on the RDS database endpoint, AWS region, IAM user or role, and AWS credentials.
|
||||
// BuildAuthToken will return an authorization token used as the password for a DB
|
||||
// connection.
|
||||
//
|
||||
// Endpoint consists of the hostname and port, IE hostname:port, of the RDS database.
|
||||
// Region is the AWS region the RDS database is in and where the authentication token
|
||||
// will be generated for. DbUser is the IAM user or role the request will be authenticated
|
||||
// for. The creds is the AWS credentials the authentication token is signed with.
|
||||
//
|
||||
// An error is returned if the authentication token is unable to be signed with
|
||||
// the credentials, or the endpoint is not a valid URL.
|
||||
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
|
||||
// * region - Region is the location of where the DB is
|
||||
// * dbUser - User account within the database to sign in with
|
||||
// * creds - Credentials to be signed with
|
||||
//
|
||||
// The following example shows how to use BuildAuthToken to create an authentication
|
||||
// token for connecting to a MySQL database in RDS.
|
||||
@@ -27,12 +24,12 @@ import (
|
||||
//
|
||||
// // Create the MySQL DNS string for the DB connection
|
||||
// // user:password@protocol(endpoint)/dbname?<params>
|
||||
// dnsStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=true",
|
||||
// connectStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?allowCleartextPasswords=true&tls=rds",
|
||||
// dbUser, authToken, dbEndpoint, dbName,
|
||||
// )
|
||||
//
|
||||
// // Use db to perform SQL operations on database
|
||||
// db, err := sql.Open("mysql", dnsStr)
|
||||
// db, err := sql.Open("mysql", connectStr)
|
||||
//
|
||||
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
|
||||
// for more information on using IAM database authentication with RDS.
|
||||
|
||||
+18
@@ -0,0 +1,18 @@
|
||||
// Package rdsutils is used to generate authentication tokens used to
|
||||
// connect to a givent Amazon Relational Database Service (RDS) database.
|
||||
//
|
||||
// Before using the authentication please visit the docs here to ensure
|
||||
// the database has the proper policies to allow for IAM token authentication.
|
||||
// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html#UsingWithRDS.IAMDBAuth.Availability
|
||||
//
|
||||
// When building the connection string, there are two required parameters that are needed to be set on the query.
|
||||
// * tls
|
||||
// * allowCleartextPasswords must be set to true
|
||||
//
|
||||
// Example creating a basic auth token with the builder:
|
||||
// v := url.Values{}
|
||||
// v.Add("tls", "tls_profile_name")
|
||||
// v.Add("allowCleartextPasswords", "true")
|
||||
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
|
||||
// connectStr, err := b.WithTCPFormat().WithParams(v).Build()
|
||||
package rdsutils
|
||||
+119
@@ -0,0 +1,119 @@
|
||||
// +build example,exclude
|
||||
|
||||
package rdsutils_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
|
||||
)
|
||||
|
||||
// ExampleConnectionStringBuilder contains usage of assuming a role and using
|
||||
// that to build the auth token.
|
||||
// Usage:
|
||||
// ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306
|
||||
func ExampleConnectionStringBuilder() {
|
||||
userPtr := flag.String("user", "", "user of the credentials")
|
||||
regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds")
|
||||
roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds")
|
||||
endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to")
|
||||
portPtr := flag.Int("port", 3306, "DB port to be connected to")
|
||||
tablePtr := flag.String("table", "test_table", "DB table to query against")
|
||||
dbNamePtr := flag.String("dbname", "", "DB name to query against")
|
||||
flag.Parse()
|
||||
|
||||
// Check required flags. Will exit with status code 1 if
|
||||
// required field isn't set.
|
||||
if err := requiredFlags(
|
||||
userPtr,
|
||||
regionPtr,
|
||||
roleArnPtr,
|
||||
endpointPtr,
|
||||
portPtr,
|
||||
dbNamePtr,
|
||||
); err != nil {
|
||||
fmt.Printf("Error: %v\n\n", err)
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err := registerRDSMysqlCerts(http.DefaultClient)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sess := session.Must(session.NewSession())
|
||||
creds := stscreds.NewCredentials(sess, *roleArnPtr)
|
||||
|
||||
v := url.Values{}
|
||||
// required fields for DB connection
|
||||
v.Add("tls", "rds")
|
||||
v.Add("allowCleartextPasswords", "true")
|
||||
endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr)
|
||||
|
||||
b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, creds)
|
||||
connectStr, err := b.WithTCPFormat().WithParams(v).Build()
|
||||
|
||||
const dbType = "mysql"
|
||||
db, err := sql.Open(dbType, connectStr)
|
||||
// if an error is encountered here, then most likely security groups are incorrect
|
||||
// in the database.
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to open connection to the database"))
|
||||
}
|
||||
|
||||
rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr))
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err))
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to read columns from row: %v", err))
|
||||
}
|
||||
|
||||
fmt.Printf("rows colums:\n%d\n", len(columns))
|
||||
}
|
||||
}
|
||||
|
||||
func requiredFlags(flags ...interface{}) error {
|
||||
for _, f := range flags {
|
||||
switch f.(type) {
|
||||
case nil:
|
||||
return fmt.Errorf("one or more required flags were not set")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerRDSMysqlCerts(c *http.Client) error {
|
||||
resp, err := c.Get("https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pem, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rootCertPool := x509.NewCertPool()
|
||||
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||
return fmt.Errorf("failed to append cert to cert pool!")
|
||||
}
|
||||
|
||||
return mysql.RegisterTLSConfig("rds", &tls.Config{RootCAs: rootCertPool, InsecureSkipVerify: true})
|
||||
}
|
||||
Reference in New Issue
Block a user