mirror of
https://github.com/aptly-dev/aptly.git
synced 2026-05-31 04:30:44 +00:00
Conver to regular Go vendor + dep tool
This commit is contained in:
+4
@@ -0,0 +1,4 @@
|
||||
## AWS SDK for Go Private packages ##
|
||||
`private` is a collection of packages used internally by the SDK, and is subject to have breaking changes. This package is not `internal` so that if you really need to use its functionality, and understand breaking changes will be made, you are able to.
|
||||
|
||||
These packages will be refactored in the future so that the API generator and model parsers are exposed cleanly on their own. Making it easier for you to generate your own code based on the API models.
|
||||
+694
@@ -0,0 +1,694 @@
|
||||
// +build codegen
|
||||
|
||||
// Package api represents API abstractions for rendering service generated files.
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// An API defines a service API's definition. and logic to serialize the definition.
|
||||
type API struct {
|
||||
Metadata Metadata
|
||||
Operations map[string]*Operation
|
||||
Shapes map[string]*Shape
|
||||
Waiters []Waiter
|
||||
Documentation string
|
||||
|
||||
// Set to true to avoid removing unused shapes
|
||||
NoRemoveUnusedShapes bool
|
||||
|
||||
// Set to true to avoid renaming to 'Input/Output' postfixed shapes
|
||||
NoRenameToplevelShapes bool
|
||||
|
||||
// Set to true to ignore service/request init methods (for testing)
|
||||
NoInitMethods bool
|
||||
|
||||
// Set to true to ignore String() and GoString methods (for generated tests)
|
||||
NoStringerMethods bool
|
||||
|
||||
// Set to true to not generate API service name constants
|
||||
NoConstServiceNames bool
|
||||
|
||||
// Set to true to not generate validation shapes
|
||||
NoValidataShapeMethods bool
|
||||
|
||||
// Set to true to not generate struct field accessors
|
||||
NoGenStructFieldAccessors bool
|
||||
|
||||
SvcClientImportPath string
|
||||
|
||||
initialized bool
|
||||
imports map[string]bool
|
||||
name string
|
||||
path string
|
||||
|
||||
BaseCrosslinkURL string
|
||||
}
|
||||
|
||||
// A Metadata is the metadata about an API's definition.
|
||||
type Metadata struct {
|
||||
APIVersion string
|
||||
EndpointPrefix string
|
||||
SigningName string
|
||||
ServiceAbbreviation string
|
||||
ServiceFullName string
|
||||
SignatureVersion string
|
||||
JSONVersion string
|
||||
TargetPrefix string
|
||||
Protocol string
|
||||
UID string
|
||||
EndpointsID string
|
||||
|
||||
NoResolveEndpoint bool
|
||||
}
|
||||
|
||||
var serviceAliases map[string]string
|
||||
|
||||
func Bootstrap() error {
|
||||
b, err := ioutil.ReadFile(filepath.Join("..", "models", "customizations", "service-aliases.json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(b, &serviceAliases)
|
||||
}
|
||||
|
||||
// PackageName name of the API package
|
||||
func (a *API) PackageName() string {
|
||||
return strings.ToLower(a.StructName())
|
||||
}
|
||||
|
||||
// InterfacePackageName returns the package name for the interface.
|
||||
func (a *API) InterfacePackageName() string {
|
||||
return a.PackageName() + "iface"
|
||||
}
|
||||
|
||||
var nameRegex = regexp.MustCompile(`^Amazon|AWS\s*|\(.*|\s+|\W+`)
|
||||
|
||||
// StructName returns the struct name for a given API.
|
||||
func (a *API) StructName() string {
|
||||
if a.name == "" {
|
||||
name := a.Metadata.ServiceAbbreviation
|
||||
if name == "" {
|
||||
name = a.Metadata.ServiceFullName
|
||||
}
|
||||
|
||||
name = nameRegex.ReplaceAllString(name, "")
|
||||
|
||||
a.name = name
|
||||
if name, ok := serviceAliases[strings.ToLower(name)]; ok {
|
||||
a.name = name
|
||||
}
|
||||
}
|
||||
return a.name
|
||||
}
|
||||
|
||||
// UseInitMethods returns if the service's init method should be rendered.
|
||||
func (a *API) UseInitMethods() bool {
|
||||
return !a.NoInitMethods
|
||||
}
|
||||
|
||||
// NiceName returns the human friendly API name.
|
||||
func (a *API) NiceName() string {
|
||||
if a.Metadata.ServiceAbbreviation != "" {
|
||||
return a.Metadata.ServiceAbbreviation
|
||||
}
|
||||
return a.Metadata.ServiceFullName
|
||||
}
|
||||
|
||||
// ProtocolPackage returns the package name of the protocol this API uses.
|
||||
func (a *API) ProtocolPackage() string {
|
||||
switch a.Metadata.Protocol {
|
||||
case "json":
|
||||
return "jsonrpc"
|
||||
case "ec2":
|
||||
return "ec2query"
|
||||
default:
|
||||
return strings.Replace(a.Metadata.Protocol, "-", "", -1)
|
||||
}
|
||||
}
|
||||
|
||||
// OperationNames returns a slice of API operations supported.
|
||||
func (a *API) OperationNames() []string {
|
||||
i, names := 0, make([]string, len(a.Operations))
|
||||
for n := range a.Operations {
|
||||
names[i] = n
|
||||
i++
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// OperationList returns a slice of API operation pointers
|
||||
func (a *API) OperationList() []*Operation {
|
||||
list := make([]*Operation, len(a.Operations))
|
||||
for i, n := range a.OperationNames() {
|
||||
list[i] = a.Operations[n]
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// OperationHasOutputPlaceholder returns if any of the API operation input
|
||||
// or output shapes are place holders.
|
||||
func (a *API) OperationHasOutputPlaceholder() bool {
|
||||
for _, op := range a.Operations {
|
||||
if op.OutputRef.Shape.Placeholder {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ShapeNames returns a slice of names for each shape used by the API.
|
||||
func (a *API) ShapeNames() []string {
|
||||
i, names := 0, make([]string, len(a.Shapes))
|
||||
for n := range a.Shapes {
|
||||
names[i] = n
|
||||
i++
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// ShapeList returns a slice of shape pointers used by the API.
|
||||
//
|
||||
// Will exclude error shapes from the list of shapes returned.
|
||||
func (a *API) ShapeList() []*Shape {
|
||||
list := make([]*Shape, 0, len(a.Shapes))
|
||||
for _, n := range a.ShapeNames() {
|
||||
// Ignore error shapes in list
|
||||
if s := a.Shapes[n]; !s.IsError {
|
||||
list = append(list, s)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// ShapeListErrors returns a list of the errors defined by the API model
|
||||
func (a *API) ShapeListErrors() []*Shape {
|
||||
list := []*Shape{}
|
||||
for _, n := range a.ShapeNames() {
|
||||
// Ignore error shapes in list
|
||||
if s := a.Shapes[n]; s.IsError {
|
||||
list = append(list, s)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// resetImports resets the import map to default values.
|
||||
func (a *API) resetImports() {
|
||||
a.imports = map[string]bool{
|
||||
"github.com/aws/aws-sdk-go/aws": true,
|
||||
}
|
||||
}
|
||||
|
||||
// importsGoCode returns the generated Go import code.
|
||||
func (a *API) importsGoCode() string {
|
||||
if len(a.imports) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
corePkgs, extPkgs := []string{}, []string{}
|
||||
for i := range a.imports {
|
||||
if strings.Contains(i, ".") {
|
||||
extPkgs = append(extPkgs, i)
|
||||
} else {
|
||||
corePkgs = append(corePkgs, i)
|
||||
}
|
||||
}
|
||||
sort.Strings(corePkgs)
|
||||
sort.Strings(extPkgs)
|
||||
|
||||
code := "import (\n"
|
||||
for _, i := range corePkgs {
|
||||
code += fmt.Sprintf("\t%q\n", i)
|
||||
}
|
||||
if len(corePkgs) > 0 {
|
||||
code += "\n"
|
||||
}
|
||||
for _, i := range extPkgs {
|
||||
code += fmt.Sprintf("\t%q\n", i)
|
||||
}
|
||||
code += ")\n\n"
|
||||
return code
|
||||
}
|
||||
|
||||
// A tplAPI is the top level template for the API
|
||||
var tplAPI = template.Must(template.New("api").Parse(`
|
||||
{{ range $_, $o := .OperationList }}
|
||||
{{ $o.GoCode }}
|
||||
|
||||
{{ end }}
|
||||
|
||||
{{ range $_, $s := .ShapeList }}
|
||||
{{ if and $s.IsInternal (eq $s.Type "structure") }}{{ $s.GoCode }}{{ end }}
|
||||
|
||||
{{ end }}
|
||||
|
||||
{{ range $_, $s := .ShapeList }}
|
||||
{{ if $s.IsEnum }}{{ $s.GoCode }}{{ end }}
|
||||
|
||||
{{ end }}
|
||||
`))
|
||||
|
||||
// APIGoCode renders the API in Go code. Returning it as a string
|
||||
func (a *API) APIGoCode() string {
|
||||
a.resetImports()
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/awsutil"] = true
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/request"] = true
|
||||
if a.OperationHasOutputPlaceholder() {
|
||||
a.imports["github.com/aws/aws-sdk-go/private/protocol/"+a.ProtocolPackage()] = true
|
||||
a.imports["github.com/aws/aws-sdk-go/private/protocol"] = true
|
||||
}
|
||||
|
||||
for _, op := range a.Operations {
|
||||
if op.AuthType == "none" {
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/credentials"] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := tplAPI.Execute(&buf, a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
code := a.importsGoCode() + strings.TrimSpace(buf.String())
|
||||
return code
|
||||
}
|
||||
|
||||
var noCrossLinkServices = map[string]struct{}{
|
||||
"apigateway": struct{}{},
|
||||
"budgets": struct{}{},
|
||||
"cloudsearch": struct{}{},
|
||||
"cloudsearchdomain": struct{}{},
|
||||
"discovery": struct{}{},
|
||||
"elastictranscoder": struct{}{},
|
||||
"es": struct{}{},
|
||||
"glacier": struct{}{},
|
||||
"importexport": struct{}{},
|
||||
"iot": struct{}{},
|
||||
"iot-data": struct{}{},
|
||||
"lambda": struct{}{},
|
||||
"machinelearning": struct{}{},
|
||||
"rekognition": struct{}{},
|
||||
"sdb": struct{}{},
|
||||
"swf": struct{}{},
|
||||
}
|
||||
|
||||
func GetCrosslinkURL(baseURL, name, uid string, params ...string) string {
|
||||
_, ok := noCrossLinkServices[strings.ToLower(name)]
|
||||
if uid != "" && baseURL != "" && !ok {
|
||||
return strings.Join(append([]string{baseURL, "goto", "WebAPI", uid}, params...), "/")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *API) APIName() string {
|
||||
return a.name
|
||||
}
|
||||
|
||||
// A tplService defines the template for the service generated code.
|
||||
var tplService = template.Must(template.New("service").Funcs(template.FuncMap{
|
||||
"ServiceNameValue": func(a *API) string {
|
||||
if a.NoConstServiceNames {
|
||||
return fmt.Sprintf("%q", a.Metadata.EndpointPrefix)
|
||||
}
|
||||
return "ServiceName"
|
||||
},
|
||||
"GetCrosslinkURL": GetCrosslinkURL,
|
||||
"EndpointsIDConstValue": func(a *API) string {
|
||||
if a.NoConstServiceNames {
|
||||
return fmt.Sprintf("%q", a.Metadata.EndpointPrefix)
|
||||
}
|
||||
if a.Metadata.EndpointPrefix == a.Metadata.EndpointsID {
|
||||
return "ServiceName"
|
||||
}
|
||||
return fmt.Sprintf("%q", a.Metadata.EndpointsID)
|
||||
},
|
||||
"EndpointsIDValue": func(a *API) string {
|
||||
if a.NoConstServiceNames {
|
||||
return fmt.Sprintf("%q", a.Metadata.EndpointPrefix)
|
||||
}
|
||||
|
||||
return "EndpointsID"
|
||||
},
|
||||
}).Parse(`
|
||||
{{ .Documentation }}// The service client's operations are safe to be used concurrently.
|
||||
// It is not safe to mutate any of the client's properties though.
|
||||
{{ $crosslinkURL := GetCrosslinkURL $.BaseCrosslinkURL $.APIName $.Metadata.UID -}}
|
||||
{{ if ne $crosslinkURL "" -}}
|
||||
// Please also see {{ $crosslinkURL }}
|
||||
{{ end -}}
|
||||
type {{ .StructName }} struct {
|
||||
*client.Client
|
||||
}
|
||||
|
||||
{{ if .UseInitMethods }}// Used for custom client initialization logic
|
||||
var initClient func(*client.Client)
|
||||
|
||||
// Used for custom request initialization logic
|
||||
var initRequest func(*request.Request)
|
||||
{{ end }}
|
||||
|
||||
|
||||
{{ if not .NoConstServiceNames -}}
|
||||
// Service information constants
|
||||
const (
|
||||
ServiceName = "{{ .Metadata.EndpointPrefix }}" // Service endpoint prefix API calls made to.
|
||||
EndpointsID = {{ EndpointsIDConstValue . }} // Service ID for Regions and Endpoints metadata.
|
||||
)
|
||||
{{- end }}
|
||||
|
||||
// New creates a new instance of the {{ .StructName }} client with a session.
|
||||
// If additional configuration is needed for the client instance use the optional
|
||||
// aws.Config parameter to add your extra config.
|
||||
//
|
||||
// Example:
|
||||
// // Create a {{ .StructName }} client from just a session.
|
||||
// svc := {{ .PackageName }}.New(mySession)
|
||||
//
|
||||
// // Create a {{ .StructName }} client with additional configuration
|
||||
// svc := {{ .PackageName }}.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
|
||||
func New(p client.ConfigProvider, cfgs ...*aws.Config) *{{ .StructName }} {
|
||||
{{ if .Metadata.NoResolveEndpoint -}}
|
||||
var c client.Config
|
||||
if v, ok := p.(client.ConfigNoResolveEndpointProvider); ok {
|
||||
c = v.ClientConfigNoResolveEndpoint(cfgs...)
|
||||
} else {
|
||||
c = p.ClientConfig({{ EndpointsIDValue . }}, cfgs...)
|
||||
}
|
||||
{{- else -}}
|
||||
c := p.ClientConfig({{ EndpointsIDValue . }}, cfgs...)
|
||||
{{- end }}
|
||||
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
|
||||
}
|
||||
|
||||
// newClient creates, initializes and returns a new service client instance.
|
||||
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *{{ .StructName }} {
|
||||
{{- if .Metadata.SigningName }}
|
||||
if len(signingName) == 0 {
|
||||
signingName = "{{ .Metadata.SigningName }}"
|
||||
}
|
||||
{{- end }}
|
||||
svc := &{{ .StructName }}{
|
||||
Client: client.New(
|
||||
cfg,
|
||||
metadata.ClientInfo{
|
||||
ServiceName: {{ ServiceNameValue . }},
|
||||
SigningName: signingName,
|
||||
SigningRegion: signingRegion,
|
||||
Endpoint: endpoint,
|
||||
APIVersion: "{{ .Metadata.APIVersion }}",
|
||||
{{ if .Metadata.JSONVersion -}}
|
||||
JSONVersion: "{{ .Metadata.JSONVersion }}",
|
||||
{{- end }}
|
||||
{{ if .Metadata.TargetPrefix -}}
|
||||
TargetPrefix: "{{ .Metadata.TargetPrefix }}",
|
||||
{{- end }}
|
||||
},
|
||||
handlers,
|
||||
),
|
||||
}
|
||||
|
||||
// Handlers
|
||||
svc.Handlers.Sign.PushBackNamed({{if eq .Metadata.SignatureVersion "v2"}}v2{{else}}v4{{end}}.SignRequestHandler)
|
||||
{{- if eq .Metadata.SignatureVersion "v2" }}
|
||||
svc.Handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
|
||||
{{- end }}
|
||||
svc.Handlers.Build.PushBackNamed({{ .ProtocolPackage }}.BuildHandler)
|
||||
svc.Handlers.Unmarshal.PushBackNamed({{ .ProtocolPackage }}.UnmarshalHandler)
|
||||
svc.Handlers.UnmarshalMeta.PushBackNamed({{ .ProtocolPackage }}.UnmarshalMetaHandler)
|
||||
svc.Handlers.UnmarshalError.PushBackNamed({{ .ProtocolPackage }}.UnmarshalErrorHandler)
|
||||
|
||||
{{ if .UseInitMethods }}// Run custom client initialization if present
|
||||
if initClient != nil {
|
||||
initClient(svc.Client)
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// newRequest creates a new request for a {{ .StructName }} operation and runs any
|
||||
// custom request initialization.
|
||||
func (c *{{ .StructName }}) newRequest(op *request.Operation, params, data interface{}) *request.Request {
|
||||
req := c.NewRequest(op, params, data)
|
||||
|
||||
{{ if .UseInitMethods }}// Run custom request initialization if present
|
||||
if initRequest != nil {
|
||||
initRequest(req)
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
return req
|
||||
}
|
||||
`))
|
||||
|
||||
// ServiceGoCode renders service go code. Returning it as a string.
|
||||
func (a *API) ServiceGoCode() string {
|
||||
a.resetImports()
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/client"] = true
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/client/metadata"] = true
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/request"] = true
|
||||
if a.Metadata.SignatureVersion == "v2" {
|
||||
a.imports["github.com/aws/aws-sdk-go/private/signer/v2"] = true
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/corehandlers"] = true
|
||||
} else {
|
||||
a.imports["github.com/aws/aws-sdk-go/aws/signer/v4"] = true
|
||||
}
|
||||
a.imports["github.com/aws/aws-sdk-go/private/protocol/"+a.ProtocolPackage()] = true
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := tplService.Execute(&buf, a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
code := a.importsGoCode() + buf.String()
|
||||
return code
|
||||
}
|
||||
|
||||
// ExampleGoCode renders service example code. Returning it as a string.
|
||||
func (a *API) ExampleGoCode() string {
|
||||
exs := []string{}
|
||||
imports := map[string]bool{}
|
||||
for _, o := range a.OperationList() {
|
||||
o.imports = map[string]bool{}
|
||||
exs = append(exs, o.Example())
|
||||
for k, v := range o.imports {
|
||||
imports[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
code := fmt.Sprintf("import (\n%q\n%q\n%q\n\n%q\n%q\n%q\n",
|
||||
"bytes",
|
||||
"fmt",
|
||||
"time",
|
||||
"github.com/aws/aws-sdk-go/aws",
|
||||
"github.com/aws/aws-sdk-go/aws/session",
|
||||
path.Join(a.SvcClientImportPath, a.PackageName()),
|
||||
)
|
||||
for k, _ := range imports {
|
||||
code += fmt.Sprintf("%q\n", k)
|
||||
}
|
||||
code += ")\n\n"
|
||||
code += "var _ time.Duration\nvar _ bytes.Buffer\n\n"
|
||||
code += strings.Join(exs, "\n\n")
|
||||
return code
|
||||
}
|
||||
|
||||
// A tplInterface defines the template for the service interface type.
|
||||
var tplInterface = template.Must(template.New("interface").Parse(`
|
||||
// {{ .StructName }}API provides an interface to enable mocking the
|
||||
// {{ .PackageName }}.{{ .StructName }} service client's API operation,
|
||||
// paginators, and waiters. This make unit testing your code that calls out
|
||||
// to the SDK's service client's calls easier.
|
||||
//
|
||||
// The best way to use this interface is so the SDK's service client's calls
|
||||
// can be stubbed out for unit testing your code with the SDK without needing
|
||||
// to inject custom request handlers into the the SDK's request pipeline.
|
||||
//
|
||||
// // myFunc uses an SDK service client to make a request to
|
||||
// // {{.Metadata.ServiceFullName}}. {{ $opts := .OperationList }}{{ $opt := index $opts 0 }}
|
||||
// func myFunc(svc {{ .InterfacePackageName }}.{{ .StructName }}API) bool {
|
||||
// // Make svc.{{ $opt.ExportedName }} request
|
||||
// }
|
||||
//
|
||||
// func main() {
|
||||
// sess := session.New()
|
||||
// svc := {{ .PackageName }}.New(sess)
|
||||
//
|
||||
// myFunc(svc)
|
||||
// }
|
||||
//
|
||||
// In your _test.go file:
|
||||
//
|
||||
// // Define a mock struct to be used in your unit tests of myFunc.
|
||||
// type mock{{ .StructName }}Client struct {
|
||||
// {{ .InterfacePackageName }}.{{ .StructName }}API
|
||||
// }
|
||||
// func (m *mock{{ .StructName }}Client) {{ $opt.ExportedName }}(input {{ $opt.InputRef.GoTypeWithPkgName }}) ({{ $opt.OutputRef.GoTypeWithPkgName }}, error) {
|
||||
// // mock response/functionality
|
||||
// }
|
||||
//
|
||||
// func TestMyFunc(t *testing.T) {
|
||||
// // Setup Test
|
||||
// mockSvc := &mock{{ .StructName }}Client{}
|
||||
//
|
||||
// myfunc(mockSvc)
|
||||
//
|
||||
// // Verify myFunc's functionality
|
||||
// }
|
||||
//
|
||||
// It is important to note that this interface will have breaking changes
|
||||
// when the service model is updated and adds new API operations, paginators,
|
||||
// and waiters. Its suggested to use the pattern above for testing, or using
|
||||
// tooling to generate mocks to satisfy the interfaces.
|
||||
type {{ .StructName }}API interface {
|
||||
{{ range $_, $o := .OperationList }}
|
||||
{{ $o.InterfaceSignature }}
|
||||
{{ end }}
|
||||
{{ range $_, $w := .Waiters }}
|
||||
{{ $w.InterfaceSignature }}
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
var _ {{ .StructName }}API = (*{{ .PackageName }}.{{ .StructName }})(nil)
|
||||
`))
|
||||
|
||||
// InterfaceGoCode returns the go code for the service's API operations as an
|
||||
// interface{}. Assumes that the interface is being created in a different
|
||||
// package than the service API's package.
|
||||
func (a *API) InterfaceGoCode() string {
|
||||
a.resetImports()
|
||||
a.imports = map[string]bool{
|
||||
"github.com/aws/aws-sdk-go/aws": true,
|
||||
"github.com/aws/aws-sdk-go/aws/request": true,
|
||||
path.Join(a.SvcClientImportPath, a.PackageName()): true,
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := tplInterface.Execute(&buf, a)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
code := a.importsGoCode() + strings.TrimSpace(buf.String())
|
||||
return code
|
||||
}
|
||||
|
||||
// NewAPIGoCodeWithPkgName returns a string of instantiating the API prefixed
|
||||
// with its package name. Takes a string depicting the Config.
|
||||
func (a *API) NewAPIGoCodeWithPkgName(cfg string) string {
|
||||
return fmt.Sprintf("%s.New(%s)", a.PackageName(), cfg)
|
||||
}
|
||||
|
||||
// computes the validation chain for all input shapes
|
||||
func (a *API) addShapeValidations() {
|
||||
for _, o := range a.Operations {
|
||||
resolveShapeValidations(o.InputRef.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
// Updates the source shape and all nested shapes with the validations that
|
||||
// could possibly be needed.
|
||||
func resolveShapeValidations(s *Shape, ancestry ...*Shape) {
|
||||
for _, a := range ancestry {
|
||||
if a == s {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
children := []string{}
|
||||
for _, name := range s.MemberNames() {
|
||||
ref := s.MemberRefs[name]
|
||||
|
||||
if s.IsRequired(name) && !s.Validations.Has(ref, ShapeValidationRequired) {
|
||||
s.Validations = append(s.Validations, ShapeValidation{
|
||||
Name: name, Ref: ref, Type: ShapeValidationRequired,
|
||||
})
|
||||
}
|
||||
|
||||
if ref.Shape.Min != 0 && !s.Validations.Has(ref, ShapeValidationMinVal) {
|
||||
s.Validations = append(s.Validations, ShapeValidation{
|
||||
Name: name, Ref: ref, Type: ShapeValidationMinVal,
|
||||
})
|
||||
}
|
||||
|
||||
switch ref.Shape.Type {
|
||||
case "map", "list", "structure":
|
||||
children = append(children, name)
|
||||
}
|
||||
}
|
||||
|
||||
ancestry = append(ancestry, s)
|
||||
for _, name := range children {
|
||||
ref := s.MemberRefs[name]
|
||||
// Since this is a grab bag we will just continue since
|
||||
// we can't validate because we don't know the valued shape.
|
||||
if ref.JSONValue {
|
||||
continue
|
||||
}
|
||||
|
||||
nestedShape := ref.Shape.NestedShape()
|
||||
|
||||
var v *ShapeValidation
|
||||
if len(nestedShape.Validations) > 0 {
|
||||
v = &ShapeValidation{
|
||||
Name: name, Ref: ref, Type: ShapeValidationNested,
|
||||
}
|
||||
} else {
|
||||
resolveShapeValidations(nestedShape, ancestry...)
|
||||
if len(nestedShape.Validations) > 0 {
|
||||
v = &ShapeValidation{
|
||||
Name: name, Ref: ref, Type: ShapeValidationNested,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if v != nil && !s.Validations.Has(v.Ref, v.Type) {
|
||||
s.Validations = append(s.Validations, *v)
|
||||
}
|
||||
}
|
||||
ancestry = ancestry[:len(ancestry)-1]
|
||||
}
|
||||
|
||||
// A tplAPIErrors is the top level template for the API
|
||||
var tplAPIErrors = template.Must(template.New("api").Parse(`
|
||||
const (
|
||||
{{ range $_, $s := $.ShapeListErrors }}
|
||||
// {{ $s.ErrorCodeName }} for service response error code
|
||||
// {{ printf "%q" $s.ErrorName }}.
|
||||
{{ if $s.Docstring -}}
|
||||
//
|
||||
{{ $s.Docstring }}
|
||||
{{ end -}}
|
||||
{{ $s.ErrorCodeName }} = {{ printf "%q" $s.ErrorName }}
|
||||
{{ end }}
|
||||
)
|
||||
`))
|
||||
|
||||
func (a *API) APIErrorsGoCode() string {
|
||||
var buf bytes.Buffer
|
||||
err := tplAPIErrors.Execute(&buf, a)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(buf.String())
|
||||
}
|
||||
+44
@@ -0,0 +1,44 @@
|
||||
// +build 1.6,codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStructNameWithFullName(t *testing.T) {
|
||||
a := API{
|
||||
Metadata: Metadata{
|
||||
ServiceFullName: "Amazon Service Name-100",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, a.StructName(), "ServiceName100")
|
||||
}
|
||||
|
||||
func TestStructNameWithAbbreviation(t *testing.T) {
|
||||
a := API{
|
||||
Metadata: Metadata{
|
||||
ServiceFullName: "AWS Service Name-100",
|
||||
ServiceAbbreviation: "AWS SN100",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, a.StructName(), "SN100")
|
||||
}
|
||||
|
||||
func TestStructNameForExceptions(t *testing.T) {
|
||||
a := API{
|
||||
Metadata: Metadata{
|
||||
ServiceFullName: "Elastic Load Balancing",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, a.StructName(), "ELB")
|
||||
|
||||
a = API{
|
||||
Metadata: Metadata{
|
||||
ServiceFullName: "AWS Config",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, a.StructName(), "ConfigService")
|
||||
}
|
||||
+176
@@ -0,0 +1,176 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type service struct {
|
||||
srcName string
|
||||
dstName string
|
||||
|
||||
serviceVersion string
|
||||
}
|
||||
|
||||
var mergeServices = map[string]service{
|
||||
"dynamodbstreams": service{
|
||||
dstName: "dynamodb",
|
||||
srcName: "streams.dynamodb",
|
||||
},
|
||||
"wafregional": service{
|
||||
dstName: "waf",
|
||||
srcName: "waf-regional",
|
||||
serviceVersion: "2015-08-24",
|
||||
},
|
||||
}
|
||||
|
||||
// customizationPasses Executes customization logic for the API by package name.
|
||||
func (a *API) customizationPasses() {
|
||||
var svcCustomizations = map[string]func(*API){
|
||||
"s3": s3Customizations,
|
||||
"cloudfront": cloudfrontCustomizations,
|
||||
"rds": rdsCustomizations,
|
||||
|
||||
// Disable endpoint resolving for services that require customer
|
||||
// to provide endpoint them selves.
|
||||
"cloudsearchdomain": disableEndpointResolving,
|
||||
"iotdataplane": disableEndpointResolving,
|
||||
}
|
||||
|
||||
for k, _ := range mergeServices {
|
||||
svcCustomizations[k] = mergeServicesCustomizations
|
||||
}
|
||||
|
||||
if fn := svcCustomizations[a.PackageName()]; fn != nil {
|
||||
fn(a)
|
||||
}
|
||||
|
||||
blobDocStringCustomizations(a)
|
||||
}
|
||||
|
||||
const base64MarshalDocStr = "// %s is automatically base64 encoded/decoded by the SDK.\n"
|
||||
|
||||
func blobDocStringCustomizations(a *API) {
|
||||
for _, s := range a.Shapes {
|
||||
payloadMemberName := s.Payload
|
||||
|
||||
for refName, ref := range s.MemberRefs {
|
||||
if refName == payloadMemberName {
|
||||
// Payload members have their own encoding and may
|
||||
// be raw bytes or io.Reader
|
||||
continue
|
||||
}
|
||||
if ref.Shape.Type == "blob" {
|
||||
docStr := fmt.Sprintf(base64MarshalDocStr, refName)
|
||||
if len(strings.TrimSpace(ref.Shape.Documentation)) != 0 {
|
||||
ref.Shape.Documentation += "//\n" + docStr
|
||||
} else if len(strings.TrimSpace(ref.Documentation)) != 0 {
|
||||
ref.Documentation += "//\n" + docStr
|
||||
} else {
|
||||
ref.Documentation = docStr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// s3Customizations customizes the API generation to replace values specific to S3.
|
||||
func s3Customizations(a *API) {
|
||||
var strExpires *Shape
|
||||
|
||||
for name, s := range a.Shapes {
|
||||
// Remove ContentMD5 members
|
||||
if _, ok := s.MemberRefs["ContentMD5"]; ok {
|
||||
delete(s.MemberRefs, "ContentMD5")
|
||||
}
|
||||
|
||||
// Expires should be a string not time.Time since the format is not
|
||||
// enforced by S3, and any value can be set to this field outside of the SDK.
|
||||
if strings.HasSuffix(name, "Output") {
|
||||
if ref, ok := s.MemberRefs["Expires"]; ok {
|
||||
if strExpires == nil {
|
||||
newShape := *ref.Shape
|
||||
strExpires = &newShape
|
||||
strExpires.Type = "string"
|
||||
strExpires.refs = []*ShapeRef{}
|
||||
}
|
||||
ref.Shape.removeRef(ref)
|
||||
ref.Shape = strExpires
|
||||
ref.Shape.refs = append(ref.Shape.refs, &s.MemberRef)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cloudfrontCustomizations customized the API generation to replace values
|
||||
// specific to CloudFront.
|
||||
func cloudfrontCustomizations(a *API) {
|
||||
// MaxItems members should always be integers
|
||||
for _, s := range a.Shapes {
|
||||
if ref, ok := s.MemberRefs["MaxItems"]; ok {
|
||||
ref.ShapeName = "Integer"
|
||||
ref.Shape = a.Shapes["Integer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mergeServicesCustomizations references any duplicate shapes from DynamoDB
|
||||
func mergeServicesCustomizations(a *API) {
|
||||
info := mergeServices[a.PackageName()]
|
||||
|
||||
p := strings.Replace(a.path, info.srcName, info.dstName, -1)
|
||||
|
||||
if info.serviceVersion != "" {
|
||||
index := strings.LastIndex(p, "/")
|
||||
files, _ := ioutil.ReadDir(p[:index])
|
||||
if len(files) > 1 {
|
||||
panic("New version was introduced")
|
||||
}
|
||||
p = p[:index] + "/" + info.serviceVersion
|
||||
}
|
||||
|
||||
file := filepath.Join(p, "api-2.json")
|
||||
|
||||
serviceAPI := API{}
|
||||
serviceAPI.Attach(file)
|
||||
serviceAPI.Setup()
|
||||
|
||||
for n := range a.Shapes {
|
||||
if _, ok := serviceAPI.Shapes[n]; ok {
|
||||
a.Shapes[n].resolvePkg = "github.com/aws/aws-sdk-go/service/" + info.dstName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rdsCustomizations are customization for the service/rds. This adds non-modeled fields used for presigning.
|
||||
func rdsCustomizations(a *API) {
|
||||
inputs := []string{
|
||||
"CopyDBSnapshotInput",
|
||||
"CreateDBInstanceReadReplicaInput",
|
||||
"CopyDBClusterSnapshotInput",
|
||||
"CreateDBClusterInput",
|
||||
}
|
||||
for _, input := range inputs {
|
||||
if ref, ok := a.Shapes[input]; ok {
|
||||
ref.MemberRefs["SourceRegion"] = &ShapeRef{
|
||||
Documentation: docstring(`SourceRegion is the source region where the resource exists. This is not sent over the wire and is only used for presigning. This value should always have the same region as the source ARN.`),
|
||||
ShapeName: "String",
|
||||
Shape: a.Shapes["String"],
|
||||
Ignore: true,
|
||||
}
|
||||
ref.MemberRefs["DestinationRegion"] = &ShapeRef{
|
||||
Documentation: docstring(`DestinationRegion is used for presigning the request to a given region.`),
|
||||
ShapeName: "String",
|
||||
Shape: a.Shapes["String"],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func disableEndpointResolving(a *API) {
|
||||
a.Metadata.NoResolveEndpoint = true
|
||||
}
|
||||
+384
@@ -0,0 +1,384 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
xhtml "golang.org/x/net/html"
|
||||
)
|
||||
|
||||
type apiDocumentation struct {
|
||||
*API
|
||||
Operations map[string]string
|
||||
Service string
|
||||
Shapes map[string]shapeDocumentation
|
||||
}
|
||||
|
||||
type shapeDocumentation struct {
|
||||
Base string
|
||||
Refs map[string]string
|
||||
}
|
||||
|
||||
// AttachDocs attaches documentation from a JSON filename.
|
||||
func (a *API) AttachDocs(filename string) {
|
||||
d := apiDocumentation{API: a}
|
||||
|
||||
f, err := os.Open(filename)
|
||||
defer f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = json.NewDecoder(f).Decode(&d)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
d.setup()
|
||||
|
||||
}
|
||||
|
||||
func (d *apiDocumentation) setup() {
|
||||
d.API.Documentation = docstring(d.Service)
|
||||
if d.Service == "" {
|
||||
d.API.Documentation =
|
||||
fmt.Sprintf("// %s is a client for %s.\n", d.API.StructName(), d.API.NiceName())
|
||||
}
|
||||
|
||||
for op, doc := range d.Operations {
|
||||
d.API.Operations[op].Documentation = strings.TrimSpace(docstring(doc))
|
||||
}
|
||||
|
||||
for shape, info := range d.Shapes {
|
||||
if sh := d.API.Shapes[shape]; sh != nil {
|
||||
sh.Documentation = docstring(info.Base)
|
||||
}
|
||||
|
||||
for ref, doc := range info.Refs {
|
||||
if doc == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(ref, "$")
|
||||
if sh := d.API.Shapes[parts[0]]; sh != nil {
|
||||
if m := sh.MemberRefs[parts[1]]; m != nil {
|
||||
m.Documentation = docstring(doc)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var reNewline = regexp.MustCompile(`\r?\n`)
|
||||
var reMultiSpace = regexp.MustCompile(`\s+`)
|
||||
var reComments = regexp.MustCompile(`<!--.*?-->`)
|
||||
var reFullname = regexp.MustCompile(`\s*<fullname?>.+?<\/fullname?>\s*`)
|
||||
var reExamples = regexp.MustCompile(`<examples?>.+?<\/examples?>`)
|
||||
var reEndNL = regexp.MustCompile(`\n+$`)
|
||||
|
||||
// docstring rewrites a string to insert godocs formatting.
|
||||
func docstring(doc string) string {
|
||||
doc = reNewline.ReplaceAllString(doc, "")
|
||||
doc = reMultiSpace.ReplaceAllString(doc, " ")
|
||||
doc = reComments.ReplaceAllString(doc, "")
|
||||
doc = reFullname.ReplaceAllString(doc, "")
|
||||
doc = reExamples.ReplaceAllString(doc, "")
|
||||
doc = generateDoc(doc)
|
||||
doc = reEndNL.ReplaceAllString(doc, "")
|
||||
if doc == "" {
|
||||
return "\n"
|
||||
}
|
||||
|
||||
doc = html.UnescapeString(doc)
|
||||
return commentify(doc)
|
||||
}
|
||||
|
||||
const (
|
||||
indent = " "
|
||||
)
|
||||
|
||||
// style is what we want to prefix a string with.
|
||||
// For instance, <li>Foo</li><li>Bar</li>, will generate
|
||||
// * Foo
|
||||
// * Bar
|
||||
var style = map[string]string{
|
||||
"ul": indent + "* ",
|
||||
"li": indent + "* ",
|
||||
"code": indent,
|
||||
"pre": indent,
|
||||
}
|
||||
|
||||
// commentify converts a string to a Go comment
|
||||
func commentify(doc string) string {
|
||||
lines := strings.Split(doc, "\n")
|
||||
out := []string{}
|
||||
for i, line := range lines {
|
||||
if i > 0 && line == "" && lines[i-1] == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, "// "+line)
|
||||
}
|
||||
|
||||
return strings.Join(out, "\n") + "\n"
|
||||
}
|
||||
|
||||
// wrap returns a rewritten version of text to have line breaks
|
||||
// at approximately length characters. Line breaks will only be
|
||||
// inserted into whitespace.
|
||||
func wrap(text string, length int, isIndented bool) string {
|
||||
var buf bytes.Buffer
|
||||
var last rune
|
||||
var lastNL bool
|
||||
var col int
|
||||
|
||||
for _, c := range text {
|
||||
switch c {
|
||||
case '\r': // ignore this
|
||||
continue // and also don't track `last`
|
||||
case '\n': // ignore this too, but reset col
|
||||
if col >= length || last == '\n' {
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
buf.WriteString("\n")
|
||||
col = 0
|
||||
case ' ', '\t': // opportunity to split
|
||||
if col >= length {
|
||||
buf.WriteByte('\n')
|
||||
col = 0
|
||||
if isIndented {
|
||||
buf.WriteString(indent)
|
||||
col += 3
|
||||
}
|
||||
} else {
|
||||
// We only want to write a leading space if the col is greater than zero.
|
||||
// This will provide the proper spacing for documentation.
|
||||
buf.WriteRune(c)
|
||||
col++ // count column
|
||||
}
|
||||
default:
|
||||
buf.WriteRune(c)
|
||||
col++
|
||||
}
|
||||
lastNL = c == '\n'
|
||||
_ = lastNL
|
||||
last = c
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
type tagInfo struct {
|
||||
tag string
|
||||
key string
|
||||
val string
|
||||
txt string
|
||||
raw string
|
||||
closingTag bool
|
||||
}
|
||||
|
||||
// generateDoc will generate the proper doc string for html encoded or plain text doc entries.
|
||||
func generateDoc(htmlSrc string) string {
|
||||
tokenizer := xhtml.NewTokenizer(strings.NewReader(htmlSrc))
|
||||
tokens := buildTokenArray(tokenizer)
|
||||
scopes := findScopes(tokens)
|
||||
return walk(scopes)
|
||||
}
|
||||
|
||||
func buildTokenArray(tokenizer *xhtml.Tokenizer) []tagInfo {
|
||||
tokens := []tagInfo{}
|
||||
for tt := tokenizer.Next(); tt != xhtml.ErrorToken; tt = tokenizer.Next() {
|
||||
switch tt {
|
||||
case xhtml.TextToken:
|
||||
txt := string(tokenizer.Text())
|
||||
if len(tokens) == 0 {
|
||||
info := tagInfo{
|
||||
raw: txt,
|
||||
}
|
||||
tokens = append(tokens, info)
|
||||
}
|
||||
tn, _ := tokenizer.TagName()
|
||||
key, val, _ := tokenizer.TagAttr()
|
||||
info := tagInfo{
|
||||
tag: string(tn),
|
||||
key: string(key),
|
||||
val: string(val),
|
||||
txt: txt,
|
||||
}
|
||||
tokens = append(tokens, info)
|
||||
case xhtml.StartTagToken:
|
||||
tn, _ := tokenizer.TagName()
|
||||
key, val, _ := tokenizer.TagAttr()
|
||||
info := tagInfo{
|
||||
tag: string(tn),
|
||||
key: string(key),
|
||||
val: string(val),
|
||||
}
|
||||
tokens = append(tokens, info)
|
||||
case xhtml.SelfClosingTagToken, xhtml.EndTagToken:
|
||||
tn, _ := tokenizer.TagName()
|
||||
key, val, _ := tokenizer.TagAttr()
|
||||
info := tagInfo{
|
||||
tag: string(tn),
|
||||
key: string(key),
|
||||
val: string(val),
|
||||
closingTag: true,
|
||||
}
|
||||
tokens = append(tokens, info)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// walk is used to traverse each scoped block. These scoped
|
||||
// blocks will act as blocked text where we do most of our
|
||||
// text manipulation.
|
||||
func walk(scopes [][]tagInfo) string {
|
||||
doc := ""
|
||||
// Documentation will be chunked by scopes.
|
||||
// Meaning, for each scope will be divided by one or more newlines.
|
||||
for _, scope := range scopes {
|
||||
indentStr, isIndented := priorityIndentation(scope)
|
||||
block := ""
|
||||
href := ""
|
||||
after := false
|
||||
level := 0
|
||||
lastTag := ""
|
||||
for _, token := range scope {
|
||||
if token.closingTag {
|
||||
endl := closeTag(token, level)
|
||||
block += endl
|
||||
level--
|
||||
lastTag = ""
|
||||
} else if token.txt == "" {
|
||||
if token.val != "" {
|
||||
href, after = formatText(token, "")
|
||||
}
|
||||
if level == 1 && isIndented {
|
||||
block += indentStr
|
||||
}
|
||||
level++
|
||||
lastTag = token.tag
|
||||
} else {
|
||||
if token.txt != " " {
|
||||
str, _ := formatText(token, lastTag)
|
||||
block += str
|
||||
if after {
|
||||
block += href
|
||||
after = false
|
||||
}
|
||||
} else {
|
||||
fmt.Println(token.tag)
|
||||
str, _ := formatText(tagInfo{}, lastTag)
|
||||
block += str
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isIndented {
|
||||
block = strings.TrimPrefix(block, " ")
|
||||
}
|
||||
block = wrap(block, 72, isIndented)
|
||||
doc += block
|
||||
}
|
||||
return doc
|
||||
}
|
||||
|
||||
// closeTag will divide up the blocks of documentation to be formated properly.
|
||||
func closeTag(token tagInfo, level int) string {
|
||||
switch token.tag {
|
||||
case "pre", "li", "div":
|
||||
return "\n"
|
||||
case "p", "h1", "h2", "h3", "h4", "h5", "h6":
|
||||
return "\n\n"
|
||||
case "code":
|
||||
// indented code is only at the 0th level.
|
||||
if level == 0 {
|
||||
return "\n"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// formatText will format any sort of text based off of a tag. It will also return
|
||||
// a boolean to add the string after the text token.
|
||||
func formatText(token tagInfo, lastTag string) (string, bool) {
|
||||
switch token.tag {
|
||||
case "a":
|
||||
if token.val != "" {
|
||||
return fmt.Sprintf(" (%s)", token.val), true
|
||||
}
|
||||
}
|
||||
|
||||
// We don't care about a single space nor no text.
|
||||
if len(token.txt) == 0 || token.txt == " " {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Here we want to indent code blocks that are newlines
|
||||
if lastTag == "code" {
|
||||
// Greater than one, because we don't care about newlines in the beginning
|
||||
block := ""
|
||||
if lines := strings.Split(token.txt, "\n"); len(lines) > 1 {
|
||||
for _, line := range lines {
|
||||
block += indent + line
|
||||
}
|
||||
block += "\n"
|
||||
return block, false
|
||||
}
|
||||
}
|
||||
return token.txt, false
|
||||
}
|
||||
|
||||
// This is a parser to check what type of indention is needed.
|
||||
func priorityIndentation(blocks []tagInfo) (string, bool) {
|
||||
if len(blocks) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
v, ok := style[blocks[0].tag]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Divides into scopes based off levels.
|
||||
// For instance,
|
||||
// <p>Testing<code>123</code></p><ul><li>Foo</li></ul>
|
||||
// This has 2 scopes, the <p> and <ul>
|
||||
func findScopes(tokens []tagInfo) [][]tagInfo {
|
||||
level := 0
|
||||
scope := []tagInfo{}
|
||||
scopes := [][]tagInfo{}
|
||||
for _, token := range tokens {
|
||||
// we will clear empty tagged tokens from the array
|
||||
txt := strings.TrimSpace(token.txt)
|
||||
tag := strings.TrimSpace(token.tag)
|
||||
if len(txt) == 0 && len(tag) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
scope = append(scope, token)
|
||||
|
||||
// If it is a closing tag then we check what level
|
||||
// we are on. If it is 0, then that means we have found a
|
||||
// scoped block.
|
||||
if token.closingTag {
|
||||
level--
|
||||
if level == 0 {
|
||||
scopes = append(scopes, scope)
|
||||
scope = []tagInfo{}
|
||||
}
|
||||
// Check opening tags and increment the level
|
||||
} else if token.txt == "" {
|
||||
level++
|
||||
}
|
||||
}
|
||||
// In this case, we did not run into a closing tag. This would mean
|
||||
// we have plaintext for documentation.
|
||||
if len(scopes) == 0 {
|
||||
scopes = append(scopes, scope)
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
+82
@@ -0,0 +1,82 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNonHTMLDocGen(t *testing.T) {
|
||||
doc := "Testing 1 2 3"
|
||||
expected := "// Testing 1 2 3\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
}
|
||||
|
||||
func TestListsHTMLDocGen(t *testing.T) {
|
||||
doc := "<ul><li>Testing 1 2 3</li> <li>FooBar</li></ul>"
|
||||
expected := "// * Testing 1 2 3\n// * FooBar\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
|
||||
doc = "<ul> <li>Testing 1 2 3</li> <li>FooBar</li> </ul>"
|
||||
expected = "// * Testing 1 2 3\n// * FooBar\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
|
||||
// Test leading spaces
|
||||
doc = " <ul> <li>Testing 1 2 3</li> <li>FooBar</li> </ul>"
|
||||
doc = docstring(doc)
|
||||
assert.Equal(t, expected, doc)
|
||||
|
||||
// Paragraph check
|
||||
doc = "<ul> <li> <p>Testing 1 2 3</p> </li><li> <p>FooBar</p></li></ul>"
|
||||
expected = "// * Testing 1 2 3\n// \n// * FooBar\n"
|
||||
doc = docstring(doc)
|
||||
assert.Equal(t, expected, doc)
|
||||
}
|
||||
|
||||
func TestInlineCodeHTMLDocGen(t *testing.T) {
|
||||
doc := "<ul> <li><code>Testing</code>: 1 2 3</li> <li>FooBar</li> </ul>"
|
||||
expected := "// * Testing: 1 2 3\n// * FooBar\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
}
|
||||
|
||||
func TestInlineCodeInParagraphHTMLDocGen(t *testing.T) {
|
||||
doc := "<p><code>Testing</code>: 1 2 3</p>"
|
||||
expected := "// Testing: 1 2 3\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
}
|
||||
|
||||
func TestEmptyPREInlineCodeHTMLDocGen(t *testing.T) {
|
||||
doc := "<pre><code>Testing</code></pre>"
|
||||
expected := "// Testing\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
}
|
||||
|
||||
func TestParagraph(t *testing.T) {
|
||||
doc := "<p>Testing 1 2 3</p>"
|
||||
expected := "// Testing 1 2 3\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
}
|
||||
|
||||
func TestComplexListParagraphCode(t *testing.T) {
|
||||
doc := "<ul> <li><p><code>FOO</code> Bar</p></li><li><p><code>Xyz</code> ABC</p></li></ul>"
|
||||
expected := "// * FOO Bar\n// \n// * Xyz ABC\n"
|
||||
doc = docstring(doc)
|
||||
|
||||
assert.Equal(t, expected, doc)
|
||||
}
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import "strings"
|
||||
|
||||
// ExportableName a name which is exportable as a value or name in Go code
|
||||
func (a *API) ExportableName(name string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
|
||||
return strings.ToUpper(name[0:1]) + name[1:]
|
||||
}
|
||||
+74
@@ -0,0 +1,74 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Load takes a set of files for each filetype and returns an API pointer.
|
||||
// The API will be initialized once all files have been loaded and parsed.
|
||||
//
|
||||
// Will panic if any failure opening the definition JSON files, or there
|
||||
// are unrecognized exported names.
|
||||
func Load(api, docs, paginators, waiters string) *API {
|
||||
a := API{}
|
||||
a.Attach(api)
|
||||
a.Attach(docs)
|
||||
a.Attach(paginators)
|
||||
a.Attach(waiters)
|
||||
a.Setup()
|
||||
return &a
|
||||
}
|
||||
|
||||
// Attach opens a file by name, and unmarshal its JSON data.
|
||||
// Will proceed to setup the API if not already done so.
|
||||
func (a *API) Attach(filename string) {
|
||||
a.path = filepath.Dir(filename)
|
||||
f, err := os.Open(filename)
|
||||
defer f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := json.NewDecoder(f).Decode(a); err != nil {
|
||||
panic(fmt.Errorf("failed to decode %s, err: %v", filename, err))
|
||||
}
|
||||
}
|
||||
|
||||
// AttachString will unmarshal a raw JSON string, and setup the
|
||||
// API if not already done so.
|
||||
func (a *API) AttachString(str string) {
|
||||
json.Unmarshal([]byte(str), a)
|
||||
|
||||
if !a.initialized {
|
||||
a.Setup()
|
||||
}
|
||||
}
|
||||
|
||||
// Setup initializes the API.
|
||||
func (a *API) Setup() {
|
||||
a.setMetadataEndpointsKey()
|
||||
a.writeShapeNames()
|
||||
a.resolveReferences()
|
||||
a.fixStutterNames()
|
||||
a.renameExportable()
|
||||
if !a.NoRenameToplevelShapes {
|
||||
a.renameToplevelShapes()
|
||||
}
|
||||
a.updateTopLevelShapeReferences()
|
||||
a.createInputOutputShapes()
|
||||
a.customizationPasses()
|
||||
|
||||
if !a.NoRemoveUnusedShapes {
|
||||
a.removeUnusedShapes()
|
||||
}
|
||||
|
||||
if !a.NoValidataShapeMethods {
|
||||
a.addShapeValidations()
|
||||
}
|
||||
|
||||
a.initialized = true
|
||||
}
|
||||
+32
@@ -0,0 +1,32 @@
|
||||
// +build 1.6,codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResolvedReferences(t *testing.T) {
|
||||
json := `{
|
||||
"operations": {
|
||||
"OperationName": {
|
||||
"input": { "shape": "TestName" }
|
||||
}
|
||||
},
|
||||
"shapes": {
|
||||
"TestName": {
|
||||
"type": "structure",
|
||||
"members": {
|
||||
"memberName1": { "shape": "OtherTest" },
|
||||
"memberName2": { "shape": "OtherTest" }
|
||||
}
|
||||
},
|
||||
"OtherTest": { "type": "string" }
|
||||
}
|
||||
}`
|
||||
a := API{}
|
||||
a.AttachString(json)
|
||||
assert.Equal(t, len(a.Shapes["OtherTest"].refs), 2)
|
||||
}
|
||||
+489
@@ -0,0 +1,489 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// An Operation defines a specific API Operation.
|
||||
type Operation struct {
|
||||
API *API `json:"-"`
|
||||
ExportedName string
|
||||
Name string
|
||||
Documentation string
|
||||
HTTP HTTPInfo
|
||||
InputRef ShapeRef `json:"input"`
|
||||
OutputRef ShapeRef `json:"output"`
|
||||
ErrorRefs []ShapeRef `json:"errors"`
|
||||
Paginator *Paginator
|
||||
Deprecated bool `json:"deprecated"`
|
||||
AuthType string `json:"authtype"`
|
||||
imports map[string]bool
|
||||
}
|
||||
|
||||
// A HTTPInfo defines the method of HTTP request for the Operation.
|
||||
type HTTPInfo struct {
|
||||
Method string
|
||||
RequestURI string
|
||||
ResponseCode uint
|
||||
}
|
||||
|
||||
// HasInput returns if the Operation accepts an input paramater
|
||||
func (o *Operation) HasInput() bool {
|
||||
return o.InputRef.ShapeName != ""
|
||||
}
|
||||
|
||||
// HasOutput returns if the Operation accepts an output parameter
|
||||
func (o *Operation) HasOutput() bool {
|
||||
return o.OutputRef.ShapeName != ""
|
||||
}
|
||||
|
||||
func (o *Operation) GetSigner() string {
|
||||
if o.AuthType == "v4-unsigned-body" {
|
||||
o.API.imports["github.com/aws/aws-sdk-go/aws/signer/v4"] = true
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
|
||||
switch o.AuthType {
|
||||
case "none":
|
||||
buf.WriteString("req.Config.Credentials = credentials.AnonymousCredentials")
|
||||
case "v4-unsigned-body":
|
||||
buf.WriteString("req.Handlers.Sign.Remove(v4.SignRequestHandler)\n")
|
||||
buf.WriteString("handler := v4.BuildNamedHandler(\"v4.CustomSignerHandler\", v4.WithUnsignedPayload)\n")
|
||||
buf.WriteString("req.Handlers.Sign.PushFrontNamed(handler)")
|
||||
}
|
||||
|
||||
buf.WriteString("\n")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// tplOperation defines a template for rendering an API Operation
|
||||
var tplOperation = template.Must(template.New("operation").Funcs(template.FuncMap{
|
||||
"GetCrosslinkURL": GetCrosslinkURL,
|
||||
}).Parse(`
|
||||
const op{{ .ExportedName }} = "{{ .Name }}"
|
||||
|
||||
// {{ .ExportedName }}Request generates a "aws/request.Request" representing the
|
||||
// client's request for the {{ .ExportedName }} operation. The "output" return
|
||||
// value can be used to capture response data after the request's "Send" method
|
||||
// is called.
|
||||
//
|
||||
// See {{ .ExportedName }} for usage and error information.
|
||||
//
|
||||
// Creating a request object using this method should be used when you want to inject
|
||||
// custom logic into the request's lifecycle using a custom handler, or if you want to
|
||||
// access properties on the request object before or after sending the request. If
|
||||
// you just want the service response, call the {{ .ExportedName }} method directly
|
||||
// instead.
|
||||
//
|
||||
// Note: You must call the "Send" method on the returned request object in order
|
||||
// to execute the request.
|
||||
//
|
||||
// // Example sending a request using the {{ .ExportedName }}Request method.
|
||||
// req, resp := client.{{ .ExportedName }}Request(params)
|
||||
//
|
||||
// err := req.Send()
|
||||
// if err == nil { // resp is now filled
|
||||
// fmt.Println(resp)
|
||||
// }
|
||||
{{ $crosslinkURL := GetCrosslinkURL $.API.BaseCrosslinkURL $.API.APIName $.API.Metadata.UID $.ExportedName -}}
|
||||
{{ if ne $crosslinkURL "" -}}
|
||||
//
|
||||
// Please also see {{ $crosslinkURL }}
|
||||
{{ end -}}
|
||||
func (c *{{ .API.StructName }}) {{ .ExportedName }}Request(` +
|
||||
`input {{ .InputRef.GoType }}) (req *request.Request, output {{ .OutputRef.GoType }}) {
|
||||
{{ if (or .Deprecated (or .InputRef.Deprecated .OutputRef.Deprecated)) }}if c.Client.Config.Logger != nil {
|
||||
c.Client.Config.Logger.Log("This operation, {{ .ExportedName }}, has been deprecated")
|
||||
}
|
||||
op := &request.Operation{ {{ else }} op := &request.Operation{ {{ end }}
|
||||
Name: op{{ .ExportedName }},
|
||||
{{ if ne .HTTP.Method "" }}HTTPMethod: "{{ .HTTP.Method }}",
|
||||
{{ end }}HTTPPath: {{ if ne .HTTP.RequestURI "" }}"{{ .HTTP.RequestURI }}"{{ else }}"/"{{ end }},
|
||||
{{ if .Paginator }}Paginator: &request.Paginator{
|
||||
InputTokens: {{ .Paginator.InputTokensString }},
|
||||
OutputTokens: {{ .Paginator.OutputTokensString }},
|
||||
LimitToken: "{{ .Paginator.LimitKey }}",
|
||||
TruncationToken: "{{ .Paginator.MoreResults }}",
|
||||
},
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
if input == nil {
|
||||
input = &{{ .InputRef.GoTypeElem }}{}
|
||||
}
|
||||
|
||||
output = &{{ .OutputRef.GoTypeElem }}{}
|
||||
req = c.newRequest(op, input, output){{ if eq .OutputRef.Shape.Placeholder true }}
|
||||
req.Handlers.Unmarshal.Remove({{ .API.ProtocolPackage }}.UnmarshalHandler)
|
||||
req.Handlers.Unmarshal.PushBackNamed(protocol.UnmarshalDiscardBodyHandler){{ end }}
|
||||
{{ if ne .AuthType "" }}{{ .GetSigner }}{{ end -}}
|
||||
return
|
||||
}
|
||||
|
||||
// {{ .ExportedName }} API operation for {{ .API.Metadata.ServiceFullName }}.
|
||||
{{ if .Documentation -}}
|
||||
//
|
||||
{{ .Documentation }}
|
||||
{{ end -}}
|
||||
//
|
||||
// Returns awserr.Error for service API and SDK errors. Use runtime type assertions
|
||||
// with awserr.Error's Code and Message methods to get detailed information about
|
||||
// the error.
|
||||
//
|
||||
// See the AWS API reference guide for {{ .API.Metadata.ServiceFullName }}'s
|
||||
// API operation {{ .ExportedName }} for usage and error information.
|
||||
{{ if .ErrorRefs -}}
|
||||
//
|
||||
// Returned Error Codes:
|
||||
{{ range $_, $err := .ErrorRefs -}}
|
||||
// * {{ $err.Shape.ErrorCodeName }} "{{ $err.Shape.ErrorName}}"
|
||||
{{ if $err.Docstring -}}
|
||||
{{ $err.IndentedDocstring }}
|
||||
{{ end -}}
|
||||
//
|
||||
{{ end -}}
|
||||
{{ end -}}
|
||||
{{ $crosslinkURL := GetCrosslinkURL $.API.BaseCrosslinkURL $.API.APIName $.API.Metadata.UID $.ExportedName -}}
|
||||
{{ if ne $crosslinkURL "" -}}
|
||||
// Please also see {{ $crosslinkURL }}
|
||||
{{ end -}}
|
||||
func (c *{{ .API.StructName }}) {{ .ExportedName }}(` +
|
||||
`input {{ .InputRef.GoType }}) ({{ .OutputRef.GoType }}, error) {
|
||||
req, out := c.{{ .ExportedName }}Request(input)
|
||||
return out, req.Send()
|
||||
}
|
||||
|
||||
// {{ .ExportedName }}WithContext is the same as {{ .ExportedName }} with the addition of
|
||||
// the ability to pass a context and additional request options.
|
||||
//
|
||||
// See {{ .ExportedName }} for details on how to use this API operation.
|
||||
//
|
||||
// The context must be non-nil and will be used for request cancellation. If
|
||||
// the context is nil a panic will occur. In the future the SDK may create
|
||||
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
|
||||
// for more information on using Contexts.
|
||||
func (c *{{ .API.StructName }}) {{ .ExportedName }}WithContext(` +
|
||||
`ctx aws.Context, input {{ .InputRef.GoType }}, opts ...request.Option) ` +
|
||||
`({{ .OutputRef.GoType }}, error) {
|
||||
req, out := c.{{ .ExportedName }}Request(input)
|
||||
req.SetContext(ctx)
|
||||
req.ApplyOptions(opts...)
|
||||
return out, req.Send()
|
||||
}
|
||||
|
||||
{{ if .Paginator }}
|
||||
// {{ .ExportedName }}Pages iterates over the pages of a {{ .ExportedName }} operation,
|
||||
// calling the "fn" function with the response data for each page. To stop
|
||||
// iterating, return false from the fn function.
|
||||
//
|
||||
// See {{ .ExportedName }} method for more information on how to use this operation.
|
||||
//
|
||||
// Note: This operation can generate multiple requests to a service.
|
||||
//
|
||||
// // Example iterating over at most 3 pages of a {{ .ExportedName }} operation.
|
||||
// pageNum := 0
|
||||
// err := client.{{ .ExportedName }}Pages(params,
|
||||
// func(page {{ .OutputRef.GoType }}, lastPage bool) bool {
|
||||
// pageNum++
|
||||
// fmt.Println(page)
|
||||
// return pageNum <= 3
|
||||
// })
|
||||
//
|
||||
func (c *{{ .API.StructName }}) {{ .ExportedName }}Pages(` +
|
||||
`input {{ .InputRef.GoType }}, fn func({{ .OutputRef.GoType }}, bool) bool) error {
|
||||
return c.{{ .ExportedName }}PagesWithContext(aws.BackgroundContext(), input, fn)
|
||||
}
|
||||
|
||||
// {{ .ExportedName }}PagesWithContext same as {{ .ExportedName }}Pages except
|
||||
// it takes a Context and allows setting request options on the pages.
|
||||
//
|
||||
// The context must be non-nil and will be used for request cancellation. If
|
||||
// the context is nil a panic will occur. In the future the SDK may create
|
||||
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
|
||||
// for more information on using Contexts.
|
||||
func (c *{{ .API.StructName }}) {{ .ExportedName }}PagesWithContext(` +
|
||||
`ctx aws.Context, ` +
|
||||
`input {{ .InputRef.GoType }}, ` +
|
||||
`fn func({{ .OutputRef.GoType }}, bool) bool, ` +
|
||||
`opts ...request.Option) error {
|
||||
p := request.Pagination {
|
||||
NewRequest: func() (*request.Request, error) {
|
||||
inCpy := *input
|
||||
req, _ := c.{{ .ExportedName }}Request(&inCpy)
|
||||
req.SetContext(ctx)
|
||||
req.ApplyOptions(opts...)
|
||||
return req, nil
|
||||
},
|
||||
}
|
||||
|
||||
cont := true
|
||||
for p.Next() && cont {
|
||||
cont = fn(p.Page().({{ .OutputRef.GoType }}), !p.HasNextPage())
|
||||
}
|
||||
return p.Err()
|
||||
}
|
||||
{{ end }}
|
||||
`))
|
||||
|
||||
// GoCode returns a string of rendered GoCode for this Operation
|
||||
func (o *Operation) GoCode() string {
|
||||
var buf bytes.Buffer
|
||||
err := tplOperation.Execute(&buf, o)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(buf.String())
|
||||
}
|
||||
|
||||
// tplInfSig defines the template for rendering an Operation's signature within an Interface definition.
|
||||
var tplInfSig = template.Must(template.New("opsig").Parse(`
|
||||
{{ .ExportedName }}({{ .InputRef.GoTypeWithPkgName }}) ({{ .OutputRef.GoTypeWithPkgName }}, error)
|
||||
{{ .ExportedName }}WithContext(aws.Context, {{ .InputRef.GoTypeWithPkgName }}, ...request.Option) ({{ .OutputRef.GoTypeWithPkgName }}, error)
|
||||
{{ .ExportedName }}Request({{ .InputRef.GoTypeWithPkgName }}) (*request.Request, {{ .OutputRef.GoTypeWithPkgName }})
|
||||
|
||||
{{ if .Paginator -}}
|
||||
{{ .ExportedName }}Pages({{ .InputRef.GoTypeWithPkgName }}, func({{ .OutputRef.GoTypeWithPkgName }}, bool) bool) error
|
||||
{{ .ExportedName }}PagesWithContext(aws.Context, {{ .InputRef.GoTypeWithPkgName }}, func({{ .OutputRef.GoTypeWithPkgName }}, bool) bool, ...request.Option) error
|
||||
{{- end }}
|
||||
`))
|
||||
|
||||
// InterfaceSignature returns a string representing the Operation's interface{}
|
||||
// functional signature.
|
||||
func (o *Operation) InterfaceSignature() string {
|
||||
var buf bytes.Buffer
|
||||
err := tplInfSig.Execute(&buf, o)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(buf.String())
|
||||
}
|
||||
|
||||
// tplExample defines the template for rendering an Operation example
|
||||
var tplExample = template.Must(template.New("operationExample").Parse(`
|
||||
func Example{{ .API.StructName }}_{{ .ExportedName }}() {
|
||||
sess := session.Must(session.NewSession())
|
||||
|
||||
svc := {{ .API.PackageName }}.New(sess)
|
||||
|
||||
{{ .ExampleInput }}
|
||||
resp, err := svc.{{ .ExportedName }}(params)
|
||||
|
||||
if err != nil {
|
||||
// Print the error, cast err to awserr.Error to get the Code and
|
||||
// Message from an error.
|
||||
fmt.Println(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Pretty-print the response data.
|
||||
fmt.Println(resp)
|
||||
}
|
||||
`))
|
||||
|
||||
// Example returns a string of the rendered Go code for the Operation
|
||||
func (o *Operation) Example() string {
|
||||
var buf bytes.Buffer
|
||||
err := tplExample.Execute(&buf, o)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(buf.String())
|
||||
}
|
||||
|
||||
// ExampleInput return a string of the rendered Go code for an example's input parameters
|
||||
func (o *Operation) ExampleInput() string {
|
||||
if len(o.InputRef.Shape.MemberRefs) == 0 {
|
||||
if strings.Contains(o.InputRef.GoTypeElem(), ".") {
|
||||
o.imports["github.com/aws/aws-sdk-go/service/"+strings.Split(o.InputRef.GoTypeElem(), ".")[0]] = true
|
||||
return fmt.Sprintf("var params *%s", o.InputRef.GoTypeElem())
|
||||
}
|
||||
return fmt.Sprintf("var params *%s.%s",
|
||||
o.API.PackageName(), o.InputRef.GoTypeElem())
|
||||
}
|
||||
e := example{o, map[string]int{}}
|
||||
return "params := " + e.traverseAny(o.InputRef.Shape, false, false)
|
||||
}
|
||||
|
||||
// A example provides
|
||||
type example struct {
|
||||
*Operation
|
||||
visited map[string]int
|
||||
}
|
||||
|
||||
// traverseAny returns rendered Go code for the shape.
|
||||
func (e *example) traverseAny(s *Shape, required, payload bool) string {
|
||||
str := ""
|
||||
e.visited[s.ShapeName]++
|
||||
|
||||
switch s.Type {
|
||||
case "structure":
|
||||
str = e.traverseStruct(s, required, payload)
|
||||
case "list":
|
||||
str = e.traverseList(s, required, payload)
|
||||
case "map":
|
||||
str = e.traverseMap(s, required, payload)
|
||||
case "jsonvalue":
|
||||
str = "aws.JSONValue{\"key\": \"value\"}"
|
||||
if required {
|
||||
str += " // Required"
|
||||
}
|
||||
default:
|
||||
str = e.traverseScalar(s, required, payload)
|
||||
}
|
||||
|
||||
e.visited[s.ShapeName]--
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
var reType = regexp.MustCompile(`\b([A-Z])`)
|
||||
|
||||
// traverseStruct returns rendered Go code for a structure type shape.
|
||||
func (e *example) traverseStruct(s *Shape, required, payload bool) string {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if s.resolvePkg != "" {
|
||||
e.imports[s.resolvePkg] = true
|
||||
buf.WriteString("&" + s.GoTypeElem() + "{")
|
||||
} else {
|
||||
buf.WriteString("&" + s.API.PackageName() + "." + s.GoTypeElem() + "{")
|
||||
}
|
||||
|
||||
if required {
|
||||
buf.WriteString(" // Required")
|
||||
}
|
||||
buf.WriteString("\n")
|
||||
|
||||
req := make([]string, len(s.Required))
|
||||
copy(req, s.Required)
|
||||
sort.Strings(req)
|
||||
|
||||
if e.visited[s.ShapeName] < 2 {
|
||||
for _, n := range req {
|
||||
m := s.MemberRefs[n].Shape
|
||||
p := n == s.Payload && (s.MemberRefs[n].Streaming || m.Streaming)
|
||||
buf.WriteString(n + ": " + e.traverseAny(m, true, p) + ",")
|
||||
if m.Type != "list" && m.Type != "structure" && m.Type != "map" {
|
||||
buf.WriteString(" // Required")
|
||||
}
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, n := range s.MemberNames() {
|
||||
if s.IsRequired(n) {
|
||||
continue
|
||||
}
|
||||
m := s.MemberRefs[n].Shape
|
||||
p := n == s.Payload && (s.MemberRefs[n].Streaming || m.Streaming)
|
||||
buf.WriteString(n + ": " + e.traverseAny(m, false, p) + ",\n")
|
||||
}
|
||||
} else {
|
||||
buf.WriteString("// Recursive values...\n")
|
||||
}
|
||||
|
||||
buf.WriteString("}")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// traverseMap returns rendered Go code for a map type shape.
|
||||
func (e *example) traverseMap(s *Shape, required, payload bool) string {
|
||||
var buf bytes.Buffer
|
||||
|
||||
t := ""
|
||||
if s.resolvePkg != "" {
|
||||
e.imports[s.resolvePkg] = true
|
||||
t = s.GoTypeElem()
|
||||
} else {
|
||||
t = reType.ReplaceAllString(s.GoTypeElem(), s.API.PackageName()+".$1")
|
||||
}
|
||||
buf.WriteString(t + "{")
|
||||
if required {
|
||||
buf.WriteString(" // Required")
|
||||
}
|
||||
buf.WriteString("\n")
|
||||
|
||||
if e.visited[s.ShapeName] < 2 {
|
||||
m := s.ValueRef.Shape
|
||||
buf.WriteString("\"Key\": " + e.traverseAny(m, true, false) + ",")
|
||||
if m.Type != "list" && m.Type != "structure" && m.Type != "map" {
|
||||
buf.WriteString(" // Required")
|
||||
}
|
||||
buf.WriteString("\n// More values...\n")
|
||||
} else {
|
||||
buf.WriteString("// Recursive values...\n")
|
||||
}
|
||||
buf.WriteString("}")
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// traverseList returns rendered Go code for a list type shape.
|
||||
func (e *example) traverseList(s *Shape, required, payload bool) string {
|
||||
var buf bytes.Buffer
|
||||
t := ""
|
||||
if s.resolvePkg != "" {
|
||||
e.imports[s.resolvePkg] = true
|
||||
t = s.GoTypeElem()
|
||||
} else {
|
||||
t = reType.ReplaceAllString(s.GoTypeElem(), s.API.PackageName()+".$1")
|
||||
}
|
||||
|
||||
buf.WriteString(t + "{")
|
||||
if required {
|
||||
buf.WriteString(" // Required")
|
||||
}
|
||||
buf.WriteString("\n")
|
||||
|
||||
if e.visited[s.ShapeName] < 2 {
|
||||
m := s.MemberRef.Shape
|
||||
buf.WriteString(e.traverseAny(m, true, false) + ",")
|
||||
if m.Type != "list" && m.Type != "structure" && m.Type != "map" {
|
||||
buf.WriteString(" // Required")
|
||||
}
|
||||
buf.WriteString("\n// More values...\n")
|
||||
} else {
|
||||
buf.WriteString("// Recursive values...\n")
|
||||
}
|
||||
buf.WriteString("}")
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// traverseScalar returns an AWS Type string representation initialized to a value.
|
||||
// Will panic if s is an unsupported shape type.
|
||||
func (e *example) traverseScalar(s *Shape, required, payload bool) string {
|
||||
str := ""
|
||||
switch s.Type {
|
||||
case "integer", "long":
|
||||
str = `aws.Int64(1)`
|
||||
case "float", "double":
|
||||
str = `aws.Float64(1.0)`
|
||||
case "string", "character":
|
||||
str = `aws.String("` + s.ShapeName + `")`
|
||||
case "blob":
|
||||
if payload {
|
||||
str = `bytes.NewReader([]byte("PAYLOAD"))`
|
||||
} else {
|
||||
str = `[]byte("PAYLOAD")`
|
||||
}
|
||||
case "boolean":
|
||||
str = `aws.Bool(true)`
|
||||
case "timestamp":
|
||||
str = `aws.Time(time.Now())`
|
||||
default:
|
||||
panic("unsupported shape " + s.Type)
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
+91
@@ -0,0 +1,91 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Paginator keeps track of pagination configuration for an API operation.
|
||||
type Paginator struct {
|
||||
InputTokens interface{} `json:"input_token"`
|
||||
OutputTokens interface{} `json:"output_token"`
|
||||
LimitKey string `json:"limit_key"`
|
||||
MoreResults string `json:"more_results"`
|
||||
}
|
||||
|
||||
// InputTokensString returns output tokens formatted as a list
|
||||
func (p *Paginator) InputTokensString() string {
|
||||
str := p.InputTokens.([]string)
|
||||
return fmt.Sprintf("%#v", str)
|
||||
}
|
||||
|
||||
// OutputTokensString returns output tokens formatted as a list
|
||||
func (p *Paginator) OutputTokensString() string {
|
||||
str := p.OutputTokens.([]string)
|
||||
return fmt.Sprintf("%#v", str)
|
||||
}
|
||||
|
||||
// used for unmarshaling from the paginators JSON file
|
||||
type paginationDefinitions struct {
|
||||
*API
|
||||
Pagination map[string]Paginator
|
||||
}
|
||||
|
||||
// AttachPaginators attaches pagination configuration from filename to the API.
|
||||
func (a *API) AttachPaginators(filename string) {
|
||||
p := paginationDefinitions{API: a}
|
||||
|
||||
f, err := os.Open(filename)
|
||||
defer f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = json.NewDecoder(f).Decode(&p)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
p.setup()
|
||||
}
|
||||
|
||||
// setup runs post-processing on the paginator configuration.
|
||||
func (p *paginationDefinitions) setup() {
|
||||
for n, e := range p.Pagination {
|
||||
if e.InputTokens == nil || e.OutputTokens == nil {
|
||||
continue
|
||||
}
|
||||
paginator := e
|
||||
|
||||
switch t := paginator.InputTokens.(type) {
|
||||
case string:
|
||||
paginator.InputTokens = []string{t}
|
||||
case []interface{}:
|
||||
toks := []string{}
|
||||
for _, e := range t {
|
||||
s := e.(string)
|
||||
toks = append(toks, s)
|
||||
}
|
||||
paginator.InputTokens = toks
|
||||
}
|
||||
switch t := paginator.OutputTokens.(type) {
|
||||
case string:
|
||||
paginator.OutputTokens = []string{t}
|
||||
case []interface{}:
|
||||
toks := []string{}
|
||||
for _, e := range t {
|
||||
s := e.(string)
|
||||
toks = append(toks, s)
|
||||
}
|
||||
paginator.OutputTokens = toks
|
||||
}
|
||||
|
||||
if o, ok := p.Operations[n]; ok {
|
||||
o.Paginator = &paginator
|
||||
} else {
|
||||
panic("unknown operation for paginator " + n)
|
||||
}
|
||||
}
|
||||
}
|
||||
+133
@@ -0,0 +1,133 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/util"
|
||||
)
|
||||
|
||||
// A paramFiller provides string formatting for a shape and its types.
|
||||
type paramFiller struct {
|
||||
prefixPackageName bool
|
||||
}
|
||||
|
||||
// typeName returns the type name of a shape.
|
||||
func (f paramFiller) typeName(shape *Shape) string {
|
||||
if f.prefixPackageName && shape.Type == "structure" {
|
||||
return "*" + shape.API.PackageName() + "." + shape.GoTypeElem()
|
||||
}
|
||||
return shape.GoType()
|
||||
}
|
||||
|
||||
// ParamsStructFromJSON returns a JSON string representation of a structure.
|
||||
func ParamsStructFromJSON(value interface{}, shape *Shape, prefixPackageName bool) string {
|
||||
f := paramFiller{prefixPackageName: prefixPackageName}
|
||||
return util.GoFmt(f.paramsStructAny(value, shape))
|
||||
}
|
||||
|
||||
// paramsStructAny returns the string representation of any value.
|
||||
func (f paramFiller) paramsStructAny(value interface{}, shape *Shape) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch shape.Type {
|
||||
case "structure":
|
||||
if value != nil {
|
||||
vmap := value.(map[string]interface{})
|
||||
return f.paramsStructStruct(vmap, shape)
|
||||
}
|
||||
case "list":
|
||||
vlist := value.([]interface{})
|
||||
return f.paramsStructList(vlist, shape)
|
||||
case "map":
|
||||
vmap := value.(map[string]interface{})
|
||||
return f.paramsStructMap(vmap, shape)
|
||||
case "string", "character":
|
||||
v := reflect.Indirect(reflect.ValueOf(value))
|
||||
if v.IsValid() {
|
||||
return fmt.Sprintf("aws.String(%#v)", v.Interface())
|
||||
}
|
||||
case "blob":
|
||||
v := reflect.Indirect(reflect.ValueOf(value))
|
||||
if v.IsValid() && shape.Streaming {
|
||||
return fmt.Sprintf("bytes.NewReader([]byte(%#v))", v.Interface())
|
||||
} else if v.IsValid() {
|
||||
return fmt.Sprintf("[]byte(%#v)", v.Interface())
|
||||
}
|
||||
case "boolean":
|
||||
v := reflect.Indirect(reflect.ValueOf(value))
|
||||
if v.IsValid() {
|
||||
return fmt.Sprintf("aws.Bool(%#v)", v.Interface())
|
||||
}
|
||||
case "integer", "long":
|
||||
v := reflect.Indirect(reflect.ValueOf(value))
|
||||
if v.IsValid() {
|
||||
return fmt.Sprintf("aws.Int64(%v)", v.Interface())
|
||||
}
|
||||
case "float", "double":
|
||||
v := reflect.Indirect(reflect.ValueOf(value))
|
||||
if v.IsValid() {
|
||||
return fmt.Sprintf("aws.Float64(%v)", v.Interface())
|
||||
}
|
||||
case "timestamp":
|
||||
v := reflect.Indirect(reflect.ValueOf(value))
|
||||
if v.IsValid() {
|
||||
return fmt.Sprintf("aws.Time(time.Unix(%d, 0))", int(v.Float()))
|
||||
}
|
||||
default:
|
||||
panic("Unhandled type " + shape.Type)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// paramsStructStruct returns the string representation of a structure
|
||||
func (f paramFiller) paramsStructStruct(value map[string]interface{}, shape *Shape) string {
|
||||
out := "&" + f.typeName(shape)[1:] + "{\n"
|
||||
for _, n := range shape.MemberNames() {
|
||||
ref := shape.MemberRefs[n]
|
||||
name := findParamMember(value, n)
|
||||
|
||||
if val := f.paramsStructAny(value[name], ref.Shape); val != "" {
|
||||
out += fmt.Sprintf("%s: %s,\n", n, val)
|
||||
}
|
||||
}
|
||||
out += "}"
|
||||
return out
|
||||
}
|
||||
|
||||
// paramsStructMap returns the string representation of a map of values
|
||||
func (f paramFiller) paramsStructMap(value map[string]interface{}, shape *Shape) string {
|
||||
out := f.typeName(shape) + "{\n"
|
||||
keys := util.SortedKeys(value)
|
||||
for _, k := range keys {
|
||||
v := value[k]
|
||||
out += fmt.Sprintf("%q: %s,\n", k, f.paramsStructAny(v, shape.ValueRef.Shape))
|
||||
}
|
||||
out += "}"
|
||||
return out
|
||||
}
|
||||
|
||||
// paramsStructList returns the string representation of slice of values
|
||||
func (f paramFiller) paramsStructList(value []interface{}, shape *Shape) string {
|
||||
out := f.typeName(shape) + "{\n"
|
||||
for _, v := range value {
|
||||
out += fmt.Sprintf("%s,\n", f.paramsStructAny(v, shape.MemberRef.Shape))
|
||||
}
|
||||
out += "}"
|
||||
return out
|
||||
}
|
||||
|
||||
// findParamMember searches a map for a key ignoring case. Returns the map key if found.
|
||||
func findParamMember(value map[string]interface{}, key string) string {
|
||||
for actualKey := range value {
|
||||
if strings.ToLower(key) == strings.ToLower(actualKey) {
|
||||
return actualKey
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
+296
@@ -0,0 +1,296 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// updateTopLevelShapeReferences moves resultWrapper, locationName, and
|
||||
// xmlNamespace traits from toplevel shape references to the toplevel
|
||||
// shapes for easier code generation
|
||||
func (a *API) updateTopLevelShapeReferences() {
|
||||
for _, o := range a.Operations {
|
||||
// these are for REST-XML services
|
||||
if o.InputRef.LocationName != "" {
|
||||
o.InputRef.Shape.LocationName = o.InputRef.LocationName
|
||||
}
|
||||
if o.InputRef.Location != "" {
|
||||
o.InputRef.Shape.Location = o.InputRef.Location
|
||||
}
|
||||
if o.InputRef.Payload != "" {
|
||||
o.InputRef.Shape.Payload = o.InputRef.Payload
|
||||
}
|
||||
if o.InputRef.XMLNamespace.Prefix != "" {
|
||||
o.InputRef.Shape.XMLNamespace.Prefix = o.InputRef.XMLNamespace.Prefix
|
||||
}
|
||||
if o.InputRef.XMLNamespace.URI != "" {
|
||||
o.InputRef.Shape.XMLNamespace.URI = o.InputRef.XMLNamespace.URI
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// writeShapeNames sets each shape's API and shape name values. Binding the
|
||||
// shape to its parent API.
|
||||
func (a *API) writeShapeNames() {
|
||||
for n, s := range a.Shapes {
|
||||
s.API = a
|
||||
s.ShapeName = n
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) resolveReferences() {
|
||||
resolver := referenceResolver{API: a, visited: map[*ShapeRef]bool{}}
|
||||
|
||||
for _, s := range a.Shapes {
|
||||
resolver.resolveShape(s)
|
||||
}
|
||||
|
||||
for _, o := range a.Operations {
|
||||
o.API = a // resolve parent reference
|
||||
|
||||
resolver.resolveReference(&o.InputRef)
|
||||
resolver.resolveReference(&o.OutputRef)
|
||||
|
||||
// Resolve references for errors also
|
||||
for i := range o.ErrorRefs {
|
||||
resolver.resolveReference(&o.ErrorRefs[i])
|
||||
o.ErrorRefs[i].Shape.IsError = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A referenceResolver provides a way to resolve shape references to
|
||||
// shape definitions.
|
||||
type referenceResolver struct {
|
||||
*API
|
||||
visited map[*ShapeRef]bool
|
||||
}
|
||||
|
||||
var jsonvalueShape = &Shape{
|
||||
ShapeName: "JSONValue",
|
||||
Type: "jsonvalue",
|
||||
ValueRef: ShapeRef{
|
||||
JSONValue: true,
|
||||
},
|
||||
}
|
||||
|
||||
// resolveReference updates a shape reference to reference the API and
|
||||
// its shape definition. All other nested references are also resolved.
|
||||
func (r *referenceResolver) resolveReference(ref *ShapeRef) {
|
||||
if ref.ShapeName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if shape, ok := r.API.Shapes[ref.ShapeName]; ok {
|
||||
if ref.JSONValue {
|
||||
ref.ShapeName = "JSONValue"
|
||||
r.API.Shapes[ref.ShapeName] = jsonvalueShape
|
||||
}
|
||||
|
||||
ref.API = r.API // resolve reference back to API
|
||||
ref.Shape = shape // resolve shape reference
|
||||
|
||||
if r.visited[ref] {
|
||||
return
|
||||
}
|
||||
r.visited[ref] = true
|
||||
|
||||
shape.refs = append(shape.refs, ref) // register the ref
|
||||
|
||||
// resolve shape's references, if it has any
|
||||
r.resolveShape(shape)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveShape resolves a shape's Member Key Value, and nested member
|
||||
// shape references.
|
||||
func (r *referenceResolver) resolveShape(shape *Shape) {
|
||||
r.resolveReference(&shape.MemberRef)
|
||||
r.resolveReference(&shape.KeyRef)
|
||||
r.resolveReference(&shape.ValueRef)
|
||||
for _, m := range shape.MemberRefs {
|
||||
r.resolveReference(m)
|
||||
}
|
||||
}
|
||||
|
||||
// renameToplevelShapes renames all top level shapes of an API to their
|
||||
// exportable variant. The shapes are also updated to include notations
|
||||
// if they are Input or Outputs.
|
||||
func (a *API) renameToplevelShapes() {
|
||||
for _, v := range a.Operations {
|
||||
if v.HasInput() {
|
||||
name := v.ExportedName + "Input"
|
||||
switch n := len(v.InputRef.Shape.refs); {
|
||||
case n == 1 && a.Shapes[name] == nil:
|
||||
v.InputRef.Shape.Rename(name)
|
||||
}
|
||||
}
|
||||
if v.HasOutput() {
|
||||
name := v.ExportedName + "Output"
|
||||
switch n := len(v.OutputRef.Shape.refs); {
|
||||
case n == 1 && a.Shapes[name] == nil:
|
||||
v.OutputRef.Shape.Rename(name)
|
||||
}
|
||||
}
|
||||
v.InputRef.Payload = a.ExportableName(v.InputRef.Payload)
|
||||
v.OutputRef.Payload = a.ExportableName(v.OutputRef.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
// fixStutterNames fixes all name struttering based on Go naming conventions.
|
||||
// "Stuttering" is when the prefix of a structure or function matches the
|
||||
// package name (case insensitive).
|
||||
func (a *API) fixStutterNames() {
|
||||
str, end := a.StructName(), ""
|
||||
if len(str) > 1 {
|
||||
l := len(str) - 1
|
||||
str, end = str[0:l], str[l:]
|
||||
}
|
||||
re := regexp.MustCompile(fmt.Sprintf(`\A(?i:%s)%s`, str, end))
|
||||
|
||||
for name, op := range a.Operations {
|
||||
newName := re.ReplaceAllString(name, "")
|
||||
if newName != name {
|
||||
delete(a.Operations, name)
|
||||
a.Operations[newName] = op
|
||||
}
|
||||
op.ExportedName = newName
|
||||
}
|
||||
|
||||
for k, s := range a.Shapes {
|
||||
newName := re.ReplaceAllString(k, "")
|
||||
if newName != s.ShapeName {
|
||||
s.Rename(newName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renameExportable renames all operation names to be exportable names.
|
||||
// All nested Shape names are also updated to the exportable variant.
|
||||
func (a *API) renameExportable() {
|
||||
for name, op := range a.Operations {
|
||||
newName := a.ExportableName(name)
|
||||
if newName != name {
|
||||
delete(a.Operations, name)
|
||||
a.Operations[newName] = op
|
||||
}
|
||||
op.ExportedName = newName
|
||||
}
|
||||
|
||||
for k, s := range a.Shapes {
|
||||
// FIXME SNS has lower and uppercased shape names with the same name,
|
||||
// except the lowercased variant is used exclusively for string and
|
||||
// other primitive types. Renaming both would cause a collision.
|
||||
// We work around this by only renaming the structure shapes.
|
||||
if s.Type == "string" {
|
||||
continue
|
||||
}
|
||||
|
||||
for mName, member := range s.MemberRefs {
|
||||
newName := a.ExportableName(mName)
|
||||
if newName != mName {
|
||||
delete(s.MemberRefs, mName)
|
||||
s.MemberRefs[newName] = member
|
||||
|
||||
// also apply locationName trait so we keep the old one
|
||||
// but only if there's no locationName trait on ref or shape
|
||||
if member.LocationName == "" && member.Shape.LocationName == "" {
|
||||
member.LocationName = mName
|
||||
}
|
||||
}
|
||||
|
||||
if newName == "_" {
|
||||
panic("Shape " + s.ShapeName + " uses reserved member name '_'")
|
||||
}
|
||||
}
|
||||
|
||||
newName := a.ExportableName(k)
|
||||
if newName != s.ShapeName {
|
||||
s.Rename(newName)
|
||||
}
|
||||
|
||||
s.Payload = a.ExportableName(s.Payload)
|
||||
|
||||
// fix required trait names
|
||||
for i, n := range s.Required {
|
||||
s.Required[i] = a.ExportableName(n)
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range a.Shapes {
|
||||
// fix enum names
|
||||
if s.IsEnum() {
|
||||
s.EnumConsts = make([]string, len(s.Enum))
|
||||
for i := range s.Enum {
|
||||
shape := s.ShapeName
|
||||
shape = strings.ToUpper(shape[0:1]) + shape[1:]
|
||||
s.EnumConsts[i] = shape + s.EnumName(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createInputOutputShapes creates toplevel input/output shapes if they
|
||||
// have not been defined in the API. This normalizes all APIs to always
|
||||
// have an input and output structure in the signature.
|
||||
func (a *API) createInputOutputShapes() {
|
||||
for _, op := range a.Operations {
|
||||
if !op.HasInput() {
|
||||
setAsPlacholderShape(&op.InputRef, op.ExportedName+"Input", a)
|
||||
}
|
||||
if !op.HasOutput() {
|
||||
setAsPlacholderShape(&op.OutputRef, op.ExportedName+"Output", a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setAsPlacholderShape(tgtShapeRef *ShapeRef, name string, a *API) {
|
||||
shape := a.makeIOShape(name)
|
||||
shape.Placeholder = true
|
||||
*tgtShapeRef = ShapeRef{API: a, ShapeName: shape.ShapeName, Shape: shape}
|
||||
shape.refs = append(shape.refs, tgtShapeRef)
|
||||
}
|
||||
|
||||
// makeIOShape returns a pointer to a new Shape initialized by the name provided.
|
||||
func (a *API) makeIOShape(name string) *Shape {
|
||||
shape := &Shape{
|
||||
API: a, ShapeName: name, Type: "structure",
|
||||
MemberRefs: map[string]*ShapeRef{},
|
||||
}
|
||||
a.Shapes[name] = shape
|
||||
return shape
|
||||
}
|
||||
|
||||
// removeUnusedShapes removes shapes from the API which are not referenced by any
|
||||
// other shape in the API.
|
||||
func (a *API) removeUnusedShapes() {
|
||||
for n, s := range a.Shapes {
|
||||
if len(s.refs) == 0 {
|
||||
delete(a.Shapes, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Represents the service package name to EndpointsID mapping
|
||||
var custEndpointsKey = map[string]string{
|
||||
"applicationautoscaling": "application-autoscaling",
|
||||
}
|
||||
|
||||
// Sents the EndpointsID field of Metadata with the value of the
|
||||
// EndpointPrefix if EndpointsID is not set. Also adds
|
||||
// customizations for services if EndpointPrefix is not a valid key.
|
||||
func (a *API) setMetadataEndpointsKey() {
|
||||
if len(a.Metadata.EndpointsID) != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := custEndpointsKey[a.PackageName()]; ok {
|
||||
a.Metadata.EndpointsID = v
|
||||
} else {
|
||||
a.Metadata.EndpointsID = a.Metadata.EndpointPrefix
|
||||
}
|
||||
}
|
||||
+636
@@ -0,0 +1,636 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"path"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// A ShapeRef defines the usage of a shape within the API.
|
||||
type ShapeRef struct {
|
||||
API *API `json:"-"`
|
||||
Shape *Shape `json:"-"`
|
||||
Documentation string
|
||||
ShapeName string `json:"shape"`
|
||||
Location string
|
||||
LocationName string
|
||||
QueryName string
|
||||
Flattened bool
|
||||
Streaming bool
|
||||
XMLAttribute bool
|
||||
// Ignore, if set, will not be sent over the wire
|
||||
Ignore bool
|
||||
XMLNamespace XMLInfo
|
||||
Payload string
|
||||
IdempotencyToken bool `json:"idempotencyToken"`
|
||||
JSONValue bool `json:"jsonvalue"`
|
||||
Deprecated bool `json:"deprecated"`
|
||||
|
||||
OrigShapeName string `json:"-"`
|
||||
}
|
||||
|
||||
// ErrorInfo represents the error block of a shape's structure
|
||||
type ErrorInfo struct {
|
||||
Code string
|
||||
HTTPStatusCode int
|
||||
}
|
||||
|
||||
// A XMLInfo defines URL and prefix for Shapes when rendered as XML
|
||||
type XMLInfo struct {
|
||||
Prefix string
|
||||
URI string
|
||||
}
|
||||
|
||||
// A Shape defines the definition of a shape type
|
||||
type Shape struct {
|
||||
API *API `json:"-"`
|
||||
ShapeName string
|
||||
Documentation string
|
||||
MemberRefs map[string]*ShapeRef `json:"members"`
|
||||
MemberRef ShapeRef `json:"member"`
|
||||
KeyRef ShapeRef `json:"key"`
|
||||
ValueRef ShapeRef `json:"value"`
|
||||
Required []string
|
||||
Payload string
|
||||
Type string
|
||||
Exception bool
|
||||
Enum []string
|
||||
EnumConsts []string
|
||||
Flattened bool
|
||||
Streaming bool
|
||||
Location string
|
||||
LocationName string
|
||||
IdempotencyToken bool `json:"idempotencyToken"`
|
||||
XMLNamespace XMLInfo
|
||||
Min float64 // optional Minimum length (string, list) or value (number)
|
||||
Max float64 // optional Maximum length (string, list) or value (number)
|
||||
|
||||
refs []*ShapeRef // References to this shape
|
||||
resolvePkg string // use this package in the goType() if present
|
||||
|
||||
OrigShapeName string `json:"-"`
|
||||
|
||||
// Defines if the shape is a placeholder and should not be used directly
|
||||
Placeholder bool
|
||||
|
||||
Deprecated bool `json:"deprecated"`
|
||||
|
||||
Validations ShapeValidations
|
||||
|
||||
// Error information that is set if the shape is an error shape.
|
||||
IsError bool
|
||||
ErrorInfo ErrorInfo `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorCodeName will return the error shape's name formated for
|
||||
// error code const.
|
||||
func (s *Shape) ErrorCodeName() string {
|
||||
return "ErrCode" + s.ShapeName
|
||||
}
|
||||
|
||||
// ErrorName will return the shape's name or error code if available based
|
||||
// on the API's protocol. This is the error code string returned by the service.
|
||||
func (s *Shape) ErrorName() string {
|
||||
name := s.ShapeName
|
||||
switch s.API.Metadata.Protocol {
|
||||
case "query", "ec2query", "rest-xml":
|
||||
if len(s.ErrorInfo.Code) > 0 {
|
||||
name = s.ErrorInfo.Code
|
||||
}
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
// GoTags returns the struct tags for a shape.
|
||||
func (s *Shape) GoTags(root, required bool) string {
|
||||
ref := &ShapeRef{ShapeName: s.ShapeName, API: s.API, Shape: s}
|
||||
return ref.GoTags(root, required)
|
||||
}
|
||||
|
||||
// Rename changes the name of the Shape to newName. Also updates
|
||||
// the associated API's reference to use newName.
|
||||
func (s *Shape) Rename(newName string) {
|
||||
for _, r := range s.refs {
|
||||
r.OrigShapeName = r.ShapeName
|
||||
r.ShapeName = newName
|
||||
}
|
||||
|
||||
delete(s.API.Shapes, s.ShapeName)
|
||||
s.OrigShapeName = s.ShapeName
|
||||
s.API.Shapes[newName] = s
|
||||
s.ShapeName = newName
|
||||
}
|
||||
|
||||
// MemberNames returns a slice of struct member names.
|
||||
func (s *Shape) MemberNames() []string {
|
||||
i, names := 0, make([]string, len(s.MemberRefs))
|
||||
for n := range s.MemberRefs {
|
||||
names[i] = n
|
||||
i++
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// GoTypeWithPkgName returns a shape's type as a string with the package name in
|
||||
// <packageName>.<type> format. Package naming only applies to structures.
|
||||
func (s *Shape) GoTypeWithPkgName() string {
|
||||
return goType(s, true)
|
||||
}
|
||||
|
||||
// GenAccessors returns if the shape's reference should have setters generated.
|
||||
func (s *ShapeRef) UseIndirection() bool {
|
||||
switch s.Shape.Type {
|
||||
case "map", "list", "blob", "structure", "jsonvalue":
|
||||
return false
|
||||
}
|
||||
|
||||
if s.Streaming || s.Shape.Streaming {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.JSONValue {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GoStructValueType returns the Shape's Go type value instead of a pointer
|
||||
// for the type.
|
||||
func (s *Shape) GoStructValueType(name string, ref *ShapeRef) string {
|
||||
v := s.GoStructType(name, ref)
|
||||
|
||||
if ref.UseIndirection() && v[0] == '*' {
|
||||
return v[1:]
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// GoStructType returns the type of a struct field based on the API
|
||||
// model definition.
|
||||
func (s *Shape) GoStructType(name string, ref *ShapeRef) string {
|
||||
if (ref.Streaming || ref.Shape.Streaming) && s.Payload == name {
|
||||
rtype := "io.ReadSeeker"
|
||||
if strings.HasSuffix(s.ShapeName, "Output") {
|
||||
rtype = "io.ReadCloser"
|
||||
}
|
||||
|
||||
s.API.imports["io"] = true
|
||||
return rtype
|
||||
}
|
||||
|
||||
if ref.JSONValue {
|
||||
s.API.imports["github.com/aws/aws-sdk-go/aws"] = true
|
||||
return "aws.JSONValue"
|
||||
}
|
||||
|
||||
for _, v := range s.Validations {
|
||||
// TODO move this to shape validation resolution
|
||||
if (v.Ref.Shape.Type == "map" || v.Ref.Shape.Type == "list") && v.Type == ShapeValidationNested {
|
||||
s.API.imports["fmt"] = true
|
||||
}
|
||||
}
|
||||
|
||||
return ref.GoType()
|
||||
}
|
||||
|
||||
// GoType returns a shape's Go type
|
||||
func (s *Shape) GoType() string {
|
||||
return goType(s, false)
|
||||
}
|
||||
|
||||
// GoType returns a shape ref's Go type.
|
||||
func (ref *ShapeRef) GoType() string {
|
||||
if ref.Shape == nil {
|
||||
panic(fmt.Errorf("missing shape definition on reference for %#v", ref))
|
||||
}
|
||||
|
||||
return ref.Shape.GoType()
|
||||
}
|
||||
|
||||
// GoTypeWithPkgName returns a shape's type as a string with the package name in
|
||||
// <packageName>.<type> format. Package naming only applies to structures.
|
||||
func (ref *ShapeRef) GoTypeWithPkgName() string {
|
||||
if ref.Shape == nil {
|
||||
panic(fmt.Errorf("missing shape definition on reference for %#v", ref))
|
||||
}
|
||||
|
||||
return ref.Shape.GoTypeWithPkgName()
|
||||
}
|
||||
|
||||
// Returns a string version of the Shape's type.
|
||||
// If withPkgName is true, the package name will be added as a prefix
|
||||
func goType(s *Shape, withPkgName bool) string {
|
||||
switch s.Type {
|
||||
case "structure":
|
||||
if withPkgName || s.resolvePkg != "" {
|
||||
pkg := s.resolvePkg
|
||||
if pkg != "" {
|
||||
s.API.imports[pkg] = true
|
||||
pkg = path.Base(pkg)
|
||||
} else {
|
||||
pkg = s.API.PackageName()
|
||||
}
|
||||
return fmt.Sprintf("*%s.%s", pkg, s.ShapeName)
|
||||
}
|
||||
return "*" + s.ShapeName
|
||||
case "map":
|
||||
return "map[string]" + s.ValueRef.GoType()
|
||||
case "jsonvalue":
|
||||
return "aws.JSONValue"
|
||||
case "list":
|
||||
return "[]" + s.MemberRef.GoType()
|
||||
case "boolean":
|
||||
return "*bool"
|
||||
case "string", "character":
|
||||
return "*string"
|
||||
case "blob":
|
||||
return "[]byte"
|
||||
case "integer", "long":
|
||||
return "*int64"
|
||||
case "float", "double":
|
||||
return "*float64"
|
||||
case "timestamp":
|
||||
s.API.imports["time"] = true
|
||||
return "*time.Time"
|
||||
default:
|
||||
panic("Unsupported shape type: " + s.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// GoTypeElem returns the Go type for the Shape. If the shape type is a pointer just
|
||||
// the type will be returned minus the pointer *.
|
||||
func (s *Shape) GoTypeElem() string {
|
||||
t := s.GoType()
|
||||
if strings.HasPrefix(t, "*") {
|
||||
return t[1:]
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// GoTypeElem returns the Go type for the Shape. If the shape type is a pointer just
|
||||
// the type will be returned minus the pointer *.
|
||||
func (ref *ShapeRef) GoTypeElem() string {
|
||||
if ref.Shape == nil {
|
||||
panic(fmt.Errorf("missing shape definition on reference for %#v", ref))
|
||||
}
|
||||
|
||||
return ref.Shape.GoTypeElem()
|
||||
}
|
||||
|
||||
// ShapeTag is a struct tag that will be applied to a shape's generated code
|
||||
type ShapeTag struct {
|
||||
Key, Val string
|
||||
}
|
||||
|
||||
// String returns the string representation of the shape tag
|
||||
func (s ShapeTag) String() string {
|
||||
return fmt.Sprintf(`%s:"%s"`, s.Key, s.Val)
|
||||
}
|
||||
|
||||
// ShapeTags is a collection of shape tags and provides serialization of the
|
||||
// tags in an ordered list.
|
||||
type ShapeTags []ShapeTag
|
||||
|
||||
// Join returns an ordered serialization of the shape tags with the provided
|
||||
// separator.
|
||||
func (s ShapeTags) Join(sep string) string {
|
||||
o := &bytes.Buffer{}
|
||||
for i, t := range s {
|
||||
o.WriteString(t.String())
|
||||
if i < len(s)-1 {
|
||||
o.WriteString(sep)
|
||||
}
|
||||
}
|
||||
|
||||
return o.String()
|
||||
}
|
||||
|
||||
// String is an alias for Join with the empty space separator.
|
||||
func (s ShapeTags) String() string {
|
||||
return s.Join(" ")
|
||||
}
|
||||
|
||||
// GoTags returns the rendered tags string for the ShapeRef
|
||||
func (ref *ShapeRef) GoTags(toplevel bool, isRequired bool) string {
|
||||
tags := ShapeTags{}
|
||||
|
||||
if ref.Location != "" {
|
||||
tags = append(tags, ShapeTag{"location", ref.Location})
|
||||
} else if ref.Shape.Location != "" {
|
||||
tags = append(tags, ShapeTag{"location", ref.Shape.Location})
|
||||
}
|
||||
|
||||
if ref.LocationName != "" {
|
||||
tags = append(tags, ShapeTag{"locationName", ref.LocationName})
|
||||
} else if ref.Shape.LocationName != "" {
|
||||
tags = append(tags, ShapeTag{"locationName", ref.Shape.LocationName})
|
||||
}
|
||||
|
||||
if ref.QueryName != "" {
|
||||
tags = append(tags, ShapeTag{"queryName", ref.QueryName})
|
||||
}
|
||||
if ref.Shape.MemberRef.LocationName != "" {
|
||||
tags = append(tags, ShapeTag{"locationNameList", ref.Shape.MemberRef.LocationName})
|
||||
}
|
||||
if ref.Shape.KeyRef.LocationName != "" {
|
||||
tags = append(tags, ShapeTag{"locationNameKey", ref.Shape.KeyRef.LocationName})
|
||||
}
|
||||
if ref.Shape.ValueRef.LocationName != "" {
|
||||
tags = append(tags, ShapeTag{"locationNameValue", ref.Shape.ValueRef.LocationName})
|
||||
}
|
||||
if ref.Shape.Min > 0 {
|
||||
tags = append(tags, ShapeTag{"min", fmt.Sprintf("%v", ref.Shape.Min)})
|
||||
}
|
||||
|
||||
if ref.Deprecated || ref.Shape.Deprecated {
|
||||
tags = append(tags, ShapeTag{"deprecated", "true"})
|
||||
}
|
||||
|
||||
// All shapes have a type
|
||||
tags = append(tags, ShapeTag{"type", ref.Shape.Type})
|
||||
|
||||
// embed the timestamp type for easier lookups
|
||||
if ref.Shape.Type == "timestamp" {
|
||||
t := ShapeTag{Key: "timestampFormat"}
|
||||
if ref.Location == "header" {
|
||||
t.Val = "rfc822"
|
||||
} else {
|
||||
switch ref.API.Metadata.Protocol {
|
||||
case "json", "rest-json":
|
||||
t.Val = "unix"
|
||||
case "rest-xml", "ec2", "query":
|
||||
t.Val = "iso8601"
|
||||
}
|
||||
}
|
||||
tags = append(tags, t)
|
||||
}
|
||||
|
||||
if ref.Shape.Flattened || ref.Flattened {
|
||||
tags = append(tags, ShapeTag{"flattened", "true"})
|
||||
}
|
||||
if ref.XMLAttribute {
|
||||
tags = append(tags, ShapeTag{"xmlAttribute", "true"})
|
||||
}
|
||||
if isRequired {
|
||||
tags = append(tags, ShapeTag{"required", "true"})
|
||||
}
|
||||
if ref.Shape.IsEnum() {
|
||||
tags = append(tags, ShapeTag{"enum", ref.ShapeName})
|
||||
}
|
||||
|
||||
if toplevel {
|
||||
if ref.Shape.Payload != "" {
|
||||
tags = append(tags, ShapeTag{"payload", ref.Shape.Payload})
|
||||
}
|
||||
if ref.XMLNamespace.Prefix != "" {
|
||||
tags = append(tags, ShapeTag{"xmlPrefix", ref.XMLNamespace.Prefix})
|
||||
} else if ref.Shape.XMLNamespace.Prefix != "" {
|
||||
tags = append(tags, ShapeTag{"xmlPrefix", ref.Shape.XMLNamespace.Prefix})
|
||||
}
|
||||
if ref.XMLNamespace.URI != "" {
|
||||
tags = append(tags, ShapeTag{"xmlURI", ref.XMLNamespace.URI})
|
||||
} else if ref.Shape.XMLNamespace.URI != "" {
|
||||
tags = append(tags, ShapeTag{"xmlURI", ref.Shape.XMLNamespace.URI})
|
||||
}
|
||||
}
|
||||
|
||||
if ref.IdempotencyToken || ref.Shape.IdempotencyToken {
|
||||
tags = append(tags, ShapeTag{"idempotencyToken", "true"})
|
||||
}
|
||||
|
||||
if ref.Ignore {
|
||||
tags = append(tags, ShapeTag{"ignore", "true"})
|
||||
}
|
||||
|
||||
return fmt.Sprintf("`%s`", tags)
|
||||
}
|
||||
|
||||
// Docstring returns the godocs formated documentation
|
||||
func (ref *ShapeRef) Docstring() string {
|
||||
if ref.Documentation != "" {
|
||||
return strings.Trim(ref.Documentation, "\n ")
|
||||
}
|
||||
return ref.Shape.Docstring()
|
||||
}
|
||||
|
||||
// Docstring returns the godocs formated documentation
|
||||
func (s *Shape) Docstring() string {
|
||||
return strings.Trim(s.Documentation, "\n ")
|
||||
}
|
||||
|
||||
// IndentedDocstring is the indented form of the doc string.
|
||||
func (ref *ShapeRef) IndentedDocstring() string {
|
||||
doc := ref.Docstring()
|
||||
return strings.Replace(doc, "// ", "// ", -1)
|
||||
}
|
||||
|
||||
var goCodeStringerTmpl = template.Must(template.New("goCodeStringerTmpl").Parse(`
|
||||
// String returns the string representation
|
||||
func (s {{ .ShapeName }}) String() string {
|
||||
return awsutil.Prettify(s)
|
||||
}
|
||||
// GoString returns the string representation
|
||||
func (s {{ .ShapeName }}) GoString() string {
|
||||
return s.String()
|
||||
}
|
||||
`))
|
||||
|
||||
// GoCodeStringers renders the Stringers for API input/output shapes
|
||||
func (s *Shape) GoCodeStringers() string {
|
||||
w := bytes.Buffer{}
|
||||
if err := goCodeStringerTmpl.Execute(&w, s); err != nil {
|
||||
panic(fmt.Sprintln("Unexpected error executing GoCodeStringers template", err))
|
||||
}
|
||||
|
||||
return w.String()
|
||||
}
|
||||
|
||||
var enumStrip = regexp.MustCompile(`[^a-zA-Z0-9_:\./-]`)
|
||||
var enumDelims = regexp.MustCompile(`[-_:\./]+`)
|
||||
var enumCamelCase = regexp.MustCompile(`([a-z])([A-Z])`)
|
||||
|
||||
// EnumName returns the Nth enum in the shapes Enum list
|
||||
func (s *Shape) EnumName(n int) string {
|
||||
enum := s.Enum[n]
|
||||
enum = enumStrip.ReplaceAllLiteralString(enum, "")
|
||||
enum = enumCamelCase.ReplaceAllString(enum, "$1-$2")
|
||||
parts := enumDelims.Split(enum, -1)
|
||||
for i, v := range parts {
|
||||
v = strings.ToLower(v)
|
||||
parts[i] = ""
|
||||
if len(v) > 0 {
|
||||
parts[i] = strings.ToUpper(v[0:1])
|
||||
}
|
||||
if len(v) > 1 {
|
||||
parts[i] += v[1:]
|
||||
}
|
||||
}
|
||||
enum = strings.Join(parts, "")
|
||||
enum = strings.ToUpper(enum[0:1]) + enum[1:]
|
||||
return enum
|
||||
}
|
||||
|
||||
// NestedShape returns the shape pointer value for the shape which is nested
|
||||
// under the current shape. If the shape is not nested nil will be returned.
|
||||
//
|
||||
// strucutures, the current shape is returned
|
||||
// map: the value shape of the map is returned
|
||||
// list: the element shape of the list is returned
|
||||
func (s *Shape) NestedShape() *Shape {
|
||||
var nestedShape *Shape
|
||||
switch s.Type {
|
||||
case "structure":
|
||||
nestedShape = s
|
||||
case "map":
|
||||
nestedShape = s.ValueRef.Shape
|
||||
case "list":
|
||||
nestedShape = s.MemberRef.Shape
|
||||
}
|
||||
|
||||
return nestedShape
|
||||
}
|
||||
|
||||
var structShapeTmpl = template.Must(template.New("StructShape").Funcs(template.FuncMap{
|
||||
"GetCrosslinkURL": GetCrosslinkURL,
|
||||
}).Parse(`
|
||||
{{ .Docstring }}
|
||||
{{ if ne $.OrigShapeName "" -}}
|
||||
{{ $crosslinkURL := GetCrosslinkURL $.API.BaseCrosslinkURL $.API.APIName $.API.Metadata.UID $.OrigShapeName -}}
|
||||
{{ if ne $crosslinkURL "" -}}
|
||||
// Please also see {{ $crosslinkURL }}
|
||||
{{ end -}}
|
||||
{{ else -}}
|
||||
{{ $crosslinkURL := GetCrosslinkURL $.API.BaseCrosslinkURL $.API.APIName $.API.Metadata.UID $.ShapeName -}}
|
||||
{{ if ne $crosslinkURL "" -}}
|
||||
// Please also see {{ $crosslinkURL }}
|
||||
{{ end -}}
|
||||
{{ end -}}
|
||||
{{ $context := . -}}
|
||||
type {{ .ShapeName }} struct {
|
||||
_ struct{} {{ .GoTags true false }}
|
||||
|
||||
{{ range $_, $name := $context.MemberNames -}}
|
||||
{{ $elem := index $context.MemberRefs $name -}}
|
||||
{{ $isRequired := $context.IsRequired $name -}}
|
||||
{{ $doc := $elem.Docstring -}}
|
||||
|
||||
{{ $doc }}
|
||||
{{ if $isRequired -}}
|
||||
{{ if $doc -}}
|
||||
//
|
||||
{{ end -}}
|
||||
// {{ $name }} is a required field
|
||||
{{ end -}}
|
||||
{{ $name }} {{ $context.GoStructType $name $elem }} {{ $elem.GoTags false $isRequired }}
|
||||
|
||||
{{ end }}
|
||||
}
|
||||
{{ if not .API.NoStringerMethods }}
|
||||
{{ .GoCodeStringers }}
|
||||
{{ end }}
|
||||
{{ if not .API.NoValidataShapeMethods }}
|
||||
{{ if .Validations -}}
|
||||
{{ .Validations.GoCode . }}
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
|
||||
{{ if not .API.NoGenStructFieldAccessors }}
|
||||
|
||||
{{ $builderShapeName := print .ShapeName -}}
|
||||
|
||||
{{ range $_, $name := $context.MemberNames -}}
|
||||
{{ $elem := index $context.MemberRefs $name -}}
|
||||
|
||||
// Set{{ $name }} sets the {{ $name }} field's value.
|
||||
func (s *{{ $builderShapeName }}) Set{{ $name }}(v {{ $context.GoStructValueType $name $elem }}) *{{ $builderShapeName }} {
|
||||
{{ if $elem.UseIndirection -}}
|
||||
s.{{ $name }} = &v
|
||||
{{ else -}}
|
||||
s.{{ $name }} = v
|
||||
{{ end -}}
|
||||
return s
|
||||
}
|
||||
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`))
|
||||
|
||||
var enumShapeTmpl = template.Must(template.New("EnumShape").Parse(`
|
||||
{{ .Docstring }}
|
||||
const (
|
||||
{{ $context := . -}}
|
||||
{{ range $index, $elem := .Enum -}}
|
||||
{{ $name := index $context.EnumConsts $index -}}
|
||||
// {{ $name }} is a {{ $context.ShapeName }} enum value
|
||||
{{ $name }} = "{{ $elem }}"
|
||||
|
||||
{{ end }}
|
||||
)
|
||||
`))
|
||||
|
||||
// GoCode returns the rendered Go code for the Shape.
|
||||
func (s *Shape) GoCode() string {
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
switch {
|
||||
case s.Type == "structure":
|
||||
if err := structShapeTmpl.Execute(b, s); err != nil {
|
||||
panic(fmt.Sprintf("Failed to generate struct shape %s, %v\n", s.ShapeName, err))
|
||||
}
|
||||
case s.IsEnum():
|
||||
if err := enumShapeTmpl.Execute(b, s); err != nil {
|
||||
panic(fmt.Sprintf("Failed to generate enum shape %s, %v\n", s.ShapeName, err))
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintln("Cannot generate toplevel shape for", s.Type))
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// IsEnum returns whether this shape is an enum list
|
||||
func (s *Shape) IsEnum() bool {
|
||||
return s.Type == "string" && len(s.Enum) > 0
|
||||
}
|
||||
|
||||
// IsRequired returns if member is a required field.
|
||||
func (s *Shape) IsRequired(member string) bool {
|
||||
for _, n := range s.Required {
|
||||
if n == member {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsInternal returns whether the shape was defined in this package
|
||||
func (s *Shape) IsInternal() bool {
|
||||
return s.resolvePkg == ""
|
||||
}
|
||||
|
||||
// removeRef removes a shape reference from the list of references this
|
||||
// shape is used in.
|
||||
func (s *Shape) removeRef(ref *ShapeRef) {
|
||||
r := s.refs
|
||||
for i := 0; i < len(r); i++ {
|
||||
if r[i] == ref {
|
||||
j := i + 1
|
||||
copy(r[i:], r[j:])
|
||||
for k, n := len(r)-j+i, len(r); k < n; k++ {
|
||||
r[k] = nil // free up the end of the list
|
||||
} // for k
|
||||
s.refs = r[:len(r)-j+i]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
+155
@@ -0,0 +1,155 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// A ShapeValidationType is the type of validation that a shape needs
|
||||
type ShapeValidationType int
|
||||
|
||||
const (
|
||||
// ShapeValidationRequired states the shape must be set
|
||||
ShapeValidationRequired = iota
|
||||
|
||||
// ShapeValidationMinVal states the shape must have at least a number of
|
||||
// elements, or for numbers a minimum value
|
||||
ShapeValidationMinVal
|
||||
|
||||
// ShapeValidationNested states the shape has nested values that need
|
||||
// to be validated
|
||||
ShapeValidationNested
|
||||
)
|
||||
|
||||
// A ShapeValidation contains information about a shape and the type of validation
|
||||
// that is needed
|
||||
type ShapeValidation struct {
|
||||
// Name of the shape to be validated
|
||||
Name string
|
||||
// Reference to the shape within the context the shape is referenced
|
||||
Ref *ShapeRef
|
||||
// Type of validation needed
|
||||
Type ShapeValidationType
|
||||
}
|
||||
|
||||
var validationGoCodeTmpls = template.Must(template.New("validationGoCodeTmpls").Parse(`
|
||||
{{ define "requiredValue" -}}
|
||||
if s.{{ .Name }} == nil {
|
||||
invalidParams.Add(request.NewErrParamRequired("{{ .Name }}"))
|
||||
}
|
||||
{{- end }}
|
||||
{{ define "minLen" -}}
|
||||
if s.{{ .Name }} != nil && len(s.{{ .Name }}) < {{ .Ref.Shape.Min }} {
|
||||
invalidParams.Add(request.NewErrParamMinLen("{{ .Name }}", {{ .Ref.Shape.Min }}))
|
||||
}
|
||||
{{- end }}
|
||||
{{ define "minLenString" -}}
|
||||
if s.{{ .Name }} != nil && len(*s.{{ .Name }}) < {{ .Ref.Shape.Min }} {
|
||||
invalidParams.Add(request.NewErrParamMinLen("{{ .Name }}", {{ .Ref.Shape.Min }}))
|
||||
}
|
||||
{{- end }}
|
||||
{{ define "minVal" -}}
|
||||
if s.{{ .Name }} != nil && *s.{{ .Name }} < {{ .Ref.Shape.Min }} {
|
||||
invalidParams.Add(request.NewErrParamMinValue("{{ .Name }}", {{ .Ref.Shape.Min }}))
|
||||
}
|
||||
{{- end }}
|
||||
{{ define "nestedMapList" -}}
|
||||
if s.{{ .Name }} != nil {
|
||||
for i, v := range s.{{ .Name }} {
|
||||
if v == nil { continue }
|
||||
if err := v.Validate(); err != nil {
|
||||
invalidParams.AddNested(fmt.Sprintf("%s[%v]", "{{ .Name }}", i), err.(request.ErrInvalidParams))
|
||||
}
|
||||
}
|
||||
}
|
||||
{{- end }}
|
||||
{{ define "nestedStruct" -}}
|
||||
if s.{{ .Name }} != nil {
|
||||
if err := s.{{ .Name }}.Validate(); err != nil {
|
||||
invalidParams.AddNested("{{ .Name }}", err.(request.ErrInvalidParams))
|
||||
}
|
||||
}
|
||||
{{- end }}
|
||||
`))
|
||||
|
||||
// GoCode returns the generated Go code for the Shape with its validation type.
|
||||
func (sv ShapeValidation) GoCode() string {
|
||||
var err error
|
||||
|
||||
w := &bytes.Buffer{}
|
||||
switch sv.Type {
|
||||
case ShapeValidationRequired:
|
||||
err = validationGoCodeTmpls.ExecuteTemplate(w, "requiredValue", sv)
|
||||
case ShapeValidationMinVal:
|
||||
switch sv.Ref.Shape.Type {
|
||||
case "list", "map", "blob":
|
||||
err = validationGoCodeTmpls.ExecuteTemplate(w, "minLen", sv)
|
||||
case "string":
|
||||
err = validationGoCodeTmpls.ExecuteTemplate(w, "minLenString", sv)
|
||||
case "integer", "long", "float", "double":
|
||||
err = validationGoCodeTmpls.ExecuteTemplate(w, "minVal", sv)
|
||||
default:
|
||||
panic(fmt.Sprintf("ShapeValidation.GoCode, %s's type %s, no min value handling",
|
||||
sv.Name, sv.Ref.Shape.Type))
|
||||
}
|
||||
case ShapeValidationNested:
|
||||
switch sv.Ref.Shape.Type {
|
||||
case "map", "list":
|
||||
err = validationGoCodeTmpls.ExecuteTemplate(w, "nestedMapList", sv)
|
||||
default:
|
||||
err = validationGoCodeTmpls.ExecuteTemplate(w, "nestedStruct", sv)
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("ShapeValidation.GoCode, %s's type %d, unknown validation type",
|
||||
sv.Name, sv.Type))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ShapeValidation.GoCode failed, err: %v", err))
|
||||
}
|
||||
|
||||
return w.String()
|
||||
}
|
||||
|
||||
// A ShapeValidations is a collection of shape validations needed nested within
|
||||
// a parent shape
|
||||
type ShapeValidations []ShapeValidation
|
||||
|
||||
var validateShapeTmpl = template.Must(template.New("ValidateShape").Parse(`
|
||||
// Validate inspects the fields of the type to determine if they are valid.
|
||||
func (s *{{ .Shape.ShapeName }}) Validate() error {
|
||||
invalidParams := request.ErrInvalidParams{Context: "{{ .Shape.ShapeName }}"}
|
||||
{{ range $_, $v := .Validations -}}
|
||||
{{ $v.GoCode }}
|
||||
{{ end }}
|
||||
if invalidParams.Len() > 0 {
|
||||
return invalidParams
|
||||
}
|
||||
return nil
|
||||
}
|
||||
`))
|
||||
|
||||
// GoCode generates the Go code needed to perform validations for the
|
||||
// shape and its nested fields.
|
||||
func (vs ShapeValidations) GoCode(shape *Shape) string {
|
||||
buf := &bytes.Buffer{}
|
||||
validateShapeTmpl.Execute(buf, map[string]interface{}{
|
||||
"Shape": shape,
|
||||
"Validations": vs,
|
||||
})
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Has returns true or false if the ShapeValidations already contains the
|
||||
// the reference and validation type.
|
||||
func (vs ShapeValidations) Has(ref *ShapeRef, typ ShapeValidationType) bool {
|
||||
for _, v := range vs {
|
||||
if v.Ref == ref && v.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
+25
@@ -0,0 +1,25 @@
|
||||
// +build 1.6,codegen
|
||||
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/model/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShapeTagJoin(t *testing.T) {
|
||||
s := api.ShapeTags{
|
||||
{Key: "location", Val: "query"},
|
||||
{Key: "locationName", Val: "abc"},
|
||||
{Key: "type", Val: "string"},
|
||||
}
|
||||
|
||||
expected := `location:"query" locationName:"abc" type:"string"`
|
||||
|
||||
o := s.Join(" ")
|
||||
o2 := s.String()
|
||||
assert.Equal(t, expected, o)
|
||||
assert.Equal(t, expected, o2)
|
||||
}
|
||||
+184
@@ -0,0 +1,184 @@
|
||||
// +build codegen
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// WaiterAcceptor is the acceptors defined in the model the SDK will use
|
||||
// to wait on resource states with.
|
||||
type WaiterAcceptor struct {
|
||||
State string
|
||||
Matcher string
|
||||
Argument string
|
||||
Expected interface{}
|
||||
}
|
||||
|
||||
// ExpectedString returns the string that was expected by the WaiterAcceptor
|
||||
func (a *WaiterAcceptor) ExpectedString() string {
|
||||
switch a.Expected.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf("%q", a.Expected)
|
||||
default:
|
||||
return fmt.Sprintf("%v", a.Expected)
|
||||
}
|
||||
}
|
||||
|
||||
// A Waiter is an individual waiter definition.
|
||||
type Waiter struct {
|
||||
Name string
|
||||
Delay int
|
||||
MaxAttempts int
|
||||
OperationName string `json:"operation"`
|
||||
Operation *Operation
|
||||
Acceptors []WaiterAcceptor
|
||||
}
|
||||
|
||||
// WaitersGoCode generates and returns Go code for each of the waiters of
|
||||
// this API.
|
||||
func (a *API) WaitersGoCode() string {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "import (\n%q\n\n%q\n%q\n)",
|
||||
"time",
|
||||
"github.com/aws/aws-sdk-go/aws",
|
||||
"github.com/aws/aws-sdk-go/aws/request",
|
||||
)
|
||||
|
||||
for _, w := range a.Waiters {
|
||||
buf.WriteString(w.GoCode())
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// used for unmarshaling from the waiter JSON file
|
||||
type waiterDefinitions struct {
|
||||
*API
|
||||
Waiters map[string]Waiter
|
||||
}
|
||||
|
||||
// AttachWaiters reads a file of waiter definitions, and adds those to the API.
|
||||
// Will panic if an error occurs.
|
||||
func (a *API) AttachWaiters(filename string) {
|
||||
p := waiterDefinitions{API: a}
|
||||
|
||||
f, err := os.Open(filename)
|
||||
defer f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = json.NewDecoder(f).Decode(&p)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
p.setup()
|
||||
}
|
||||
|
||||
func (p *waiterDefinitions) setup() {
|
||||
p.API.Waiters = []Waiter{}
|
||||
i, keys := 0, make([]string, len(p.Waiters))
|
||||
for k := range p.Waiters {
|
||||
keys[i] = k
|
||||
i++
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, n := range keys {
|
||||
e := p.Waiters[n]
|
||||
n = p.ExportableName(n)
|
||||
e.Name = n
|
||||
e.OperationName = p.ExportableName(e.OperationName)
|
||||
e.Operation = p.API.Operations[e.OperationName]
|
||||
if e.Operation == nil {
|
||||
panic("unknown operation " + e.OperationName + " for waiter " + n)
|
||||
}
|
||||
p.API.Waiters = append(p.API.Waiters, e)
|
||||
}
|
||||
}
|
||||
|
||||
var waiterTmpls = template.Must(template.New("waiterTmpls").Funcs(
|
||||
template.FuncMap{
|
||||
"titleCase": func(v string) string {
|
||||
return strings.Title(v)
|
||||
},
|
||||
},
|
||||
).Parse(`
|
||||
{{ define "waiter"}}
|
||||
// WaitUntil{{ .Name }} uses the {{ .Operation.API.NiceName }} API operation
|
||||
// {{ .OperationName }} to wait for a condition to be met before returning.
|
||||
// If the condition is not meet within the max attempt window an error will
|
||||
// be returned.
|
||||
func (c *{{ .Operation.API.StructName }}) WaitUntil{{ .Name }}(input {{ .Operation.InputRef.GoType }}) error {
|
||||
return c.WaitUntil{{ .Name }}WithContext(aws.BackgroundContext(), input)
|
||||
}
|
||||
|
||||
// WaitUntil{{ .Name }}WithContext is an extended version of WaitUntil{{ .Name }}.
|
||||
// With the support for passing in a context and options to configure the
|
||||
// Waiter and the underlying request options.
|
||||
//
|
||||
// The context must be non-nil and will be used for request cancellation. If
|
||||
// the context is nil a panic will occur. In the future the SDK may create
|
||||
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
|
||||
// for more information on using Contexts.
|
||||
func (c *{{ .Operation.API.StructName }}) WaitUntil{{ .Name }}WithContext(` +
|
||||
`ctx aws.Context, input {{ .Operation.InputRef.GoType }}, opts ...request.WaiterOption) error {
|
||||
w := request.Waiter{
|
||||
Name: "WaitUntil{{ .Name }}",
|
||||
MaxAttempts: {{ .MaxAttempts }},
|
||||
Delay: request.ConstantWaiterDelay({{ .Delay }} * time.Second),
|
||||
Acceptors: []request.WaiterAcceptor{
|
||||
{{ range $_, $a := .Acceptors }}{
|
||||
State: request.{{ titleCase .State }}WaiterState,
|
||||
Matcher: request.{{ titleCase .Matcher }}WaiterMatch,
|
||||
{{- if .Argument }}Argument: "{{ .Argument }}",{{ end }}
|
||||
Expected: {{ .ExpectedString }},
|
||||
},
|
||||
{{ end }}
|
||||
},
|
||||
Logger: c.Config.Logger,
|
||||
NewRequest: func(opts []request.Option) (*request.Request, error) {
|
||||
req, _ := c.{{ .OperationName }}Request(input)
|
||||
req.SetContext(ctx)
|
||||
req.ApplyOptions(opts...)
|
||||
return req, nil
|
||||
},
|
||||
}
|
||||
w.ApplyOptions(opts...)
|
||||
|
||||
return w.WaitWithContext(ctx)
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
{{ define "waiter interface" }}
|
||||
WaitUntil{{ .Name }}({{ .Operation.InputRef.GoTypeWithPkgName }}) error
|
||||
WaitUntil{{ .Name }}WithContext(aws.Context, {{ .Operation.InputRef.GoTypeWithPkgName }}, ...request.WaiterOption) error
|
||||
{{- end }}
|
||||
`))
|
||||
|
||||
// InterfaceSignature returns a string representing the Waiter's interface
|
||||
// function signature.
|
||||
func (w *Waiter) InterfaceSignature() string {
|
||||
var buf bytes.Buffer
|
||||
if err := waiterTmpls.ExecuteTemplate(&buf, "waiter interface", w); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(buf.String())
|
||||
}
|
||||
|
||||
// GoCode returns the generated Go code for an individual waiter.
|
||||
func (w *Waiter) GoCode() string {
|
||||
var buf bytes.Buffer
|
||||
if err := waiterTmpls.ExecuteTemplate(&buf, "waiter", w); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
+29
@@ -0,0 +1,29 @@
|
||||
// +build codegen
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/model/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
dir, _ := os.Open(filepath.Join("models", "apis"))
|
||||
names, _ := dir.Readdirnames(0)
|
||||
for _, name := range names {
|
||||
m, _ := filepath.Glob(filepath.Join("models", "apis", name, "*", "api-2.json"))
|
||||
if len(m) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
sort.Strings(m)
|
||||
f := m[len(m)-1]
|
||||
a := api.API{}
|
||||
a.Attach(f)
|
||||
fmt.Printf("%s\t%s\n", a.Metadata.ServiceFullName, a.Metadata.APIVersion)
|
||||
}
|
||||
}
|
||||
+274
@@ -0,0 +1,274 @@
|
||||
// +build codegen
|
||||
|
||||
// Command aws-gen-gocli parses a JSON description of an AWS API and generates a
|
||||
// Go file containing a client for the API.
|
||||
//
|
||||
// aws-gen-gocli apis/s3/2006-03-03/api-2.json
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/model/api"
|
||||
"github.com/aws/aws-sdk-go/private/util"
|
||||
)
|
||||
|
||||
type generateInfo struct {
|
||||
*api.API
|
||||
PackageDir string
|
||||
}
|
||||
|
||||
var excludeServices = map[string]struct{}{
|
||||
"importexport": {},
|
||||
}
|
||||
|
||||
// newGenerateInfo initializes the service API's folder structure for a specific service.
|
||||
// If the SERVICES environment variable is set, and this service is not apart of the list
|
||||
// this service will be skipped.
|
||||
func newGenerateInfo(modelFile, svcPath, svcImportPath string) *generateInfo {
|
||||
g := &generateInfo{API: &api.API{SvcClientImportPath: svcImportPath, BaseCrosslinkURL: "https://docs.aws.amazon.com"}}
|
||||
g.API.Attach(modelFile)
|
||||
|
||||
if _, ok := excludeServices[g.API.PackageName()]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
paginatorsFile := strings.Replace(modelFile, "api-2.json", "paginators-1.json", -1)
|
||||
if _, err := os.Stat(paginatorsFile); err == nil {
|
||||
g.API.AttachPaginators(paginatorsFile)
|
||||
} else if !os.IsNotExist(err) {
|
||||
fmt.Println("api-2.json error:", err)
|
||||
}
|
||||
|
||||
docsFile := strings.Replace(modelFile, "api-2.json", "docs-2.json", -1)
|
||||
if _, err := os.Stat(docsFile); err == nil {
|
||||
g.API.AttachDocs(docsFile)
|
||||
} else {
|
||||
fmt.Println("docs-2.json error:", err)
|
||||
}
|
||||
|
||||
waitersFile := strings.Replace(modelFile, "api-2.json", "waiters-2.json", -1)
|
||||
if _, err := os.Stat(waitersFile); err == nil {
|
||||
g.API.AttachWaiters(waitersFile)
|
||||
} else if !os.IsNotExist(err) {
|
||||
fmt.Println("waiters-2.json error:", err)
|
||||
}
|
||||
|
||||
g.API.Setup()
|
||||
|
||||
if svc := os.Getenv("SERVICES"); svc != "" {
|
||||
svcs := strings.Split(svc, ",")
|
||||
|
||||
included := false
|
||||
for _, s := range svcs {
|
||||
if s == g.API.PackageName() {
|
||||
included = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !included {
|
||||
// skip this non-included service
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ensure the directory exists
|
||||
pkgDir := filepath.Join(svcPath, g.API.PackageName())
|
||||
os.MkdirAll(pkgDir, 0775)
|
||||
os.MkdirAll(filepath.Join(pkgDir, g.API.InterfacePackageName()), 0775)
|
||||
|
||||
g.PackageDir = pkgDir
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
// Generates service api, examples, and interface from api json definition files.
|
||||
//
|
||||
// Flags:
|
||||
// -path alternative service path to write generated files to for each service.
|
||||
//
|
||||
// Env:
|
||||
// SERVICES comma separated list of services to generate.
|
||||
func main() {
|
||||
var svcPath, sessionPath, svcImportPath string
|
||||
flag.StringVar(&svcPath, "path", "service", "directory to generate service clients in")
|
||||
flag.StringVar(&sessionPath, "sessionPath", filepath.Join("aws", "session"), "generate session service client factories")
|
||||
flag.StringVar(&svcImportPath, "svc-import-path", "github.com/aws/aws-sdk-go/service", "namespace to generate service client Go code import path under")
|
||||
flag.Parse()
|
||||
api.Bootstrap()
|
||||
|
||||
files := []string{}
|
||||
for i := 0; i < flag.NArg(); i++ {
|
||||
file := flag.Arg(i)
|
||||
if strings.Contains(file, "*") {
|
||||
paths, _ := filepath.Glob(file)
|
||||
files = append(files, paths...)
|
||||
} else {
|
||||
files = append(files, file)
|
||||
}
|
||||
}
|
||||
|
||||
for svcName := range excludeServices {
|
||||
if strings.Contains(os.Getenv("SERVICES"), svcName) {
|
||||
fmt.Printf("Service %s is not supported\n", svcName)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(files)
|
||||
|
||||
// Remove old API versions from list
|
||||
m := map[string]bool{}
|
||||
for i := range files {
|
||||
idx := len(files) - 1 - i
|
||||
parts := strings.Split(files[idx], string(filepath.Separator))
|
||||
svc := parts[len(parts)-3] // service name is 2nd-to-last component
|
||||
|
||||
if m[svc] {
|
||||
files[idx] = "" // wipe this one out if we already saw the service
|
||||
}
|
||||
m[svc] = true
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := range files {
|
||||
filename := files[i]
|
||||
if filename == "" { // empty file
|
||||
continue
|
||||
}
|
||||
|
||||
genInfo := newGenerateInfo(filename, svcPath, svcImportPath)
|
||||
if genInfo == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := excludeServices[genInfo.API.PackageName()]; ok {
|
||||
// Skip services not yet supported.
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(g *generateInfo, filename string) {
|
||||
defer wg.Done()
|
||||
writeServiceFiles(g, filename)
|
||||
}(genInfo, filename)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func writeServiceFiles(g *generateInfo, filename string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating %s\n%s\n%s\n",
|
||||
filename, r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Printf("Generating %s (%s)...\n",
|
||||
g.API.PackageName(), g.API.Metadata.APIVersion)
|
||||
|
||||
// write api.go and service.go files
|
||||
Must(writeAPIFile(g))
|
||||
Must(writeExamplesFile(g))
|
||||
Must(writeServiceFile(g))
|
||||
Must(writeInterfaceFile(g))
|
||||
Must(writeWaitersFile(g))
|
||||
Must(writeAPIErrorsFile(g))
|
||||
}
|
||||
|
||||
// Must will panic if the error passed in is not nil.
|
||||
func Must(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
const codeLayout = `// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
|
||||
%s
|
||||
package %s
|
||||
|
||||
%s
|
||||
`
|
||||
|
||||
func writeGoFile(file string, layout string, args ...interface{}) error {
|
||||
return ioutil.WriteFile(file, []byte(util.GoFmt(fmt.Sprintf(layout, args...))), 0664)
|
||||
}
|
||||
|
||||
// writeExamplesFile writes out the service example file.
|
||||
func writeExamplesFile(g *generateInfo) error {
|
||||
return writeGoFile(filepath.Join(g.PackageDir, "examples_test.go"),
|
||||
codeLayout,
|
||||
"",
|
||||
g.API.PackageName()+"_test",
|
||||
g.API.ExampleGoCode(),
|
||||
)
|
||||
}
|
||||
|
||||
// writeServiceFile writes out the service initialization file.
|
||||
func writeServiceFile(g *generateInfo) error {
|
||||
return writeGoFile(filepath.Join(g.PackageDir, "service.go"),
|
||||
codeLayout,
|
||||
"",
|
||||
g.API.PackageName(),
|
||||
g.API.ServiceGoCode(),
|
||||
)
|
||||
}
|
||||
|
||||
// writeInterfaceFile writes out the service interface file.
|
||||
func writeInterfaceFile(g *generateInfo) error {
|
||||
const pkgDoc = `
|
||||
// Package %s provides an interface to enable mocking the %s service client
|
||||
// for testing your code.
|
||||
//
|
||||
// It is important to note that this interface will have breaking changes
|
||||
// when the service model is updated and adds new API operations, paginators,
|
||||
// and waiters.`
|
||||
return writeGoFile(filepath.Join(g.PackageDir, g.API.InterfacePackageName(), "interface.go"),
|
||||
codeLayout,
|
||||
fmt.Sprintf(pkgDoc, g.API.InterfacePackageName(), g.API.Metadata.ServiceFullName),
|
||||
g.API.InterfacePackageName(),
|
||||
g.API.InterfaceGoCode(),
|
||||
)
|
||||
}
|
||||
|
||||
func writeWaitersFile(g *generateInfo) error {
|
||||
if len(g.API.Waiters) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return writeGoFile(filepath.Join(g.PackageDir, "waiters.go"),
|
||||
codeLayout,
|
||||
"",
|
||||
g.API.PackageName(),
|
||||
g.API.WaitersGoCode(),
|
||||
)
|
||||
}
|
||||
|
||||
// writeAPIFile writes out the service api file.
|
||||
func writeAPIFile(g *generateInfo) error {
|
||||
return writeGoFile(filepath.Join(g.PackageDir, "api.go"),
|
||||
codeLayout,
|
||||
fmt.Sprintf("\n// Package %s provides a client for %s.",
|
||||
g.API.PackageName(), g.API.Metadata.ServiceFullName),
|
||||
g.API.PackageName(),
|
||||
g.API.APIGoCode(),
|
||||
)
|
||||
}
|
||||
|
||||
// writeAPIErrorsFile writes out the service api errors file.
|
||||
func writeAPIErrorsFile(g *generateInfo) error {
|
||||
return writeGoFile(filepath.Join(g.PackageDir, "errors.go"),
|
||||
codeLayout,
|
||||
"",
|
||||
g.API.PackageName(),
|
||||
g.API.APIErrorsGoCode(),
|
||||
)
|
||||
}
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
// +build codegen
|
||||
|
||||
// Command gen-endpoints parses a JSON description of the AWS endpoint
|
||||
// discovery logic and generates a Go file which returns an endpoint.
|
||||
//
|
||||
// aws-gen-goendpoints apis/_endpoints.json aws/endpoints_map.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
)
|
||||
|
||||
// Generates the endpoints from json description
|
||||
//
|
||||
// Args:
|
||||
// -model The definition file to use
|
||||
// -out The output file to generate
|
||||
func main() {
|
||||
var modelName, outName string
|
||||
flag.StringVar(&modelName, "model", "", "Endpoints definition model")
|
||||
flag.StringVar(&outName, "out", "", "File to write generated endpoints to.")
|
||||
flag.Parse()
|
||||
|
||||
if len(modelName) == 0 || len(outName) == 0 {
|
||||
exitErrorf("model and out both required.")
|
||||
}
|
||||
|
||||
modelFile, err := os.Open(modelName)
|
||||
if err != nil {
|
||||
exitErrorf("failed to open model definition, %v.", err)
|
||||
}
|
||||
defer modelFile.Close()
|
||||
|
||||
outFile, err := os.Create(outName)
|
||||
if err != nil {
|
||||
exitErrorf("failed to open out file, %v.", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := outFile.Close(); err != nil {
|
||||
exitErrorf("failed to successfully write %q file, %v", outName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := endpoints.CodeGenModel(modelFile, outFile); err != nil {
|
||||
exitErrorf("failed to codegen model, %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func exitErrorf(msg string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, msg+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
+35
@@ -0,0 +1,35 @@
|
||||
// Package ec2query provides serialization of AWS EC2 requests and responses.
|
||||
package ec2query
|
||||
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil"
|
||||
)
|
||||
|
||||
// BuildHandler is a named request handler for building ec2query protocol requests
|
||||
var BuildHandler = request.NamedHandler{Name: "awssdk.ec2query.Build", Fn: Build}
|
||||
|
||||
// Build builds a request for the EC2 protocol.
|
||||
func Build(r *request.Request) {
|
||||
body := url.Values{
|
||||
"Action": {r.Operation.Name},
|
||||
"Version": {r.ClientInfo.APIVersion},
|
||||
}
|
||||
if err := queryutil.Parse(body, r.Params, true); err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed encoding EC2 Query request", err)
|
||||
}
|
||||
|
||||
if r.ExpireTime == 0 {
|
||||
r.HTTPRequest.Method = "POST"
|
||||
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
|
||||
r.SetBufferBody([]byte(body.Encode()))
|
||||
} else { // This is a pre-signed request
|
||||
r.HTTPRequest.Method = "GET"
|
||||
r.HTTPRequest.URL.RawQuery = body.Encode()
|
||||
}
|
||||
}
|
||||
+85
@@ -0,0 +1,85 @@
|
||||
// +build bench
|
||||
|
||||
package ec2query_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
|
||||
"github.com/aws/aws-sdk-go/service/ec2"
|
||||
)
|
||||
|
||||
func BenchmarkEC2QueryBuild_Complex_ec2AuthorizeSecurityGroupEgress(b *testing.B) {
|
||||
params := &ec2.AuthorizeSecurityGroupEgressInput{
|
||||
GroupId: aws.String("String"), // Required
|
||||
CidrIp: aws.String("String"),
|
||||
DryRun: aws.Bool(true),
|
||||
FromPort: aws.Int64(1),
|
||||
IpPermissions: []*ec2.IpPermission{
|
||||
{ // Required
|
||||
FromPort: aws.Int64(1),
|
||||
IpProtocol: aws.String("String"),
|
||||
IpRanges: []*ec2.IpRange{
|
||||
{ // Required
|
||||
CidrIp: aws.String("String"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
PrefixListIds: []*ec2.PrefixListId{
|
||||
{ // Required
|
||||
PrefixListId: aws.String("String"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
ToPort: aws.Int64(1),
|
||||
UserIdGroupPairs: []*ec2.UserIdGroupPair{
|
||||
{ // Required
|
||||
GroupId: aws.String("String"),
|
||||
GroupName: aws.String("String"),
|
||||
UserId: aws.String("String"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
IpProtocol: aws.String("String"),
|
||||
SourceSecurityGroupName: aws.String("String"),
|
||||
SourceSecurityGroupOwnerId: aws.String("String"),
|
||||
ToPort: aws.Int64(1),
|
||||
}
|
||||
|
||||
benchEC2QueryBuild(b, "AuthorizeSecurityGroupEgress", params)
|
||||
}
|
||||
|
||||
func BenchmarkEC2QueryBuild_Simple_ec2AttachNetworkInterface(b *testing.B) {
|
||||
params := &ec2.AttachNetworkInterfaceInput{
|
||||
DeviceIndex: aws.Int64(1), // Required
|
||||
InstanceId: aws.String("String"), // Required
|
||||
NetworkInterfaceId: aws.String("String"), // Required
|
||||
DryRun: aws.Bool(true),
|
||||
}
|
||||
|
||||
benchEC2QueryBuild(b, "AttachNetworkInterface", params)
|
||||
}
|
||||
|
||||
func benchEC2QueryBuild(b *testing.B, opName string, params interface{}) {
|
||||
svc := awstesting.NewClient()
|
||||
svc.ServiceName = "ec2"
|
||||
svc.APIVersion = "2015-04-15"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(&request.Operation{
|
||||
Name: opName,
|
||||
HTTPMethod: "POST",
|
||||
HTTPPath: "/",
|
||||
}, params, nil)
|
||||
ec2query.Build(r)
|
||||
if r.Error != nil {
|
||||
b.Fatal("Unexpected error", r.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
+1695
File diff suppressed because it is too large
Load Diff
+63
@@ -0,0 +1,63 @@
|
||||
package ec2query
|
||||
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"io"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
|
||||
)
|
||||
|
||||
// UnmarshalHandler is a named request handler for unmarshaling ec2query protocol requests
|
||||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.ec2query.Unmarshal", Fn: Unmarshal}
|
||||
|
||||
// UnmarshalMetaHandler is a named request handler for unmarshaling ec2query protocol request metadata
|
||||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.ec2query.UnmarshalMeta", Fn: UnmarshalMeta}
|
||||
|
||||
// UnmarshalErrorHandler is a named request handler for unmarshaling ec2query protocol request errors
|
||||
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.ec2query.UnmarshalError", Fn: UnmarshalError}
|
||||
|
||||
// Unmarshal unmarshals a response body for the EC2 protocol.
|
||||
func Unmarshal(r *request.Request) {
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
if r.DataFilled() {
|
||||
decoder := xml.NewDecoder(r.HTTPResponse.Body)
|
||||
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query response", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalMeta unmarshals response headers for the EC2 protocol.
|
||||
func UnmarshalMeta(r *request.Request) {
|
||||
// TODO implement unmarshaling of request IDs
|
||||
}
|
||||
|
||||
type xmlErrorResponse struct {
|
||||
XMLName xml.Name `xml:"Response"`
|
||||
Code string `xml:"Errors>Error>Code"`
|
||||
Message string `xml:"Errors>Error>Message"`
|
||||
RequestID string `xml:"RequestID"`
|
||||
}
|
||||
|
||||
// UnmarshalError unmarshals a response error for the EC2 protocol.
|
||||
func UnmarshalError(r *request.Request) {
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
|
||||
resp := &xmlErrorResponse{}
|
||||
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
|
||||
if err != nil && err != io.EOF {
|
||||
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query error response", err)
|
||||
} else {
|
||||
r.Error = awserr.NewRequestFailure(
|
||||
awserr.New(resp.Code, resp.Message, nil),
|
||||
r.HTTPResponse.StatusCode,
|
||||
resp.RequestID,
|
||||
)
|
||||
}
|
||||
}
|
||||
+1572
File diff suppressed because it is too large
Load Diff
+75
@@ -0,0 +1,75 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// RandReader is the random reader the protocol package will use to read
|
||||
// random bytes from. This is exported for testing, and should not be used.
|
||||
var RandReader = rand.Reader
|
||||
|
||||
const idempotencyTokenFillTag = `idempotencyToken`
|
||||
|
||||
// CanSetIdempotencyToken returns true if the struct field should be
|
||||
// automatically populated with a Idempotency token.
|
||||
//
|
||||
// Only *string and string type fields that are tagged with idempotencyToken
|
||||
// which are not already set can be auto filled.
|
||||
func CanSetIdempotencyToken(v reflect.Value, f reflect.StructField) bool {
|
||||
switch u := v.Interface().(type) {
|
||||
// To auto fill an Idempotency token the field must be a string,
|
||||
// tagged for auto fill, and have a zero value.
|
||||
case *string:
|
||||
return u == nil && len(f.Tag.Get(idempotencyTokenFillTag)) != 0
|
||||
case string:
|
||||
return len(u) == 0 && len(f.Tag.Get(idempotencyTokenFillTag)) != 0
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetIdempotencyToken returns a randomly generated idempotency token.
|
||||
func GetIdempotencyToken() string {
|
||||
b := make([]byte, 16)
|
||||
RandReader.Read(b)
|
||||
|
||||
return UUIDVersion4(b)
|
||||
}
|
||||
|
||||
// SetIdempotencyToken will set the value provided with a Idempotency Token.
|
||||
// Given that the value can be set. Will panic if value is not setable.
|
||||
func SetIdempotencyToken(v reflect.Value) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() && v.CanSet() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
v = reflect.Indirect(v)
|
||||
|
||||
if !v.CanSet() {
|
||||
panic(fmt.Sprintf("unable to set idempotnecy token %v", v))
|
||||
}
|
||||
|
||||
b := make([]byte, 16)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
// TODO handle error
|
||||
return
|
||||
}
|
||||
|
||||
v.Set(reflect.ValueOf(UUIDVersion4(b)))
|
||||
}
|
||||
|
||||
// UUIDVersion4 returns a Version 4 random UUID from the byte slice provided
|
||||
func UUIDVersion4(u []byte) string {
|
||||
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_.28random.29
|
||||
// 13th character is "4"
|
||||
u[6] = (u[6] | 0x40) & 0x4F
|
||||
// 17th character is "8", "9", "a", or "b"
|
||||
u[8] = (u[8] | 0x80) & 0xBF
|
||||
|
||||
return fmt.Sprintf(`%X-%X-%X-%X-%X`, u[0:4], u[4:6], u[6:8], u[8:10], u[10:])
|
||||
}
|
||||
+106
@@ -0,0 +1,106 @@
|
||||
package protocol_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCanSetIdempotencyToken(t *testing.T) {
|
||||
cases := []struct {
|
||||
CanSet bool
|
||||
Case interface{}
|
||||
}{
|
||||
{
|
||||
true,
|
||||
struct {
|
||||
Field *string `idempotencyToken:"true"`
|
||||
}{},
|
||||
},
|
||||
{
|
||||
true,
|
||||
struct {
|
||||
Field string `idempotencyToken:"true"`
|
||||
}{},
|
||||
},
|
||||
{
|
||||
false,
|
||||
struct {
|
||||
Field *string `idempotencyToken:"true"`
|
||||
}{Field: new(string)},
|
||||
},
|
||||
{
|
||||
false,
|
||||
struct {
|
||||
Field string `idempotencyToken:"true"`
|
||||
}{Field: "value"},
|
||||
},
|
||||
{
|
||||
false,
|
||||
struct {
|
||||
Field *int `idempotencyToken:"true"`
|
||||
}{},
|
||||
},
|
||||
{
|
||||
false,
|
||||
struct {
|
||||
Field *string
|
||||
}{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
v := reflect.Indirect(reflect.ValueOf(c.Case))
|
||||
ty := v.Type()
|
||||
canSet := protocol.CanSetIdempotencyToken(v.Field(0), ty.Field(0))
|
||||
assert.Equal(t, c.CanSet, canSet, "Expect case %d can set to match", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIdempotencyToken(t *testing.T) {
|
||||
cases := []struct {
|
||||
Case interface{}
|
||||
}{
|
||||
{
|
||||
&struct {
|
||||
Field *string `idempotencyToken:"true"`
|
||||
}{},
|
||||
},
|
||||
{
|
||||
&struct {
|
||||
Field string `idempotencyToken:"true"`
|
||||
}{},
|
||||
},
|
||||
{
|
||||
&struct {
|
||||
Field *string `idempotencyToken:"true"`
|
||||
}{Field: new(string)},
|
||||
},
|
||||
{
|
||||
&struct {
|
||||
Field string `idempotencyToken:"true"`
|
||||
}{Field: ""},
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
v := reflect.Indirect(reflect.ValueOf(c.Case))
|
||||
|
||||
protocol.SetIdempotencyToken(v.Field(0))
|
||||
assert.NotEmpty(t, v.Field(0).Interface(), "Expect case %d to be set", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDVersion4(t *testing.T) {
|
||||
uuid := protocol.UUIDVersion4(make([]byte, 16))
|
||||
assert.Equal(t, `00000000-0000-4000-8000-000000000000`, uuid)
|
||||
|
||||
b := make([]byte, 16)
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] = 1
|
||||
}
|
||||
uuid = protocol.UUIDVersion4(b)
|
||||
assert.Equal(t, `01010101-0101-4101-8101-010101010101`, uuid)
|
||||
}
|
||||
+279
@@ -0,0 +1,279 @@
|
||||
// Package jsonutil provides JSON serialization of AWS requests and responses.
|
||||
package jsonutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol"
|
||||
)
|
||||
|
||||
var timeType = reflect.ValueOf(time.Time{}).Type()
|
||||
var byteSliceType = reflect.ValueOf([]byte{}).Type()
|
||||
|
||||
// BuildJSON builds a JSON string for a given object v.
|
||||
func BuildJSON(v interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
err := buildAny(reflect.ValueOf(v), &buf, "")
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
func buildAny(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
|
||||
origVal := value
|
||||
value = reflect.Indirect(value)
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
vtype := value.Type()
|
||||
|
||||
t := tag.Get("type")
|
||||
if t == "" {
|
||||
switch vtype.Kind() {
|
||||
case reflect.Struct:
|
||||
// also it can't be a time object
|
||||
if value.Type() != timeType {
|
||||
t = "structure"
|
||||
}
|
||||
case reflect.Slice:
|
||||
// also it can't be a byte slice
|
||||
if _, ok := value.Interface().([]byte); !ok {
|
||||
t = "list"
|
||||
}
|
||||
case reflect.Map:
|
||||
t = "map"
|
||||
}
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "structure":
|
||||
if field, ok := vtype.FieldByName("_"); ok {
|
||||
tag = field.Tag
|
||||
}
|
||||
return buildStruct(value, buf, tag)
|
||||
case "list":
|
||||
return buildList(value, buf, tag)
|
||||
case "map":
|
||||
return buildMap(value, buf, tag)
|
||||
default:
|
||||
return buildScalar(origVal, buf, tag)
|
||||
}
|
||||
}
|
||||
|
||||
func buildStruct(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// unwrap payloads
|
||||
if payload := tag.Get("payload"); payload != "" {
|
||||
field, _ := value.Type().FieldByName(payload)
|
||||
tag = field.Tag
|
||||
value = elemOf(value.FieldByName(payload))
|
||||
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteByte('{')
|
||||
|
||||
t := value.Type()
|
||||
first := true
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
member := value.Field(i)
|
||||
|
||||
// This allocates the most memory.
|
||||
// Additionally, we cannot skip nil fields due to
|
||||
// idempotency auto filling.
|
||||
field := t.Field(i)
|
||||
|
||||
if field.PkgPath != "" {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
if field.Tag.Get("json") == "-" {
|
||||
continue
|
||||
}
|
||||
if field.Tag.Get("location") != "" {
|
||||
continue // ignore non-body elements
|
||||
}
|
||||
if field.Tag.Get("ignore") != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if protocol.CanSetIdempotencyToken(member, field) {
|
||||
token := protocol.GetIdempotencyToken()
|
||||
member = reflect.ValueOf(&token)
|
||||
}
|
||||
|
||||
if (member.Kind() == reflect.Ptr || member.Kind() == reflect.Slice || member.Kind() == reflect.Map) && member.IsNil() {
|
||||
continue // ignore unset fields
|
||||
}
|
||||
|
||||
if first {
|
||||
first = false
|
||||
} else {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
|
||||
// figure out what this field is called
|
||||
name := field.Name
|
||||
if locName := field.Tag.Get("locationName"); locName != "" {
|
||||
name = locName
|
||||
}
|
||||
|
||||
writeString(name, buf)
|
||||
buf.WriteString(`:`)
|
||||
|
||||
err := buildAny(member, buf, field.Tag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
buf.WriteString("}")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildList(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
|
||||
buf.WriteString("[")
|
||||
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
buildAny(value.Index(i), buf, "")
|
||||
|
||||
if i < value.Len()-1 {
|
||||
buf.WriteString(",")
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString("]")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type sortedValues []reflect.Value
|
||||
|
||||
func (sv sortedValues) Len() int { return len(sv) }
|
||||
func (sv sortedValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
|
||||
func (sv sortedValues) Less(i, j int) bool { return sv[i].String() < sv[j].String() }
|
||||
|
||||
func buildMap(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
|
||||
buf.WriteString("{")
|
||||
|
||||
sv := sortedValues(value.MapKeys())
|
||||
sort.Sort(sv)
|
||||
|
||||
for i, k := range sv {
|
||||
if i > 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
|
||||
writeString(k.String(), buf)
|
||||
buf.WriteString(`:`)
|
||||
|
||||
buildAny(value.MapIndex(k), buf, "")
|
||||
}
|
||||
|
||||
buf.WriteString("}")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildScalar(v reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
|
||||
// prevents allocation on the heap.
|
||||
scratch := [64]byte{}
|
||||
switch value := reflect.Indirect(v); value.Kind() {
|
||||
case reflect.String:
|
||||
writeString(value.String(), buf)
|
||||
case reflect.Bool:
|
||||
if value.Bool() {
|
||||
buf.WriteString("true")
|
||||
} else {
|
||||
buf.WriteString("false")
|
||||
}
|
||||
case reflect.Int64:
|
||||
buf.Write(strconv.AppendInt(scratch[:0], value.Int(), 10))
|
||||
case reflect.Float64:
|
||||
f := value.Float()
|
||||
if math.IsInf(f, 0) || math.IsNaN(f) {
|
||||
return &json.UnsupportedValueError{Value: v, Str: strconv.FormatFloat(f, 'f', -1, 64)}
|
||||
}
|
||||
buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64))
|
||||
default:
|
||||
switch value.Type() {
|
||||
case timeType:
|
||||
converted := v.Interface().(*time.Time)
|
||||
|
||||
buf.Write(strconv.AppendInt(scratch[:0], converted.UTC().Unix(), 10))
|
||||
case byteSliceType:
|
||||
if !value.IsNil() {
|
||||
converted := value.Interface().([]byte)
|
||||
buf.WriteByte('"')
|
||||
if len(converted) < 1024 {
|
||||
// for small buffers, using Encode directly is much faster.
|
||||
dst := make([]byte, base64.StdEncoding.EncodedLen(len(converted)))
|
||||
base64.StdEncoding.Encode(dst, converted)
|
||||
buf.Write(dst)
|
||||
} else {
|
||||
// for large buffers, avoid unnecessary extra temporary
|
||||
// buffer space.
|
||||
enc := base64.NewEncoder(base64.StdEncoding, buf)
|
||||
enc.Write(converted)
|
||||
enc.Close()
|
||||
}
|
||||
buf.WriteByte('"')
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported JSON value %v (%s)", value.Interface(), value.Type())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var hex = "0123456789abcdef"
|
||||
|
||||
func writeString(s string, buf *bytes.Buffer) {
|
||||
buf.WriteByte('"')
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == '"' {
|
||||
buf.WriteString(`\"`)
|
||||
} else if s[i] == '\\' {
|
||||
buf.WriteString(`\\`)
|
||||
} else if s[i] == '\b' {
|
||||
buf.WriteString(`\b`)
|
||||
} else if s[i] == '\f' {
|
||||
buf.WriteString(`\f`)
|
||||
} else if s[i] == '\r' {
|
||||
buf.WriteString(`\r`)
|
||||
} else if s[i] == '\t' {
|
||||
buf.WriteString(`\t`)
|
||||
} else if s[i] == '\n' {
|
||||
buf.WriteString(`\n`)
|
||||
} else if s[i] < 32 {
|
||||
buf.WriteString("\\u00")
|
||||
buf.WriteByte(hex[s[i]>>4])
|
||||
buf.WriteByte(hex[s[i]&0xF])
|
||||
} else {
|
||||
buf.WriteByte(s[i])
|
||||
}
|
||||
}
|
||||
buf.WriteByte('"')
|
||||
}
|
||||
|
||||
// Returns the reflection element of a value, if it is a pointer.
|
||||
func elemOf(value reflect.Value) reflect.Value {
|
||||
for value.Kind() == reflect.Ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
return value
|
||||
}
|
||||
+109
@@ -0,0 +1,109 @@
|
||||
package jsonutil_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func S(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func D(s int64) *int64 {
|
||||
return &s
|
||||
}
|
||||
|
||||
func F(s float64) *float64 {
|
||||
return &s
|
||||
}
|
||||
|
||||
func T(s time.Time) *time.Time {
|
||||
return &s
|
||||
}
|
||||
|
||||
type J struct {
|
||||
S *string
|
||||
SS []string
|
||||
D *int64
|
||||
F *float64
|
||||
T *time.Time
|
||||
}
|
||||
|
||||
var zero = 0.0
|
||||
|
||||
var jsonTests = []struct {
|
||||
in interface{}
|
||||
out string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
J{},
|
||||
`{}`,
|
||||
``,
|
||||
},
|
||||
{
|
||||
J{
|
||||
S: S("str"),
|
||||
SS: []string{"A", "B", "C"},
|
||||
D: D(123),
|
||||
F: F(4.56),
|
||||
T: T(time.Unix(987, 0)),
|
||||
},
|
||||
`{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`,
|
||||
``,
|
||||
},
|
||||
{
|
||||
J{
|
||||
S: S(`"''"`),
|
||||
},
|
||||
`{"S":"\"''\""}`,
|
||||
``,
|
||||
},
|
||||
{
|
||||
J{
|
||||
S: S("\x00føø\u00FF\n\\\"\r\t\b\f"),
|
||||
},
|
||||
`{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`,
|
||||
``,
|
||||
},
|
||||
{
|
||||
J{
|
||||
F: F(4.56 / zero),
|
||||
},
|
||||
"",
|
||||
`json: unsupported value: +Inf`,
|
||||
},
|
||||
}
|
||||
|
||||
func TestBuildJSON(t *testing.T) {
|
||||
for _, test := range jsonTests {
|
||||
out, err := jsonutil.BuildJSON(test.in)
|
||||
if test.err != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), test.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, string(out), test.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBuildJSON(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range jsonTests {
|
||||
jsonutil.BuildJSON(test.in)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStdlibJSON(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range jsonTests {
|
||||
json.Marshal(test.in)
|
||||
}
|
||||
}
|
||||
}
|
||||
+213
@@ -0,0 +1,213 @@
|
||||
package jsonutil
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnmarshalJSON reads a stream and unmarshals the results in object v.
|
||||
func UnmarshalJSON(v interface{}, stream io.Reader) error {
|
||||
var out interface{}
|
||||
|
||||
b, err := ioutil.ReadAll(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(b, &out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unmarshalAny(reflect.ValueOf(v), out, "")
|
||||
}
|
||||
|
||||
func unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag) error {
|
||||
vtype := value.Type()
|
||||
if vtype.Kind() == reflect.Ptr {
|
||||
vtype = vtype.Elem() // check kind of actual element type
|
||||
}
|
||||
|
||||
t := tag.Get("type")
|
||||
if t == "" {
|
||||
switch vtype.Kind() {
|
||||
case reflect.Struct:
|
||||
// also it can't be a time object
|
||||
if _, ok := value.Interface().(*time.Time); !ok {
|
||||
t = "structure"
|
||||
}
|
||||
case reflect.Slice:
|
||||
// also it can't be a byte slice
|
||||
if _, ok := value.Interface().([]byte); !ok {
|
||||
t = "list"
|
||||
}
|
||||
case reflect.Map:
|
||||
t = "map"
|
||||
}
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "structure":
|
||||
if field, ok := vtype.FieldByName("_"); ok {
|
||||
tag = field.Tag
|
||||
}
|
||||
return unmarshalStruct(value, data, tag)
|
||||
case "list":
|
||||
return unmarshalList(value, data, tag)
|
||||
case "map":
|
||||
return unmarshalMap(value, data, tag)
|
||||
default:
|
||||
return unmarshalScalar(value, data, tag)
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTag) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
mapData, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("JSON value is not a structure (%#v)", data)
|
||||
}
|
||||
|
||||
t := value.Type()
|
||||
if value.Kind() == reflect.Ptr {
|
||||
if value.IsNil() { // create the structure if it's nil
|
||||
s := reflect.New(value.Type().Elem())
|
||||
value.Set(s)
|
||||
value = s
|
||||
}
|
||||
|
||||
value = value.Elem()
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
// unwrap any payloads
|
||||
if payload := tag.Get("payload"); payload != "" {
|
||||
field, _ := t.FieldByName(payload)
|
||||
return unmarshalAny(value.FieldByName(payload), data, field.Tag)
|
||||
}
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if field.PkgPath != "" {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
|
||||
// figure out what this field is called
|
||||
name := field.Name
|
||||
if locName := field.Tag.Get("locationName"); locName != "" {
|
||||
name = locName
|
||||
}
|
||||
|
||||
member := value.FieldByIndex(field.Index)
|
||||
err := unmarshalAny(member, mapData[name], field.Tag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
listData, ok := data.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("JSON value is not a list (%#v)", data)
|
||||
}
|
||||
|
||||
if value.IsNil() {
|
||||
l := len(listData)
|
||||
value.Set(reflect.MakeSlice(value.Type(), l, l))
|
||||
}
|
||||
|
||||
for i, c := range listData {
|
||||
err := unmarshalAny(value.Index(i), c, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
mapData, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("JSON value is not a map (%#v)", data)
|
||||
}
|
||||
|
||||
if value.IsNil() {
|
||||
value.Set(reflect.MakeMap(value.Type()))
|
||||
}
|
||||
|
||||
for k, v := range mapData {
|
||||
kvalue := reflect.ValueOf(k)
|
||||
vvalue := reflect.New(value.Type().Elem()).Elem()
|
||||
|
||||
unmarshalAny(vvalue, v, "")
|
||||
value.SetMapIndex(kvalue, vvalue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalScalar(value reflect.Value, data interface{}, tag reflect.StructTag) error {
|
||||
errf := func() error {
|
||||
return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
|
||||
}
|
||||
|
||||
switch d := data.(type) {
|
||||
case nil:
|
||||
return nil // nothing to do here
|
||||
case string:
|
||||
switch value.Interface().(type) {
|
||||
case *string:
|
||||
value.Set(reflect.ValueOf(&d))
|
||||
case []byte:
|
||||
b, err := base64.StdEncoding.DecodeString(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.Set(reflect.ValueOf(b))
|
||||
default:
|
||||
return errf()
|
||||
}
|
||||
case float64:
|
||||
switch value.Interface().(type) {
|
||||
case *int64:
|
||||
di := int64(d)
|
||||
value.Set(reflect.ValueOf(&di))
|
||||
case *float64:
|
||||
value.Set(reflect.ValueOf(&d))
|
||||
case *time.Time:
|
||||
t := time.Unix(int64(d), 0).UTC()
|
||||
value.Set(reflect.ValueOf(&t))
|
||||
default:
|
||||
return errf()
|
||||
}
|
||||
case bool:
|
||||
switch value.Interface().(type) {
|
||||
case *bool:
|
||||
value.Set(reflect.ValueOf(&d))
|
||||
default:
|
||||
return errf()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported JSON value (%v)", data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+71
@@ -0,0 +1,71 @@
|
||||
// +build bench
|
||||
|
||||
package jsonrpc_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
|
||||
"github.com/aws/aws-sdk-go/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
|
||||
)
|
||||
|
||||
func BenchmarkJSONRPCBuild_Simple_dynamodbPutItem(b *testing.B) {
|
||||
svc := awstesting.NewClient()
|
||||
|
||||
params := getDynamodbPutItemParams(b)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(&request.Operation{Name: "Operation"}, params, nil)
|
||||
jsonrpc.Build(r)
|
||||
if r.Error != nil {
|
||||
b.Fatal("Unexpected error", r.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJSONUtilBuild_Simple_dynamodbPutItem(b *testing.B) {
|
||||
svc := awstesting.NewClient()
|
||||
|
||||
params := getDynamodbPutItemParams(b)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(&request.Operation{Name: "Operation"}, params, nil)
|
||||
_, err := jsonutil.BuildJSON(r.Params)
|
||||
if err != nil {
|
||||
b.Fatal("Unexpected error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodingJSONMarshal_Simple_dynamodbPutItem(b *testing.B) {
|
||||
params := getDynamodbPutItemParams(b)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := &bytes.Buffer{}
|
||||
encoder := json.NewEncoder(buf)
|
||||
if err := encoder.Encode(params); err != nil {
|
||||
b.Fatal("Unexpected error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getDynamodbPutItemParams(b *testing.B) *dynamodb.PutItemInput {
|
||||
av, err := dynamodbattribute.ConvertToMap(struct {
|
||||
Key string
|
||||
Data string
|
||||
}{Key: "MyKey", Data: "MyData"})
|
||||
if err != nil {
|
||||
b.Fatal("benchPutItem, expect no ConvertToMap errors", err)
|
||||
}
|
||||
return &dynamodb.PutItemInput{
|
||||
Item: av,
|
||||
TableName: aws.String("tablename"),
|
||||
}
|
||||
}
|
||||
+2038
File diff suppressed because it is too large
Load Diff
+111
@@ -0,0 +1,111 @@
|
||||
// Package jsonrpc provides JSON RPC utilities for serialization of AWS
|
||||
// requests and responses.
|
||||
package jsonrpc
|
||||
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/json.json build_test.go
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
)
|
||||
|
||||
var emptyJSON = []byte("{}")
|
||||
|
||||
// BuildHandler is a named request handler for building jsonrpc protocol requests
|
||||
var BuildHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Build", Fn: Build}
|
||||
|
||||
// UnmarshalHandler is a named request handler for unmarshaling jsonrpc protocol requests
|
||||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Unmarshal", Fn: Unmarshal}
|
||||
|
||||
// UnmarshalMetaHandler is a named request handler for unmarshaling jsonrpc protocol request metadata
|
||||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalMeta", Fn: UnmarshalMeta}
|
||||
|
||||
// UnmarshalErrorHandler is a named request handler for unmarshaling jsonrpc protocol request errors
|
||||
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalError", Fn: UnmarshalError}
|
||||
|
||||
// Build builds a JSON payload for a JSON RPC request.
|
||||
func Build(req *request.Request) {
|
||||
var buf []byte
|
||||
var err error
|
||||
if req.ParamsFilled() {
|
||||
buf, err = jsonutil.BuildJSON(req.Params)
|
||||
if err != nil {
|
||||
req.Error = awserr.New("SerializationError", "failed encoding JSON RPC request", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
buf = emptyJSON
|
||||
}
|
||||
|
||||
if req.ClientInfo.TargetPrefix != "" || string(buf) != "{}" {
|
||||
req.SetBufferBody(buf)
|
||||
}
|
||||
|
||||
if req.ClientInfo.TargetPrefix != "" {
|
||||
target := req.ClientInfo.TargetPrefix + "." + req.Operation.Name
|
||||
req.HTTPRequest.Header.Add("X-Amz-Target", target)
|
||||
}
|
||||
if req.ClientInfo.JSONVersion != "" {
|
||||
jsonVersion := req.ClientInfo.JSONVersion
|
||||
req.HTTPRequest.Header.Add("Content-Type", "application/x-amz-json-"+jsonVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// Unmarshal unmarshals a response for a JSON RPC service.
|
||||
func Unmarshal(req *request.Request) {
|
||||
defer req.HTTPResponse.Body.Close()
|
||||
if req.DataFilled() {
|
||||
err := jsonutil.UnmarshalJSON(req.Data, req.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
req.Error = awserr.New("SerializationError", "failed decoding JSON RPC response", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMeta unmarshals headers from a response for a JSON RPC service.
|
||||
func UnmarshalMeta(req *request.Request) {
|
||||
rest.UnmarshalMeta(req)
|
||||
}
|
||||
|
||||
// UnmarshalError unmarshals an error response for a JSON RPC service.
|
||||
func UnmarshalError(req *request.Request) {
|
||||
defer req.HTTPResponse.Body.Close()
|
||||
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
req.Error = awserr.New("SerializationError", "failed reading JSON RPC error response", err)
|
||||
return
|
||||
}
|
||||
if len(bodyBytes) == 0 {
|
||||
req.Error = awserr.NewRequestFailure(
|
||||
awserr.New("SerializationError", req.HTTPResponse.Status, nil),
|
||||
req.HTTPResponse.StatusCode,
|
||||
"",
|
||||
)
|
||||
return
|
||||
}
|
||||
var jsonErr jsonErrorResponse
|
||||
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
|
||||
req.Error = awserr.New("SerializationError", "failed decoding JSON RPC error response", err)
|
||||
return
|
||||
}
|
||||
|
||||
codes := strings.SplitN(jsonErr.Code, "#", 2)
|
||||
req.Error = awserr.NewRequestFailure(
|
||||
awserr.New(codes[len(codes)-1], jsonErr.Message, nil),
|
||||
req.HTTPResponse.StatusCode,
|
||||
req.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
type jsonErrorResponse struct {
|
||||
Code string `json:"__type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
+1238
File diff suppressed because it is too large
Load Diff
+203
@@ -0,0 +1,203 @@
|
||||
package protocol_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/private/protocol"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/query"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/restjson"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/restxml"
|
||||
)
|
||||
|
||||
func xmlData(set bool, b []byte, size, delta int) {
|
||||
if !set {
|
||||
copy(b, []byte("<B><A>"))
|
||||
}
|
||||
if size == 0 {
|
||||
copy(b[delta-len("</B></A>"):], []byte("</B></A>"))
|
||||
}
|
||||
}
|
||||
|
||||
func jsonData(set bool, b []byte, size, delta int) {
|
||||
if !set {
|
||||
copy(b, []byte("{\"A\": \""))
|
||||
}
|
||||
if size == 0 {
|
||||
copy(b[delta-len("\"}"):], []byte("\"}"))
|
||||
}
|
||||
}
|
||||
|
||||
func buildNewRequest(data interface{}) *request.Request {
|
||||
v := url.Values{}
|
||||
v.Set("test", "TEST")
|
||||
v.Add("test1", "TEST1")
|
||||
|
||||
req := &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
Header: make(http.Header),
|
||||
Body: &awstesting.ReadCloser{Size: 2048},
|
||||
URL: &url.URL{
|
||||
RawQuery: v.Encode(),
|
||||
},
|
||||
},
|
||||
Params: &struct {
|
||||
LocationName string `locationName:"test"`
|
||||
}{
|
||||
"Test",
|
||||
},
|
||||
ClientInfo: metadata.ClientInfo{
|
||||
ServiceName: "test",
|
||||
TargetPrefix: "test",
|
||||
JSONVersion: "test",
|
||||
APIVersion: "test",
|
||||
Endpoint: "test",
|
||||
SigningName: "test",
|
||||
SigningRegion: "test",
|
||||
},
|
||||
Operation: &request.Operation{
|
||||
Name: "test",
|
||||
},
|
||||
}
|
||||
req.HTTPResponse = &http.Response{
|
||||
Body: &awstesting.ReadCloser{Size: 2048},
|
||||
Header: http.Header{
|
||||
"X-Amzn-Requestid": []string{"1"},
|
||||
},
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
data = &struct {
|
||||
_ struct{} `type:"structure"`
|
||||
LocationName *string `locationName:"testName"`
|
||||
Location *string `location:"statusCode"`
|
||||
A *string `type:"string"`
|
||||
}{}
|
||||
}
|
||||
|
||||
req.Data = data
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
type expected struct {
|
||||
dataType int
|
||||
closed bool
|
||||
size int
|
||||
errExists bool
|
||||
}
|
||||
|
||||
const (
|
||||
jsonType = iota
|
||||
xmlType
|
||||
)
|
||||
|
||||
func checkForLeak(data interface{}, build, fn func(*request.Request), t *testing.T, result expected) {
|
||||
req := buildNewRequest(data)
|
||||
reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
|
||||
switch result.dataType {
|
||||
case jsonType:
|
||||
reader.FillData = jsonData
|
||||
case xmlType:
|
||||
reader.FillData = xmlData
|
||||
}
|
||||
build(req)
|
||||
fn(req)
|
||||
|
||||
if result.errExists {
|
||||
assert.NotNil(t, req.Error)
|
||||
} else {
|
||||
fmt.Println(req.Error)
|
||||
assert.Nil(t, req.Error)
|
||||
}
|
||||
|
||||
assert.Equal(t, reader.Closed, result.closed)
|
||||
assert.Equal(t, reader.Size, result.size)
|
||||
}
|
||||
|
||||
func TestJSONRpc(t *testing.T) {
|
||||
checkForLeak(nil, jsonrpc.Build, jsonrpc.Unmarshal, t, expected{jsonType, true, 0, false})
|
||||
checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalError, t, expected{jsonType, true, 0, true})
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
checkForLeak(nil, query.Build, query.Unmarshal, t, expected{jsonType, true, 0, false})
|
||||
checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
checkForLeak(nil, query.Build, query.UnmarshalError, t, expected{jsonType, true, 0, true})
|
||||
}
|
||||
|
||||
func TestRest(t *testing.T) {
|
||||
// case 1: Payload io.ReadSeeker
|
||||
checkForLeak(nil, rest.Build, rest.Unmarshal, t, expected{jsonType, false, 2048, false})
|
||||
checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
|
||||
// case 2: Payload *string
|
||||
// should close the body
|
||||
dataStr := struct {
|
||||
_ struct{} `type:"structure" payload:"Payload"`
|
||||
LocationName *string `locationName:"testName"`
|
||||
Location *string `location:"statusCode"`
|
||||
A *string `type:"string"`
|
||||
Payload *string `locationName:"payload" type:"blob" required:"true"`
|
||||
}{}
|
||||
checkForLeak(&dataStr, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
|
||||
checkForLeak(&dataStr, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
|
||||
// case 3: Payload []byte
|
||||
// should close the body
|
||||
dataBytes := struct {
|
||||
_ struct{} `type:"structure" payload:"Payload"`
|
||||
LocationName *string `locationName:"testName"`
|
||||
Location *string `location:"statusCode"`
|
||||
A *string `type:"string"`
|
||||
Payload []byte `locationName:"payload" type:"blob" required:"true"`
|
||||
}{}
|
||||
checkForLeak(&dataBytes, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
|
||||
checkForLeak(&dataBytes, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
|
||||
// case 4: Payload unsupported type
|
||||
// should close the body
|
||||
dataUnsupported := struct {
|
||||
_ struct{} `type:"structure" payload:"Payload"`
|
||||
LocationName *string `locationName:"testName"`
|
||||
Location *string `location:"statusCode"`
|
||||
A *string `type:"string"`
|
||||
Payload string `locationName:"payload" type:"blob" required:"true"`
|
||||
}{}
|
||||
checkForLeak(&dataUnsupported, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, true})
|
||||
checkForLeak(&dataUnsupported, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
}
|
||||
|
||||
func TestRestJSON(t *testing.T) {
|
||||
checkForLeak(nil, restjson.Build, restjson.Unmarshal, t, expected{jsonType, true, 0, false})
|
||||
checkForLeak(nil, restjson.Build, restjson.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
checkForLeak(nil, restjson.Build, restjson.UnmarshalError, t, expected{jsonType, true, 0, true})
|
||||
}
|
||||
|
||||
func TestRestXML(t *testing.T) {
|
||||
checkForLeak(nil, restxml.Build, restxml.Unmarshal, t, expected{xmlType, true, 0, false})
|
||||
checkForLeak(nil, restxml.Build, restxml.UnmarshalMeta, t, expected{xmlType, false, 2048, false})
|
||||
checkForLeak(nil, restxml.Build, restxml.UnmarshalError, t, expected{xmlType, true, 0, true})
|
||||
}
|
||||
|
||||
func TestXML(t *testing.T) {
|
||||
checkForLeak(nil, ec2query.Build, ec2query.Unmarshal, t, expected{jsonType, true, 0, false})
|
||||
checkForLeak(nil, ec2query.Build, ec2query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
|
||||
checkForLeak(nil, ec2query.Build, ec2query.UnmarshalError, t, expected{jsonType, true, 0, true})
|
||||
}
|
||||
|
||||
func TestProtocol(t *testing.T) {
|
||||
checkForLeak(nil, restxml.Build, protocol.UnmarshalDiscardBody, t, expected{xmlType, true, 0, false})
|
||||
}
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
// Package query provides serialization of AWS query requests, and responses.
|
||||
package query
|
||||
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil"
|
||||
)
|
||||
|
||||
// BuildHandler is a named request handler for building query protocol requests
|
||||
var BuildHandler = request.NamedHandler{Name: "awssdk.query.Build", Fn: Build}
|
||||
|
||||
// Build builds a request for an AWS Query service.
|
||||
func Build(r *request.Request) {
|
||||
body := url.Values{
|
||||
"Action": {r.Operation.Name},
|
||||
"Version": {r.ClientInfo.APIVersion},
|
||||
}
|
||||
if err := queryutil.Parse(body, r.Params, false); err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed encoding Query request", err)
|
||||
return
|
||||
}
|
||||
|
||||
if r.ExpireTime == 0 {
|
||||
r.HTTPRequest.Method = "POST"
|
||||
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
|
||||
r.SetBufferBody([]byte(body.Encode()))
|
||||
} else { // This is a pre-signed request
|
||||
r.HTTPRequest.Method = "GET"
|
||||
r.HTTPRequest.URL.RawQuery = body.Encode()
|
||||
}
|
||||
}
|
||||
+3362
File diff suppressed because it is too large
Load Diff
+237
@@ -0,0 +1,237 @@
|
||||
package queryutil
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol"
|
||||
)
|
||||
|
||||
// Parse parses an object i and fills a url.Values object. The isEC2 flag
|
||||
// indicates if this is the EC2 Query sub-protocol.
|
||||
func Parse(body url.Values, i interface{}, isEC2 bool) error {
|
||||
q := queryParser{isEC2: isEC2}
|
||||
return q.parseValue(body, reflect.ValueOf(i), "", "")
|
||||
}
|
||||
|
||||
func elemOf(value reflect.Value) reflect.Value {
|
||||
for value.Kind() == reflect.Ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
type queryParser struct {
|
||||
isEC2 bool
|
||||
}
|
||||
|
||||
func (q *queryParser) parseValue(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
|
||||
value = elemOf(value)
|
||||
|
||||
// no need to handle zero values
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
t := tag.Get("type")
|
||||
if t == "" {
|
||||
switch value.Kind() {
|
||||
case reflect.Struct:
|
||||
t = "structure"
|
||||
case reflect.Slice:
|
||||
t = "list"
|
||||
case reflect.Map:
|
||||
t = "map"
|
||||
}
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "structure":
|
||||
return q.parseStruct(v, value, prefix)
|
||||
case "list":
|
||||
return q.parseList(v, value, prefix, tag)
|
||||
case "map":
|
||||
return q.parseMap(v, value, prefix, tag)
|
||||
default:
|
||||
return q.parseScalar(v, value, prefix, tag)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix string) error {
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
t := value.Type()
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
elemValue := elemOf(value.Field(i))
|
||||
field := t.Field(i)
|
||||
|
||||
if field.PkgPath != "" {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
if field.Tag.Get("ignore") != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if protocol.CanSetIdempotencyToken(value.Field(i), field) {
|
||||
token := protocol.GetIdempotencyToken()
|
||||
elemValue = reflect.ValueOf(token)
|
||||
}
|
||||
|
||||
var name string
|
||||
if q.isEC2 {
|
||||
name = field.Tag.Get("queryName")
|
||||
}
|
||||
if name == "" {
|
||||
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
|
||||
name = field.Tag.Get("locationNameList")
|
||||
} else if locName := field.Tag.Get("locationName"); locName != "" {
|
||||
name = locName
|
||||
}
|
||||
if name != "" && q.isEC2 {
|
||||
name = strings.ToUpper(name[0:1]) + name[1:]
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
if prefix != "" {
|
||||
name = prefix + "." + name
|
||||
}
|
||||
|
||||
if err := q.parseValue(v, elemValue, name, field.Tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
|
||||
// If it's empty, generate an empty value
|
||||
if !value.IsNil() && value.Len() == 0 {
|
||||
v.Set(prefix, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// check for unflattened list member
|
||||
if !q.isEC2 && tag.Get("flattened") == "" {
|
||||
if listName := tag.Get("locationNameList"); listName == "" {
|
||||
prefix += ".member"
|
||||
} else {
|
||||
prefix += "." + listName
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
slicePrefix := prefix
|
||||
if slicePrefix == "" {
|
||||
slicePrefix = strconv.Itoa(i + 1)
|
||||
} else {
|
||||
slicePrefix = slicePrefix + "." + strconv.Itoa(i+1)
|
||||
}
|
||||
if err := q.parseValue(v, value.Index(i), slicePrefix, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *queryParser) parseMap(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
|
||||
// If it's empty, generate an empty value
|
||||
if !value.IsNil() && value.Len() == 0 {
|
||||
v.Set(prefix, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// check for unflattened list member
|
||||
if !q.isEC2 && tag.Get("flattened") == "" {
|
||||
prefix += ".entry"
|
||||
}
|
||||
|
||||
// sort keys for improved serialization consistency.
|
||||
// this is not strictly necessary for protocol support.
|
||||
mapKeyValues := value.MapKeys()
|
||||
mapKeys := map[string]reflect.Value{}
|
||||
mapKeyNames := make([]string, len(mapKeyValues))
|
||||
for i, mapKey := range mapKeyValues {
|
||||
name := mapKey.String()
|
||||
mapKeys[name] = mapKey
|
||||
mapKeyNames[i] = name
|
||||
}
|
||||
sort.Strings(mapKeyNames)
|
||||
|
||||
for i, mapKeyName := range mapKeyNames {
|
||||
mapKey := mapKeys[mapKeyName]
|
||||
mapValue := value.MapIndex(mapKey)
|
||||
|
||||
kname := tag.Get("locationNameKey")
|
||||
if kname == "" {
|
||||
kname = "key"
|
||||
}
|
||||
vname := tag.Get("locationNameValue")
|
||||
if vname == "" {
|
||||
vname = "value"
|
||||
}
|
||||
|
||||
// serialize key
|
||||
var keyName string
|
||||
if prefix == "" {
|
||||
keyName = strconv.Itoa(i+1) + "." + kname
|
||||
} else {
|
||||
keyName = prefix + "." + strconv.Itoa(i+1) + "." + kname
|
||||
}
|
||||
|
||||
if err := q.parseValue(v, mapKey, keyName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// serialize value
|
||||
var valueName string
|
||||
if prefix == "" {
|
||||
valueName = strconv.Itoa(i+1) + "." + vname
|
||||
} else {
|
||||
valueName = prefix + "." + strconv.Itoa(i+1) + "." + vname
|
||||
}
|
||||
|
||||
if err := q.parseValue(v, mapValue, valueName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, tag reflect.StructTag) error {
|
||||
switch value := r.Interface().(type) {
|
||||
case string:
|
||||
v.Set(name, value)
|
||||
case []byte:
|
||||
if !r.IsNil() {
|
||||
v.Set(name, base64.StdEncoding.EncodeToString(value))
|
||||
}
|
||||
case bool:
|
||||
v.Set(name, strconv.FormatBool(value))
|
||||
case int64:
|
||||
v.Set(name, strconv.FormatInt(value, 10))
|
||||
case int:
|
||||
v.Set(name, strconv.Itoa(value))
|
||||
case float64:
|
||||
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
|
||||
case float32:
|
||||
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
|
||||
case time.Time:
|
||||
const ISO8601UTC = "2006-01-02T15:04:05Z"
|
||||
v.Set(name, value.UTC().Format(ISO8601UTC))
|
||||
default:
|
||||
return fmt.Errorf("unsupported value for param %s: %v (%s)", name, r.Interface(), r.Type().Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+35
@@ -0,0 +1,35 @@
|
||||
package query
|
||||
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
|
||||
)
|
||||
|
||||
// UnmarshalHandler is a named request handler for unmarshaling query protocol requests
|
||||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.query.Unmarshal", Fn: Unmarshal}
|
||||
|
||||
// UnmarshalMetaHandler is a named request handler for unmarshaling query protocol request metadata
|
||||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalMeta", Fn: UnmarshalMeta}
|
||||
|
||||
// Unmarshal unmarshals a response for an AWS Query service.
|
||||
func Unmarshal(r *request.Request) {
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
if r.DataFilled() {
|
||||
decoder := xml.NewDecoder(r.HTTPResponse.Body)
|
||||
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result")
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed decoding Query response", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalMeta unmarshals header response values for an AWS Query service.
|
||||
func UnmarshalMeta(r *request.Request) {
|
||||
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
|
||||
}
|
||||
+66
@@ -0,0 +1,66 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
type xmlErrorResponse struct {
|
||||
XMLName xml.Name `xml:"ErrorResponse"`
|
||||
Code string `xml:"Error>Code"`
|
||||
Message string `xml:"Error>Message"`
|
||||
RequestID string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
type xmlServiceUnavailableResponse struct {
|
||||
XMLName xml.Name `xml:"ServiceUnavailableException"`
|
||||
}
|
||||
|
||||
// UnmarshalErrorHandler is a name request handler to unmarshal request errors
|
||||
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalError", Fn: UnmarshalError}
|
||||
|
||||
// UnmarshalError unmarshals an error response for an AWS Query service.
|
||||
func UnmarshalError(r *request.Request) {
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
|
||||
bodyBytes, err := ioutil.ReadAll(r.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to read from query HTTP response body", err)
|
||||
return
|
||||
}
|
||||
|
||||
// First check for specific error
|
||||
resp := xmlErrorResponse{}
|
||||
decodeErr := xml.Unmarshal(bodyBytes, &resp)
|
||||
if decodeErr == nil {
|
||||
reqID := resp.RequestID
|
||||
if reqID == "" {
|
||||
reqID = r.RequestID
|
||||
}
|
||||
r.Error = awserr.NewRequestFailure(
|
||||
awserr.New(resp.Code, resp.Message, nil),
|
||||
r.HTTPResponse.StatusCode,
|
||||
reqID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for unhandled error
|
||||
servUnavailResp := xmlServiceUnavailableResponse{}
|
||||
unavailErr := xml.Unmarshal(bodyBytes, &servUnavailResp)
|
||||
if unavailErr == nil {
|
||||
r.Error = awserr.NewRequestFailure(
|
||||
awserr.New("ServiceUnavailableException", "service is unavailable", nil),
|
||||
r.HTTPResponse.StatusCode,
|
||||
r.RequestID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Failed to retrieve any error message from the response body
|
||||
r.Error = awserr.New("SerializationError",
|
||||
"failed to decode query XML error response", decodeErr)
|
||||
}
|
||||
+2616
File diff suppressed because it is too large
Load Diff
+290
@@ -0,0 +1,290 @@
|
||||
// Package rest provides RESTful serialization of AWS requests and responses.
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
|
||||
const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT"
|
||||
|
||||
// Whether the byte value can be sent without escaping in AWS URLs
|
||||
var noEscape [256]bool
|
||||
|
||||
var errValueNotSet = fmt.Errorf("value not set")
|
||||
|
||||
func init() {
|
||||
for i := 0; i < len(noEscape); i++ {
|
||||
// AWS expects every character except these to be escaped
|
||||
noEscape[i] = (i >= 'A' && i <= 'Z') ||
|
||||
(i >= 'a' && i <= 'z') ||
|
||||
(i >= '0' && i <= '9') ||
|
||||
i == '-' ||
|
||||
i == '.' ||
|
||||
i == '_' ||
|
||||
i == '~'
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHandler is a named request handler for building rest protocol requests
|
||||
var BuildHandler = request.NamedHandler{Name: "awssdk.rest.Build", Fn: Build}
|
||||
|
||||
// Build builds the REST component of a service request.
|
||||
func Build(r *request.Request) {
|
||||
if r.ParamsFilled() {
|
||||
v := reflect.ValueOf(r.Params).Elem()
|
||||
buildLocationElements(r, v, false)
|
||||
buildBody(r, v)
|
||||
}
|
||||
}
|
||||
|
||||
// BuildAsGET builds the REST component of a service request with the ability to hoist
|
||||
// data from the body.
|
||||
func BuildAsGET(r *request.Request) {
|
||||
if r.ParamsFilled() {
|
||||
v := reflect.ValueOf(r.Params).Elem()
|
||||
buildLocationElements(r, v, true)
|
||||
buildBody(r, v)
|
||||
}
|
||||
}
|
||||
|
||||
func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bool) {
|
||||
query := r.HTTPRequest.URL.Query()
|
||||
|
||||
// Setup the raw path to match the base path pattern. This is needed
|
||||
// so that when the path is mutated a custom escaped version can be
|
||||
// stored in RawPath that will be used by the Go client.
|
||||
r.HTTPRequest.URL.RawPath = r.HTTPRequest.URL.Path
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
m := v.Field(i)
|
||||
if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if m.IsValid() {
|
||||
field := v.Type().Field(i)
|
||||
name := field.Tag.Get("locationName")
|
||||
if name == "" {
|
||||
name = field.Name
|
||||
}
|
||||
if kind := m.Kind(); kind == reflect.Ptr {
|
||||
m = m.Elem()
|
||||
} else if kind == reflect.Interface {
|
||||
if !m.Elem().IsValid() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !m.IsValid() {
|
||||
continue
|
||||
}
|
||||
if field.Tag.Get("ignore") != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
switch field.Tag.Get("location") {
|
||||
case "headers": // header maps
|
||||
err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag)
|
||||
case "header":
|
||||
err = buildHeader(&r.HTTPRequest.Header, m, name, field.Tag)
|
||||
case "uri":
|
||||
err = buildURI(r.HTTPRequest.URL, m, name, field.Tag)
|
||||
case "querystring":
|
||||
err = buildQueryString(query, m, name, field.Tag)
|
||||
default:
|
||||
if buildGETQuery {
|
||||
err = buildQueryString(query, m, name, field.Tag)
|
||||
}
|
||||
}
|
||||
r.Error = err
|
||||
}
|
||||
if r.Error != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r.HTTPRequest.URL.RawQuery = query.Encode()
|
||||
if !aws.BoolValue(r.Config.DisableRestProtocolURICleaning) {
|
||||
cleanPath(r.HTTPRequest.URL)
|
||||
}
|
||||
}
|
||||
|
||||
func buildBody(r *request.Request, v reflect.Value) {
|
||||
if field, ok := v.Type().FieldByName("_"); ok {
|
||||
if payloadName := field.Tag.Get("payload"); payloadName != "" {
|
||||
pfield, _ := v.Type().FieldByName(payloadName)
|
||||
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
|
||||
payload := reflect.Indirect(v.FieldByName(payloadName))
|
||||
if payload.IsValid() && payload.Interface() != nil {
|
||||
switch reader := payload.Interface().(type) {
|
||||
case io.ReadSeeker:
|
||||
r.SetReaderBody(reader)
|
||||
case []byte:
|
||||
r.SetBufferBody(reader)
|
||||
case string:
|
||||
r.SetStringBody(reader)
|
||||
default:
|
||||
r.Error = awserr.New("SerializationError",
|
||||
"failed to encode REST request",
|
||||
fmt.Errorf("unknown payload type %s", payload.Type()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildHeader(header *http.Header, v reflect.Value, name string, tag reflect.StructTag) error {
|
||||
str, err := convertType(v, tag)
|
||||
if err == errValueNotSet {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return awserr.New("SerializationError", "failed to encode REST request", err)
|
||||
}
|
||||
|
||||
header.Add(name, str)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag) error {
|
||||
prefix := tag.Get("locationName")
|
||||
for _, key := range v.MapKeys() {
|
||||
str, err := convertType(v.MapIndex(key), tag)
|
||||
if err == errValueNotSet {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return awserr.New("SerializationError", "failed to encode REST request", err)
|
||||
|
||||
}
|
||||
|
||||
header.Add(prefix+key.String(), str)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildURI(u *url.URL, v reflect.Value, name string, tag reflect.StructTag) error {
|
||||
value, err := convertType(v, tag)
|
||||
if err == errValueNotSet {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return awserr.New("SerializationError", "failed to encode REST request", err)
|
||||
}
|
||||
|
||||
u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1)
|
||||
u.Path = strings.Replace(u.Path, "{"+name+"+}", value, -1)
|
||||
|
||||
u.RawPath = strings.Replace(u.RawPath, "{"+name+"}", EscapePath(value, true), -1)
|
||||
u.RawPath = strings.Replace(u.RawPath, "{"+name+"+}", EscapePath(value, false), -1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildQueryString(query url.Values, v reflect.Value, name string, tag reflect.StructTag) error {
|
||||
switch value := v.Interface().(type) {
|
||||
case []*string:
|
||||
for _, item := range value {
|
||||
query.Add(name, *item)
|
||||
}
|
||||
case map[string]*string:
|
||||
for key, item := range value {
|
||||
query.Add(key, *item)
|
||||
}
|
||||
case map[string][]*string:
|
||||
for key, items := range value {
|
||||
for _, item := range items {
|
||||
query.Add(key, *item)
|
||||
}
|
||||
}
|
||||
default:
|
||||
str, err := convertType(v, tag)
|
||||
if err == errValueNotSet {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return awserr.New("SerializationError", "failed to encode REST request", err)
|
||||
}
|
||||
query.Set(name, str)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanPath(u *url.URL) {
|
||||
hasSlash := strings.HasSuffix(u.Path, "/")
|
||||
|
||||
// clean up path, removing duplicate `/`
|
||||
u.Path = path.Clean(u.Path)
|
||||
u.RawPath = path.Clean(u.RawPath)
|
||||
|
||||
if hasSlash && !strings.HasSuffix(u.Path, "/") {
|
||||
u.Path += "/"
|
||||
u.RawPath += "/"
|
||||
}
|
||||
}
|
||||
|
||||
// EscapePath escapes part of a URL path in Amazon style
|
||||
func EscapePath(path string, encodeSep bool) string {
|
||||
var buf bytes.Buffer
|
||||
for i := 0; i < len(path); i++ {
|
||||
c := path[i]
|
||||
if noEscape[c] || (c == '/' && !encodeSep) {
|
||||
buf.WriteByte(c)
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "%%%02X", c)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func convertType(v reflect.Value, tag reflect.StructTag) (string, error) {
|
||||
v = reflect.Indirect(v)
|
||||
if !v.IsValid() {
|
||||
return "", errValueNotSet
|
||||
}
|
||||
|
||||
var str string
|
||||
switch value := v.Interface().(type) {
|
||||
case string:
|
||||
str = value
|
||||
case []byte:
|
||||
str = base64.StdEncoding.EncodeToString(value)
|
||||
case bool:
|
||||
str = strconv.FormatBool(value)
|
||||
case int64:
|
||||
str = strconv.FormatInt(value, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(value, 'f', -1, 64)
|
||||
case time.Time:
|
||||
str = value.UTC().Format(RFC822)
|
||||
case aws.JSONValue:
|
||||
b, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if tag.Get("location") == "header" {
|
||||
str = base64.StdEncoding.EncodeToString(b)
|
||||
} else {
|
||||
str = string(b)
|
||||
}
|
||||
default:
|
||||
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
|
||||
return "", err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
+63
@@ -0,0 +1,63 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
func TestCleanPath(t *testing.T) {
|
||||
uri := &url.URL{
|
||||
Path: "//foo//bar",
|
||||
Scheme: "https",
|
||||
Host: "host",
|
||||
}
|
||||
cleanPath(uri)
|
||||
|
||||
expected := "https://host/foo/bar"
|
||||
if a, e := uri.String(), expected; a != e {
|
||||
t.Errorf("expect %q URI, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPath(t *testing.T) {
|
||||
in := struct {
|
||||
Bucket *string `location:"uri" locationName:"bucket"`
|
||||
Key *string `location:"uri" locationName:"key"`
|
||||
}{
|
||||
Bucket: aws.String("mybucket"),
|
||||
Key: aws.String("my/cool+thing space/object世界"),
|
||||
}
|
||||
|
||||
expectURL := `/mybucket/my/cool+thing space/object世界`
|
||||
expectEscapedURL := `/mybucket/my/cool%2Bthing%20space/object%E4%B8%96%E7%95%8C`
|
||||
|
||||
req := &request.Request{
|
||||
HTTPRequest: &http.Request{
|
||||
URL: &url.URL{Scheme: "https", Host: "exmaple.com", Path: "/{bucket}/{key+}"},
|
||||
},
|
||||
Params: &in,
|
||||
}
|
||||
|
||||
Build(req)
|
||||
|
||||
if req.Error != nil {
|
||||
t.Fatalf("unexpected error, %v", req.Error)
|
||||
}
|
||||
|
||||
if a, e := req.HTTPRequest.URL.Path, expectURL; a != e {
|
||||
t.Errorf("expect %q URI, got %q", e, a)
|
||||
}
|
||||
|
||||
if a, e := req.HTTPRequest.URL.RawPath, expectEscapedURL; a != e {
|
||||
t.Errorf("expect %q escaped URI, got %q", e, a)
|
||||
}
|
||||
|
||||
if a, e := req.HTTPRequest.URL.EscapedPath(), expectEscapedURL; a != e {
|
||||
t.Errorf("expect %q escaped URI, got %q", e, a)
|
||||
}
|
||||
|
||||
}
|
||||
+45
@@ -0,0 +1,45 @@
|
||||
package rest
|
||||
|
||||
import "reflect"
|
||||
|
||||
// PayloadMember returns the payload field member of i if there is one, or nil.
|
||||
func PayloadMember(i interface{}) interface{} {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(i).Elem()
|
||||
if !v.IsValid() {
|
||||
return nil
|
||||
}
|
||||
if field, ok := v.Type().FieldByName("_"); ok {
|
||||
if payloadName := field.Tag.Get("payload"); payloadName != "" {
|
||||
field, _ := v.Type().FieldByName(payloadName)
|
||||
if field.Tag.Get("type") != "structure" {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := v.FieldByName(payloadName)
|
||||
if payload.IsValid() || (payload.Kind() == reflect.Ptr && !payload.IsNil()) {
|
||||
return payload.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PayloadType returns the type of a payload field member of i if there is one, or "".
|
||||
func PayloadType(i interface{}) string {
|
||||
v := reflect.Indirect(reflect.ValueOf(i))
|
||||
if !v.IsValid() {
|
||||
return ""
|
||||
}
|
||||
if field, ok := v.Type().FieldByName("_"); ok {
|
||||
if payloadName := field.Tag.Get("payload"); payloadName != "" {
|
||||
if member, ok := v.Type().FieldByName(payloadName); ok {
|
||||
return member.Tag.Get("type")
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
+63
@@ -0,0 +1,63 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
)
|
||||
|
||||
func TestUnsetHeaders(t *testing.T) {
|
||||
cfg := &aws.Config{Region: aws.String("us-west-2")}
|
||||
c := unit.Session.ClientConfig("testService", cfg)
|
||||
svc := client.New(
|
||||
*cfg,
|
||||
metadata.ClientInfo{
|
||||
ServiceName: "testService",
|
||||
SigningName: c.SigningName,
|
||||
SigningRegion: c.SigningRegion,
|
||||
Endpoint: c.Endpoint,
|
||||
APIVersion: "",
|
||||
},
|
||||
c.Handlers,
|
||||
)
|
||||
|
||||
// Handlers
|
||||
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
|
||||
svc.Handlers.Build.PushBackNamed(rest.BuildHandler)
|
||||
svc.Handlers.Unmarshal.PushBackNamed(rest.UnmarshalHandler)
|
||||
svc.Handlers.UnmarshalMeta.PushBackNamed(rest.UnmarshalMetaHandler)
|
||||
op := &request.Operation{
|
||||
Name: "test-operation",
|
||||
HTTPPath: "/",
|
||||
}
|
||||
|
||||
input := &struct {
|
||||
Foo aws.JSONValue `location:"header" locationName:"x-amz-foo" type:"jsonvalue"`
|
||||
Bar aws.JSONValue `location:"header" locationName:"x-amz-bar" type:"jsonvalue"`
|
||||
}{}
|
||||
|
||||
output := &struct {
|
||||
Foo aws.JSONValue `location:"header" locationName:"x-amz-foo" type:"jsonvalue"`
|
||||
Bar aws.JSONValue `location:"header" locationName:"x-amz-bar" type:"jsonvalue"`
|
||||
}{}
|
||||
|
||||
req := svc.NewRequest(op, input, output)
|
||||
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(bytes.NewBuffer(nil)), Header: http.Header{}}
|
||||
req.HTTPResponse.Header.Set("X-Amz-Foo", "e30=")
|
||||
|
||||
// unmarshal response
|
||||
rest.UnmarshalMeta(req)
|
||||
rest.Unmarshal(req)
|
||||
if req.Error != nil {
|
||||
t.Fatal(req.Error)
|
||||
}
|
||||
}
|
||||
+227
@@ -0,0 +1,227 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
|
||||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
|
||||
|
||||
// UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
|
||||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
|
||||
|
||||
// Unmarshal unmarshals the REST component of a response in a REST service.
|
||||
func Unmarshal(r *request.Request) {
|
||||
if r.DataFilled() {
|
||||
v := reflect.Indirect(reflect.ValueOf(r.Data))
|
||||
unmarshalBody(r, v)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
|
||||
func UnmarshalMeta(r *request.Request) {
|
||||
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
|
||||
if r.RequestID == "" {
|
||||
// Alternative version of request id in the header
|
||||
r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
|
||||
}
|
||||
if r.DataFilled() {
|
||||
v := reflect.Indirect(reflect.ValueOf(r.Data))
|
||||
unmarshalLocationElements(r, v)
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalBody(r *request.Request, v reflect.Value) {
|
||||
if field, ok := v.Type().FieldByName("_"); ok {
|
||||
if payloadName := field.Tag.Get("payload"); payloadName != "" {
|
||||
pfield, _ := v.Type().FieldByName(payloadName)
|
||||
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
|
||||
payload := v.FieldByName(payloadName)
|
||||
if payload.IsValid() {
|
||||
switch payload.Interface().(type) {
|
||||
case []byte:
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
|
||||
} else {
|
||||
payload.Set(reflect.ValueOf(b))
|
||||
}
|
||||
case *string:
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
|
||||
} else {
|
||||
str := string(b)
|
||||
payload.Set(reflect.ValueOf(&str))
|
||||
}
|
||||
default:
|
||||
switch payload.Type().String() {
|
||||
case "io.ReadCloser":
|
||||
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
|
||||
case "io.ReadSeeker":
|
||||
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError",
|
||||
"failed to read response body", err)
|
||||
return
|
||||
}
|
||||
payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
|
||||
default:
|
||||
io.Copy(ioutil.Discard, r.HTTPResponse.Body)
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
r.Error = awserr.New("SerializationError",
|
||||
"failed to decode REST response",
|
||||
fmt.Errorf("unknown payload type %s", payload.Type()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalLocationElements(r *request.Request, v reflect.Value) {
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
m, field := v.Field(i), v.Type().Field(i)
|
||||
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if m.IsValid() {
|
||||
name := field.Tag.Get("locationName")
|
||||
if name == "" {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
switch field.Tag.Get("location") {
|
||||
case "statusCode":
|
||||
unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
|
||||
case "header":
|
||||
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name), field.Tag)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
|
||||
break
|
||||
}
|
||||
case "headers":
|
||||
prefix := field.Tag.Get("locationName")
|
||||
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.Error != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalStatusCode(v reflect.Value, statusCode int) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
switch v.Interface().(type) {
|
||||
case *int64:
|
||||
s := int64(statusCode)
|
||||
v.Set(reflect.ValueOf(&s))
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
|
||||
switch r.Interface().(type) {
|
||||
case map[string]*string: // we only support string map value types
|
||||
out := map[string]*string{}
|
||||
for k, v := range headers {
|
||||
k = http.CanonicalHeaderKey(k)
|
||||
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
|
||||
out[k[len(prefix):]] = &v[0]
|
||||
}
|
||||
}
|
||||
r.Set(reflect.ValueOf(out))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
|
||||
isJSONValue := tag.Get("type") == "jsonvalue"
|
||||
if isJSONValue {
|
||||
if len(header) == 0 {
|
||||
return nil
|
||||
}
|
||||
} else if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v.Interface().(type) {
|
||||
case *string:
|
||||
v.Set(reflect.ValueOf(&header))
|
||||
case []byte:
|
||||
b, err := base64.StdEncoding.DecodeString(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.ValueOf(&b))
|
||||
case *bool:
|
||||
b, err := strconv.ParseBool(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.ValueOf(&b))
|
||||
case *int64:
|
||||
i, err := strconv.ParseInt(header, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.ValueOf(&i))
|
||||
case *float64:
|
||||
f, err := strconv.ParseFloat(header, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.ValueOf(&f))
|
||||
case *time.Time:
|
||||
t, err := time.Parse(RFC822, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.ValueOf(&t))
|
||||
case aws.JSONValue:
|
||||
b := []byte(header)
|
||||
var err error
|
||||
if tag.Get("location") == "header" {
|
||||
b, err = base64.StdEncoding.DecodeString(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m := aws.JSONValue{}
|
||||
err = json.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Set(reflect.ValueOf(m))
|
||||
default:
|
||||
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+356
@@ -0,0 +1,356 @@
|
||||
// +build bench
|
||||
|
||||
package restjson_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/restjson"
|
||||
"github.com/aws/aws-sdk-go/service/elastictranscoder"
|
||||
)
|
||||
|
||||
func BenchmarkRESTJSONBuild_Complex_elastictranscoderCreateJobInput(b *testing.B) {
|
||||
svc := awstesting.NewClient()
|
||||
svc.ServiceName = "elastictranscoder"
|
||||
svc.APIVersion = "2012-09-25"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(&request.Operation{Name: "CreateJobInput"}, restjsonBuildParms, nil)
|
||||
restjson.Build(r)
|
||||
if r.Error != nil {
|
||||
b.Fatal("Unexpected error", r.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRESTBuild_Complex_elastictranscoderCreateJobInput(b *testing.B) {
|
||||
svc := awstesting.NewClient()
|
||||
svc.ServiceName = "elastictranscoder"
|
||||
svc.APIVersion = "2012-09-25"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(&request.Operation{Name: "CreateJobInput"}, restjsonBuildParms, nil)
|
||||
rest.Build(r)
|
||||
if r.Error != nil {
|
||||
b.Fatal("Unexpected error", r.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodingJSONMarshal_Complex_elastictranscoderCreateJobInput(b *testing.B) {
|
||||
params := restjsonBuildParms
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := &bytes.Buffer{}
|
||||
encoder := json.NewEncoder(buf)
|
||||
if err := encoder.Encode(params); err != nil {
|
||||
b.Fatal("Unexpected error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRESTJSONBuild_Simple_elastictranscoderListJobsByPipeline(b *testing.B) {
|
||||
svc := awstesting.NewClient()
|
||||
svc.ServiceName = "elastictranscoder"
|
||||
svc.APIVersion = "2012-09-25"
|
||||
|
||||
params := &elastictranscoder.ListJobsByPipelineInput{
|
||||
PipelineId: aws.String("Id"), // Required
|
||||
Ascending: aws.String("Ascending"),
|
||||
PageToken: aws.String("Id"),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(&request.Operation{Name: "ListJobsByPipeline"}, params, nil)
|
||||
restjson.Build(r)
|
||||
if r.Error != nil {
|
||||
b.Fatal("Unexpected error", r.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRESTBuild_Simple_elastictranscoderListJobsByPipeline(b *testing.B) {
|
||||
svc := awstesting.NewClient()
|
||||
svc.ServiceName = "elastictranscoder"
|
||||
svc.APIVersion = "2012-09-25"
|
||||
|
||||
params := &elastictranscoder.ListJobsByPipelineInput{
|
||||
PipelineId: aws.String("Id"), // Required
|
||||
Ascending: aws.String("Ascending"),
|
||||
PageToken: aws.String("Id"),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(&request.Operation{Name: "ListJobsByPipeline"}, params, nil)
|
||||
rest.Build(r)
|
||||
if r.Error != nil {
|
||||
b.Fatal("Unexpected error", r.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodingJSONMarshal_Simple_elastictranscoderListJobsByPipeline(b *testing.B) {
|
||||
params := &elastictranscoder.ListJobsByPipelineInput{
|
||||
PipelineId: aws.String("Id"), // Required
|
||||
Ascending: aws.String("Ascending"),
|
||||
PageToken: aws.String("Id"),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := &bytes.Buffer{}
|
||||
encoder := json.NewEncoder(buf)
|
||||
if err := encoder.Encode(params); err != nil {
|
||||
b.Fatal("Unexpected error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var restjsonBuildParms = &elastictranscoder.CreateJobInput{
|
||||
Input: &elastictranscoder.JobInput{ // Required
|
||||
AspectRatio: aws.String("AspectRatio"),
|
||||
Container: aws.String("JobContainer"),
|
||||
DetectedProperties: &elastictranscoder.DetectedProperties{
|
||||
DurationMillis: aws.Int64(1),
|
||||
FileSize: aws.Int64(1),
|
||||
FrameRate: aws.String("FloatString"),
|
||||
Height: aws.Int64(1),
|
||||
Width: aws.Int64(1),
|
||||
},
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
FrameRate: aws.String("FrameRate"),
|
||||
Interlaced: aws.String("Interlaced"),
|
||||
Key: aws.String("Key"),
|
||||
Resolution: aws.String("Resolution"),
|
||||
},
|
||||
PipelineId: aws.String("Id"), // Required
|
||||
Output: &elastictranscoder.CreateJobOutput{
|
||||
AlbumArt: &elastictranscoder.JobAlbumArt{
|
||||
Artwork: []*elastictranscoder.Artwork{
|
||||
{ // Required
|
||||
AlbumArtFormat: aws.String("JpgOrPng"),
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
InputKey: aws.String("WatermarkKey"),
|
||||
MaxHeight: aws.String("DigitsOrAuto"),
|
||||
MaxWidth: aws.String("DigitsOrAuto"),
|
||||
PaddingPolicy: aws.String("PaddingPolicy"),
|
||||
SizingPolicy: aws.String("SizingPolicy"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
MergePolicy: aws.String("MergePolicy"),
|
||||
},
|
||||
Captions: &elastictranscoder.Captions{
|
||||
CaptionFormats: []*elastictranscoder.CaptionFormat{
|
||||
{ // Required
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
Format: aws.String("CaptionFormatFormat"),
|
||||
Pattern: aws.String("CaptionFormatPattern"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
CaptionSources: []*elastictranscoder.CaptionSource{
|
||||
{ // Required
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
Key: aws.String("Key"),
|
||||
Label: aws.String("Name"),
|
||||
Language: aws.String("Key"),
|
||||
TimeOffset: aws.String("TimeOffset"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
MergePolicy: aws.String("CaptionMergePolicy"),
|
||||
},
|
||||
Composition: []*elastictranscoder.Clip{
|
||||
{ // Required
|
||||
TimeSpan: &elastictranscoder.TimeSpan{
|
||||
Duration: aws.String("Time"),
|
||||
StartTime: aws.String("Time"),
|
||||
},
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
Key: aws.String("Key"),
|
||||
PresetId: aws.String("Id"),
|
||||
Rotate: aws.String("Rotate"),
|
||||
SegmentDuration: aws.String("FloatString"),
|
||||
ThumbnailEncryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
ThumbnailPattern: aws.String("ThumbnailPattern"),
|
||||
Watermarks: []*elastictranscoder.JobWatermark{
|
||||
{ // Required
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
InputKey: aws.String("WatermarkKey"),
|
||||
PresetWatermarkId: aws.String("PresetWatermarkId"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
OutputKeyPrefix: aws.String("Key"),
|
||||
Outputs: []*elastictranscoder.CreateJobOutput{
|
||||
{ // Required
|
||||
AlbumArt: &elastictranscoder.JobAlbumArt{
|
||||
Artwork: []*elastictranscoder.Artwork{
|
||||
{ // Required
|
||||
AlbumArtFormat: aws.String("JpgOrPng"),
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
InputKey: aws.String("WatermarkKey"),
|
||||
MaxHeight: aws.String("DigitsOrAuto"),
|
||||
MaxWidth: aws.String("DigitsOrAuto"),
|
||||
PaddingPolicy: aws.String("PaddingPolicy"),
|
||||
SizingPolicy: aws.String("SizingPolicy"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
MergePolicy: aws.String("MergePolicy"),
|
||||
},
|
||||
Captions: &elastictranscoder.Captions{
|
||||
CaptionFormats: []*elastictranscoder.CaptionFormat{
|
||||
{ // Required
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
Format: aws.String("CaptionFormatFormat"),
|
||||
Pattern: aws.String("CaptionFormatPattern"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
CaptionSources: []*elastictranscoder.CaptionSource{
|
||||
{ // Required
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
Key: aws.String("Key"),
|
||||
Label: aws.String("Name"),
|
||||
Language: aws.String("Key"),
|
||||
TimeOffset: aws.String("TimeOffset"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
MergePolicy: aws.String("CaptionMergePolicy"),
|
||||
},
|
||||
Composition: []*elastictranscoder.Clip{
|
||||
{ // Required
|
||||
TimeSpan: &elastictranscoder.TimeSpan{
|
||||
Duration: aws.String("Time"),
|
||||
StartTime: aws.String("Time"),
|
||||
},
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
Key: aws.String("Key"),
|
||||
PresetId: aws.String("Id"),
|
||||
Rotate: aws.String("Rotate"),
|
||||
SegmentDuration: aws.String("FloatString"),
|
||||
ThumbnailEncryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
ThumbnailPattern: aws.String("ThumbnailPattern"),
|
||||
Watermarks: []*elastictranscoder.JobWatermark{
|
||||
{ // Required
|
||||
Encryption: &elastictranscoder.Encryption{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
Mode: aws.String("EncryptionMode"),
|
||||
},
|
||||
InputKey: aws.String("WatermarkKey"),
|
||||
PresetWatermarkId: aws.String("PresetWatermarkId"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
Playlists: []*elastictranscoder.CreateJobPlaylist{
|
||||
{ // Required
|
||||
Format: aws.String("PlaylistFormat"),
|
||||
HlsContentProtection: &elastictranscoder.HlsContentProtection{
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("Base64EncodedString"),
|
||||
KeyMd5: aws.String("Base64EncodedString"),
|
||||
KeyStoragePolicy: aws.String("KeyStoragePolicy"),
|
||||
LicenseAcquisitionUrl: aws.String("ZeroTo512String"),
|
||||
Method: aws.String("HlsContentProtectionMethod"),
|
||||
},
|
||||
Name: aws.String("Filename"),
|
||||
OutputKeys: []*string{
|
||||
aws.String("Key"), // Required
|
||||
// More values...
|
||||
},
|
||||
PlayReadyDrm: &elastictranscoder.PlayReadyDrm{
|
||||
Format: aws.String("PlayReadyDrmFormatString"),
|
||||
InitializationVector: aws.String("ZeroTo255String"),
|
||||
Key: aws.String("NonEmptyBase64EncodedString"),
|
||||
KeyId: aws.String("KeyIdGuid"),
|
||||
KeyMd5: aws.String("NonEmptyBase64EncodedString"),
|
||||
LicenseAcquisitionUrl: aws.String("OneTo512String"),
|
||||
},
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
UserMetadata: map[string]*string{
|
||||
"Key": aws.String("String"), // Required
|
||||
// More values...
|
||||
},
|
||||
}
|
||||
+4976
File diff suppressed because it is too large
Load Diff
+92
@@ -0,0 +1,92 @@
|
||||
// Package restjson provides RESTful JSON serialization of AWS
|
||||
// requests and responses.
|
||||
package restjson
|
||||
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-json.json build_test.go
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-json.json unmarshal_test.go
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
)
|
||||
|
||||
// BuildHandler is a named request handler for building restjson protocol requests
|
||||
var BuildHandler = request.NamedHandler{Name: "awssdk.restjson.Build", Fn: Build}
|
||||
|
||||
// UnmarshalHandler is a named request handler for unmarshaling restjson protocol requests
|
||||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.restjson.Unmarshal", Fn: Unmarshal}
|
||||
|
||||
// UnmarshalMetaHandler is a named request handler for unmarshaling restjson protocol request metadata
|
||||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.restjson.UnmarshalMeta", Fn: UnmarshalMeta}
|
||||
|
||||
// UnmarshalErrorHandler is a named request handler for unmarshaling restjson protocol request errors
|
||||
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.restjson.UnmarshalError", Fn: UnmarshalError}
|
||||
|
||||
// Build builds a request for the REST JSON protocol.
|
||||
func Build(r *request.Request) {
|
||||
rest.Build(r)
|
||||
|
||||
if t := rest.PayloadType(r.Params); t == "structure" || t == "" {
|
||||
jsonrpc.Build(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Unmarshal unmarshals a response body for the REST JSON protocol.
|
||||
func Unmarshal(r *request.Request) {
|
||||
if t := rest.PayloadType(r.Data); t == "structure" || t == "" {
|
||||
jsonrpc.Unmarshal(r)
|
||||
} else {
|
||||
rest.Unmarshal(r)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalMeta unmarshals response headers for the REST JSON protocol.
|
||||
func UnmarshalMeta(r *request.Request) {
|
||||
rest.UnmarshalMeta(r)
|
||||
}
|
||||
|
||||
// UnmarshalError unmarshals a response error for the REST JSON protocol.
|
||||
func UnmarshalError(r *request.Request) {
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
code := r.HTTPResponse.Header.Get("X-Amzn-Errortype")
|
||||
bodyBytes, err := ioutil.ReadAll(r.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed reading REST JSON error response", err)
|
||||
return
|
||||
}
|
||||
if len(bodyBytes) == 0 {
|
||||
r.Error = awserr.NewRequestFailure(
|
||||
awserr.New("SerializationError", r.HTTPResponse.Status, nil),
|
||||
r.HTTPResponse.StatusCode,
|
||||
"",
|
||||
)
|
||||
return
|
||||
}
|
||||
var jsonErr jsonErrorResponse
|
||||
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed decoding REST JSON error response", err)
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
code = jsonErr.Code
|
||||
}
|
||||
|
||||
code = strings.SplitN(code, ":", 2)[0]
|
||||
r.Error = awserr.NewRequestFailure(
|
||||
awserr.New(code, jsonErr.Message, nil),
|
||||
r.HTTPResponse.StatusCode,
|
||||
r.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
type jsonErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
+2160
File diff suppressed because it is too large
Load Diff
+246
@@ -0,0 +1,246 @@
|
||||
// +build bench
|
||||
|
||||
package restxml_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/restxml"
|
||||
"github.com/aws/aws-sdk-go/service/cloudfront"
|
||||
)
|
||||
|
||||
func BenchmarkRESTXMLBuild_Complex_cloudfrontCreateDistribution(b *testing.B) {
|
||||
params := restxmlBuildCreateDistroParms
|
||||
|
||||
op := &request.Operation{
|
||||
Name: "CreateDistribution",
|
||||
HTTPMethod: "POST",
|
||||
HTTPPath: "/2015-04-17/distribution/{DistributionId}/invalidation",
|
||||
}
|
||||
|
||||
benchRESTXMLBuild(b, op, params)
|
||||
}
|
||||
|
||||
func BenchmarkRESTXMLBuild_Simple_cloudfrontDeleteStreamingDistribution(b *testing.B) {
|
||||
params := &cloudfront.DeleteDistributionInput{
|
||||
Id: aws.String("string"), // Required
|
||||
IfMatch: aws.String("string"),
|
||||
}
|
||||
op := &request.Operation{
|
||||
Name: "DeleteStreamingDistribution",
|
||||
HTTPMethod: "DELETE",
|
||||
HTTPPath: "/2015-04-17/streaming-distribution/{Id}",
|
||||
}
|
||||
benchRESTXMLBuild(b, op, params)
|
||||
}
|
||||
|
||||
func BenchmarkEncodingXMLMarshal_Simple_cloudfrontDeleteStreamingDistribution(b *testing.B) {
|
||||
params := &cloudfront.DeleteDistributionInput{
|
||||
Id: aws.String("string"), // Required
|
||||
IfMatch: aws.String("string"),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := &bytes.Buffer{}
|
||||
encoder := xml.NewEncoder(buf)
|
||||
if err := encoder.Encode(params); err != nil {
|
||||
b.Fatal("Unexpected error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchRESTXMLBuild(b *testing.B, op *request.Operation, params interface{}) {
|
||||
svc := awstesting.NewClient()
|
||||
svc.ServiceName = "cloudfront"
|
||||
svc.APIVersion = "2015-04-17"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := svc.NewRequest(op, params, nil)
|
||||
restxml.Build(r)
|
||||
if r.Error != nil {
|
||||
b.Fatal("Unexpected error", r.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var restxmlBuildCreateDistroParms = &cloudfront.CreateDistributionInput{
|
||||
DistributionConfig: &cloudfront.DistributionConfig{ // Required
|
||||
CallerReference: aws.String("string"), // Required
|
||||
Comment: aws.String("string"), // Required
|
||||
DefaultCacheBehavior: &cloudfront.DefaultCacheBehavior{ // Required
|
||||
ForwardedValues: &cloudfront.ForwardedValues{ // Required
|
||||
Cookies: &cloudfront.CookiePreference{ // Required
|
||||
Forward: aws.String("ItemSelection"), // Required
|
||||
WhitelistedNames: &cloudfront.CookieNames{
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
},
|
||||
QueryString: aws.Bool(true), // Required
|
||||
Headers: &cloudfront.Headers{
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
},
|
||||
MinTTL: aws.Int64(1), // Required
|
||||
TargetOriginId: aws.String("string"), // Required
|
||||
TrustedSigners: &cloudfront.TrustedSigners{ // Required
|
||||
Enabled: aws.Bool(true), // Required
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
ViewerProtocolPolicy: aws.String("ViewerProtocolPolicy"), // Required
|
||||
AllowedMethods: &cloudfront.AllowedMethods{
|
||||
Items: []*string{ // Required
|
||||
aws.String("Method"), // Required
|
||||
// More values...
|
||||
},
|
||||
Quantity: aws.Int64(1), // Required
|
||||
CachedMethods: &cloudfront.CachedMethods{
|
||||
Items: []*string{ // Required
|
||||
aws.String("Method"), // Required
|
||||
// More values...
|
||||
},
|
||||
Quantity: aws.Int64(1), // Required
|
||||
},
|
||||
},
|
||||
DefaultTTL: aws.Int64(1),
|
||||
MaxTTL: aws.Int64(1),
|
||||
SmoothStreaming: aws.Bool(true),
|
||||
},
|
||||
Enabled: aws.Bool(true), // Required
|
||||
Origins: &cloudfront.Origins{ // Required
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*cloudfront.Origin{
|
||||
{ // Required
|
||||
DomainName: aws.String("string"), // Required
|
||||
Id: aws.String("string"), // Required
|
||||
CustomOriginConfig: &cloudfront.CustomOriginConfig{
|
||||
HTTPPort: aws.Int64(1), // Required
|
||||
HTTPSPort: aws.Int64(1), // Required
|
||||
OriginProtocolPolicy: aws.String("OriginProtocolPolicy"), // Required
|
||||
},
|
||||
OriginPath: aws.String("string"),
|
||||
S3OriginConfig: &cloudfront.S3OriginConfig{
|
||||
OriginAccessIdentity: aws.String("string"), // Required
|
||||
},
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
Aliases: &cloudfront.Aliases{
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
CacheBehaviors: &cloudfront.CacheBehaviors{
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*cloudfront.CacheBehavior{
|
||||
{ // Required
|
||||
ForwardedValues: &cloudfront.ForwardedValues{ // Required
|
||||
Cookies: &cloudfront.CookiePreference{ // Required
|
||||
Forward: aws.String("ItemSelection"), // Required
|
||||
WhitelistedNames: &cloudfront.CookieNames{
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
},
|
||||
QueryString: aws.Bool(true), // Required
|
||||
Headers: &cloudfront.Headers{
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
},
|
||||
MinTTL: aws.Int64(1), // Required
|
||||
PathPattern: aws.String("string"), // Required
|
||||
TargetOriginId: aws.String("string"), // Required
|
||||
TrustedSigners: &cloudfront.TrustedSigners{ // Required
|
||||
Enabled: aws.Bool(true), // Required
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
ViewerProtocolPolicy: aws.String("ViewerProtocolPolicy"), // Required
|
||||
AllowedMethods: &cloudfront.AllowedMethods{
|
||||
Items: []*string{ // Required
|
||||
aws.String("Method"), // Required
|
||||
// More values...
|
||||
},
|
||||
Quantity: aws.Int64(1), // Required
|
||||
CachedMethods: &cloudfront.CachedMethods{
|
||||
Items: []*string{ // Required
|
||||
aws.String("Method"), // Required
|
||||
// More values...
|
||||
},
|
||||
Quantity: aws.Int64(1), // Required
|
||||
},
|
||||
},
|
||||
DefaultTTL: aws.Int64(1),
|
||||
MaxTTL: aws.Int64(1),
|
||||
SmoothStreaming: aws.Bool(true),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
CustomErrorResponses: &cloudfront.CustomErrorResponses{
|
||||
Quantity: aws.Int64(1), // Required
|
||||
Items: []*cloudfront.CustomErrorResponse{
|
||||
{ // Required
|
||||
ErrorCode: aws.Int64(1), // Required
|
||||
ErrorCachingMinTTL: aws.Int64(1),
|
||||
ResponseCode: aws.String("string"),
|
||||
ResponsePagePath: aws.String("string"),
|
||||
},
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
DefaultRootObject: aws.String("string"),
|
||||
Logging: &cloudfront.LoggingConfig{
|
||||
Bucket: aws.String("string"), // Required
|
||||
Enabled: aws.Bool(true), // Required
|
||||
IncludeCookies: aws.Bool(true), // Required
|
||||
Prefix: aws.String("string"), // Required
|
||||
},
|
||||
PriceClass: aws.String("PriceClass"),
|
||||
Restrictions: &cloudfront.Restrictions{
|
||||
GeoRestriction: &cloudfront.GeoRestriction{ // Required
|
||||
Quantity: aws.Int64(1), // Required
|
||||
RestrictionType: aws.String("GeoRestrictionType"), // Required
|
||||
Items: []*string{
|
||||
aws.String("string"), // Required
|
||||
// More values...
|
||||
},
|
||||
},
|
||||
},
|
||||
ViewerCertificate: &cloudfront.ViewerCertificate{
|
||||
CloudFrontDefaultCertificate: aws.Bool(true),
|
||||
IAMCertificateId: aws.String("string"),
|
||||
MinimumProtocolVersion: aws.String("MinimumProtocolVersion"),
|
||||
SSLSupportMethod: aws.String("SSLSupportMethod"),
|
||||
},
|
||||
},
|
||||
}
|
||||
+5821
File diff suppressed because it is too large
Load Diff
+69
@@ -0,0 +1,69 @@
|
||||
// Package restxml provides RESTful XML serialization of AWS
|
||||
// requests and responses.
|
||||
package restxml
|
||||
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-xml.json build_test.go
|
||||
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-xml.json unmarshal_test.go
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/query"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
|
||||
)
|
||||
|
||||
// BuildHandler is a named request handler for building restxml protocol requests
|
||||
var BuildHandler = request.NamedHandler{Name: "awssdk.restxml.Build", Fn: Build}
|
||||
|
||||
// UnmarshalHandler is a named request handler for unmarshaling restxml protocol requests
|
||||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.restxml.Unmarshal", Fn: Unmarshal}
|
||||
|
||||
// UnmarshalMetaHandler is a named request handler for unmarshaling restxml protocol request metadata
|
||||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.restxml.UnmarshalMeta", Fn: UnmarshalMeta}
|
||||
|
||||
// UnmarshalErrorHandler is a named request handler for unmarshaling restxml protocol request errors
|
||||
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.restxml.UnmarshalError", Fn: UnmarshalError}
|
||||
|
||||
// Build builds a request payload for the REST XML protocol.
|
||||
func Build(r *request.Request) {
|
||||
rest.Build(r)
|
||||
|
||||
if t := rest.PayloadType(r.Params); t == "structure" || t == "" {
|
||||
var buf bytes.Buffer
|
||||
err := xmlutil.BuildXML(r.Params, xml.NewEncoder(&buf))
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to encode rest XML request", err)
|
||||
return
|
||||
}
|
||||
r.SetBufferBody(buf.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
// Unmarshal unmarshals a payload response for the REST XML protocol.
|
||||
func Unmarshal(r *request.Request) {
|
||||
if t := rest.PayloadType(r.Data); t == "structure" || t == "" {
|
||||
defer r.HTTPResponse.Body.Close()
|
||||
decoder := xml.NewDecoder(r.HTTPResponse.Body)
|
||||
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
|
||||
if err != nil {
|
||||
r.Error = awserr.New("SerializationError", "failed to decode REST XML response", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
rest.Unmarshal(r)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalMeta unmarshals response headers for the REST XML protocol.
|
||||
func UnmarshalMeta(r *request.Request) {
|
||||
rest.UnmarshalMeta(r)
|
||||
}
|
||||
|
||||
// UnmarshalError unmarshals a response error for the REST XML protocol.
|
||||
func UnmarshalError(r *request.Request) {
|
||||
query.UnmarshalError(r)
|
||||
}
|
||||
+2289
File diff suppressed because it is too large
Load Diff
+21
@@ -0,0 +1,21 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// UnmarshalDiscardBodyHandler is a named request handler to empty and close a response's body
|
||||
var UnmarshalDiscardBodyHandler = request.NamedHandler{Name: "awssdk.shared.UnmarshalDiscardBody", Fn: UnmarshalDiscardBody}
|
||||
|
||||
// UnmarshalDiscardBody is a request handler to empty a response's body and closing it.
|
||||
func UnmarshalDiscardBody(r *request.Request) {
|
||||
if r.HTTPResponse == nil || r.HTTPResponse.Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
io.Copy(ioutil.Discard, r.HTTPResponse.Body)
|
||||
r.HTTPResponse.Body.Close()
|
||||
}
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
package protocol_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockCloser struct {
|
||||
*strings.Reader
|
||||
Closed bool
|
||||
}
|
||||
|
||||
func (m *mockCloser) Close() error {
|
||||
m.Closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUnmarshalDrainBody(t *testing.T) {
|
||||
b := &mockCloser{Reader: strings.NewReader("example body")}
|
||||
r := &request.Request{HTTPResponse: &http.Response{
|
||||
Body: b,
|
||||
}}
|
||||
|
||||
protocol.UnmarshalDiscardBody(r)
|
||||
assert.NoError(t, r.Error)
|
||||
assert.Equal(t, 0, b.Len())
|
||||
assert.True(t, b.Closed)
|
||||
}
|
||||
|
||||
func TestUnmarshalDrainBodyNoBody(t *testing.T) {
|
||||
r := &request.Request{HTTPResponse: &http.Response{}}
|
||||
|
||||
protocol.UnmarshalDiscardBody(r)
|
||||
assert.NoError(t, r.Error)
|
||||
}
|
||||
+297
@@ -0,0 +1,297 @@
|
||||
// Package xmlutil provides XML serialization of AWS requests and responses.
|
||||
package xmlutil
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol"
|
||||
)
|
||||
|
||||
// BuildXML will serialize params into an xml.Encoder.
|
||||
// Error will be returned if the serialization of any of the params or nested values fails.
|
||||
func BuildXML(params interface{}, e *xml.Encoder) error {
|
||||
b := xmlBuilder{encoder: e, namespaces: map[string]string{}}
|
||||
root := NewXMLElement(xml.Name{})
|
||||
if err := b.buildValue(reflect.ValueOf(params), root, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range root.Children {
|
||||
for _, v := range c {
|
||||
return StructToXML(e, v, false)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the reflection element of a value, if it is a pointer.
|
||||
func elemOf(value reflect.Value) reflect.Value {
|
||||
for value.Kind() == reflect.Ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// A xmlBuilder serializes values from Go code to XML
|
||||
type xmlBuilder struct {
|
||||
encoder *xml.Encoder
|
||||
namespaces map[string]string
|
||||
}
|
||||
|
||||
// buildValue generic XMLNode builder for any type. Will build value for their specific type
|
||||
// struct, list, map, scalar.
|
||||
//
|
||||
// Also takes a "type" tag value to set what type a value should be converted to XMLNode as. If
|
||||
// type is not provided reflect will be used to determine the value's type.
|
||||
func (b *xmlBuilder) buildValue(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
|
||||
value = elemOf(value)
|
||||
if !value.IsValid() { // no need to handle zero values
|
||||
return nil
|
||||
} else if tag.Get("location") != "" { // don't handle non-body location values
|
||||
return nil
|
||||
}
|
||||
|
||||
t := tag.Get("type")
|
||||
if t == "" {
|
||||
switch value.Kind() {
|
||||
case reflect.Struct:
|
||||
t = "structure"
|
||||
case reflect.Slice:
|
||||
t = "list"
|
||||
case reflect.Map:
|
||||
t = "map"
|
||||
}
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "structure":
|
||||
if field, ok := value.Type().FieldByName("_"); ok {
|
||||
tag = tag + reflect.StructTag(" ") + field.Tag
|
||||
}
|
||||
return b.buildStruct(value, current, tag)
|
||||
case "list":
|
||||
return b.buildList(value, current, tag)
|
||||
case "map":
|
||||
return b.buildMap(value, current, tag)
|
||||
default:
|
||||
return b.buildScalar(value, current, tag)
|
||||
}
|
||||
}
|
||||
|
||||
// buildStruct adds a struct and its fields to the current XMLNode. All fields any any nested
|
||||
// types are converted to XMLNodes also.
|
||||
func (b *xmlBuilder) buildStruct(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
fieldAdded := false
|
||||
|
||||
// unwrap payloads
|
||||
if payload := tag.Get("payload"); payload != "" {
|
||||
field, _ := value.Type().FieldByName(payload)
|
||||
tag = field.Tag
|
||||
value = elemOf(value.FieldByName(payload))
|
||||
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
child := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
|
||||
|
||||
// there is an xmlNamespace associated with this struct
|
||||
if prefix, uri := tag.Get("xmlPrefix"), tag.Get("xmlURI"); uri != "" {
|
||||
ns := xml.Attr{
|
||||
Name: xml.Name{Local: "xmlns"},
|
||||
Value: uri,
|
||||
}
|
||||
if prefix != "" {
|
||||
b.namespaces[prefix] = uri // register the namespace
|
||||
ns.Name.Local = "xmlns:" + prefix
|
||||
}
|
||||
|
||||
child.Attr = append(child.Attr, ns)
|
||||
}
|
||||
|
||||
t := value.Type()
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
member := elemOf(value.Field(i))
|
||||
field := t.Field(i)
|
||||
|
||||
if field.PkgPath != "" {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
if field.Tag.Get("ignore") != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
mTag := field.Tag
|
||||
if mTag.Get("location") != "" { // skip non-body members
|
||||
continue
|
||||
}
|
||||
|
||||
if protocol.CanSetIdempotencyToken(value.Field(i), field) {
|
||||
token := protocol.GetIdempotencyToken()
|
||||
member = reflect.ValueOf(token)
|
||||
}
|
||||
|
||||
memberName := mTag.Get("locationName")
|
||||
if memberName == "" {
|
||||
memberName = field.Name
|
||||
mTag = reflect.StructTag(string(mTag) + ` locationName:"` + memberName + `"`)
|
||||
}
|
||||
if err := b.buildValue(member, child, mTag); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldAdded = true
|
||||
}
|
||||
|
||||
if fieldAdded { // only append this child if we have one ore more valid members
|
||||
current.AddChild(child)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildList adds the value's list items to the current XMLNode as children nodes. All
|
||||
// nested values in the list are converted to XMLNodes also.
|
||||
func (b *xmlBuilder) buildList(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
|
||||
if value.IsNil() { // don't build omitted lists
|
||||
return nil
|
||||
}
|
||||
|
||||
// check for unflattened list member
|
||||
flattened := tag.Get("flattened") != ""
|
||||
|
||||
xname := xml.Name{Local: tag.Get("locationName")}
|
||||
if flattened {
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
child := NewXMLElement(xname)
|
||||
current.AddChild(child)
|
||||
if err := b.buildValue(value.Index(i), child, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
list := NewXMLElement(xname)
|
||||
current.AddChild(list)
|
||||
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
iname := tag.Get("locationNameList")
|
||||
if iname == "" {
|
||||
iname = "member"
|
||||
}
|
||||
|
||||
child := NewXMLElement(xml.Name{Local: iname})
|
||||
list.AddChild(child)
|
||||
if err := b.buildValue(value.Index(i), child, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildMap adds the value's key/value pairs to the current XMLNode as children nodes. All
|
||||
// nested values in the map are converted to XMLNodes also.
|
||||
//
|
||||
// Error will be returned if it is unable to build the map's values into XMLNodes
|
||||
func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
|
||||
if value.IsNil() { // don't build omitted maps
|
||||
return nil
|
||||
}
|
||||
|
||||
maproot := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
|
||||
current.AddChild(maproot)
|
||||
current = maproot
|
||||
|
||||
kname, vname := "key", "value"
|
||||
if n := tag.Get("locationNameKey"); n != "" {
|
||||
kname = n
|
||||
}
|
||||
if n := tag.Get("locationNameValue"); n != "" {
|
||||
vname = n
|
||||
}
|
||||
|
||||
// sorting is not required for compliance, but it makes testing easier
|
||||
keys := make([]string, value.Len())
|
||||
for i, k := range value.MapKeys() {
|
||||
keys[i] = k.String()
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
v := value.MapIndex(reflect.ValueOf(k))
|
||||
|
||||
mapcur := current
|
||||
if tag.Get("flattened") == "" { // add "entry" tag to non-flat maps
|
||||
child := NewXMLElement(xml.Name{Local: "entry"})
|
||||
mapcur.AddChild(child)
|
||||
mapcur = child
|
||||
}
|
||||
|
||||
kchild := NewXMLElement(xml.Name{Local: kname})
|
||||
kchild.Text = k
|
||||
vchild := NewXMLElement(xml.Name{Local: vname})
|
||||
mapcur.AddChild(kchild)
|
||||
mapcur.AddChild(vchild)
|
||||
|
||||
if err := b.buildValue(v, vchild, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildScalar will convert the value into a string and append it as a attribute or child
|
||||
// of the current XMLNode.
|
||||
//
|
||||
// The value will be added as an attribute if tag contains a "xmlAttribute" attribute value.
|
||||
//
|
||||
// Error will be returned if the value type is unsupported.
|
||||
func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
|
||||
var str string
|
||||
switch converted := value.Interface().(type) {
|
||||
case string:
|
||||
str = converted
|
||||
case []byte:
|
||||
if !value.IsNil() {
|
||||
str = base64.StdEncoding.EncodeToString(converted)
|
||||
}
|
||||
case bool:
|
||||
str = strconv.FormatBool(converted)
|
||||
case int64:
|
||||
str = strconv.FormatInt(converted, 10)
|
||||
case int:
|
||||
str = strconv.Itoa(converted)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(converted, 'f', -1, 64)
|
||||
case float32:
|
||||
str = strconv.FormatFloat(float64(converted), 'f', -1, 32)
|
||||
case time.Time:
|
||||
const ISO8601UTC = "2006-01-02T15:04:05Z"
|
||||
str = converted.UTC().Format(ISO8601UTC)
|
||||
default:
|
||||
return fmt.Errorf("unsupported value for param %s: %v (%s)",
|
||||
tag.Get("locationName"), value.Interface(), value.Type().Name())
|
||||
}
|
||||
|
||||
xname := xml.Name{Local: tag.Get("locationName")}
|
||||
if tag.Get("xmlAttribute") != "" { // put into current node's attribute list
|
||||
attr := xml.Attr{Name: xname, Value: str}
|
||||
current.Attr = append(current.Attr, attr)
|
||||
} else { // regular text node
|
||||
current.AddChild(&XMLNode{Name: xname, Text: str})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+257
@@ -0,0 +1,257 @@
|
||||
package xmlutil
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnmarshalXML deserializes an xml.Decoder into the container v. V
|
||||
// needs to match the shape of the XML expected to be decoded.
|
||||
// If the shape doesn't match unmarshaling will fail.
|
||||
func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
|
||||
n, _ := XMLToStruct(d, nil)
|
||||
if n.Children != nil {
|
||||
for _, root := range n.Children {
|
||||
for _, c := range root {
|
||||
if wrappedChild, ok := c.Children[wrapper]; ok {
|
||||
c = wrappedChild[0] // pull out wrapped element
|
||||
}
|
||||
|
||||
err := parse(reflect.ValueOf(v), c, "")
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
|
||||
// will be used to determine the type from r.
|
||||
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
|
||||
rtype := r.Type()
|
||||
if rtype.Kind() == reflect.Ptr {
|
||||
rtype = rtype.Elem() // check kind of actual element type
|
||||
}
|
||||
|
||||
t := tag.Get("type")
|
||||
if t == "" {
|
||||
switch rtype.Kind() {
|
||||
case reflect.Struct:
|
||||
t = "structure"
|
||||
case reflect.Slice:
|
||||
t = "list"
|
||||
case reflect.Map:
|
||||
t = "map"
|
||||
}
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "structure":
|
||||
if field, ok := rtype.FieldByName("_"); ok {
|
||||
tag = field.Tag
|
||||
}
|
||||
return parseStruct(r, node, tag)
|
||||
case "list":
|
||||
return parseList(r, node, tag)
|
||||
case "map":
|
||||
return parseMap(r, node, tag)
|
||||
default:
|
||||
return parseScalar(r, node, tag)
|
||||
}
|
||||
}
|
||||
|
||||
// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
|
||||
// types in the structure will also be deserialized.
|
||||
func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
|
||||
t := r.Type()
|
||||
if r.Kind() == reflect.Ptr {
|
||||
if r.IsNil() { // create the structure if it's nil
|
||||
s := reflect.New(r.Type().Elem())
|
||||
r.Set(s)
|
||||
r = s
|
||||
}
|
||||
|
||||
r = r.Elem()
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
// unwrap any payloads
|
||||
if payload := tag.Get("payload"); payload != "" {
|
||||
field, _ := t.FieldByName(payload)
|
||||
return parseStruct(r.FieldByName(payload), node, field.Tag)
|
||||
}
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if c := field.Name[0:1]; strings.ToLower(c) == c {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
|
||||
// figure out what this field is called
|
||||
name := field.Name
|
||||
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
|
||||
name = field.Tag.Get("locationNameList")
|
||||
} else if locName := field.Tag.Get("locationName"); locName != "" {
|
||||
name = locName
|
||||
}
|
||||
|
||||
// try to find the field by name in elements
|
||||
elems := node.Children[name]
|
||||
|
||||
if elems == nil { // try to find the field in attributes
|
||||
if val, ok := node.findElem(name); ok {
|
||||
elems = []*XMLNode{{Text: val}}
|
||||
}
|
||||
}
|
||||
|
||||
member := r.FieldByName(field.Name)
|
||||
for _, elem := range elems {
|
||||
err := parse(member, elem, field.Tag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseList deserializes a list of values from an XML node. Each list entry
|
||||
// will also be deserialized.
|
||||
func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
|
||||
t := r.Type()
|
||||
|
||||
if tag.Get("flattened") == "" { // look at all item entries
|
||||
mname := "member"
|
||||
if name := tag.Get("locationNameList"); name != "" {
|
||||
mname = name
|
||||
}
|
||||
|
||||
if Children, ok := node.Children[mname]; ok {
|
||||
if r.IsNil() {
|
||||
r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
|
||||
}
|
||||
|
||||
for i, c := range Children {
|
||||
err := parse(r.Index(i), c, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // flattened list means this is a single element
|
||||
if r.IsNil() {
|
||||
r.Set(reflect.MakeSlice(t, 0, 0))
|
||||
}
|
||||
|
||||
childR := reflect.Zero(t.Elem())
|
||||
r.Set(reflect.Append(r, childR))
|
||||
err := parse(r.Index(r.Len()-1), node, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
|
||||
// will also be deserialized as map entries.
|
||||
func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
|
||||
if r.IsNil() {
|
||||
r.Set(reflect.MakeMap(r.Type()))
|
||||
}
|
||||
|
||||
if tag.Get("flattened") == "" { // look at all child entries
|
||||
for _, entry := range node.Children["entry"] {
|
||||
parseMapEntry(r, entry, tag)
|
||||
}
|
||||
} else { // this element is itself an entry
|
||||
parseMapEntry(r, node, tag)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseMapEntry deserializes a map entry from a XML node.
|
||||
func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
|
||||
kname, vname := "key", "value"
|
||||
if n := tag.Get("locationNameKey"); n != "" {
|
||||
kname = n
|
||||
}
|
||||
if n := tag.Get("locationNameValue"); n != "" {
|
||||
vname = n
|
||||
}
|
||||
|
||||
keys, ok := node.Children[kname]
|
||||
values := node.Children[vname]
|
||||
if ok {
|
||||
for i, key := range keys {
|
||||
keyR := reflect.ValueOf(key.Text)
|
||||
value := values[i]
|
||||
valueR := reflect.New(r.Type().Elem()).Elem()
|
||||
|
||||
parse(valueR, value, "")
|
||||
r.SetMapIndex(keyR, valueR)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseScaller deserializes an XMLNode value into a concrete type based on the
|
||||
// interface type of r.
|
||||
//
|
||||
// Error is returned if the deserialization fails due to invalid type conversion,
|
||||
// or unsupported interface type.
|
||||
func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
|
||||
switch r.Interface().(type) {
|
||||
case *string:
|
||||
r.Set(reflect.ValueOf(&node.Text))
|
||||
return nil
|
||||
case []byte:
|
||||
b, err := base64.StdEncoding.DecodeString(node.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Set(reflect.ValueOf(b))
|
||||
case *bool:
|
||||
v, err := strconv.ParseBool(node.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Set(reflect.ValueOf(&v))
|
||||
case *int64:
|
||||
v, err := strconv.ParseInt(node.Text, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Set(reflect.ValueOf(&v))
|
||||
case *float64:
|
||||
v, err := strconv.ParseFloat(node.Text, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Set(reflect.ValueOf(&v))
|
||||
case *time.Time:
|
||||
const ISO8601UTC = "2006-01-02T15:04:05Z"
|
||||
t, err := time.Parse(ISO8601UTC, node.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Set(reflect.ValueOf(&t))
|
||||
default:
|
||||
return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+142
@@ -0,0 +1,142 @@
|
||||
package xmlutil
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// A XMLNode contains the values to be encoded or decoded.
|
||||
type XMLNode struct {
|
||||
Name xml.Name `json:",omitempty"`
|
||||
Children map[string][]*XMLNode `json:",omitempty"`
|
||||
Text string `json:",omitempty"`
|
||||
Attr []xml.Attr `json:",omitempty"`
|
||||
|
||||
namespaces map[string]string
|
||||
parent *XMLNode
|
||||
}
|
||||
|
||||
// NewXMLElement returns a pointer to a new XMLNode initialized to default values.
|
||||
func NewXMLElement(name xml.Name) *XMLNode {
|
||||
return &XMLNode{
|
||||
Name: name,
|
||||
Children: map[string][]*XMLNode{},
|
||||
Attr: []xml.Attr{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddChild adds child to the XMLNode.
|
||||
func (n *XMLNode) AddChild(child *XMLNode) {
|
||||
if _, ok := n.Children[child.Name.Local]; !ok {
|
||||
n.Children[child.Name.Local] = []*XMLNode{}
|
||||
}
|
||||
n.Children[child.Name.Local] = append(n.Children[child.Name.Local], child)
|
||||
}
|
||||
|
||||
// XMLToStruct converts a xml.Decoder stream to XMLNode with nested values.
|
||||
func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) {
|
||||
out := &XMLNode{}
|
||||
for {
|
||||
tok, err := d.Token()
|
||||
if tok == nil || err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
|
||||
switch typed := tok.(type) {
|
||||
case xml.CharData:
|
||||
out.Text = string(typed.Copy())
|
||||
case xml.StartElement:
|
||||
el := typed.Copy()
|
||||
out.Attr = el.Attr
|
||||
if out.Children == nil {
|
||||
out.Children = map[string][]*XMLNode{}
|
||||
}
|
||||
|
||||
name := typed.Name.Local
|
||||
slice := out.Children[name]
|
||||
if slice == nil {
|
||||
slice = []*XMLNode{}
|
||||
}
|
||||
node, e := XMLToStruct(d, &el)
|
||||
out.findNamespaces()
|
||||
if e != nil {
|
||||
return out, e
|
||||
}
|
||||
node.Name = typed.Name
|
||||
node.findNamespaces()
|
||||
tempOut := *out
|
||||
// Save into a temp variable, simply because out gets squashed during
|
||||
// loop iterations
|
||||
node.parent = &tempOut
|
||||
slice = append(slice, node)
|
||||
out.Children[name] = slice
|
||||
case xml.EndElement:
|
||||
if s != nil && s.Name.Local == typed.Name.Local { // matching end token
|
||||
return out, nil
|
||||
}
|
||||
out = &XMLNode{}
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (n *XMLNode) findNamespaces() {
|
||||
ns := map[string]string{}
|
||||
for _, a := range n.Attr {
|
||||
if a.Name.Space == "xmlns" {
|
||||
ns[a.Value] = a.Name.Local
|
||||
}
|
||||
}
|
||||
|
||||
n.namespaces = ns
|
||||
}
|
||||
|
||||
func (n *XMLNode) findElem(name string) (string, bool) {
|
||||
for node := n; node != nil; node = node.parent {
|
||||
for _, a := range node.Attr {
|
||||
namespace := a.Name.Space
|
||||
if v, ok := node.namespaces[namespace]; ok {
|
||||
namespace = v
|
||||
}
|
||||
if name == fmt.Sprintf("%s:%s", namespace, a.Name.Local) {
|
||||
return a.Value, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// StructToXML writes an XMLNode to a xml.Encoder as tokens.
|
||||
func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error {
|
||||
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr})
|
||||
|
||||
if node.Text != "" {
|
||||
e.EncodeToken(xml.CharData([]byte(node.Text)))
|
||||
} else if sorted {
|
||||
sortedNames := []string{}
|
||||
for k := range node.Children {
|
||||
sortedNames = append(sortedNames, k)
|
||||
}
|
||||
sort.Strings(sortedNames)
|
||||
|
||||
for _, k := range sortedNames {
|
||||
for _, v := range node.Children[k] {
|
||||
StructToXML(e, v, sorted)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, c := range node.Children {
|
||||
for _, v := range c {
|
||||
StructToXML(e, v, sorted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
e.EncodeToken(xml.EndElement{Name: node.Name})
|
||||
return e.Flush()
|
||||
}
|
||||
Generated
Vendored
+53
@@ -0,0 +1,53 @@
|
||||
package xmlutil_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/awstesting/unit"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
func TestUnmarshal(t *testing.T) {
|
||||
xmlVal := []byte(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<AccessControlPolicy xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Owner><ID>foo-id</ID><DisplayName>user</DisplayName></Owner><AccessControlList><Grant><Grantee xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:type="type"><ID>foo-id</ID><DisplayName>user</DisplayName></Grantee><Permission>FULL_CONTROL</Permission></Grant></AccessControlList><
|
||||
/AccessControlPolicy>`)
|
||||
|
||||
var server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(xmlVal)
|
||||
}))
|
||||
|
||||
sess := unit.Session
|
||||
sess.Config.Endpoint = &server.URL
|
||||
sess.Config.S3ForcePathStyle = aws.Bool(true)
|
||||
svc := s3.New(sess)
|
||||
|
||||
out, err := svc.GetBucketAcl(&s3.GetBucketAclInput{
|
||||
Bucket: aws.String("foo"),
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := &s3.GetBucketAclOutput{
|
||||
Grants: []*s3.Grant{
|
||||
{
|
||||
Grantee: &s3.Grantee{
|
||||
DisplayName: aws.String("user"),
|
||||
ID: aws.String("foo-id"),
|
||||
Type: aws.String("type"),
|
||||
},
|
||||
Permission: aws.String("FULL_CONTROL"),
|
||||
},
|
||||
},
|
||||
|
||||
Owner: &s3.Owner{
|
||||
DisplayName: aws.String("user"),
|
||||
ID: aws.String("foo-id"),
|
||||
},
|
||||
}
|
||||
assert.Equal(t, expected, out)
|
||||
}
|
||||
+180
@@ -0,0 +1,180 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidMethod = errors.New("v2 signer only handles HTTP POST")
|
||||
)
|
||||
|
||||
const (
|
||||
signatureVersion = "2"
|
||||
signatureMethod = "HmacSHA256"
|
||||
timeFormat = "2006-01-02T15:04:05Z"
|
||||
)
|
||||
|
||||
type signer struct {
|
||||
// Values that must be populated from the request
|
||||
Request *http.Request
|
||||
Time time.Time
|
||||
Credentials *credentials.Credentials
|
||||
Debug aws.LogLevelType
|
||||
Logger aws.Logger
|
||||
|
||||
Query url.Values
|
||||
stringToSign string
|
||||
signature string
|
||||
}
|
||||
|
||||
// SignRequestHandler is a named request handler the SDK will use to sign
|
||||
// service client request with using the V4 signature.
|
||||
var SignRequestHandler = request.NamedHandler{
|
||||
Name: "v2.SignRequestHandler", Fn: SignSDKRequest,
|
||||
}
|
||||
|
||||
// SignSDKRequest requests with signature version 2.
|
||||
//
|
||||
// Will sign the requests with the service config's Credentials object
|
||||
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
|
||||
// object.
|
||||
func SignSDKRequest(req *request.Request) {
|
||||
// If the request does not need to be signed ignore the signing of the
|
||||
// request if the AnonymousCredentials object is used.
|
||||
if req.Config.Credentials == credentials.AnonymousCredentials {
|
||||
return
|
||||
}
|
||||
|
||||
if req.HTTPRequest.Method != "POST" && req.HTTPRequest.Method != "GET" {
|
||||
// The V2 signer only supports GET and POST
|
||||
req.Error = errInvalidMethod
|
||||
return
|
||||
}
|
||||
|
||||
v2 := signer{
|
||||
Request: req.HTTPRequest,
|
||||
Time: req.Time,
|
||||
Credentials: req.Config.Credentials,
|
||||
Debug: req.Config.LogLevel.Value(),
|
||||
Logger: req.Config.Logger,
|
||||
}
|
||||
|
||||
req.Error = v2.Sign()
|
||||
|
||||
if req.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if req.HTTPRequest.Method == "POST" {
|
||||
// Set the body of the request based on the modified query parameters
|
||||
req.SetStringBody(v2.Query.Encode())
|
||||
|
||||
// Now that the body has changed, remove any Content-Length header,
|
||||
// because it will be incorrect
|
||||
req.HTTPRequest.ContentLength = 0
|
||||
req.HTTPRequest.Header.Del("Content-Length")
|
||||
} else {
|
||||
req.HTTPRequest.URL.RawQuery = v2.Query.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
func (v2 *signer) Sign() error {
|
||||
credValue, err := v2.Credentials.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v2.Request.Method == "POST" {
|
||||
// Parse the HTTP request to obtain the query parameters that will
|
||||
// be used to build the string to sign. Note that because the HTTP
|
||||
// request will need to be modified, the PostForm and Form properties
|
||||
// are reset to nil after parsing.
|
||||
v2.Request.ParseForm()
|
||||
v2.Query = v2.Request.PostForm
|
||||
v2.Request.PostForm = nil
|
||||
v2.Request.Form = nil
|
||||
} else {
|
||||
v2.Query = v2.Request.URL.Query()
|
||||
}
|
||||
|
||||
// Set new query parameters
|
||||
v2.Query.Set("AWSAccessKeyId", credValue.AccessKeyID)
|
||||
v2.Query.Set("SignatureVersion", signatureVersion)
|
||||
v2.Query.Set("SignatureMethod", signatureMethod)
|
||||
v2.Query.Set("Timestamp", v2.Time.UTC().Format(timeFormat))
|
||||
if credValue.SessionToken != "" {
|
||||
v2.Query.Set("SecurityToken", credValue.SessionToken)
|
||||
}
|
||||
|
||||
// in case this is a retry, ensure no signature present
|
||||
v2.Query.Del("Signature")
|
||||
|
||||
method := v2.Request.Method
|
||||
host := v2.Request.URL.Host
|
||||
path := v2.Request.URL.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
// obtain all of the query keys and sort them
|
||||
queryKeys := make([]string, 0, len(v2.Query))
|
||||
for key := range v2.Query {
|
||||
queryKeys = append(queryKeys, key)
|
||||
}
|
||||
sort.Strings(queryKeys)
|
||||
|
||||
// build URL-encoded query keys and values
|
||||
queryKeysAndValues := make([]string, len(queryKeys))
|
||||
for i, key := range queryKeys {
|
||||
k := strings.Replace(url.QueryEscape(key), "+", "%20", -1)
|
||||
v := strings.Replace(url.QueryEscape(v2.Query.Get(key)), "+", "%20", -1)
|
||||
queryKeysAndValues[i] = k + "=" + v
|
||||
}
|
||||
|
||||
// join into one query string
|
||||
query := strings.Join(queryKeysAndValues, "&")
|
||||
|
||||
// build the canonical string for the V2 signature
|
||||
v2.stringToSign = strings.Join([]string{
|
||||
method,
|
||||
host,
|
||||
path,
|
||||
query,
|
||||
}, "\n")
|
||||
|
||||
hash := hmac.New(sha256.New, []byte(credValue.SecretAccessKey))
|
||||
hash.Write([]byte(v2.stringToSign))
|
||||
v2.signature = base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||
v2.Query.Set("Signature", v2.signature)
|
||||
|
||||
if v2.Debug.Matches(aws.LogDebugWithSigning) {
|
||||
v2.logSigningInfo()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const logSignInfoMsg = `DEBUG: Request Signature:
|
||||
---[ STRING TO SIGN ]--------------------------------
|
||||
%s
|
||||
---[ SIGNATURE ]-------------------------------------
|
||||
%s
|
||||
-----------------------------------------------------`
|
||||
|
||||
func (v2 *signer) logSigningInfo() {
|
||||
msg := fmt.Sprintf(logSignInfoMsg, v2.stringToSign, v2.Query.Get("Signature"))
|
||||
v2.Logger.Log(msg)
|
||||
}
|
||||
+195
@@ -0,0 +1,195 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/awstesting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type signerBuilder struct {
|
||||
ServiceName string
|
||||
Region string
|
||||
SignTime time.Time
|
||||
Query url.Values
|
||||
Method string
|
||||
SessionToken string
|
||||
}
|
||||
|
||||
func (sb signerBuilder) BuildSigner() signer {
|
||||
endpoint := "https://" + sb.ServiceName + "." + sb.Region + ".amazonaws.com"
|
||||
var req *http.Request
|
||||
if sb.Method == "POST" {
|
||||
body := []byte(sb.Query.Encode())
|
||||
reader := bytes.NewReader(body)
|
||||
req, _ = http.NewRequest(sb.Method, endpoint, reader)
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Add("Content-Length", string(len(body)))
|
||||
} else {
|
||||
req, _ = http.NewRequest(sb.Method, endpoint, nil)
|
||||
req.URL.RawQuery = sb.Query.Encode()
|
||||
}
|
||||
|
||||
sig := signer{
|
||||
Request: req,
|
||||
Time: sb.SignTime,
|
||||
Credentials: credentials.NewStaticCredentials(
|
||||
"AKID",
|
||||
"SECRET",
|
||||
sb.SessionToken),
|
||||
}
|
||||
|
||||
if os.Getenv("DEBUG") != "" {
|
||||
sig.Debug = aws.LogDebug
|
||||
sig.Logger = aws.NewDefaultLogger()
|
||||
}
|
||||
|
||||
return sig
|
||||
}
|
||||
|
||||
func TestSignRequestWithAndWithoutSession(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// have to create more than once, so use a function
|
||||
newQuery := func() url.Values {
|
||||
query := make(url.Values)
|
||||
query.Add("Action", "CreateDomain")
|
||||
query.Add("DomainName", "TestDomain-1437033376")
|
||||
query.Add("Version", "2009-04-15")
|
||||
return query
|
||||
}
|
||||
|
||||
// create request without a SecurityToken (session) in the credentials
|
||||
|
||||
query := newQuery()
|
||||
timestamp := time.Date(2015, 7, 16, 7, 56, 16, 0, time.UTC)
|
||||
builder := signerBuilder{
|
||||
Method: "POST",
|
||||
ServiceName: "sdb",
|
||||
Region: "ap-southeast-2",
|
||||
SignTime: timestamp,
|
||||
Query: query,
|
||||
}
|
||||
|
||||
signer := builder.BuildSigner()
|
||||
|
||||
err := signer.Sign()
|
||||
assert.NoError(err)
|
||||
assert.Equal("tm4dX8Ks7pzFSVHz7qHdoJVXKRLuC4gWz9eti60d8ks=", signer.signature)
|
||||
assert.Equal(8, len(signer.Query))
|
||||
assert.Equal("AKID", signer.Query.Get("AWSAccessKeyId"))
|
||||
assert.Equal("2015-07-16T07:56:16Z", signer.Query.Get("Timestamp"))
|
||||
assert.Equal("HmacSHA256", signer.Query.Get("SignatureMethod"))
|
||||
assert.Equal("2", signer.Query.Get("SignatureVersion"))
|
||||
assert.Equal("tm4dX8Ks7pzFSVHz7qHdoJVXKRLuC4gWz9eti60d8ks=", signer.Query.Get("Signature"))
|
||||
assert.Equal("CreateDomain", signer.Query.Get("Action"))
|
||||
assert.Equal("TestDomain-1437033376", signer.Query.Get("DomainName"))
|
||||
assert.Equal("2009-04-15", signer.Query.Get("Version"))
|
||||
|
||||
// should not have a SecurityToken parameter
|
||||
_, ok := signer.Query["SecurityToken"]
|
||||
assert.False(ok)
|
||||
|
||||
// now sign again, this time with a security token (session)
|
||||
|
||||
query = newQuery()
|
||||
builder.SessionToken = "SESSION"
|
||||
signer = builder.BuildSigner()
|
||||
|
||||
err = signer.Sign()
|
||||
assert.NoError(err)
|
||||
assert.Equal("Ch6qv3rzXB1SLqY2vFhsgA1WQ9rnQIE2WJCigOvAJwI=", signer.signature)
|
||||
assert.Equal(9, len(signer.Query)) // expect one more parameter
|
||||
assert.Equal("Ch6qv3rzXB1SLqY2vFhsgA1WQ9rnQIE2WJCigOvAJwI=", signer.Query.Get("Signature"))
|
||||
assert.Equal("SESSION", signer.Query.Get("SecurityToken"))
|
||||
}
|
||||
|
||||
func TestMoreComplexSignRequest(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
query := make(url.Values)
|
||||
query.Add("Action", "PutAttributes")
|
||||
query.Add("DomainName", "TestDomain-1437041569")
|
||||
query.Add("Version", "2009-04-15")
|
||||
query.Add("Attribute.2.Name", "Attr2")
|
||||
query.Add("Attribute.2.Value", "Value2")
|
||||
query.Add("Attribute.2.Replace", "true")
|
||||
query.Add("Attribute.1.Name", "Attr1-%\\+ %")
|
||||
query.Add("Attribute.1.Value", " \tValue1 +!@#$%^&*(){}[]\"';:?/.>,<\x12\x00")
|
||||
query.Add("Attribute.1.Replace", "true")
|
||||
query.Add("ItemName", "Item 1")
|
||||
|
||||
timestamp := time.Date(2015, 7, 16, 10, 12, 51, 0, time.UTC)
|
||||
builder := signerBuilder{
|
||||
Method: "POST",
|
||||
ServiceName: "sdb",
|
||||
Region: "ap-southeast-2",
|
||||
SignTime: timestamp,
|
||||
Query: query,
|
||||
SessionToken: "SESSION",
|
||||
}
|
||||
|
||||
signer := builder.BuildSigner()
|
||||
|
||||
err := signer.Sign()
|
||||
assert.NoError(err)
|
||||
assert.Equal("WNdE62UJKLKoA6XncVY/9RDbrKmcVMdQPQOTAs8SgwQ=", signer.signature)
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
svc := awstesting.NewClient(&aws.Config{
|
||||
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
|
||||
Region: aws.String("ap-southeast-2"),
|
||||
})
|
||||
r := svc.NewRequest(
|
||||
&request.Operation{
|
||||
Name: "OpName",
|
||||
HTTPMethod: "GET",
|
||||
HTTPPath: "/",
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
r.Build()
|
||||
assert.Equal("GET", r.HTTPRequest.Method)
|
||||
assert.Equal("", r.HTTPRequest.URL.Query().Get("Signature"))
|
||||
|
||||
SignSDKRequest(r)
|
||||
assert.NoError(r.Error)
|
||||
t.Logf("Signature: %s", r.HTTPRequest.URL.Query().Get("Signature"))
|
||||
assert.NotEqual("", r.HTTPRequest.URL.Query().Get("Signature"))
|
||||
}
|
||||
|
||||
func TestAnonymousCredentials(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
svc := awstesting.NewClient(&aws.Config{
|
||||
Credentials: credentials.AnonymousCredentials,
|
||||
Region: aws.String("ap-southeast-2"),
|
||||
})
|
||||
r := svc.NewRequest(
|
||||
&request.Operation{
|
||||
Name: "PutAttributes",
|
||||
HTTPMethod: "POST",
|
||||
HTTPPath: "/",
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
r.Build()
|
||||
|
||||
SignSDKRequest(r)
|
||||
|
||||
req := r.HTTPRequest
|
||||
req.ParseForm()
|
||||
|
||||
assert.Empty(req.PostForm.Get("Signature"))
|
||||
}
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
package util
|
||||
|
||||
import "sort"
|
||||
|
||||
// SortedKeys returns a sorted slice of keys of a map.
|
||||
func SortedKeys(m map[string]interface{}) []string {
|
||||
i, sorted := 0, make([]string, len(m))
|
||||
for k := range m {
|
||||
sorted[i] = k
|
||||
i++
|
||||
}
|
||||
sort.Strings(sorted)
|
||||
return sorted
|
||||
}
|
||||
+109
@@ -0,0 +1,109 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
|
||||
)
|
||||
|
||||
// GoFmt returns the Go formated string of the input.
|
||||
//
|
||||
// Panics if the format fails.
|
||||
func GoFmt(buf string) string {
|
||||
formatted, err := format.Source([]byte(buf))
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("%s\nOriginal code:\n%s", err.Error(), buf))
|
||||
}
|
||||
return string(formatted)
|
||||
}
|
||||
|
||||
var reTrim = regexp.MustCompile(`\s{2,}`)
|
||||
|
||||
// Trim removes all leading and trailing white space.
|
||||
//
|
||||
// All consecutive spaces will be reduced to a single space.
|
||||
func Trim(s string) string {
|
||||
return strings.TrimSpace(reTrim.ReplaceAllString(s, " "))
|
||||
}
|
||||
|
||||
// Capitalize capitalizes the first character of the string.
|
||||
func Capitalize(s string) string {
|
||||
if len(s) == 1 {
|
||||
return strings.ToUpper(s)
|
||||
}
|
||||
return strings.ToUpper(s[0:1]) + s[1:]
|
||||
}
|
||||
|
||||
// SortXML sorts the reader's XML elements
|
||||
func SortXML(r io.Reader) string {
|
||||
var buf bytes.Buffer
|
||||
d := xml.NewDecoder(r)
|
||||
root, _ := xmlutil.XMLToStruct(d, nil)
|
||||
e := xml.NewEncoder(&buf)
|
||||
xmlutil.StructToXML(e, root, true)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// PrettyPrint generates a human readable representation of the value v.
|
||||
// All values of v are recursively found and pretty printed also.
|
||||
func PrettyPrint(v interface{}) string {
|
||||
value := reflect.ValueOf(v)
|
||||
switch value.Kind() {
|
||||
case reflect.Struct:
|
||||
str := fullName(value.Type()) + "{\n"
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
l := string(value.Type().Field(i).Name[0])
|
||||
if strings.ToUpper(l) == l {
|
||||
str += value.Type().Field(i).Name + ": "
|
||||
str += PrettyPrint(value.Field(i).Interface())
|
||||
str += ",\n"
|
||||
}
|
||||
}
|
||||
str += "}"
|
||||
return str
|
||||
case reflect.Map:
|
||||
str := "map[" + fullName(value.Type().Key()) + "]" + fullName(value.Type().Elem()) + "{\n"
|
||||
for _, k := range value.MapKeys() {
|
||||
str += "\"" + k.String() + "\": "
|
||||
str += PrettyPrint(value.MapIndex(k).Interface())
|
||||
str += ",\n"
|
||||
}
|
||||
str += "}"
|
||||
return str
|
||||
case reflect.Ptr:
|
||||
if e := value.Elem(); e.IsValid() {
|
||||
return "&" + PrettyPrint(e.Interface())
|
||||
}
|
||||
return "nil"
|
||||
case reflect.Slice:
|
||||
str := "[]" + fullName(value.Type().Elem()) + "{\n"
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
str += PrettyPrint(value.Index(i).Interface())
|
||||
str += ",\n"
|
||||
}
|
||||
str += "}"
|
||||
return str
|
||||
default:
|
||||
return fmt.Sprintf("%#v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func pkgName(t reflect.Type) string {
|
||||
pkg := t.PkgPath()
|
||||
c := strings.Split(pkg, "/")
|
||||
return c[len(c)-1]
|
||||
}
|
||||
|
||||
func fullName(t reflect.Type) string {
|
||||
if pkg := pkgName(t); pkg != "" {
|
||||
return pkg + "." + t.Name()
|
||||
}
|
||||
return t.Name()
|
||||
}
|
||||
Reference in New Issue
Block a user