Merge tag 'upstream/0.9.1'

Upstream version 0.9.1
This commit is contained in:
Sébastien Delafond
2015-05-02 12:26:06 +02:00
576 changed files with 34685 additions and 10054 deletions
+576
View File
@@ -0,0 +1,576 @@
#!/bin/bash
# (The MIT License)
#
# Copyright (c) 2014 Andrey Smirnov
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the 'Software'), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
__aptly_mirror_list()
{
aptly mirror list -raw
}
__aptly_repo_list()
{
aptly repo list -raw
}
__aptly_snapshot_list()
{
aptly snapshot list -raw
}
__aptly_published_distributions()
{
aptly publish list -raw | cut -d ' ' -f 2 | sort | uniq
}
__aptly_published_prefixes()
{
aptly publish list -raw | cut -d ' ' -f 1 | sort | uniq
}
__aptly_prefixes_for_distribution()
{
aptly publish list -raw | awk -v dist="$1" '{ if (dist == $2) print $1 }' | sort | uniq
}
_aptly()
{
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
commands="api config db graph mirror package publish repo serve snapshot task version"
options="-architectures= -config= -dep-follow-all-variants -dep-follow-recommends -dep-follow-source -dep-follow-suggests"
db_subcommands="cleanup recover"
mirror_subcommands="create drop show list rename search update"
publish_subcommands="drop list repo snapshot switch update"
snapshot_subcommands="create diff drop filter list merge pull rename search show verify"
repo_subcommands="add copy create drop edit import list move remove rename search show"
package_subcommands="search show"
task_subcommands="run"
config_subcommands="show"
api_subcommands="serve"
local cmd subcmd numargs numoptions i
numargs=0
numoptions=0
for (( i=1; i < $COMP_CWORD; i++ )); do
if [[ -n "$cmd" ]]; then
if [[ ! -n "$subcmd" ]]; then
subcmd=${COMP_WORDS[i]}
numargs=$(( COMP_CWORD - i - 1 ))
else
if [[ "${COMP_WORDS[i]}" == -* ]]; then
numoptions=$(( numoptions + 1 ))
numargs=$(( numargs - 1 ))
fi
fi
else
if [[ ! "${COMP_WORDS[i]}" == -* ]]; then
cmd=${COMP_WORDS[i]}
fi
fi
done
if [[ ! -n "$cmd" ]];
then
case "$cur" in
-*)
COMPREPLY=($(compgen -W "${options}" -- ${cur}))
return 0
;;
*)
COMPREPLY=($(compgen -W "${commands}" -- ${cur}))
return 0
;;
esac
fi
if [[ ! -n "$subcmd" ]];
then
case "$prev" in
"db")
COMPREPLY=($(compgen -W "${db_subcommands}" -- ${cur}))
return 0
;;
"mirror")
COMPREPLY=($(compgen -W "${mirror_subcommands}" -- ${cur}))
return 0
;;
"repo")
COMPREPLY=($(compgen -W "${repo_subcommands}" -- ${cur}))
return 0
;;
"snapshot")
COMPREPLY=($(compgen -W "${snapshot_subcommands}" -- ${cur}))
return 0
;;
"publish")
COMPREPLY=($(compgen -W "${publish_subcommands}" -- ${cur}))
return 0
;;
"package")
COMPREPLY=($(compgen -W "${package_subcommands}" -- ${cur}))
return 0
;;
"task")
COMPREPLY=($(compgen -W "${task_subcommands}" -- ${cur}))
return 0
;;
"config")
COMPREPLY=($(compgen -W "${config_subcommands}" -- ${cur}))
return 0
;;
"api")
COMPREPLY=($(compgen -W "${api_subcommands}" -- ${cur}))
return 0
;;
*)
;;
esac
fi
case "$cmd" in
"mirror")
case "$subcmd" in
"create")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-filter= -filter-with-deps -force-components -ignore-signatures -keyring= -with-sources -with-udebs" -- ${cur}))
return 0
fi
fi
;;
"edit")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-filter= -filter-with-deps -with-sources -with-udebs" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
fi
return 0
fi
;;
"show")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-packages" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
fi
return 0
fi
;;
"search")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-deps" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
fi
return 0
fi
;;
"rename")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
return 0
fi
;;
"drop")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-force" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
fi
return 0
fi
;;
"list")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "-raw" -- ${cur}))
return 0
fi
;;
"update")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-force -download-limit= -ignore-checksums -ignore-signatures -keyring=" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
fi
return 0
fi
;;
esac
;;
"repo")
case "$subcmd" in
"add")
case $numargs in
0)
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-force-replace -remove-files" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
return 0
;;
1)
local files=$(find . -mindepth 1 -maxdepth 1 \( -type d -or -name \*.deb -or -name \*.dsc \) -exec basename {} \;)
COMPREPLY=($(compgen -W "${files}" -- ${cur}))
return 0
;;
esac
;;
"copy"|"move")
case $numargs in
0)
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-deps -dry-run" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
return 0
;;
1)
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
return 0
;;
esac
;;
"create")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-comment= -distribution= -component=" -- ${cur}))
return 0
fi
fi
;;
"drop")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-force" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
return 0
fi
;;
"edit")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-comment= -distribution= -component=" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
return 0
fi
;;
"search")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-deps" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
return 0
fi
;;
"list")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "-raw" -- ${cur}))
return 0
fi
;;
"import")
case $numargs in
0)
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-deps -dry-run" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
fi
return 0
;;
1)
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
return 0
;;
esac
;;
"remove")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-dry-run" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
return 0
fi
;;
"show")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-packages" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
return 0
fi
;;
"rename")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
return 0
fi
;;
esac
;;
"snapshot")
case "$subcmd" in
"create")
case $numargs in
1)
COMPREPLY=($(compgen -W "from empty" -- ${cur}))
return 0
;;
2)
if [[ "$prev" == "from" ]]; then
COMPREPLY=($(compgen -W "mirror repo" -- ${cur}))
return 0
fi
;;
3)
if [[ "$prev" == "mirror" ]]; then
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
return 0
fi
if [[ "$prev" == "repo" ]]; then
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
return 0
fi
;;
esac
;;
"diff")
if [[ $numargs -eq 0 ]] && [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-only-matching" -- ${cur}))
return 0
fi
if [[ $numargs -lt 2 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
return 0
fi
;;
"drop")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-force" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
fi
return 0
fi
;;
"list")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "-raw -sort=" -- ${cur}))
return 0
fi
;;
"merge")
if [[ $numargs -gt 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-latest" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
fi
return 0
fi
;;
"pull")
if [[ $numargs -eq 0 ]] && [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-all-matches -dry-run -no-deps -no-remove" -- ${cur}))
return 0
fi
if [[ $numargs -lt 2 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
return 0
fi
;;
"filter")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-deps" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
fi
return 0
fi
;;
"show")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-packages" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
fi
return 0
fi
;;
"search")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-deps" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
fi
return 0
fi
;;
"rename")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
return 0
fi
;;
"verify")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
return 0
fi
;;
esac
;;
"publish")
case "$subcmd" in
"snapshot"|"repo")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-batch -force-overwrite -distribution= -component= -gpg-key= -keyring= -label= -origin= -passphrase= -passphrase-file= -secret-keyring= -skip-signing" -- ${cur}))
else
if [[ "$subcmd" == "snapshot" ]]; then
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_repo_list)" -- ${cur}))
fi
fi
return 0
fi
if [[ $numargs -eq 1 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_published_prefixes)" -- ${cur}))
return 0
fi
;;
"list")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "-raw" -- ${cur}))
return 0
fi
;;
"update")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-batch -force-overwrite -gpg-key= -keyring= -passphrase= -passphrase-file= -secret-keyring= -skip-signing" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_published_distributions)" -- ${cur}))
fi
return 0
fi
if [[ $numargs -eq 1 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_prefixes_for_distribution $prev)" -- ${cur}))
return 0
fi
;;
"switch")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-batch -force-overwrite -component= -gpg-key= -keyring= -passphrase= -passphrase-file= -secret-keyring= -skip-signing" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_published_distributions)" -- ${cur}))
fi
return 0
fi
if [[ $numargs -eq 1 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_prefixes_for_distribution $prev)" -- ${cur}))
return 0
fi
if [[ $numargs -ge 2 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_snapshot_list)" -- ${cur}))
return 0
fi
;;
"drop")
if [[ $numargs -eq 0 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_published_distributions)" -- ${cur}))
return 0
fi
if [[ $numargs -eq 1 ]]; then
COMPREPLY=($(compgen -W "$(__aptly_prefixes_for_distribution $prev)" -- ${cur}))
return 0
fi
;;
esac
;;
"package")
case "$subcmd" in
"show")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-with-files -with-references" -- ${cur}))
fi
return 0
fi
;;
esac
;;
"serve")
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-listen=" -- ${cur}))
return 0
fi
;;
"api")
case "$subcmd" in
"serve")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-listen=" -- ${cur}))
fi
return 0
fi
;;
esac
;;
esac
} && complete -F _aptly aptly
+15 -3
View File
@@ -1,17 +1,20 @@
language: go
go:
- 1.2.1
- 1.3.1
- 1.3.3
- tip
env:
global:
- secure: "YSwtFrMqh4oUvdSQTXBXMHHLWeQgyNEL23ChIZwU0nuDGIcQZ65kipu0PzefedtUbK4ieC065YCUi4UDDh6gPotB/Wu1pnYg3dyQ7rFvhaVYAAUEpajAdXZhlx+7+J8a4FZMeC/kqiahxoRgLbthF9019ouIqhGB9zHKI6/yZwc="
- secure: "V7OjWrfQ8UbktgT036jYQPb/7GJT3Ol9LObDr8FYlzsQ+F1uj2wLac6ePuxcOS4FwWOJinWGM1h+JiFkbxbyFqfRNJ0jj0O2p93QyDojxFVOn1mXqqvV66KFqAWR2Vzkny/gDvj8LTvdB1cgAIm2FNOkQc6E1BFnyWS2sN9ea5E="
- secure: "OxiVNmre2JzUszwPNNilKDgIqtfX2gnRSsVz6nuySB1uO2yQsOQmKWJ9cVYgH2IB5H8eWXKOhexcSE28kz6TPLRuEcU9fnqKY3uEkdwm7rJfz9lf+7C4bJEUdA1OIzJppjnWUiXxD7CEPL1DlnMZM24eDQYqa/4WKACAgkK53gE="
before_install:
- sudo apt-get update -qq
- sudo apt-get install -y python-boto
- sudo apt-get install -y python-virtualenv graphviz
- virtualenv env
- . env/bin/activate
- pip install boto requests python-swiftclient
install:
- make prepare
@@ -21,3 +24,12 @@ script: make travis
matrix:
allow_failures:
- go: tip
notifications:
webhooks:
urls:
- "https://webhooks.gitter.im/e/c691da114a41eed6ec45"
on_success: change # options: [always|never|change] default: always
on_failure: always # options: [always|never|change] default: always
on_start: false # default: false
+12 -2
View File
@@ -3,5 +3,15 @@ List of contributors, in chronological order:
* Andrey Smirnov (https://github.com/smira)
* Sebastien Binet (https://github.com/sbinet)
* Ryan Uber (https://github.com/ryanuber)
* Simon Aquino (https://github.com/simonaquino)
* Vincent Batoufflet (https://github.com/vbatoufflet)
* Simon Aquino (https://github.com/queeno)
* Vincent Batoufflet (https://github.com/vbatoufflet)
* Ivan Kurnosov (https://github.com/zerkms)
* Dmitrii Kashin (https://github.com/freehck)
* Chris Read (https://github.com/cread)
* Rohan Garg (https://github.com/shadeslayer)
* Russ Allbery (https://github.com/rra)
* Sylvain Baubeau (https://github.com/lebauce)
* Andrea Bernardo Ciddio (https://github.com/bcandrea)
* Michael Koval (https://github.com/mkoval)
* Alexander Guy (https://github.com/alexanderguy)
* Sebastien Badia (https://github.com/sbadia)
+12 -6
View File
@@ -1,27 +1,33 @@
gom 'code.google.com/p/go-uuid/uuid', :commit => '5fac954758f5'
gom 'code.google.com/p/go.crypto/ssh/terminal', :commit => '7aa593ce8cea'
gom 'code.google.com/p/gographviz', :commit => '454bc64fdfa2'
gom 'code.google.com/p/mxk/go1/flowcontrol', :commit => '5ff2502e2556'
gom 'code.google.com/p/snappy-go/snappy', :commit => '12e4b4183793'
gom 'github.com/cheggaaa/pb', :commit => '74be7a1388046f374ac36e93d46f5d56e856f827'
gom 'github.com/vaughan0/go-ini', :commit => 'a98ad7ee00ec53921f08832bc06ecf7fd600e6a1'
gom 'github.com/AlekSi/pointer', :commit => '5f6d527dae3d678b46fbb20331ddf44e2b841943'
gom 'github.com/cheggaaa/pb', :commit => '2c1b74620cc58a81ac152ee2d322e28c806d81ed'
gom 'github.com/gin-gonic/gin', :commit => 'b1758d3bfa09e61ddbc1c9a627e936eec6a170de'
gom 'github.com/jlaffaye/ftp', :commit => 'fec71e62e457557fbe85cefc847a048d57815d76'
gom 'github.com/julienschmidt/httprouter', :commit => '46807412fe50aaceb73bb57061c2230fd26a1640'
gom 'github.com/mattn/go-shellwords', :commit => 'c7ca6f94add751566a61cf2199e1de78d4c3eee4'
gom 'github.com/mitchellh/goamz/s3', :commit => 'e7664b32019f31fd1bdf33f9e85f28722f700405'
gom 'github.com/mkrautz/goar', :commit => '36eb5f3452b1283a211fa35bc00c646fd0db5c4b'
gom 'github.com/ncw/swift', :commit => '384ef27c70645e285f8bb9d02276bf654d06027e'
gom 'github.com/smira/commander', :commit => 'f408b00e68d5d6e21b9f18bd310978dafc604e47'
gom 'github.com/smira/flag', :commit => '357ed3e599ffcbd4aeaa828e1d10da2df3ea5107'
gom 'github.com/smira/go-ftp-protocol/protocol', :commit => '066b75c2b70dca7ae10b1b88b47534a3c31ccfaa'
gom 'github.com/syndtr/goleveldb/leveldb', :commit => 'e2fa4e6ac1cc41a73bc9fd467878ecbf65df5cc3'
gom 'github.com/syndtr/gosnappy/snappy', :commit => 'ce8acff4829e0c2458a67ead32390ac0a381c862'
gom 'github.com/syndtr/goleveldb/leveldb', :commit => '97e257099d2ab9578151ba85e2641e2cd14d3ca8'
gom 'github.com/ugorji/go/codec', :commit => '71c2886f5a673a35f909803f38ece5810165097b'
gom 'github.com/vaughan0/go-ini', :commit => 'a98ad7ee00ec53921f08832bc06ecf7fd600e6a1'
gom 'github.com/wsxiaoys/terminal/color', :commit => '5668e431776a7957528361f90ce828266c69ed08'
gom 'golang.org/x/crypto/ssh/terminal', :commit => 'a7ead6ddf06233883deca151dffaef2effbf498f'
group :test do
gom 'launchpad.net/gocheck'
gom 'gopkg.in/check.v1'
end
group :development do
gom 'github.com/golang/lint/golint'
gom 'github.com/mattn/goveralls'
gom 'github.com/axw/gocov/gocov'
gom 'code.google.com/p/go.tools/cmd/cover'
gom 'golang.org/x/tools/cmd/cover'
end
+6 -4
View File
@@ -1,6 +1,6 @@
GOVERSION=$(shell go version | awk '{print $$3;}')
PACKAGES=database deb files http query s3 utils
ALL_PACKAGES=aptly cmd console database deb files http query s3 utils
PACKAGES=context database deb files http query swift s3 utils
ALL_PACKAGES=api aptly context cmd console database deb files http query swift s3 utils
BINPATH=$(abspath ./_vendor/bin)
GOM_ENVIRONMENT=-test
PYTHON?=python
@@ -36,7 +36,7 @@ coverage: coverage.out
rm -f coverage.out
check:
$(GOM) exec go tool vet -all=true -shadow=true $(ALL_PACKAGES:%=./%)
$(GOM) exec go tool vet -all=true $(ALL_PACKAGES:%=./%)
$(GOM) exec golint $(ALL_PACKAGES:%=./%)
install:
@@ -67,7 +67,7 @@ package:
(cd root/etc/bash_completion.d && wget https://raw.github.com/aptly-dev/aptly-bash-completion/master/aptly)
gzip root/usr/share/man/man1/aptly.1
fpm -s dir -t deb -n aptly -v $(VERSION) --url=http://www.aptly.info/ --license=MIT --vendor="Andrey Smirnov <me@smira.ru>" \
-f -m "Andrey Smirnov <me@smira.ru>" --description="Debian repository management tool" --deb-recommends bzip2 -C root/ .
-f -m "Andrey Smirnov <me@smira.ru>" --description="Debian repository management tool" --deb-recommends bzip2 --deb-recommends graphviz -C root/ .
mv aptly_$(VERSION)_*.deb ~
src-package:
@@ -77,6 +77,8 @@ src-package:
cd aptly-$(VERSION)/src/github.com/smira/aptly && gom -production install
cd aptly-$(VERSION)/src/github.com/smira/aptly && find . \( -name .git -o -name .bzr -o -name .hg \) -print | xargs rm -rf
rm -rf aptly-$(VERSION)/src/github.com/smira/aptly/_vendor/{pkg,bin}
mkdir -p aptly-$(VERSION)/bash_completion.d
(cd aptly-$(VERSION)/bash_completion.d && wget https://raw.github.com/aptly-dev/aptly-bash-completion/$(VERSION)/aptly)
tar cyf aptly-$(VERSION)-src.tar.bz2 aptly-$(VERSION)
rm -rf aptly-$(VERSION)
curl -T aptly-$(VERSION)-src.tar.bz2 -usmira:$(BINTRAY_KEY) https://api.bintray.com/content/smira/aptly/aptly/$(VERSION)/$(VERSION)/aptly-$(VERSION)-src.tar.bz2
+12 -5
View File
@@ -8,8 +8,11 @@ aptly
.. image:: https://coveralls.io/repos/smira/aptly/badge.png?branch=HEAD
:target: https://coveralls.io/r/smira/aptly?branch=HEAD
.. image:: http://gobuild.io/badge/github.com/smira/aptly/download.png
:target: http://gobuild.io/github.com/smira/aptly
.. image:: https://badges.gitter.im/Join Chat.svg
:target: https://gitter.im/smira/aptly?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge
.. image:: http://goreportcard.com/badge/gojp/goreportcard
:target: http://goreportcard.com/report/gojp/goreportcard
Aptly is a swiss army knife for Debian repository management.
@@ -28,6 +31,7 @@ Aptly features: ("+" means planned features)
* merge two or more snapshots into one
* filter repository by search query, pulling dependencies when required
* publish self-made packages as Debian repositories
* REST API for remote access
* mirror repositories "as-is" (without resigning with user's key) (+)
* support for yum repositories (+)
@@ -44,8 +48,7 @@ To install aptly on Debian/Ubuntu, add new repository to /etc/apt/sources.list::
And import key that is used to sign the release::
$ gpg --keyserver keys.gnupg.net --recv-keys 2A194991
$ gpg -a --export 2A194991 | sudo apt-key add -
$ apt-key adv --keyserver keys.gnupg.net --recv-keys E083A3782A194991
After that you can install aptly as any other software package::
@@ -55,9 +58,13 @@ After that you can install aptly as any other software package::
Don't worry about squeeze part in repo name: aptly package should work on Debian squeeze+,
Ubuntu 10.0+. Package contains aptly binary, man page and bash completion.
If you would like to use nightly builds (unstable), please use following repository::
deb http://repo.aptly.info/ nightly main
Binary executables (depends almost only on libc) are available for download from `Bintray <http://dl.bintray.com/smira/aptly/>`_.
If you have Go environment set up, you can build aptly from source by running (go 1.2+ required)::
If you have Go environment set up, you can build aptly from source by running (go 1.3+ required)::
go get -u github.com/mattn/gom
mkdir -p $GOPATH/src/github.com/smira/aptly
@@ -1,2 +0,0 @@
defaultcc: golang-codereviews@googlegroups.com
contributors: http://go.googlecode.com/hg/CONTRIBUTORS
@@ -1,41 +0,0 @@
package openpgp
import (
"testing"
"time"
)
func TestKeyExpiry(t *testing.T) {
kring, _ := ReadKeyRing(readerFromHex(expiringKeyHex))
entity := kring[0]
const timeFormat = "2006-01-02"
time1, _ := time.Parse(timeFormat, "2013-07-01")
// The expiringKeyHex key is structured as:
//
// pub 1024R/5E237D8C created: 2013-07-01 expires: 2013-07-31 usage: SC
// sub 1024R/1ABB25A0 created: 2013-07-01 expires: 2013-07-08 usage: E
// sub 1024R/96A672F5 created: 2013-07-01 expires: 2013-07-31 usage: E
//
// So this should select the first, non-expired encryption key.
key, _ := entity.encryptionKey(time1)
if id := key.PublicKey.KeyIdShortString(); id != "1ABB25A0" {
t.Errorf("Expected key 1ABB25A0 at time %s, but got key %s", time1.Format(timeFormat), id)
}
// Once the first encryption subkey has expired, the second should be
// selected.
time2, _ := time.Parse(timeFormat, "2013-07-09")
key, _ = entity.encryptionKey(time2)
if id := key.PublicKey.KeyIdShortString(); id != "96A672F5" {
t.Errorf("Expected key 96A672F5 at time %s, but got key %s", time2.Format(timeFormat), id)
}
// Once all the keys have expired, nothing should be returned.
time3, _ := time.Parse(timeFormat, "2013-08-01")
if key, ok := entity.encryptionKey(time3); ok {
t.Errorf("Expected no key at time %s, but got key %s", time3.Format(timeFormat), key.PublicKey.KeyIdShortString())
}
}
const expiringKeyHex = "988d0451d1ec5d010400ba3385721f2dc3f4ab096b2ee867ab77213f0a27a8538441c35d2fa225b08798a1439a66a5150e6bdc3f40f5d28d588c712394c632b6299f77db8c0d48d37903fb72ebd794d61be6aa774688839e5fdecfe06b2684cc115d240c98c66cb1ef22ae84e3aa0c2b0c28665c1e7d4d044e7f270706193f5223c8d44e0d70b7b8da830011010001b40f4578706972792074657374206b657988be041301020028050251d1ec5d021b03050900278d00060b090807030206150802090a0b0416020301021e01021780000a091072589ad75e237d8c033503fd10506d72837834eb7f994117740723adc39227104b0d326a1161871c0b415d25b4aedef946ca77ea4c05af9c22b32cf98be86ab890111fced1ee3f75e87b7cc3c00dc63bbc85dfab91c0dc2ad9de2c4d13a34659333a85c6acc1a669c5e1d6cecb0cf1e56c10e72d855ae177ddc9e766f9b2dda57ccbb75f57156438bbdb4e42b88d0451d1ec5d0104009c64906559866c5cb61578f5846a94fcee142a489c9b41e67b12bb54cfe86eb9bc8566460f9a720cb00d6526fbccfd4f552071a8e3f7744b1882d01036d811ee5a3fb91a1c568055758f43ba5d2c6a9676b012f3a1a89e47bbf624f1ad571b208f3cc6224eb378f1645dd3d47584463f9eadeacfd1ce6f813064fbfdcc4b5a53001101000188a504180102000f021b0c050251d1f06b050900093e89000a091072589ad75e237d8c20e00400ab8310a41461425b37889c4da28129b5fae6084fafbc0a47dd1adc74a264c6e9c9cc125f40462ee1433072a58384daef88c961c390ed06426a81b464a53194c4e291ddd7e2e2ba3efced01537d713bd111f48437bde2363446200995e8e0d4e528dda377fd1e8f8ede9c8e2198b393bd86852ce7457a7e3daf74d510461a5b77b88d0451d1ece8010400b3a519f83ab0010307e83bca895170acce8964a044190a2b368892f7a244758d9fc193482648acb1fb9780d28cc22d171931f38bb40279389fc9bf2110876d4f3db4fcfb13f22f7083877fe56592b3b65251312c36f83ffcb6d313c6a17f197dd471f0712aad15a8537b435a92471ba2e5b0c72a6c72536c3b567c558d7b6051001101000188a504180102000f021b0c050251d1f07b050900279091000a091072589ad75e237d8ce69e03fe286026afacf7c97ee20673864d4459a2240b5655219950643c7dba0ac384b1d4359c67805b21d98211f7b09c2a0ccf6410c8c04d4ff4a51293725d8d6570d9d8bb0e10c07d22357caeb49626df99c180be02d77d1fe8ed25e7a54481237646083a9f89a11566cd20b9e995b1487c5f9e02aeb434f3a1897cd416dd0a87861838da3e9e"
@@ -1,165 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
// This file implements the core Keccak permutation function necessary for computing SHA3.
// This is implemented in a separate file to allow for replacement by an optimized implementation.
// Nothing in this package is exported.
// For the detailed specification, refer to the Keccak web site (http://keccak.noekeon.org/).
// rc stores the round constants for use in the ι step.
var rc = [...]uint64{
0x0000000000000001,
0x0000000000008082,
0x800000000000808A,
0x8000000080008000,
0x000000000000808B,
0x0000000080000001,
0x8000000080008081,
0x8000000000008009,
0x000000000000008A,
0x0000000000000088,
0x0000000080008009,
0x000000008000000A,
0x000000008000808B,
0x800000000000008B,
0x8000000000008089,
0x8000000000008003,
0x8000000000008002,
0x8000000000000080,
0x000000000000800A,
0x800000008000000A,
0x8000000080008081,
0x8000000000008080,
0x0000000080000001,
0x8000000080008008,
}
// keccakF computes the complete Keccak-f function consisting of 24 rounds with a different
// constant (rc) in each round. This implementation fully unrolls the round function to avoid
// inner loops, as well as pre-calculating shift offsets.
func keccakF(a *[numLanes]uint64) {
var t, bc0, bc1, bc2, bc3, bc4 uint64
for _, roundConstant := range rc {
// θ step
bc0 = a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20]
bc1 = a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21]
bc2 = a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22]
bc3 = a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23]
bc4 = a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24]
t = bc4 ^ (bc1<<1 ^ bc1>>63)
a[0] ^= t
a[5] ^= t
a[10] ^= t
a[15] ^= t
a[20] ^= t
t = bc0 ^ (bc2<<1 ^ bc2>>63)
a[1] ^= t
a[6] ^= t
a[11] ^= t
a[16] ^= t
a[21] ^= t
t = bc1 ^ (bc3<<1 ^ bc3>>63)
a[2] ^= t
a[7] ^= t
a[12] ^= t
a[17] ^= t
a[22] ^= t
t = bc2 ^ (bc4<<1 ^ bc4>>63)
a[3] ^= t
a[8] ^= t
a[13] ^= t
a[18] ^= t
a[23] ^= t
t = bc3 ^ (bc0<<1 ^ bc0>>63)
a[4] ^= t
a[9] ^= t
a[14] ^= t
a[19] ^= t
a[24] ^= t
// ρ and π steps
t = a[1]
t, a[10] = a[10], t<<1^t>>(64-1)
t, a[7] = a[7], t<<3^t>>(64-3)
t, a[11] = a[11], t<<6^t>>(64-6)
t, a[17] = a[17], t<<10^t>>(64-10)
t, a[18] = a[18], t<<15^t>>(64-15)
t, a[3] = a[3], t<<21^t>>(64-21)
t, a[5] = a[5], t<<28^t>>(64-28)
t, a[16] = a[16], t<<36^t>>(64-36)
t, a[8] = a[8], t<<45^t>>(64-45)
t, a[21] = a[21], t<<55^t>>(64-55)
t, a[24] = a[24], t<<2^t>>(64-2)
t, a[4] = a[4], t<<14^t>>(64-14)
t, a[15] = a[15], t<<27^t>>(64-27)
t, a[23] = a[23], t<<41^t>>(64-41)
t, a[19] = a[19], t<<56^t>>(64-56)
t, a[13] = a[13], t<<8^t>>(64-8)
t, a[12] = a[12], t<<25^t>>(64-25)
t, a[2] = a[2], t<<43^t>>(64-43)
t, a[20] = a[20], t<<62^t>>(64-62)
t, a[14] = a[14], t<<18^t>>(64-18)
t, a[22] = a[22], t<<39^t>>(64-39)
t, a[9] = a[9], t<<61^t>>(64-61)
t, a[6] = a[6], t<<20^t>>(64-20)
a[1] = t<<44 ^ t>>(64-44)
// χ step
bc0 = a[0]
bc1 = a[1]
bc2 = a[2]
bc3 = a[3]
bc4 = a[4]
a[0] ^= bc2 &^ bc1
a[1] ^= bc3 &^ bc2
a[2] ^= bc4 &^ bc3
a[3] ^= bc0 &^ bc4
a[4] ^= bc1 &^ bc0
bc0 = a[5]
bc1 = a[6]
bc2 = a[7]
bc3 = a[8]
bc4 = a[9]
a[5] ^= bc2 &^ bc1
a[6] ^= bc3 &^ bc2
a[7] ^= bc4 &^ bc3
a[8] ^= bc0 &^ bc4
a[9] ^= bc1 &^ bc0
bc0 = a[10]
bc1 = a[11]
bc2 = a[12]
bc3 = a[13]
bc4 = a[14]
a[10] ^= bc2 &^ bc1
a[11] ^= bc3 &^ bc2
a[12] ^= bc4 &^ bc3
a[13] ^= bc0 &^ bc4
a[14] ^= bc1 &^ bc0
bc0 = a[15]
bc1 = a[16]
bc2 = a[17]
bc3 = a[18]
bc4 = a[19]
a[15] ^= bc2 &^ bc1
a[16] ^= bc3 &^ bc2
a[17] ^= bc4 &^ bc3
a[18] ^= bc0 &^ bc4
a[19] ^= bc1 &^ bc0
bc0 = a[20]
bc1 = a[21]
bc2 = a[22]
bc3 = a[23]
bc4 = a[24]
a[20] ^= bc2 &^ bc1
a[21] ^= bc3 &^ bc2
a[22] ^= bc4 &^ bc3
a[23] ^= bc0 &^ bc4
a[24] ^= bc1 &^ bc0
// ι step
a[0] ^= roundConstant
}
}
@@ -1,213 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha3 implements the SHA3 hash algorithm (formerly called Keccak) chosen by NIST in 2012.
// This file provides a SHA3 implementation which implements the standard hash.Hash interface.
// Writing input data, including padding, and reading output data are computed in this file.
// Note that the current implementation can compute the hash of an integral number of bytes only.
// This is a consequence of the hash interface in which a buffer of bytes is passed in.
// The internals of the Keccak-f function are computed in keccakf.go.
// For the detailed specification, refer to the Keccak web site (http://keccak.noekeon.org/).
package sha3
import (
"encoding/binary"
"hash"
)
// laneSize is the size in bytes of each "lane" of the internal state of SHA3 (5 * 5 * 8).
// Note that changing this size would requires using a type other than uint64 to store each lane.
const laneSize = 8
// sliceSize represents the dimensions of the internal state, a square matrix of
// sliceSize ** 2 lanes. This is the size of both the "rows" and "columns" dimensions in the
// terminology of the SHA3 specification.
const sliceSize = 5
// numLanes represents the total number of lanes in the state.
const numLanes = sliceSize * sliceSize
// stateSize is the size in bytes of the internal state of SHA3 (5 * 5 * WSize).
const stateSize = laneSize * numLanes
// digest represents the partial evaluation of a checksum.
// Note that capacity, and not outputSize, is the critical security parameter, as SHA3 can output
// an arbitrary number of bytes for any given capacity. The Keccak proposal recommends that
// capacity = 2*outputSize to ensure that finding a collision of size outputSize requires
// O(2^{outputSize/2}) computations (the birthday lower bound). Future standards may modify the
// capacity/outputSize ratio to allow for more output with lower cryptographic security.
type digest struct {
a [numLanes]uint64 // main state of the hash
outputSize int // desired output size in bytes
capacity int // number of bytes to leave untouched during squeeze/absorb
absorbed int // number of bytes absorbed thus far
}
// minInt returns the lesser of two integer arguments, to simplify the absorption routine.
func minInt(v1, v2 int) int {
if v1 <= v2 {
return v1
}
return v2
}
// rate returns the number of bytes of the internal state which can be absorbed or squeezed
// in between calls to the permutation function.
func (d *digest) rate() int {
return stateSize - d.capacity
}
// Reset clears the internal state by zeroing bytes in the state buffer.
// This can be skipped for a newly-created hash state; the default zero-allocated state is correct.
func (d *digest) Reset() {
d.absorbed = 0
for i := range d.a {
d.a[i] = 0
}
}
// BlockSize, required by the hash.Hash interface, does not have a standard intepretation
// for a sponge-based construction like SHA3. We return the data rate: the number of bytes which
// can be absorbed per invocation of the permutation function. For Merkle-Damgård based hashes
// (ie SHA1, SHA2, MD5) the output size of the internal compression function is returned.
// We consider this to be roughly equivalent because it represents the number of bytes of output
// produced per cryptographic operation.
func (d *digest) BlockSize() int { return d.rate() }
// Size returns the output size of the hash function in bytes.
func (d *digest) Size() int {
return d.outputSize
}
// unalignedAbsorb is a helper function for Write, which absorbs data that isn't aligned with an
// 8-byte lane. This requires shifting the individual bytes into position in a uint64.
func (d *digest) unalignedAbsorb(p []byte) {
var t uint64
for i := len(p) - 1; i >= 0; i-- {
t <<= 8
t |= uint64(p[i])
}
offset := (d.absorbed) % d.rate()
t <<= 8 * uint(offset%laneSize)
d.a[offset/laneSize] ^= t
d.absorbed += len(p)
}
// Write "absorbs" bytes into the state of the SHA3 hash, updating as needed when the sponge
// "fills up" with rate() bytes. Since lanes are stored internally as type uint64, this requires
// converting the incoming bytes into uint64s using a little endian interpretation. This
// implementation is optimized for large, aligned writes of multiples of 8 bytes (laneSize).
// Non-aligned or uneven numbers of bytes require shifting and are slower.
func (d *digest) Write(p []byte) (int, error) {
// An initial offset is needed if the we aren't absorbing to the first lane initially.
offset := d.absorbed % d.rate()
toWrite := len(p)
// The first lane may need to absorb unaligned and/or incomplete data.
if (offset%laneSize != 0 || len(p) < 8) && len(p) > 0 {
toAbsorb := minInt(laneSize-(offset%laneSize), len(p))
d.unalignedAbsorb(p[:toAbsorb])
p = p[toAbsorb:]
offset = (d.absorbed) % d.rate()
// For every rate() bytes absorbed, the state must be permuted via the F Function.
if (d.absorbed)%d.rate() == 0 {
keccakF(&d.a)
}
}
// This loop should absorb the bulk of the data into full, aligned lanes.
// It will call the update function as necessary.
for len(p) > 7 {
firstLane := offset / laneSize
lastLane := minInt(d.rate()/laneSize, firstLane+len(p)/laneSize)
// This inner loop absorbs input bytes into the state in groups of 8, converted to uint64s.
for lane := firstLane; lane < lastLane; lane++ {
d.a[lane] ^= binary.LittleEndian.Uint64(p[:laneSize])
p = p[laneSize:]
}
d.absorbed += (lastLane - firstLane) * laneSize
// For every rate() bytes absorbed, the state must be permuted via the F Function.
if (d.absorbed)%d.rate() == 0 {
keccakF(&d.a)
}
offset = 0
}
// If there are insufficient bytes to fill the final lane, an unaligned absorption.
// This should always start at a correct lane boundary though, or else it would be caught
// by the uneven opening lane case above.
if len(p) > 0 {
d.unalignedAbsorb(p)
}
return toWrite, nil
}
// pad computes the SHA3 padding scheme based on the number of bytes absorbed.
// The padding is a 1 bit, followed by an arbitrary number of 0s and then a final 1 bit, such that
// the input bits plus padding bits are a multiple of rate(). Adding the padding simply requires
// xoring an opening and closing bit into the appropriate lanes.
func (d *digest) pad() {
offset := d.absorbed % d.rate()
// The opening pad bit must be shifted into position based on the number of bytes absorbed
padOpenLane := offset / laneSize
d.a[padOpenLane] ^= 0x0000000000000001 << uint(8*(offset%laneSize))
// The closing padding bit is always in the last position
padCloseLane := (d.rate() / laneSize) - 1
d.a[padCloseLane] ^= 0x8000000000000000
}
// finalize prepares the hash to output data by padding and one final permutation of the state.
func (d *digest) finalize() {
d.pad()
keccakF(&d.a)
}
// squeeze outputs an arbitrary number of bytes from the hash state.
// Squeezing can require multiple calls to the F function (one per rate() bytes squeezed),
// although this is not the case for standard SHA3 parameters. This implementation only supports
// squeezing a single time, subsequent squeezes may lose alignment. Future implementations
// may wish to support multiple squeeze calls, for example to support use as a PRNG.
func (d *digest) squeeze(in []byte, toSqueeze int) []byte {
// Because we read in blocks of laneSize, we need enough room to read
// an integral number of lanes
needed := toSqueeze + (laneSize-toSqueeze%laneSize)%laneSize
if cap(in)-len(in) < needed {
newIn := make([]byte, len(in), len(in)+needed)
copy(newIn, in)
in = newIn
}
out := in[len(in) : len(in)+needed]
for len(out) > 0 {
for i := 0; i < d.rate() && len(out) > 0; i += laneSize {
binary.LittleEndian.PutUint64(out[:], d.a[i/laneSize])
out = out[laneSize:]
}
if len(out) > 0 {
keccakF(&d.a)
}
}
return in[:len(in)+toSqueeze] // Re-slice in case we wrote extra data.
}
// Sum applies padding to the hash state and then squeezes out the desired nubmer of output bytes.
func (d *digest) Sum(in []byte) []byte {
// Make a copy of the original hash so that caller can keep writing and summing.
dup := *d
dup.finalize()
return dup.squeeze(in, dup.outputSize)
}
// The NewKeccakX constructors enable initializing a hash in any of the four recommend sizes
// from the Keccak specification, all of which set capacity=2*outputSize. Note that the final
// NIST standard for SHA3 may specify different input/output lengths.
// The output size is indicated in bits but converted into bytes internally.
func NewKeccak224() hash.Hash { return &digest{outputSize: 224 / 8, capacity: 2 * 224 / 8} }
func NewKeccak256() hash.Hash { return &digest{outputSize: 256 / 8, capacity: 2 * 256 / 8} }
func NewKeccak384() hash.Hash { return &digest{outputSize: 384 / 8, capacity: 2 * 384 / 8} }
func NewKeccak512() hash.Hash { return &digest{outputSize: 512 / 8, capacity: 2 * 512 / 8} }
@@ -1,270 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
// These tests are a subset of those provided by the Keccak web site(http://keccak.noekeon.org/).
import (
"bytes"
"encoding/hex"
"fmt"
"hash"
"strings"
"testing"
)
// testDigests maintains a digest state of each standard type.
var testDigests = map[string]*digest{
"Keccak224": {outputSize: 224 / 8, capacity: 2 * 224 / 8},
"Keccak256": {outputSize: 256 / 8, capacity: 2 * 256 / 8},
"Keccak384": {outputSize: 384 / 8, capacity: 2 * 384 / 8},
"Keccak512": {outputSize: 512 / 8, capacity: 2 * 512 / 8},
}
// testVector represents a test input and expected outputs from multiple algorithm variants.
type testVector struct {
desc string
input []byte
repeat int // input will be concatenated the input this many times.
want map[string]string
}
// decodeHex converts an hex-encoded string into a raw byte string.
func decodeHex(s string) []byte {
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}
// shortTestVectors stores a series of short testVectors.
// Inputs of 8, 248, and 264 bits from http://keccak.noekeon.org/ are included below.
// The standard defines additional test inputs of all sizes between 0 and 2047 bits.
// Because the current implementation can only handle an integral number of bytes,
// most of the standard test inputs can't be used.
var shortKeccakTestVectors = []testVector{
{
desc: "short-8b",
input: decodeHex("CC"),
repeat: 1,
want: map[string]string{
"Keccak224": "A9CAB59EB40A10B246290F2D6086E32E3689FAF1D26B470C899F2802",
"Keccak256": "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A",
"Keccak384": "1B84E62A46E5A201861754AF5DC95C4A1A69CAF4A796AE405680161E29572641F5FA1E8641D7958336EE7B11C58F73E9",
"Keccak512": "8630C13CBD066EA74BBE7FE468FEC1DEE10EDC1254FB4C1B7C5FD69B646E44160B8CE01D05A0908CA790DFB080F4B513BC3B6225ECE7A810371441A5AC666EB9",
},
},
{
desc: "short-248b",
input: decodeHex("84FB51B517DF6C5ACCB5D022F8F28DA09B10232D42320FFC32DBECC3835B29"),
repeat: 1,
want: map[string]string{
"Keccak224": "81AF3A7A5BD4C1F948D6AF4B96F93C3B0CF9C0E7A6DA6FCD71EEC7F6",
"Keccak256": "D477FB02CAAA95B3280EC8EE882C29D9E8A654B21EF178E0F97571BF9D4D3C1C",
"Keccak384": "503DCAA4ADDA5A9420B2E436DD62D9AB2E0254295C2982EF67FCE40F117A2400AB492F7BD5D133C6EC2232268BC27B42",
"Keccak512": "9D8098D8D6EDBBAA2BCFC6FB2F89C3EAC67FEC25CDFE75AA7BD570A648E8C8945FF2EC280F6DCF73386109155C5BBC444C707BB42EAB873F5F7476657B1BC1A8",
},
},
{
desc: "short-264b",
input: decodeHex("DE8F1B3FAA4B7040ED4563C3B8E598253178E87E4D0DF75E4FF2F2DEDD5A0BE046"),
repeat: 1,
want: map[string]string{
"Keccak224": "F217812E362EC64D4DC5EACFABC165184BFA456E5C32C2C7900253D0",
"Keccak256": "E78C421E6213AFF8DE1F025759A4F2C943DB62BBDE359C8737E19B3776ED2DD2",
"Keccak384": "CF38764973F1EC1C34B5433AE75A3AAD1AAEF6AB197850C56C8617BCD6A882F6666883AC17B2DCCDBAA647075D0972B5",
"Keccak512": "9A7688E31AAF40C15575FC58C6B39267AAD3722E696E518A9945CF7F7C0FEA84CB3CB2E9F0384A6B5DC671ADE7FB4D2B27011173F3EEEAF17CB451CF26542031",
},
},
}
// longTestVectors stores longer testVectors (currently only one).
// The computed test vector is 64 MiB long and is a truncated version of the
// ExtremelyLongMsgKAT taken from http://keccak.noekeon.org/.
var longKeccakTestVectors = []testVector{
{
desc: "long-64MiB",
input: []byte("abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmno"),
repeat: 1024 * 1024,
want: map[string]string{
"Keccak224": "50E35E40980FEEFF1EA490957B0E970257F75EA0D410EE0F0B8A7A58",
"Keccak256": "5015A4935F0B51E091C6550A94DCD262C08998232CCAA22E7F0756DEAC0DC0D0",
"Keccak384": "7907A8D0FAA7BC6A90FE14C6C958C956A0877E751455D8F13ACDB96F144B5896E716C06EC0CB56557A94EF5C3355F6F3",
"Keccak512": "3EC327D6759F769DEB74E80CA70C831BC29CAB048A4BF4190E4A1DD5C6507CF2B4B58937FDE81D36014E7DFE1B1DD8B0F27CB7614F9A645FEC114F1DAAEFC056",
},
},
}
// TestKeccakVectors checks that correct output is produced for a set of known testVectors.
func TestKeccakVectors(t *testing.T) {
testCases := append([]testVector{}, shortKeccakTestVectors...)
if !testing.Short() {
testCases = append(testCases, longKeccakTestVectors...)
}
for _, tc := range testCases {
for alg, want := range tc.want {
d := testDigests[alg]
d.Reset()
for i := 0; i < tc.repeat; i++ {
d.Write(tc.input)
}
got := strings.ToUpper(hex.EncodeToString(d.Sum(nil)))
if got != want {
t.Errorf("%s, alg=%s\ngot %q, want %q", tc.desc, alg, got, want)
}
}
}
}
// dumpState is a debugging function to pretty-print the internal state of the hash.
func (d *digest) dumpState() {
fmt.Printf("SHA3 hash, %d B output, %d B capacity (%d B rate)\n", d.outputSize, d.capacity, d.rate())
fmt.Printf("Internal state after absorbing %d B:\n", d.absorbed)
for x := 0; x < sliceSize; x++ {
for y := 0; y < sliceSize; y++ {
fmt.Printf("%v, ", d.a[x*sliceSize+y])
}
fmt.Println("")
}
}
// TestUnalignedWrite tests that writing data in an arbitrary pattern with small input buffers.
func TestUnalignedWrite(t *testing.T) {
buf := sequentialBytes(0x10000)
for alg, d := range testDigests {
d.Reset()
d.Write(buf)
want := d.Sum(nil)
d.Reset()
for i := 0; i < len(buf); {
// Cycle through offsets which make a 137 byte sequence.
// Because 137 is prime this sequence should exercise all corner cases.
offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1}
for _, j := range offsets {
j = minInt(j, len(buf)-i)
d.Write(buf[i : i+j])
i += j
}
}
got := d.Sum(nil)
if !bytes.Equal(got, want) {
t.Errorf("Unaligned writes, alg=%s\ngot %q, want %q", alg, got, want)
}
}
}
func TestAppend(t *testing.T) {
d := NewKeccak224()
for capacity := 2; capacity < 64; capacity += 64 {
// The first time around the loop, Sum will have to reallocate.
// The second time, it will not.
buf := make([]byte, 2, capacity)
d.Reset()
d.Write([]byte{0xcc})
buf = d.Sum(buf)
expected := "0000A9CAB59EB40A10B246290F2D6086E32E3689FAF1D26B470C899F2802"
if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
t.Errorf("got %s, want %s", got, expected)
}
}
}
func TestAppendNoRealloc(t *testing.T) {
buf := make([]byte, 1, 200)
d := NewKeccak224()
d.Write([]byte{0xcc})
buf = d.Sum(buf)
expected := "00A9CAB59EB40A10B246290F2D6086E32E3689FAF1D26B470C899F2802"
if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
t.Errorf("got %s, want %s", got, expected)
}
}
// sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing.
func sequentialBytes(size int) []byte {
result := make([]byte, size)
for i := range result {
result[i] = byte(i)
}
return result
}
// benchmarkBlockWrite tests the speed of writing data and never calling the permutation function.
func benchmarkBlockWrite(b *testing.B, d *digest) {
b.StopTimer()
d.Reset()
// Write all but the last byte of a block, to ensure that the permutation is not called.
data := sequentialBytes(d.rate() - 1)
b.SetBytes(int64(len(data)))
b.StartTimer()
for i := 0; i < b.N; i++ {
d.absorbed = 0 // Reset absorbed to avoid ever calling the permutation function
d.Write(data)
}
b.StopTimer()
d.Reset()
}
// BenchmarkPermutationFunction measures the speed of the permutation function with no input data.
func BenchmarkPermutationFunction(b *testing.B) {
b.SetBytes(int64(stateSize))
var lanes [numLanes]uint64
for i := 0; i < b.N; i++ {
keccakF(&lanes)
}
}
// BenchmarkSingleByteWrite tests the latency from writing a single byte
func BenchmarkSingleByteWrite(b *testing.B) {
b.StopTimer()
d := testDigests["Keccak512"]
d.Reset()
data := sequentialBytes(1) //1 byte buffer
b.SetBytes(int64(d.rate()) - 1)
b.StartTimer()
for i := 0; i < b.N; i++ {
d.absorbed = 0 // Reset absorbed to avoid ever calling the permutation function
// Write all but the last byte of a block, one byte at a time.
for j := 0; j < d.rate()-1; j++ {
d.Write(data)
}
}
b.StopTimer()
d.Reset()
}
// BenchmarkSingleByteX measures the block write speed for each size of the digest.
func BenchmarkBlockWrite512(b *testing.B) { benchmarkBlockWrite(b, testDigests["Keccak512"]) }
func BenchmarkBlockWrite384(b *testing.B) { benchmarkBlockWrite(b, testDigests["Keccak384"]) }
func BenchmarkBlockWrite256(b *testing.B) { benchmarkBlockWrite(b, testDigests["Keccak256"]) }
func BenchmarkBlockWrite224(b *testing.B) { benchmarkBlockWrite(b, testDigests["Keccak224"]) }
// benchmarkBulkHash tests the speed to hash a 16 KiB buffer.
func benchmarkBulkHash(b *testing.B, h hash.Hash) {
b.StopTimer()
h.Reset()
size := 1 << 14
data := sequentialBytes(size)
b.SetBytes(int64(size))
b.StartTimer()
var digest []byte
for i := 0; i < b.N; i++ {
h.Write(data)
digest = h.Sum(digest[:0])
}
b.StopTimer()
h.Reset()
}
// benchmarkBulkKeccakX test the speed to hash a 16 KiB buffer by calling benchmarkBulkHash.
func BenchmarkBulkKeccak512(b *testing.B) { benchmarkBulkHash(b, NewKeccak512()) }
func BenchmarkBulkKeccak384(b *testing.B) { benchmarkBulkHash(b, NewKeccak384()) }
func BenchmarkBulkKeccak256(b *testing.B) { benchmarkBulkHash(b, NewKeccak256()) }
func BenchmarkBulkKeccak224(b *testing.B) { benchmarkBulkHash(b, NewKeccak224()) }
@@ -1,250 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"encoding/base64"
"errors"
"io"
"sync"
)
// See [PROTOCOL.agent], section 3.
const (
// 3.2 Requests from client to agent for protocol 2 key operations
agentRequestIdentities = 11
agentSignRequest = 13
agentAddIdentity = 17
agentRemoveIdentity = 18
agentRemoveAllIdentities = 19
agentAddIdConstrained = 25
// 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20
agentRemoveSmartcardKey = 21
agentLock = 22
agentUnlock = 23
agentAddSmartcardKeyConstrained = 26
// 3.4 Generic replies from agent to client
agentFailure = 5
agentSuccess = 6
// 3.6 Replies from agent to client for protocol 2 key operations
agentIdentitiesAnswer = 12
agentSignResponse = 14
// 3.7 Key constraint identifiers
agentConstrainLifetime = 1
agentConstrainConfirm = 2
)
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
// is a sanity check, not a limit in the spec.
const maxAgentResponseBytes = 16 << 20
// Agent messages:
// These structures mirror the wire format of the corresponding ssh agent
// messages found in [PROTOCOL.agent].
type failureAgentMsg struct{}
type successAgentMsg struct{}
// See [PROTOCOL.agent], section 2.5.2.
type requestIdentitiesAgentMsg struct{}
// See [PROTOCOL.agent], section 2.5.2.
type identitiesAnswerAgentMsg struct {
NumKeys uint32
Keys []byte `ssh:"rest"`
}
// See [PROTOCOL.agent], section 2.6.2.
type signRequestAgentMsg struct {
KeyBlob []byte
Data []byte
Flags uint32
}
// See [PROTOCOL.agent], section 2.6.2.
type signResponseAgentMsg struct {
SigBlob []byte
}
// AgentKey represents a protocol 2 key as defined in [PROTOCOL.agent],
// section 2.5.2.
type AgentKey struct {
blob []byte
Comment string
}
// String returns the storage form of an agent key with the format, base64
// encoded serialized key, and the comment if it is not empty.
func (ak *AgentKey) String() string {
algo, _, ok := parseString(ak.blob)
if !ok {
return "ssh: malformed key"
}
s := string(algo) + " " + base64.StdEncoding.EncodeToString(ak.blob)
if ak.Comment != "" {
s += " " + ak.Comment
}
return s
}
// Key returns an agent's public key as one of the supported key or certificate types.
func (ak *AgentKey) Key() (PublicKey, error) {
if key, _, ok := ParsePublicKey(ak.blob); ok {
return key, nil
}
return nil, errors.New("ssh: failed to parse key blob")
}
func parseAgentKey(in []byte) (out *AgentKey, rest []byte, ok bool) {
ak := new(AgentKey)
if ak.blob, in, ok = parseString(in); !ok {
return
}
comment, in, ok := parseString(in)
if !ok {
return
}
ak.Comment = string(comment)
return ak, in, true
}
// AgentClient provides a means to communicate with an ssh agent process based
// on the protocol described in [PROTOCOL.agent]?rev=1.6.
type AgentClient struct {
// conn is typically represented by using a *net.UnixConn
conn io.ReadWriter
// mu is used to prevent concurrent access to the agent
mu sync.Mutex
}
// NewAgentClient creates and returns a new *AgentClient using the
// passed in io.ReadWriter as a connection to a ssh agent.
func NewAgentClient(rw io.ReadWriter) *AgentClient {
return &AgentClient{conn: rw}
}
// sendAndReceive sends req to the agent and waits for a reply. On success,
// the reply is unmarshaled into reply and replyType is set to the first byte of
// the reply, which contains the type of the message.
func (ac *AgentClient) sendAndReceive(req []byte) (reply interface{}, replyType uint8, err error) {
// ac.mu prevents multiple, concurrent requests. Since the agent is typically
// on the same machine, we don't attempt to pipeline the requests.
ac.mu.Lock()
defer ac.mu.Unlock()
msg := make([]byte, stringLength(len(req)))
marshalString(msg, req)
if _, err = ac.conn.Write(msg); err != nil {
return
}
var respSizeBuf [4]byte
if _, err = io.ReadFull(ac.conn, respSizeBuf[:]); err != nil {
return
}
respSize, _, _ := parseUint32(respSizeBuf[:])
if respSize > maxAgentResponseBytes {
err = errors.New("ssh: agent reply too large")
return
}
buf := make([]byte, respSize)
if _, err = io.ReadFull(ac.conn, buf); err != nil {
return
}
return unmarshalAgentMsg(buf)
}
// RequestIdentities queries the agent for protocol 2 keys as defined in
// [PROTOCOL.agent] section 2.5.2.
func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
req := marshal(agentRequestIdentities, requestIdentitiesAgentMsg{})
msg, msgType, err := ac.sendAndReceive(req)
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *identitiesAnswerAgentMsg:
if msg.NumKeys > maxAgentResponseBytes/8 {
return nil, errors.New("ssh: too many keys in agent reply")
}
keys := make([]*AgentKey, msg.NumKeys)
data := msg.Keys
for i := uint32(0); i < msg.NumKeys; i++ {
var key *AgentKey
var ok bool
if key, data, ok = parseAgentKey(data); !ok {
return nil, ParseError{agentIdentitiesAnswer}
}
keys[i] = key
}
return keys, nil
case *failureAgentMsg:
return nil, errors.New("ssh: failed to list keys")
}
return nil, UnexpectedMessageError{agentIdentitiesAnswer, msgType}
}
// SignRequest requests the signing of data by the agent using a protocol 2 key
// as defined in [PROTOCOL.agent] section 2.6.2.
func (ac *AgentClient) SignRequest(key PublicKey, data []byte) ([]byte, error) {
req := marshal(agentSignRequest, signRequestAgentMsg{
KeyBlob: MarshalPublicKey(key),
Data: data,
})
msg, msgType, err := ac.sendAndReceive(req)
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *signResponseAgentMsg:
return msg.SigBlob, nil
case *failureAgentMsg:
return nil, errors.New("ssh: failed to sign challenge")
}
return nil, UnexpectedMessageError{agentSignResponse, msgType}
}
// unmarshalAgentMsg parses an agent message in packet, returning the parsed
// form and the message type of packet.
func unmarshalAgentMsg(packet []byte) (interface{}, uint8, error) {
if len(packet) < 1 {
return nil, 0, ParseError{0}
}
var msg interface{}
switch packet[0] {
case agentFailure:
msg = new(failureAgentMsg)
case agentSuccess:
msg = new(successAgentMsg)
case agentIdentitiesAnswer:
msg = new(identitiesAnswerAgentMsg)
case agentSignResponse:
msg = new(signResponseAgentMsg)
default:
return nil, 0, UnexpectedMessageError{0, packet[0]}
}
if err := unmarshal(msg, packet, packet[0]); err != nil {
return nil, 0, err
}
return msg, packet[0], nil
}
@@ -1,378 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"time"
)
// These constants from [PROTOCOL.certkeys] represent the algorithm names
// for certificate types supported by this package.
const (
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
)
// Certificate types are used to specify whether a certificate is for identification
// of a user or a host. Current identities are defined in [PROTOCOL.certkeys].
const (
UserCert = 1
HostCert = 2
)
type signature struct {
Format string
Blob []byte
}
type tuple struct {
Name string
Data string
}
const (
maxUint64 = 1<<64 - 1
maxInt64 = 1<<63 - 1
)
// CertTime represents an unsigned 64-bit time value in seconds starting from
// UNIX epoch. We use CertTime instead of time.Time in order to properly handle
// the "infinite" time value ^0, which would become negative when expressed as
// an int64.
type CertTime uint64
func (ct CertTime) Time() time.Time {
if ct > maxInt64 {
return time.Unix(maxInt64, 0)
}
return time.Unix(int64(ct), 0)
}
func (ct CertTime) IsInfinite() bool {
return ct == maxUint64
}
// An OpenSSHCertV01 represents an OpenSSH certificate as defined in
// [PROTOCOL.certkeys]?rev=1.8.
type OpenSSHCertV01 struct {
Nonce []byte
Key PublicKey
Serial uint64
Type uint32
KeyId string
ValidPrincipals []string
ValidAfter, ValidBefore CertTime
CriticalOptions []tuple
Extensions []tuple
Reserved []byte
SignatureKey PublicKey
Signature *signature
}
// validateOpenSSHCertV01Signature uses the cert's SignatureKey to verify that
// the cert's Signature.Blob is the result of signing the cert bytes starting
// from the algorithm string and going up to and including the SignatureKey.
func validateOpenSSHCertV01Signature(cert *OpenSSHCertV01) bool {
return cert.SignatureKey.Verify(cert.BytesForSigning(), cert.Signature.Blob)
}
var certAlgoNames = map[string]string{
KeyAlgoRSA: CertAlgoRSAv01,
KeyAlgoDSA: CertAlgoDSAv01,
KeyAlgoECDSA256: CertAlgoECDSA256v01,
KeyAlgoECDSA384: CertAlgoECDSA384v01,
KeyAlgoECDSA521: CertAlgoECDSA521v01,
}
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
// Panics if a non-certificate algorithm is passed.
func certToPrivAlgo(algo string) string {
for privAlgo, pubAlgo := range certAlgoNames {
if pubAlgo == algo {
return privAlgo
}
}
panic("unknown cert algorithm")
}
func (cert *OpenSSHCertV01) marshal(includeAlgo, includeSig bool) []byte {
algoName := cert.PublicKeyAlgo()
pubKey := cert.Key.Marshal()
sigKey := MarshalPublicKey(cert.SignatureKey)
var length int
if includeAlgo {
length += stringLength(len(algoName))
}
length += stringLength(len(cert.Nonce))
length += len(pubKey)
length += 8 // Length of Serial
length += 4 // Length of Type
length += stringLength(len(cert.KeyId))
length += lengthPrefixedNameListLength(cert.ValidPrincipals)
length += 8 // Length of ValidAfter
length += 8 // Length of ValidBefore
length += tupleListLength(cert.CriticalOptions)
length += tupleListLength(cert.Extensions)
length += stringLength(len(cert.Reserved))
length += stringLength(len(sigKey))
if includeSig {
length += signatureLength(cert.Signature)
}
ret := make([]byte, length)
r := ret
if includeAlgo {
r = marshalString(r, []byte(algoName))
}
r = marshalString(r, cert.Nonce)
copy(r, pubKey)
r = r[len(pubKey):]
r = marshalUint64(r, cert.Serial)
r = marshalUint32(r, cert.Type)
r = marshalString(r, []byte(cert.KeyId))
r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals)
r = marshalUint64(r, uint64(cert.ValidAfter))
r = marshalUint64(r, uint64(cert.ValidBefore))
r = marshalTupleList(r, cert.CriticalOptions)
r = marshalTupleList(r, cert.Extensions)
r = marshalString(r, cert.Reserved)
r = marshalString(r, sigKey)
if includeSig {
r = marshalSignature(r, cert.Signature)
}
if len(r) > 0 {
panic("ssh: internal error, marshaling certificate did not fill the entire buffer")
}
return ret
}
func (cert *OpenSSHCertV01) BytesForSigning() []byte {
return cert.marshal(true, false)
}
func (cert *OpenSSHCertV01) Marshal() []byte {
return cert.marshal(false, true)
}
func (c *OpenSSHCertV01) PublicKeyAlgo() string {
algo, ok := certAlgoNames[c.Key.PublicKeyAlgo()]
if !ok {
panic("unknown cert key type")
}
return algo
}
func (c *OpenSSHCertV01) PrivateKeyAlgo() string {
return c.Key.PrivateKeyAlgo()
}
func (c *OpenSSHCertV01) Verify(data []byte, sig []byte) bool {
return c.Key.Verify(data, sig)
}
func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []byte, ok bool) {
cert := new(OpenSSHCertV01)
if cert.Nonce, in, ok = parseString(in); !ok {
return
}
privAlgo := certToPrivAlgo(algo)
cert.Key, in, ok = parsePubKey(in, privAlgo)
if !ok {
return
}
// We test PublicKeyAlgo to make sure we don't use some weird sub-cert.
if cert.Key.PublicKeyAlgo() != privAlgo {
ok = false
return
}
if cert.Serial, in, ok = parseUint64(in); !ok {
return
}
if cert.Type, in, ok = parseUint32(in); !ok {
return
}
keyId, in, ok := parseString(in)
if !ok {
return
}
cert.KeyId = string(keyId)
if cert.ValidPrincipals, in, ok = parseLengthPrefixedNameList(in); !ok {
return
}
va, in, ok := parseUint64(in)
if !ok {
return
}
cert.ValidAfter = CertTime(va)
vb, in, ok := parseUint64(in)
if !ok {
return
}
cert.ValidBefore = CertTime(vb)
if cert.CriticalOptions, in, ok = parseTupleList(in); !ok {
return
}
if cert.Extensions, in, ok = parseTupleList(in); !ok {
return
}
if cert.Reserved, in, ok = parseString(in); !ok {
return
}
sigKey, in, ok := parseString(in)
if !ok {
return
}
if cert.SignatureKey, _, ok = ParsePublicKey(sigKey); !ok {
return
}
if cert.Signature, in, ok = parseSignature(in); !ok {
return
}
ok = true
return cert, in, ok
}
func lengthPrefixedNameListLength(namelist []string) int {
length := 4 // length prefix for list
for _, name := range namelist {
length += 4 // length prefix for name
length += len(name)
}
return length
}
func marshalLengthPrefixedNameList(to []byte, namelist []string) []byte {
length := uint32(lengthPrefixedNameListLength(namelist) - 4)
to = marshalUint32(to, length)
for _, name := range namelist {
to = marshalString(to, []byte(name))
}
return to
}
func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool) {
list, rest, ok := parseString(in)
if !ok {
return
}
for len(list) > 0 {
var next []byte
if next, list, ok = parseString(list); !ok {
return nil, nil, false
}
out = append(out, string(next))
}
ok = true
return
}
func tupleListLength(tupleList []tuple) int {
length := 4 // length prefix for list
for _, t := range tupleList {
length += 4 // length prefix for t.Name
length += len(t.Name)
length += 4 // length prefix for t.Data
length += len(t.Data)
}
return length
}
func marshalTupleList(to []byte, tuplelist []tuple) []byte {
length := uint32(tupleListLength(tuplelist) - 4)
to = marshalUint32(to, length)
for _, t := range tuplelist {
to = marshalString(to, []byte(t.Name))
to = marshalString(to, []byte(t.Data))
}
return to
}
func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) {
list, rest, ok := parseString(in)
if !ok {
return
}
for len(list) > 0 {
var name, data []byte
var ok bool
name, list, ok = parseString(list)
if !ok {
return nil, nil, false
}
data, list, ok = parseString(list)
if !ok {
return nil, nil, false
}
out = append(out, tuple{string(name), string(data)})
}
ok = true
return
}
func signatureLength(sig *signature) int {
length := 4 // length prefix for signature
length += stringLength(len(sig.Format))
length += stringLength(len(sig.Blob))
return length
}
func marshalSignature(to []byte, sig *signature) []byte {
length := uint32(signatureLength(sig) - 4)
to = marshalUint32(to, length)
to = marshalString(to, []byte(sig.Format))
to = marshalString(to, sig.Blob)
return to
}
func parseSignatureBody(in []byte) (out *signature, rest []byte, ok bool) {
var format []byte
if format, in, ok = parseString(in); !ok {
return
}
out = &signature{
Format: string(format),
}
if out.Blob, in, ok = parseString(in); !ok {
return
}
return out, in, ok
}
func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
var sigBytes []byte
if sigBytes, rest, ok = parseString(in); !ok {
return
}
out, sigBytes, ok = parseSignatureBody(sigBytes)
if !ok || len(sigBytes) > 0 {
return nil, nil, false
}
return
}
@@ -1,55 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"testing"
)
// Cert generated by ssh-keygen 6.0p1 Debian-4.
// % ssh-keygen -s ca-key -I test user-key
var exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
func TestParseCert(t *testing.T) {
authKeyBytes := []byte(exampleSSHCert)
key, _, _, rest, ok := ParseAuthorizedKey(authKeyBytes)
if !ok {
t.Fatalf("could not parse certificate")
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
if _, ok = key.(*OpenSSHCertV01); !ok {
t.Fatalf("got %#v, want *OpenSSHCertV01", key)
}
marshaled := MarshalAuthorizedKey(key)
// Before comparison, remove the trailing newline that
// MarshalAuthorizedKey adds.
marshaled = marshaled[:len(marshaled)-1]
if !bytes.Equal(authKeyBytes, marshaled) {
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
}
}
func TestVerifyCert(t *testing.T) {
key, _, _, _, _ := ParseAuthorizedKey([]byte(exampleSSHCert))
validCert := key.(*OpenSSHCertV01)
if ok := validateOpenSSHCertV01Signature(validCert); !ok {
t.Error("Unable to validate certificate!")
}
invalidCert := &OpenSSHCertV01{
Key: rsaKey.PublicKey(),
SignatureKey: ecdsaKey.PublicKey(),
Signature: &signature{},
}
if ok := validateOpenSSHCertV01Signature(invalidCert); ok {
t.Error("Invalid cert signature passed validation!")
}
}
@@ -1,594 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
)
// extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
// section 5.2.
type extendedDataTypeCode uint32
const (
// extendedDataStderr is the extended data type that is used for stderr.
extendedDataStderr extendedDataTypeCode = 1
// minPacketLength defines the smallest valid packet
minPacketLength = 9
// channelMaxPacketSize defines the maximum packet size advertised in open messages
channelMaxPacketSize = 1 << 15 // RFC 4253 6.1, minimum 32 KiB
// channelWindowSize defines the window size advertised in open messages
channelWindowSize = 64 * channelMaxPacketSize // Like OpenSSH
)
// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
// SSH connection. Channel.Read can return a ChannelRequest as an error.
type Channel interface {
// Accept accepts the channel creation request.
Accept() error
// Reject rejects the channel creation request. After calling this, no
// other methods on the Channel may be called. If they are then the
// peer is likely to signal a protocol error and drop the connection.
Reject(reason RejectionReason, message string) error
// Read may return a ChannelRequest as an error.
Read(data []byte) (int, error)
Write(data []byte) (int, error)
Close() error
// Stderr returns an io.Writer that writes to this channel with the
// extended data type set to stderr.
Stderr() io.Writer
// AckRequest either sends an ack or nack to the channel request.
AckRequest(ok bool) error
// ChannelType returns the type of the channel, as supplied by the
// client.
ChannelType() string
// ExtraData returns the arbitrary payload for this channel, as supplied
// by the client. This data is specific to the channel type.
ExtraData() []byte
}
// ChannelRequest represents a request sent on a channel, outside of the normal
// stream of bytes. It may result from calling Read on a Channel.
type ChannelRequest struct {
Request string
WantReply bool
Payload []byte
}
func (c ChannelRequest) Error() string {
return "ssh: channel request received"
}
// RejectionReason is an enumeration used when rejecting channel creation
// requests. See RFC 4254, section 5.1.
type RejectionReason uint32
const (
Prohibited RejectionReason = iota + 1
ConnectionFailed
UnknownChannelType
ResourceShortage
)
// String converts the rejection reason to human readable form.
func (r RejectionReason) String() string {
switch r {
case Prohibited:
return "administratively prohibited"
case ConnectionFailed:
return "connect failed"
case UnknownChannelType:
return "unknown channel type"
case ResourceShortage:
return "resource shortage"
}
return fmt.Sprintf("unknown reason %d", int(r))
}
type channel struct {
packetConn // the underlying transport
localId, remoteId uint32
remoteWin window
maxPacket uint32
isClosed uint32 // atomic bool, non zero if true
}
func (c *channel) sendWindowAdj(n int) error {
msg := windowAdjustMsg{
PeersId: c.remoteId,
AdditionalBytes: uint32(n),
}
return c.writePacket(marshal(msgChannelWindowAdjust, msg))
}
// sendEOF sends EOF to the remote side. RFC 4254 Section 5.3
func (c *channel) sendEOF() error {
return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
PeersId: c.remoteId,
}))
}
// sendClose informs the remote side of our intent to close the channel.
func (c *channel) sendClose() error {
return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
PeersId: c.remoteId,
}))
}
func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
reject := channelOpenFailureMsg{
PeersId: c.remoteId,
Reason: reason,
Message: message,
Language: "en",
}
return c.writePacket(marshal(msgChannelOpenFailure, reject))
}
func (c *channel) writePacket(b []byte) error {
if c.closed() {
return io.EOF
}
if uint32(len(b)) > c.maxPacket {
return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
}
return c.packetConn.writePacket(b)
}
func (c *channel) closed() bool {
return atomic.LoadUint32(&c.isClosed) > 0
}
func (c *channel) setClosed() bool {
return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1)
}
type serverChan struct {
channel
// immutable once created
chanType string
extraData []byte
serverConn *ServerConn
myWindow uint32
theyClosed bool // indicates the close msg has been received from the remote side
theySentEOF bool
isDead uint32
err error
pendingRequests []ChannelRequest
pendingData []byte
head, length int
// This lock is inferior to serverConn.lock
cond *sync.Cond
}
func (c *serverChan) Accept() error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
confirm := channelOpenConfirmMsg{
PeersId: c.remoteId,
MyId: c.localId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxPacket,
}
return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
}
func (c *serverChan) Reject(reason RejectionReason, message string) error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
return c.sendChannelOpenFailure(reason, message)
}
func (c *serverChan) handlePacket(packet interface{}) {
c.cond.L.Lock()
defer c.cond.L.Unlock()
switch packet := packet.(type) {
case *channelRequestMsg:
req := ChannelRequest{
Request: packet.Request,
WantReply: packet.WantReply,
Payload: packet.RequestSpecificData,
}
c.pendingRequests = append(c.pendingRequests, req)
c.cond.Signal()
case *channelCloseMsg:
c.theyClosed = true
c.cond.Signal()
case *channelEOFMsg:
c.theySentEOF = true
c.cond.Signal()
case *windowAdjustMsg:
if !c.remoteWin.add(packet.AdditionalBytes) {
panic("illegal window update")
}
default:
panic("unknown packet type")
}
}
func (c *serverChan) handleData(data []byte) {
c.cond.L.Lock()
defer c.cond.L.Unlock()
// The other side should never send us more than our window.
if len(data)+c.length > len(c.pendingData) {
// TODO(agl): we should tear down the channel with a protocol
// error.
return
}
c.myWindow -= uint32(len(data))
for i := 0; i < 2; i++ {
tail := c.head + c.length
if tail >= len(c.pendingData) {
tail -= len(c.pendingData)
}
n := copy(c.pendingData[tail:], data)
data = data[n:]
c.length += n
}
c.cond.Signal()
}
func (c *serverChan) Stderr() io.Writer {
return extendedDataChannel{c: c, t: extendedDataStderr}
}
// extendedDataChannel is an io.Writer that writes any data to c as extended
// data of the given type.
type extendedDataChannel struct {
t extendedDataTypeCode
c *serverChan
}
func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
c := edc.c
for len(data) > 0 {
space := min(c.maxPacket-headerLength, len(data))
if space, err = c.getWindowSpace(space); err != nil {
return 0, err
}
todo := data
if uint32(len(todo)) > space {
todo = todo[:space]
}
packet := make([]byte, headerLength+len(todo))
packet[0] = msgChannelExtendedData
marshalUint32(packet[1:], c.remoteId)
marshalUint32(packet[5:], uint32(edc.t))
marshalUint32(packet[9:], uint32(len(todo)))
copy(packet[13:], todo)
if err = c.writePacket(packet); err != nil {
return
}
n += len(todo)
data = data[len(todo):]
}
return
}
func (c *serverChan) Read(data []byte) (n int, err error) {
n, err, windowAdjustment := c.read(data)
if windowAdjustment > 0 {
packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
PeersId: c.remoteId,
AdditionalBytes: windowAdjustment,
})
err = c.writePacket(packet)
}
return
}
func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
c.cond.L.Lock()
defer c.cond.L.Unlock()
if c.err != nil {
return 0, c.err, 0
}
for {
if c.theySentEOF || c.theyClosed || c.dead() {
return 0, io.EOF, 0
}
if len(c.pendingRequests) > 0 {
req := c.pendingRequests[0]
if len(c.pendingRequests) == 1 {
c.pendingRequests = nil
} else {
oldPendingRequests := c.pendingRequests
c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
copy(c.pendingRequests, oldPendingRequests[1:])
}
return 0, req, 0
}
if c.length > 0 {
tail := min(uint32(c.head+c.length), len(c.pendingData))
n = copy(data, c.pendingData[c.head:tail])
c.head += n
c.length -= n
if c.head == len(c.pendingData) {
c.head = 0
}
windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
if windowAdjustment < uint32(len(c.pendingData)/2) {
windowAdjustment = 0
}
c.myWindow += windowAdjustment
return
}
c.cond.Wait()
}
panic("unreachable")
}
// getWindowSpace takes, at most, max bytes of space from the peer's window. It
// returns the number of bytes actually reserved.
func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
if c.dead() || c.closed() {
return 0, io.EOF
}
return c.remoteWin.reserve(max), nil
}
func (c *serverChan) dead() bool {
return atomic.LoadUint32(&c.isDead) > 0
}
func (c *serverChan) setDead() {
atomic.StoreUint32(&c.isDead, 1)
}
func (c *serverChan) Write(data []byte) (n int, err error) {
const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
for len(data) > 0 {
space := min(c.maxPacket-headerLength, len(data))
if space, err = c.getWindowSpace(space); err != nil {
return 0, err
}
todo := data
if uint32(len(todo)) > space {
todo = todo[:space]
}
packet := make([]byte, headerLength+len(todo))
packet[0] = msgChannelData
marshalUint32(packet[1:], c.remoteId)
marshalUint32(packet[5:], uint32(len(todo)))
copy(packet[9:], todo)
if err = c.writePacket(packet); err != nil {
return
}
n += len(todo)
data = data[len(todo):]
}
return
}
// Close signals the intent to close the channel.
func (c *serverChan) Close() error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if !c.setClosed() {
return errors.New("ssh: channel already closed")
}
return c.sendClose()
}
func (c *serverChan) AckRequest(ok bool) error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if !ok {
ack := channelRequestFailureMsg{
PeersId: c.remoteId,
}
return c.writePacket(marshal(msgChannelFailure, ack))
}
ack := channelRequestSuccessMsg{
PeersId: c.remoteId,
}
return c.writePacket(marshal(msgChannelSuccess, ack))
}
func (c *serverChan) ChannelType() string {
return c.chanType
}
func (c *serverChan) ExtraData() []byte {
return c.extraData
}
// A clientChan represents a single RFC 4254 channel multiplexed
// over a SSH connection.
type clientChan struct {
channel
stdin *chanWriter
stdout *chanReader
stderr *chanReader
msg chan interface{}
}
// newClientChan returns a partially constructed *clientChan
// using the local id provided. To be usable clientChan.remoteId
// needs to be assigned once known.
func newClientChan(cc packetConn, id uint32) *clientChan {
c := &clientChan{
channel: channel{
packetConn: cc,
localId: id,
remoteWin: window{Cond: newCond()},
},
msg: make(chan interface{}, 16),
}
c.stdin = &chanWriter{
channel: &c.channel,
}
c.stdout = &chanReader{
channel: &c.channel,
buffer: newBuffer(),
}
c.stderr = &chanReader{
channel: &c.channel,
buffer: newBuffer(),
}
return c
}
// waitForChannelOpenResponse, if successful, fills out
// the remoteId and records any initial window advertisement.
func (c *clientChan) waitForChannelOpenResponse() error {
switch msg := (<-c.msg).(type) {
case *channelOpenConfirmMsg:
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return errors.New("ssh: invalid MaxPacketSize from peer")
}
// fixup remoteId field
c.remoteId = msg.MyId
c.maxPacket = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow)
return nil
case *channelOpenFailureMsg:
return errors.New(safeString(msg.Message))
}
return errors.New("ssh: unexpected packet")
}
// Close signals the intent to close the channel.
func (c *clientChan) Close() error {
if !c.setClosed() {
return errors.New("ssh: channel already closed")
}
c.stdout.eof()
c.stderr.eof()
return c.sendClose()
}
// A chanWriter represents the stdin of a remote process.
type chanWriter struct {
*channel
// indicates the writer has been closed. eof is owned by the
// caller of Write/Close.
eof bool
}
// Write writes data to the remote process's standard input.
func (w *chanWriter) Write(data []byte) (written int, err error) {
const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
for len(data) > 0 {
if w.eof || w.closed() {
err = io.EOF
return
}
// never send more data than maxPacket even if
// there is sufficient window.
n := min(w.maxPacket-headerLength, len(data))
r := w.remoteWin.reserve(n)
n = r
remoteId := w.remoteId
packet := []byte{
msgChannelData,
byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
}
if err = w.writePacket(append(packet, data[:n]...)); err != nil {
break
}
data = data[n:]
written += int(n)
}
return
}
func min(a uint32, b int) uint32 {
if a < uint32(b) {
return a
}
return uint32(b)
}
func (w *chanWriter) Close() error {
w.eof = true
return w.sendEOF()
}
// A chanReader represents stdout or stderr of a remote process.
type chanReader struct {
*channel // the channel backing this reader
*buffer
}
// Read reads data from the remote process's stdout or stderr.
func (r *chanReader) Read(buf []byte) (int, error) {
n, err := r.buffer.Read(buf)
if err != nil {
if err == io.EOF {
return n, err
}
return 0, err
}
err = r.sendWindowAdj(n)
if err == io.EOF && n > 0 {
// sendWindowAdjust can return io.EOF if the remote peer has
// closed the connection, however we want to defer forwarding io.EOF to the
// caller of Read until the buffer has been drained.
err = nil
}
return n, err
}
@@ -1,100 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/aes"
"crypto/cipher"
"crypto/rc4"
)
// streamDump is used to dump the initial keystream for stream ciphers. It is a
// a write-only buffer, and not intended for reading so do not require a mutex.
var streamDump [512]byte
// noneCipher implements cipher.Stream and provides no encryption. It is used
// by the transport before the first key-exchange.
type noneCipher struct{}
func (c noneCipher) XORKeyStream(dst, src []byte) {
copy(dst, src)
}
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewCTR(c, iv), nil
}
func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type cipherMode struct {
keySize int
ivSize int
skip int
createFunc func(key, iv []byte) (cipher.Stream, error)
}
func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
}
// Specifies a default set of ciphers and a preference order. This is based on
// OpenSSH's default client preference order, minus algorithms that are not
// implemented.
var DefaultCipherOrder = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"arcfour256", "arcfour128",
}
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, 1536, newRC4},
"arcfour256": {32, 0, 1536, newRC4},
}
// defaultKeyExchangeOrder specifies a default set of key exchange algorithms
// with preferences.
var defaultKeyExchangeOrder = []string{
// P384 and P521 are not constant-time yet, but since we don't
// reuse ephemeral keys, using them for ECDH should be OK.
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA1, kexAlgoDH1SHA1,
}
@@ -1,62 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"testing"
)
// TestCipherReversal tests that each cipher factory produces ciphers that can
// encrypt and decrypt some data successfully.
func TestCipherReversal(t *testing.T) {
testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
cryptBuffer := make([]byte, 32)
for name, cipherMode := range cipherModes {
encrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create encrypter for %q: %s", name, err)
continue
}
decrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create decrypter for %q: %s", name, err)
continue
}
copy(cryptBuffer, testData)
encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if name == "none" {
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made change with 'none' cipher")
continue
}
} else {
if bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made no change with %q", name)
continue
}
}
decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("decrypted bytes not equal to input with %q", name)
continue
}
}
}
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range DefaultCipherOrder {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}
@@ -1,524 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
)
// ClientConn represents the client side of an SSH connection.
type ClientConn struct {
transport *transport
config *ClientConfig
chanList // channels associated with this connection
forwardList // forwarded tcpip connections from the remote side
globalRequest
// Address as passed to the Dial function.
dialAddress string
serverVersion string
}
type globalRequest struct {
sync.Mutex
response chan interface{}
}
// Client returns a new SSH client connection using c as the underlying transport.
func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
return clientWithAddress(c, "", config)
}
func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
conn := &ClientConn{
transport: newTransport(c, config.rand(), true /* is client */),
config: config,
globalRequest: globalRequest{response: make(chan interface{}, 1)},
dialAddress: addr,
}
if err := conn.handshake(); err != nil {
conn.transport.Close()
return nil, fmt.Errorf("handshake failed: %v", err)
}
go conn.mainLoop()
return conn, nil
}
// Close closes the connection.
func (c *ClientConn) Close() error { return c.transport.Close() }
// LocalAddr returns the local network address.
func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
// RemoteAddr returns the remote network address.
func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
// handshake performs the client side key exchange. See RFC 4253 Section 7.
func (c *ClientConn) handshake() error {
clientVersion := []byte(packageVersion)
if c.config.ClientVersion != "" {
clientVersion = []byte(c.config.ClientVersion)
}
serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion)
if err != nil {
return err
}
c.serverVersion = string(serverVersion)
clientKexInit := kexInitMsg{
KexAlgos: c.config.Crypto.kexes(),
ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: c.config.Crypto.ciphers(),
CiphersServerClient: c.config.Crypto.ciphers(),
MACsClientServer: c.config.Crypto.macs(),
MACsServerClient: c.config.Crypto.macs(),
CompressionClientServer: supportedCompressions,
CompressionServerClient: supportedCompressions,
}
kexInitPacket := marshal(msgKexInit, clientKexInit)
if err := c.transport.writePacket(kexInitPacket); err != nil {
return err
}
packet, err := c.transport.readPacket()
if err != nil {
return err
}
var serverKexInit kexInitMsg
if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
return err
}
algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit)
if algs == nil {
return errors.New("ssh: no common algorithms")
}
if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
// The server sent a Kex message for the wrong algorithm,
// which we have to ignore.
if _, err := c.transport.readPacket(); err != nil {
return err
}
}
kex, ok := kexAlgoMap[algs.kex]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
}
magics := handshakeMagics{
clientVersion: clientVersion,
serverVersion: serverVersion,
clientKexInit: kexInitPacket,
serverKexInit: packet,
}
result, err := kex.Client(c.transport, c.config.rand(), &magics)
if err != nil {
return err
}
err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, result.Signature)
if err != nil {
return err
}
if checker := c.config.HostKeyChecker; checker != nil {
err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey)
if err != nil {
return err
}
}
c.transport.prepareKeyChange(algs, result)
if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
if packet, err = c.transport.readPacket(); err != nil {
return err
}
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
return c.authenticate()
}
// Verify the host key obtained in the key exchange.
func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error {
hostKey, rest, ok := ParsePublicKey(hostKeyBytes)
if len(rest) > 0 || !ok {
return errors.New("ssh: could not parse hostkey")
}
sig, rest, ok := parseSignatureBody(signature)
if len(rest) > 0 || !ok {
return errors.New("ssh: signature parse error")
}
if sig.Format != hostKeyAlgo {
return fmt.Errorf("ssh: unexpected signature type %q", sig.Format)
}
if !hostKey.Verify(data, sig.Blob) {
return errors.New("ssh: host key signature error")
}
return nil
}
// mainLoop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
defer func() {
c.transport.Close()
c.chanList.closeAll()
c.forwardList.closeAll()
}()
for {
packet, err := c.transport.readPacket()
if err != nil {
break
}
// TODO(dfc) A note on blocking channel use.
// The msg, data and dataExt channels of a clientChan can
// cause this loop to block indefinitely if the consumer does
// not service them.
switch packet[0] {
case msgChannelData:
if len(packet) < 9 {
// malformed data packet
return
}
remoteId := binary.BigEndian.Uint32(packet[1:5])
length := binary.BigEndian.Uint32(packet[5:9])
packet = packet[9:]
if length != uint32(len(packet)) {
return
}
ch, ok := c.getChan(remoteId)
if !ok {
return
}
ch.stdout.write(packet)
case msgChannelExtendedData:
if len(packet) < 13 {
// malformed data packet
return
}
remoteId := binary.BigEndian.Uint32(packet[1:5])
datatype := binary.BigEndian.Uint32(packet[5:9])
length := binary.BigEndian.Uint32(packet[9:13])
packet = packet[13:]
if length != uint32(len(packet)) {
return
}
// RFC 4254 5.2 defines data_type_code 1 to be data destined
// for stderr on interactive sessions. Other data types are
// silently discarded.
if datatype == 1 {
ch, ok := c.getChan(remoteId)
if !ok {
return
}
ch.stderr.write(packet)
}
default:
decoded, err := decode(packet)
if err != nil {
if _, ok := err.(UnexpectedMessageError); ok {
fmt.Printf("mainLoop: unexpected message: %v\n", err)
continue
}
return
}
switch msg := decoded.(type) {
case *channelOpenMsg:
c.handleChanOpen(msg)
case *channelOpenConfirmMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
ch.msg <- msg
case *channelOpenFailureMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
ch.msg <- msg
case *channelCloseMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
ch.Close()
close(ch.msg)
c.chanList.remove(msg.PeersId)
case *channelEOFMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
ch.stdout.eof()
// RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time.
ch.stderr.eof()
case *channelRequestSuccessMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
ch.msg <- msg
case *channelRequestFailureMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
ch.msg <- msg
case *channelRequestMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
ch.msg <- msg
case *windowAdjustMsg:
ch, ok := c.getChan(msg.PeersId)
if !ok {
return
}
if !ch.remoteWin.add(msg.AdditionalBytes) {
// invalid window update
return
}
case *globalRequestMsg:
// This handles keepalive messages and matches
// the behaviour of OpenSSH.
if msg.WantReply {
c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
}
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
c.globalRequest.response <- msg
case *disconnectMsg:
return
default:
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
}
}
}
}
// Handle channel open messages from the remote side.
func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
c.sendConnectionFailed(msg.PeersId)
}
switch msg.ChanType {
case "forwarded-tcpip":
laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData)
if !ok {
// invalid request
c.sendConnectionFailed(msg.PeersId)
return
}
l, ok := c.forwardList.lookup(*laddr)
if !ok {
// TODO: print on a more structured log.
fmt.Println("could not find forward list entry for", laddr)
// Section 7.2, implementations MUST reject spurious incoming
// connections.
c.sendConnectionFailed(msg.PeersId)
return
}
raddr, rest, ok := parseTCPAddr(rest)
if !ok {
// invalid request
c.sendConnectionFailed(msg.PeersId)
return
}
ch := c.newChan(c.transport)
ch.remoteId = msg.PeersId
ch.remoteWin.add(msg.PeersWindow)
ch.maxPacket = msg.MaxPacketSize
m := channelOpenConfirmMsg{
PeersId: ch.remoteId,
MyId: ch.localId,
MyWindow: channelWindowSize,
MaxPacketSize: channelMaxPacketSize,
}
c.transport.writePacket(marshal(msgChannelOpenConfirm, m))
l <- forward{ch, raddr}
default:
// unknown channel type
m := channelOpenFailureMsg{
PeersId: msg.PeersId,
Reason: UnknownChannelType,
Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType),
Language: "en_US.UTF-8",
}
c.transport.writePacket(marshal(msgChannelOpenFailure, m))
}
}
// sendGlobalRequest sends a global request message as specified
// in RFC4254 section 4. To correctly synchronise messages, a lock
// is held internally until a response is returned.
func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
c.globalRequest.Lock()
defer c.globalRequest.Unlock()
if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil {
return nil, err
}
r := <-c.globalRequest.response
if r, ok := r.(*globalRequestSuccessMsg); ok {
return r, nil
}
return nil, errors.New("request failed")
}
// sendConnectionFailed rejects an incoming channel identified
// by remoteId.
func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
m := channelOpenFailureMsg{
PeersId: remoteId,
Reason: ConnectionFailed,
Message: "invalid request",
Language: "en_US.UTF-8",
}
return c.transport.writePacket(marshal(msgChannelOpenFailure, m))
}
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
// RFC 4254 section 7.2 is mute on what to do if parsing fails but the forwardlist
// requires a valid *net.TCPAddr to operate, so we enforce that restriction here.
func parseTCPAddr(b []byte) (*net.TCPAddr, []byte, bool) {
addr, b, ok := parseString(b)
if !ok {
return nil, b, false
}
port, b, ok := parseUint32(b)
if !ok {
return nil, b, false
}
ip := net.ParseIP(string(addr))
if ip == nil {
return nil, b, false
}
return &net.TCPAddr{IP: ip, Port: int(port)}, b, true
}
// Dial connects to the given network address using net.Dial and
// then initiates a SSH handshake, returning the resulting client connection.
func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
return clientWithAddress(conn, addr, config)
}
// A ClientConfig structure is used to configure a ClientConn. After one has
// been passed to an SSH function it must not be modified.
type ClientConfig struct {
// Rand provides the source of entropy for key exchange. If Rand is
// nil, the cryptographic random reader in package crypto/rand will
// be used.
Rand io.Reader
// The username to authenticate.
User string
// A slice of ClientAuth methods. Only the first instance
// of a particular RFC 4252 method will be used during authentication.
Auth []ClientAuth
// HostKeyChecker, if not nil, is called during the cryptographic
// handshake to validate the server's host key. A nil HostKeyChecker
// implies that all host keys are accepted.
HostKeyChecker HostKeyChecker
// Cryptographic-related configuration.
Crypto CryptoConfig
// The identification string that will be used for the connection.
// If empty, a reasonable default is used.
ClientVersion string
}
func (c *ClientConfig) rand() io.Reader {
if c.Rand == nil {
return rand.Reader
}
return c.Rand
}
// Thread safe channel list.
type chanList struct {
// protects concurrent access to chans
sync.Mutex
// chans are indexed by the local id of the channel, clientChan.localId.
// The PeersId value of messages received by ClientConn.mainLoop is
// used to locate the right local clientChan in this slice.
chans []*clientChan
}
// Allocate a new ClientChan with the next avail local id.
func (c *chanList) newChan(p packetConn) *clientChan {
c.Lock()
defer c.Unlock()
for i := range c.chans {
if c.chans[i] == nil {
ch := newClientChan(p, uint32(i))
c.chans[i] = ch
return ch
}
}
i := len(c.chans)
ch := newClientChan(p, uint32(i))
c.chans = append(c.chans, ch)
return ch
}
func (c *chanList) getChan(id uint32) (*clientChan, bool) {
c.Lock()
defer c.Unlock()
if id >= uint32(len(c.chans)) {
return nil, false
}
return c.chans[id], true
}
func (c *chanList) remove(id uint32) {
c.Lock()
defer c.Unlock()
c.chans[id] = nil
}
func (c *chanList) closeAll() {
c.Lock()
defer c.Unlock()
for _, ch := range c.chans {
if ch == nil {
continue
}
ch.Close()
close(ch.msg)
}
}
@@ -1,509 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"fmt"
"io"
"net"
)
// authenticate authenticates with the remote server. See RFC 4252.
func (c *ClientConn) authenticate() error {
// initiate user auth session
if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.transport.readPacket()
if err != nil {
return err
}
var serviceAccept serviceAcceptMsg
if err := unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
return err
}
// during the authentication phase the client first attempts the "none" method
// then any untried methods suggested by the server.
tried, remain := make(map[string]bool), make(map[string]bool)
for auth := ClientAuth(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(c.transport.sessionID, c.config.User, c.transport, c.config.rand())
if err != nil {
return err
}
if ok {
// success
return nil
}
tried[auth.method()] = true
delete(remain, auth.method())
for _, meth := range methods {
if tried[meth] {
// if we've tried meth already, skip it.
continue
}
remain[meth] = true
}
auth = nil
for _, a := range c.config.Auth {
if remain[a.method()] {
auth = a
break
}
}
}
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried))
}
func keys(m map[string]bool) (s []string) {
for k := range m {
s = append(s, k)
}
return
}
// HostKeyChecker represents a database of known server host keys.
type HostKeyChecker interface {
// Check is called during the handshake to check server's
// public key for unexpected changes. The hostKey argument is
// in SSH wire format. It can be parsed using
// ssh.ParsePublicKey. The address before DNS resolution is
// passed in the addr argument, so the key can also be checked
// against the hostname.
Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error
}
// A ClientAuth represents an instance of an RFC 4252 authentication method.
type ClientAuth interface {
// auth authenticates user over transport t.
// Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative
// method names is returned.
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
// method returns the RFC 4252 method name.
method() string
}
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
if err := c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
})); err != nil {
return false, nil, err
}
return handleAuthResponse(c)
}
func (n *noneAuth) method() string {
return "none"
}
// "password" authentication, RFC 4252 Section 8.
type passwordAuth struct {
ClientPassword
}
func (p *passwordAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct {
User string
Service string
Method string
Reply bool
Password string
}
pw, err := p.Password(user)
if err != nil {
return false, nil, err
}
if err := c.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{
User: user,
Service: serviceSSH,
Method: "password",
Reply: false,
Password: pw,
})); err != nil {
return false, nil, err
}
return handleAuthResponse(c)
}
func (p *passwordAuth) method() string {
return "password"
}
// A ClientPassword implements access to a client's passwords.
type ClientPassword interface {
// Password returns the password to use for user.
Password(user string) (password string, err error)
}
// ClientAuthPassword returns a ClientAuth using password authentication.
func ClientAuthPassword(impl ClientPassword) ClientAuth {
return &passwordAuth{impl}
}
// ClientKeyring implements access to a client key ring.
type ClientKeyring interface {
// Key returns the i'th Publickey, or nil if no key exists at i.
Key(i int) (key PublicKey, err error)
// Sign returns a signature of the given data using the i'th key
// and the supplied random source.
Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
}
// "publickey" authentication, RFC 4252 Section 7.
type publickeyAuth struct {
ClientKeyring
}
type publickeyAuthMsg struct {
User string
Service string
Method string
// HasSig indicates to the receiver packet that the auth request is signed and
// should be used for authentication of the request.
HasSig bool
Algoname string
Pubkey string
// Sig is defined as []byte so marshal will exclude it during validateKey
Sig []byte `ssh:"rest"`
}
func (p *publickeyAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
var index int
// a map of public keys to their index in the keyring
validKeys := make(map[int]PublicKey)
for {
key, err := p.Key(index)
if err != nil {
return false, nil, err
}
if key == nil {
// no more keys in the keyring
break
}
if ok, err := p.validateKey(key, user, c); ok {
validKeys[index] = key
} else {
if err != nil {
return false, nil, err
}
}
index++
}
// methods that may continue if this auth is not successful.
var methods []string
for i, key := range validKeys {
pubkey := MarshalPublicKey(key)
algoname := key.PublicKeyAlgo()
data := buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
}, []byte(algoname), pubkey)
sigBlob, err := p.Sign(i, rand, data)
if err != nil {
return false, nil, err
}
// manually wrap the serialized signature in a string
s := serializeSignature(key.PublicKeyAlgo(), sigBlob)
sig := make([]byte, stringLength(len(s)))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: true,
Algoname: algoname,
Pubkey: string(pubkey),
Sig: sig,
}
p := marshal(msgUserAuthRequest, msg)
if err := c.writePacket(p); err != nil {
return false, nil, err
}
success, methods, err := handleAuthResponse(c)
if err != nil {
return false, nil, err
}
if success {
return success, methods, err
}
}
return false, methods, nil
}
// validateKey validates the key provided it is acceptable to the server.
func (p *publickeyAuth) validateKey(key PublicKey, user string, c packetConn) (bool, error) {
pubkey := MarshalPublicKey(key)
algoname := key.PublicKeyAlgo()
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: false,
Algoname: algoname,
Pubkey: string(pubkey),
}
if err := c.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
return false, err
}
return p.confirmKeyAck(key, c)
}
func (p *publickeyAuth) confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
pubkey := MarshalPublicKey(key)
algoname := key.PublicKeyAlgo()
for {
packet, err := c.readPacket()
if err != nil {
return false, err
}
switch packet[0] {
case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user
case msgUserAuthPubKeyOk:
msg := userAuthPubKeyOkMsg{}
if err := unmarshal(&msg, packet, msgUserAuthPubKeyOk); err != nil {
return false, err
}
if msg.Algo != algoname || msg.PubKey != string(pubkey) {
return false, nil
}
return true, nil
case msgUserAuthFailure:
return false, nil
default:
return false, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
}
panic("unreachable")
}
func (p *publickeyAuth) method() string {
return "publickey"
}
// ClientAuthKeyring returns a ClientAuth using public key authentication.
func ClientAuthKeyring(impl ClientKeyring) ClientAuth {
return &publickeyAuth{impl}
}
// handleAuthResponse returns whether the preceding authentication request succeeded
// along with a list of remaining authentication methods to try next and
// an error if an unexpected response was received.
func handleAuthResponse(c packetConn) (bool, []string, error) {
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthBanner:
// TODO: add callback to present the banner to the user
case msgUserAuthFailure:
msg := userAuthFailureMsg{}
if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil {
return false, nil, err
}
return false, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
case msgDisconnect:
return false, nil, io.EOF
default:
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
}
panic("unreachable")
}
// ClientAuthAgent returns a ClientAuth using public key authentication via
// an agent.
func ClientAuthAgent(agent *AgentClient) ClientAuth {
return ClientAuthKeyring(&agentKeyring{agent: agent})
}
// agentKeyring implements ClientKeyring.
type agentKeyring struct {
agent *AgentClient
keys []*AgentKey
}
func (kr *agentKeyring) Key(i int) (key PublicKey, err error) {
if kr.keys == nil {
if kr.keys, err = kr.agent.RequestIdentities(); err != nil {
return
}
}
if i >= len(kr.keys) {
return
}
return kr.keys[i].Key()
}
func (kr *agentKeyring) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
var key PublicKey
if key, err = kr.Key(i); err != nil {
return
}
if key == nil {
return nil, errors.New("ssh: key index out of range")
}
if sig, err = kr.agent.SignRequest(key, data); err != nil {
return
}
// Unmarshal the signature.
var ok bool
if _, sig, ok = parseString(sig); !ok {
return nil, errors.New("ssh: malformed signature response from agent")
}
if sig, _, ok = parseString(sig); !ok {
return nil, errors.New("ssh: malformed signature response from agent")
}
return sig, nil
}
// ClientKeyboardInteractive should prompt the user for the given
// questions.
type ClientKeyboardInteractive interface {
// Challenge should print the questions, optionally disabling
// echoing (eg. for passwords), and return all the answers.
// Challenge may be called multiple times in a single
// session. After successful authentication, the server may
// send a challenge with no questions, for which the user and
// instruction messages should be printed. RFC 4256 section
// 3.3 details how the UI should behave for both CLI and
// GUI environments.
Challenge(user, instruction string, questions []string, echos []bool) ([]string, error)
}
// ClientAuthKeyboardInteractive returns a ClientAuth using a
// prompt/response sequence controlled by the server.
func ClientAuthKeyboardInteractive(impl ClientKeyboardInteractive) ClientAuth {
return &keyboardInteractiveAuth{impl}
}
type keyboardInteractiveAuth struct {
ClientKeyboardInteractive
}
func (k *keyboardInteractiveAuth) method() string {
return "keyboard-interactive"
}
func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type initiateMsg struct {
User string
Service string
Method string
Language string
Submethods string
}
if err := c.writePacket(marshal(msgUserAuthRequest, initiateMsg{
User: user,
Service: serviceSSH,
Method: "keyboard-interactive",
})); err != nil {
return false, nil, err
}
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
}
// like handleAuthResponse, but with less options.
switch packet[0] {
case msgUserAuthBanner:
// TODO: Print banners during userauth.
continue
case msgUserAuthInfoRequest:
// OK
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil {
return false, nil, err
}
return false, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
default:
return false, nil, UnexpectedMessageError{msgUserAuthInfoRequest, packet[0]}
}
var msg userAuthInfoRequestMsg
if err := unmarshal(&msg, packet, packet[0]); err != nil {
return false, nil, err
}
// Manually unpack the prompt/echo pairs.
rest := msg.Prompts
var prompts []string
var echos []bool
for i := 0; i < int(msg.NumPrompts); i++ {
prompt, r, ok := parseString(rest)
if !ok || len(r) == 0 {
return false, nil, errors.New("ssh: prompt format error")
}
prompts = append(prompts, string(prompt))
echos = append(echos, r[0] != 0)
rest = r[1:]
}
if len(rest) != 0 {
return false, nil, fmt.Errorf("ssh: junk following message %q", rest)
}
answers, err := k.Challenge(msg.User, msg.Instruction, prompts, echos)
if err != nil {
return false, nil, err
}
if len(answers) != len(prompts) {
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
}
responseLength := 1 + 4
for _, a := range answers {
responseLength += stringLength(len(a))
}
serialized := make([]byte, responseLength)
p := serialized
p[0] = msgUserAuthInfoResponse
p = p[1:]
p = marshalUint32(p, uint32(len(answers)))
for _, a := range answers {
p = marshalString(p, []byte(a))
}
if err := c.writePacket(serialized); err != nil {
return false, nil, err
}
}
}
@@ -1,368 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/dsa"
"io"
"io/ioutil"
"math/big"
"strings"
"testing"
_ "crypto/sha1"
)
// private key for mock server
const testServerPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
-----END RSA PRIVATE KEY-----`
const testClientPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
-----END RSA PRIVATE KEY-----`
// keychain implements the ClientKeyring interface
type keychain struct {
keys []Signer
}
func (k *keychain) Key(i int) (PublicKey, error) {
if i < 0 || i >= len(k.keys) {
return nil, nil
}
return k.keys[i].PublicKey(), nil
}
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
return k.keys[i].Sign(rand, data)
}
func (k *keychain) add(key Signer) {
k.keys = append(k.keys, key)
}
func (k *keychain) loadPEM(file string) error {
buf, err := ioutil.ReadFile(file)
if err != nil {
return err
}
key, err := ParsePrivateKey(buf)
if err != nil {
return err
}
k.add(key)
return nil
}
// password implements the ClientPassword interface
type password string
func (p password) Password(user string) (string, error) {
return string(p), nil
}
type keyboardInteractive map[string]string
func (cr *keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
var answers []string
for _, q := range questions {
answers = append(answers, (*cr)[q])
}
return answers, nil
}
// reused internally by tests
var (
rsaKey Signer
dsaKey Signer
clientKeychain = new(keychain)
clientPassword = password("tiger")
serverConfig = &ServerConfig{
PasswordCallback: func(conn *ServerConn, user, pass string) bool {
return user == "testuser" && pass == string(clientPassword)
},
PublicKeyCallback: func(conn *ServerConn, user, algo string, pubkey []byte) bool {
key, _ := clientKeychain.Key(0)
expected := MarshalPublicKey(key)
algoname := key.PublicKeyAlgo()
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
},
KeyboardInteractiveCallback: func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool {
ans, err := client.Challenge("user",
"instruction",
[]string{"question1", "question2"},
[]bool{true, true})
if err != nil {
return false
}
ok := user == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
client.Challenge("user", "motd", nil, nil)
return ok
},
}
)
func init() {
var err error
rsaKey, err = ParsePrivateKey([]byte(testServerPrivateKey))
if err != nil {
panic("unable to set private key: " + err.Error())
}
rawDSAKey := new(dsa.PrivateKey)
// taken from crypto/dsa/dsa_test.go
rawDSAKey.P, _ = new(big.Int).SetString("A9B5B793FB4785793D246BAE77E8FF63CA52F442DA763C440259919FE1BC1D6065A9350637A04F75A2F039401D49F08E066C4D275A5A65DA5684BC563C14289D7AB8A67163BFBF79D85972619AD2CFF55AB0EE77A9002B0EF96293BDD0F42685EBB2C66C327079F6C98000FBCB79AACDE1BC6F9D5C7B1A97E3D9D54ED7951FEF", 16)
rawDSAKey.Q, _ = new(big.Int).SetString("E1D3391245933D68A0714ED34BBCB7A1F422B9C1", 16)
rawDSAKey.G, _ = new(big.Int).SetString("634364FC25248933D01D1993ECABD0657CC0CB2CEED7ED2E3E8AECDFCDC4A25C3B15E9E3B163ACA2984B5539181F3EFF1A5E8903D71D5B95DA4F27202B77D2C44B430BB53741A8D59A8F86887525C9F2A6A5980A195EAA7F2FF910064301DEF89D3AA213E1FAC7768D89365318E370AF54A112EFBA9246D9158386BA1B4EEFDA", 16)
rawDSAKey.Y, _ = new(big.Int).SetString("32969E5780CFE1C849A1C276D7AEB4F38A23B591739AA2FE197349AEEBD31366AEE5EB7E6C6DDB7C57D02432B30DB5AA66D9884299FAA72568944E4EEDC92EA3FBC6F39F53412FBCC563208F7C15B737AC8910DBC2D9C9B8C001E72FDC40EB694AB1F06A5A2DBD18D9E36C66F31F566742F11EC0A52E9F7B89355C02FB5D32D2", 16)
rawDSAKey.X, _ = new(big.Int).SetString("5078D4D29795CBE76D3AACFE48C9AF0BCDBEE91A", 16)
dsaKey, err = NewSignerFromKey(rawDSAKey)
if err != nil {
panic("NewSignerFromKey: " + err.Error())
}
clientKeychain.add(rsaKey)
serverConfig.AddHostKey(rsaKey)
}
// newMockAuthServer creates a new Server bound to
// the loopback interface. The server exits after
// processing one handshake.
func newMockAuthServer(t *testing.T) string {
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to newMockAuthServer: %s", err)
}
go func() {
defer l.Close()
c, err := l.Accept()
if err != nil {
t.Errorf("Unable to accept incoming connection: %v", err)
return
}
if err := c.Handshake(); err != nil {
// not Errorf because this is expected to
// fail for some tests.
t.Logf("Handshaking error: %v", err)
return
}
defer c.Close()
}()
return l.Addr().String()
}
func TestClientAuthPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyring(clientKeychain),
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
c.Close()
}
func TestClientAuthPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(clientPassword),
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
c.Close()
}
func TestClientAuthWrongPassword(t *testing.T) {
wrongPw := password("wrong")
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(wrongPw),
ClientAuthKeyring(clientKeychain),
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
c.Close()
}
func TestClientAuthKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "answer2",
})
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyboardInteractive(&answers),
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
c.Close()
}
func TestClientAuthWrongKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "WRONG",
})
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyboardInteractive(&answers),
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err == nil {
c.Close()
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive")
}
}
// the mock server will only authenticate ssh-rsa keys
func TestClientAuthInvalidPublicKey(t *testing.T) {
kc := new(keychain)
kc.add(dsaKey)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyring(kc),
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err == nil {
c.Close()
t.Fatalf("dsa private key should not have authenticated with rsa public key")
}
}
// the client should authenticate with the second key
func TestClientAuthRSAandDSA(t *testing.T) {
kc := new(keychain)
kc.add(dsaKey)
kc.add(rsaKey)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyring(kc),
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err)
}
c.Close()
}
func TestClientHMAC(t *testing.T) {
kc := new(keychain)
kc.add(rsaKey)
for _, mac := range DefaultMACOrder {
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyring(kc),
},
Crypto: CryptoConfig{
MACs: []string{mac},
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
}
c.Close()
}
}
// issue 4285.
func TestClientUnsupportedCipher(t *testing.T) {
kc := new(keychain)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyring(kc),
},
Crypto: CryptoConfig{
Ciphers: []string{"aes128-cbc"}, // not currently supported
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err == nil {
t.Errorf("expected no ciphers in common")
c.Close()
}
}
func TestClientUnsupportedKex(t *testing.T) {
kc := new(keychain)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthKeyring(kc),
},
Crypto: CryptoConfig{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
},
}
c, err := Dial("tcp", newMockAuthServer(t), config)
if err == nil || !strings.Contains(err.Error(), "no common algorithms") {
t.Errorf("got %v, expected 'no common algorithms'", err)
}
if c != nil {
c.Close()
}
}
@@ -1,352 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto"
"fmt"
"sync"
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
)
// These are string constants in the SSH protocol.
const (
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
serviceSSH = "ssh-connection"
)
var supportedKexAlgos = []string{
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA1, kexAlgoDH1SHA1,
}
var supportedHostKeyAlgos = []string{
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSA, KeyAlgoDSA,
}
var supportedCompressions = []string{compressionNone}
// hashFuncs keeps the mapping of supported algorithms to their respective
// hashes needed for signature verification.
var hashFuncs = map[string]crypto.Hash{
KeyAlgoRSA: crypto.SHA1,
KeyAlgoDSA: crypto.SHA1,
KeyAlgoECDSA256: crypto.SHA256,
KeyAlgoECDSA384: crypto.SHA384,
KeyAlgoECDSA521: crypto.SHA512,
CertAlgoRSAv01: crypto.SHA1,
CertAlgoDSAv01: crypto.SHA1,
CertAlgoECDSA256v01: crypto.SHA256,
CertAlgoECDSA384v01: crypto.SHA384,
CertAlgoECDSA521v01: crypto.SHA512,
}
// UnexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
type UnexpectedMessageError struct {
expected, got uint8
}
func (u UnexpectedMessageError) Error() string {
return fmt.Sprintf("ssh: unexpected message type %d (expected %d)", u.got, u.expected)
}
// ParseError results from a malformed SSH message.
type ParseError struct {
msgType uint8
}
func (p ParseError) Error() string {
return fmt.Sprintf("ssh: parse error in message type %d", p.msgType)
}
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
for _, clientAlgo := range clientAlgos {
for _, serverAlgo := range serverAlgos {
if clientAlgo == serverAlgo {
return clientAlgo, true
}
}
}
return
}
func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCipher string, ok bool) {
for _, clientCipher := range clientCiphers {
for _, serverCipher := range serverCiphers {
// reject the cipher if we have no cipherModes definition
if clientCipher == serverCipher && cipherModes[clientCipher] != nil {
return clientCipher, true
}
}
}
return
}
type algorithms struct {
kex string
hostKey string
wCipher string
rCipher string
rMAC string
wMAC string
rCompression string
wCompression string
}
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
var ok bool
result := &algorithms{}
result.kex, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if !ok {
return
}
result.hostKey, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if !ok {
return
}
result.wCipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if !ok {
return
}
result.rCipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if !ok {
return
}
result.wMAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if !ok {
return
}
result.rMAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if !ok {
return
}
result.wCompression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if !ok {
return
}
result.rCompression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if !ok {
return
}
return result
}
// Cryptographic configuration common to both ServerConfig and ClientConfig.
type CryptoConfig struct {
// The allowed key exchanges algorithms. If unspecified then a
// default set of algorithms is used.
KeyExchanges []string
// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
// used.
Ciphers []string
// The allowed MAC algorithms. If unspecified then DefaultMACOrder is used.
MACs []string
}
func (c *CryptoConfig) ciphers() []string {
if c.Ciphers == nil {
return DefaultCipherOrder
}
return c.Ciphers
}
func (c *CryptoConfig) kexes() []string {
if c.KeyExchanges == nil {
return defaultKeyExchangeOrder
}
return c.KeyExchanges
}
func (c *CryptoConfig) macs() []string {
if c.MACs == nil {
return DefaultMACOrder
}
return c.MACs
}
// serialize a signed slice according to RFC 4254 6.6. The name should
// be a key type name, rather than a cert type name.
func serializeSignature(name string, sig []byte) []byte {
length := stringLength(len(name))
length += stringLength(len(sig))
ret := make([]byte, length)
r := marshalString(ret, []byte(name))
r = marshalString(r, sig)
return ret
}
// MarshalPublicKey serializes a supported key or certificate for use
// by the SSH wire protocol. It can be used for comparison with the
// pubkey argument of ServerConfig's PublicKeyCallback as well as for
// generating an authorized_keys or host_keys file.
func MarshalPublicKey(key PublicKey) []byte {
// See also RFC 4253 6.6.
algoname := key.PublicKeyAlgo()
blob := key.Marshal()
length := stringLength(len(algoname))
length += len(blob)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
copy(r, blob)
return ret
}
// pubAlgoToPrivAlgo returns the private key algorithm format name that
// corresponds to a given public key algorithm format name. For most
// public keys, the private key algorithm name is the same. For some
// situations, such as openssh certificates, the private key algorithm and
// public key algorithm names differ. This accounts for those situations.
func pubAlgoToPrivAlgo(pubAlgo string) string {
switch pubAlgo {
case CertAlgoRSAv01:
return KeyAlgoRSA
case CertAlgoDSAv01:
return KeyAlgoDSA
case CertAlgoECDSA256v01:
return KeyAlgoECDSA256
case CertAlgoECDSA384v01:
return KeyAlgoECDSA384
case CertAlgoECDSA521v01:
return KeyAlgoECDSA521
}
return pubAlgo
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(len(sessionId))
length += 1
length += stringLength(len(user))
length += stringLength(len(service))
length += stringLength(len(method))
length += 1
length += stringLength(len(algo))
length += stringLength(len(pubKey))
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}
// safeString sanitises s according to RFC 4251, section 9.2.
// All control characters except tab, carriage return and newline are
// replaced by 0x20.
func safeString(s string) string {
out := []byte(s)
for i, c := range out {
if c < 0x20 && c != 0xd && c != 0xa && c != 0x9 {
out[i] = 0x20
}
}
return string(out)
}
func appendU16(buf []byte, n uint16) []byte {
return append(buf, byte(n>>8), byte(n))
}
func appendU32(buf []byte, n uint32) []byte {
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
func appendInt(buf []byte, n int) []byte {
return appendU32(buf, uint32(n))
}
func appendString(buf []byte, s string) []byte {
buf = appendU32(buf, uint32(len(s)))
buf = append(buf, s...)
return buf
}
func appendBool(buf []byte, b bool) []byte {
if b {
buf = append(buf, 1)
} else {
buf = append(buf, 0)
}
return buf
}
// newCond is a helper to hide the fact that there is no usable zero
// value for sync.Cond.
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) }
// window represents the buffer available to clients
// wishing to write to a channel.
type window struct {
*sync.Cond
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
}
// add adds win to the amount of window available
// for consumers.
func (w *window) add(win uint32) bool {
// a zero sized window adjust is a noop.
if win == 0 {
return true
}
w.L.Lock()
if w.win+win < win {
w.L.Unlock()
return false
}
w.win += win
// It is unusual that multiple goroutines would be attempting to reserve
// window space, but not guaranteed. Use broadcast to notify all waiters
// that additional window is available.
w.Broadcast()
w.L.Unlock()
return true
}
// reserve reserves win from the available window capacity.
// If no capacity remains, reserve will block. reserve may
// return less than requested.
func (w *window) reserve(win uint32) uint32 {
w.L.Lock()
for w.win == 0 {
w.Wait()
}
if w.win < win {
win = w.win
}
w.win -= win
w.L.Unlock()
return win
}
@@ -1,57 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"net"
"testing"
)
func TestSafeString(t *testing.T) {
strings := map[string]string{
"\x20\x0d\x0a": "\x20\x0d\x0a",
"flibble": "flibble",
"new\x20line": "new\x20line",
"123456\x07789": "123456 789",
"\t\t\x10\r\n": "\t\t \r\n",
}
for s, expected := range strings {
actual := safeString(s)
if expected != actual {
t.Errorf("expected: %v, actual: %v", []byte(expected), []byte(actual))
}
}
}
// Make sure Read/Write are not exposed.
func TestConnHideRWMethods(t *testing.T) {
for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
if _, ok := c.(io.Reader); ok {
t.Errorf("%T implements io.Reader", c)
}
if _, ok := c.(io.Writer); ok {
t.Errorf("%T implements io.Writer", c)
}
}
}
func TestConnSupportsLocalRemoteMethods(t *testing.T) {
type LocalAddr interface {
LocalAddr() net.Addr
}
type RemoteAddr interface {
RemoteAddr() net.Addr
}
for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
if _, ok := c.(LocalAddr); !ok {
t.Errorf("%T does not implement LocalAddr", c)
}
if _, ok := c.(RemoteAddr); !ok {
t.Errorf("%T does not implement RemoteAddr", c)
}
}
}
@@ -1,214 +0,0 @@
package ssh
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"reflect"
"strings"
"testing"
)
var (
ecdsaKey Signer
ecdsa384Key Signer
ecdsa521Key Signer
testCertKey Signer
)
type testSigner struct {
Signer
pub PublicKey
}
func (ts *testSigner) PublicKey() PublicKey {
if ts.pub != nil {
return ts.pub
}
return ts.Signer.PublicKey()
}
func init() {
raw256, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
ecdsaKey, _ = NewSignerFromKey(raw256)
raw384, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
ecdsa384Key, _ = NewSignerFromKey(raw384)
raw521, _ := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
ecdsa521Key, _ = NewSignerFromKey(raw521)
// Create a cert and sign it for use in tests.
testCert := &OpenSSHCertV01{
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
Key: ecdsaKey.PublicKey(),
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
ValidAfter: 0, // unix epoch
ValidBefore: maxUint64, // The end of currently representable time.
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
SignatureKey: rsaKey.PublicKey(),
}
sigBytes, _ := rsaKey.Sign(rand.Reader, testCert.BytesForSigning())
testCert.Signature = &signature{
Format: testCert.SignatureKey.PublicKeyAlgo(),
Blob: sigBytes,
}
testCertKey = &testSigner{
Signer: ecdsaKey,
pub: testCert,
}
}
func rawKey(pub PublicKey) interface{} {
switch k := pub.(type) {
case *rsaPublicKey:
return (*rsa.PublicKey)(k)
case *dsaPublicKey:
return (*dsa.PublicKey)(k)
case *ecdsaPublicKey:
return (*ecdsa.PublicKey)(k)
case *OpenSSHCertV01:
return k
}
panic("unknown key type")
}
func TestKeyMarshalParse(t *testing.T) {
keys := []Signer{rsaKey, dsaKey, ecdsaKey, ecdsa384Key, ecdsa521Key, testCertKey}
for _, priv := range keys {
pub := priv.PublicKey()
roundtrip, rest, ok := ParsePublicKey(MarshalPublicKey(pub))
if !ok {
t.Errorf("ParsePublicKey(%T) failed", pub)
}
if len(rest) > 0 {
t.Errorf("ParsePublicKey(%T): trailing junk", pub)
}
k1 := rawKey(pub)
k2 := rawKey(roundtrip)
if !reflect.DeepEqual(k1, k2) {
t.Errorf("got %#v in roundtrip, want %#v", k2, k1)
}
}
}
func TestUnsupportedCurves(t *testing.T) {
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") {
t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err)
}
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") {
t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err)
}
}
func TestNewPublicKey(t *testing.T) {
keys := []Signer{rsaKey, dsaKey, ecdsaKey}
for _, k := range keys {
raw := rawKey(k.PublicKey())
pub, err := NewPublicKey(raw)
if err != nil {
t.Errorf("NewPublicKey(%#v): %v", raw, err)
}
if !reflect.DeepEqual(k.PublicKey(), pub) {
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey())
}
}
}
func TestKeySignVerify(t *testing.T) {
keys := []Signer{rsaKey, dsaKey, ecdsaKey, testCertKey}
for _, priv := range keys {
pub := priv.PublicKey()
data := []byte("sign me")
sig, err := priv.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("Sign(%T): %v", priv, err)
}
if !pub.Verify(data, sig) {
t.Errorf("publicKey.Verify(%T) failed", priv)
}
}
}
func TestParseRSAPrivateKey(t *testing.T) {
key, err := ParsePrivateKey([]byte(testServerPrivateKey))
if err != nil {
t.Fatalf("ParsePrivateKey: %v", err)
}
rsa, ok := key.(*rsaPrivateKey)
if !ok {
t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
}
if err := rsa.Validate(); err != nil {
t.Errorf("Validate: %v", err)
}
}
func TestParseECPrivateKey(t *testing.T) {
// Taken from the data in test/ .
pem := []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
-----END EC PRIVATE KEY-----`)
key, err := ParsePrivateKey(pem)
if err != nil {
t.Fatalf("ParsePrivateKey: %v", err)
}
ecKey, ok := key.(*ecdsaPrivateKey)
if !ok {
t.Fatalf("got %T, want *ecdsaPrivateKey", ecKey)
}
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
t.Fatalf("public key does not validate.")
}
}
// ssh-keygen -t dsa -f /tmp/idsa.pem
var dsaPEM = `-----BEGIN DSA PRIVATE KEY-----
MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB
lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD
nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV
2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r
juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr
FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz
DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj
nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY
Fmsr0W6fHB9nhS4/UXM8
-----END DSA PRIVATE KEY-----`
func TestParseDSA(t *testing.T) {
s, err := ParsePrivateKey([]byte(dsaPEM))
if err != nil {
t.Fatalf("ParsePrivateKey returned error: %s", err)
}
data := []byte("sign me")
sig, err := s.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("dsa.Sign: %v", err)
}
if !s.PublicKey().Verify(data, sig) {
t.Error("Verify failed.")
}
}
@@ -1,692 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
_ "crypto/sha1"
)
type ServerConfig struct {
hostKeys []Signer
// Rand provides the source of entropy for key exchange. If Rand is
// nil, the cryptographic random reader in package crypto/rand will
// be used.
Rand io.Reader
// NoClientAuth is true if clients are allowed to connect without
// authenticating.
NoClientAuth bool
// PasswordCallback, if non-nil, is called when a user attempts to
// authenticate using a password. It may be called concurrently from
// several goroutines.
PasswordCallback func(conn *ServerConn, user, password string) bool
// PublicKeyCallback, if non-nil, is called when a client attempts public
// key authentication. It must return true if the given public key is
// valid for the given user.
PublicKeyCallback func(conn *ServerConn, user, algo string, pubkey []byte) bool
// KeyboardInteractiveCallback, if non-nil, is called when
// keyboard-interactive authentication is selected (RFC
// 4256). The client object's Challenge function should be
// used to query the user. The callback may offer multiple
// Challenge rounds. To avoid information leaks, the client
// should be presented a challenge even if the user is
// unknown.
KeyboardInteractiveCallback func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool
// Cryptographic-related configuration.
Crypto CryptoConfig
}
func (c *ServerConfig) rand() io.Reader {
if c.Rand == nil {
return rand.Reader
}
return c.Rand
}
// AddHostKey adds a private key as a host key. If an existing host
// key exists with the same algorithm, it is overwritten.
func (s *ServerConfig) AddHostKey(key Signer) {
for i, k := range s.hostKeys {
if k.PublicKey().PublicKeyAlgo() == key.PublicKey().PublicKeyAlgo() {
s.hostKeys[i] = key
return
}
}
s.hostKeys = append(s.hostKeys, key)
}
// SetRSAPrivateKey sets the private key for a Server. A Server must have a
// private key configured in order to accept connections. The private key must
// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
// typically contains such a key.
func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error {
priv, err := ParsePrivateKey(pemBytes)
if err != nil {
return err
}
s.AddHostKey(priv)
return nil
}
// cachedPubKey contains the results of querying whether a public key is
// acceptable for a user. The cache only applies to a single ServerConn.
type cachedPubKey struct {
user, algo string
pubKey []byte
result bool
}
const maxCachedPubKeys = 16
// A ServerConn represents an incoming connection.
type ServerConn struct {
transport *transport
config *ServerConfig
channels map[uint32]*serverChan
nextChanId uint32
// lock protects err and channels.
lock sync.Mutex
err error
// cachedPubKeys contains the cache results of tests for public keys.
// Since SSH clients will query whether a public key is acceptable
// before attempting to authenticate with it, we end up with duplicate
// queries for public key validity.
cachedPubKeys []cachedPubKey
// User holds the successfully authenticated user name.
// It is empty if no authentication is used. It is populated before
// any authentication callback is called and not assigned to after that.
User string
// ClientVersion is the client's version, populated after
// Handshake is called. It should not be modified.
ClientVersion []byte
// Our version.
serverVersion []byte
}
// Server returns a new SSH server connection
// using c as the underlying transport.
func Server(c net.Conn, config *ServerConfig) *ServerConn {
return &ServerConn{
transport: newTransport(c, config.rand(), false /* not client */),
channels: make(map[uint32]*serverChan),
config: config,
}
}
// signAndMarshal signs the data with the appropriate algorithm,
// and serializes the result in SSH wire format.
func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
sig, err := k.Sign(rand, data)
if err != nil {
return nil, err
}
return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
}
// Close closes the connection.
func (s *ServerConn) Close() error { return s.transport.Close() }
// LocalAddr returns the local network address.
func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
// RemoteAddr returns the remote network address.
func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
// Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error {
var err error
s.serverVersion = []byte(packageVersion)
s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersion)
if err != nil {
return err
}
if err := s.clientInitHandshake(nil, nil); err != nil {
return err
}
var packet []byte
if packet, err = s.transport.readPacket(); err != nil {
return err
}
var serviceRequest serviceRequestMsg
if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
return err
}
if serviceRequest.Service != serviceUserAuth {
return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
}
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
return err
}
if err := s.authenticate(); err != nil {
return err
}
return err
}
func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
serverKexInit := kexInitMsg{
KexAlgos: s.config.Crypto.kexes(),
CiphersClientServer: s.config.Crypto.ciphers(),
CiphersServerClient: s.config.Crypto.ciphers(),
MACsClientServer: s.config.Crypto.macs(),
MACsServerClient: s.config.Crypto.macs(),
CompressionClientServer: supportedCompressions,
CompressionServerClient: supportedCompressions,
}
for _, k := range s.config.hostKeys {
serverKexInit.ServerHostKeyAlgos = append(
serverKexInit.ServerHostKeyAlgos, k.PublicKey().PublicKeyAlgo())
}
serverKexInitPacket := marshal(msgKexInit, serverKexInit)
if err = s.transport.writePacket(serverKexInitPacket); err != nil {
return
}
if clientKexInitPacket == nil {
clientKexInit = new(kexInitMsg)
if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
return
}
if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
return
}
}
algs := findAgreedAlgorithms(clientKexInit, &serverKexInit)
if algs == nil {
return errors.New("ssh: no common algorithms")
}
if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm,
// which we have to ignore.
if _, err = s.transport.readPacket(); err != nil {
return
}
}
var hostKey Signer
for _, k := range s.config.hostKeys {
if algs.hostKey == k.PublicKey().PublicKeyAlgo() {
hostKey = k
}
}
kex, ok := kexAlgoMap[algs.kex]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
}
magics := handshakeMagics{
serverVersion: s.serverVersion,
clientVersion: s.ClientVersion,
serverKexInit: marshal(msgKexInit, serverKexInit),
clientKexInit: clientKexInitPacket,
}
result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
if err != nil {
return err
}
if err = s.transport.prepareKeyChange(algs, result); err != nil {
return err
}
if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
return
}
if packet, err := s.transport.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
return
}
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
return true
}
return false
}
// testPubKey returns true if the given public key is acceptable for the user.
func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool {
if s.config.PublicKeyCallback == nil || !isAcceptableAlgo(algo) {
return false
}
for _, c := range s.cachedPubKeys {
if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
return c.result
}
}
result := s.config.PublicKeyCallback(s, user, algo, pubKey)
if len(s.cachedPubKeys) < maxCachedPubKeys {
c := cachedPubKey{
user: user,
algo: algo,
pubKey: make([]byte, len(pubKey)),
result: result,
}
copy(c.pubKey, pubKey)
s.cachedPubKeys = append(s.cachedPubKeys, c)
}
return result
}
func (s *ServerConn) authenticate() error {
var userAuthReq userAuthRequestMsg
var err error
var packet []byte
userAuthLoop:
for {
if packet, err = s.transport.readPacket(); err != nil {
return err
}
if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
return err
}
if userAuthReq.Service != serviceSSH {
return errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
}
switch userAuthReq.Method {
case "none":
if s.config.NoClientAuth {
break userAuthLoop
}
case "password":
if s.config.PasswordCallback == nil {
break
}
payload := userAuthReq.Payload
if len(payload) < 1 || payload[0] != 0 {
return ParseError{msgUserAuthRequest}
}
payload = payload[1:]
password, payload, ok := parseString(payload)
if !ok || len(payload) > 0 {
return ParseError{msgUserAuthRequest}
}
s.User = userAuthReq.User
if s.config.PasswordCallback(s, userAuthReq.User, string(password)) {
break userAuthLoop
}
case "keyboard-interactive":
if s.config.KeyboardInteractiveCallback == nil {
break
}
s.User = userAuthReq.User
if s.config.KeyboardInteractiveCallback(s, s.User, &sshClientKeyboardInteractive{s}) {
break userAuthLoop
}
case "publickey":
if s.config.PublicKeyCallback == nil {
break
}
payload := userAuthReq.Payload
if len(payload) < 1 {
return ParseError{msgUserAuthRequest}
}
isQuery := payload[0] == 0
payload = payload[1:]
algoBytes, payload, ok := parseString(payload)
if !ok {
return ParseError{msgUserAuthRequest}
}
algo := string(algoBytes)
pubKey, payload, ok := parseString(payload)
if !ok {
return ParseError{msgUserAuthRequest}
}
if isQuery {
// The client can query if the given public key
// would be okay.
if len(payload) > 0 {
return ParseError{msgUserAuthRequest}
}
if s.testPubKey(userAuthReq.User, algo, pubKey) {
okMsg := userAuthPubKeyOkMsg{
Algo: algo,
PubKey: string(pubKey),
}
if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
return err
}
continue userAuthLoop
}
} else {
sig, payload, ok := parseSignature(payload)
if !ok || len(payload) > 0 {
return ParseError{msgUserAuthRequest}
}
// Ensure the public key algo and signature algo
// are supported. Compare the private key
// algorithm name that corresponds to algo with
// sig.Format. This is usually the same, but
// for certs, the names differ.
if !isAcceptableAlgo(algo) || !isAcceptableAlgo(sig.Format) || pubAlgoToPrivAlgo(algo) != sig.Format {
break
}
signedData := buildDataSignedForAuth(s.transport.sessionID, userAuthReq, algoBytes, pubKey)
key, _, ok := ParsePublicKey(pubKey)
if !ok {
return ParseError{msgUserAuthRequest}
}
if !key.Verify(signedData, sig.Blob) {
return ParseError{msgUserAuthRequest}
}
// TODO(jmpittman): Implement full validation for certificates.
s.User = userAuthReq.User
if s.testPubKey(userAuthReq.User, algo, pubKey) {
break userAuthLoop
}
}
}
var failureMsg userAuthFailureMsg
if s.config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
}
if s.config.PublicKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey")
}
if s.config.KeyboardInteractiveCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
}
if len(failureMsg.Methods) == 0 {
return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
return err
}
}
packet = []byte{msgUserAuthSuccess}
if err = s.transport.writePacket(packet); err != nil {
return err
}
return nil
}
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
// asking the client on the other side of a ServerConn.
type sshClientKeyboardInteractive struct {
*ServerConn
}
func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
if len(questions) != len(echos) {
return nil, errors.New("ssh: echos and questions must have equal length")
}
var prompts []byte
for i := range questions {
prompts = appendString(prompts, questions[i])
prompts = appendBool(prompts, echos[i])
}
if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
Instruction: instruction,
NumPrompts: uint32(len(questions)),
Prompts: prompts,
})); err != nil {
return nil, err
}
packet, err := c.transport.readPacket()
if err != nil {
return nil, err
}
if packet[0] != msgUserAuthInfoResponse {
return nil, UnexpectedMessageError{msgUserAuthInfoResponse, packet[0]}
}
packet = packet[1:]
n, packet, ok := parseUint32(packet)
if !ok || int(n) != len(questions) {
return nil, &ParseError{msgUserAuthInfoResponse}
}
for i := uint32(0); i < n; i++ {
ans, rest, ok := parseString(packet)
if !ok {
return nil, &ParseError{msgUserAuthInfoResponse}
}
answers = append(answers, string(ans))
packet = rest
}
if len(packet) != 0 {
return nil, errors.New("ssh: junk at end of message")
}
return answers, nil
}
const defaultWindowSize = 32768
// Accept reads and processes messages on a ServerConn. It must be called
// in order to demultiplex messages to any resulting Channels.
func (s *ServerConn) Accept() (Channel, error) {
// TODO(dfc) s.lock is not held here so visibility of s.err is not guaranteed.
if s.err != nil {
return nil, s.err
}
for {
packet, err := s.transport.readPacket()
if err != nil {
s.lock.Lock()
s.err = err
s.lock.Unlock()
// TODO(dfc) s.lock protects s.channels but isn't being held here.
for _, c := range s.channels {
c.setDead()
c.handleData(nil)
}
return nil, err
}
switch packet[0] {
case msgChannelData:
if len(packet) < 9 {
// malformed data packet
return nil, ParseError{msgChannelData}
}
remoteId := binary.BigEndian.Uint32(packet[1:5])
s.lock.Lock()
c, ok := s.channels[remoteId]
if !ok {
s.lock.Unlock()
continue
}
if length := binary.BigEndian.Uint32(packet[5:9]); length > 0 {
packet = packet[9:]
c.handleData(packet[:length])
}
s.lock.Unlock()
default:
decoded, err := decode(packet)
if err != nil {
return nil, err
}
switch msg := decoded.(type) {
case *channelOpenMsg:
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return nil, errors.New("ssh: invalid MaxPacketSize from peer")
}
c := &serverChan{
channel: channel{
packetConn: s.transport,
remoteId: msg.PeersId,
remoteWin: window{Cond: newCond()},
maxPacket: msg.MaxPacketSize,
},
chanType: msg.ChanType,
extraData: msg.TypeSpecificData,
myWindow: defaultWindowSize,
serverConn: s,
cond: newCond(),
pendingData: make([]byte, defaultWindowSize),
}
c.remoteWin.add(msg.PeersWindow)
s.lock.Lock()
c.localId = s.nextChanId
s.nextChanId++
s.channels[c.localId] = c
s.lock.Unlock()
return c, nil
case *channelRequestMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *windowAdjustMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelEOFMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelCloseMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *globalRequestMsg:
if msg.WantReply {
if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
case *kexInitMsg:
s.lock.Lock()
if err := s.clientInitHandshake(msg, packet); err != nil {
s.lock.Unlock()
return nil, err
}
s.lock.Unlock()
case *disconnectMsg:
return nil, io.EOF
default:
// Unknown message. Ignore.
}
}
}
panic("unreachable")
}
// A Listener implements a network listener (net.Listener) for SSH connections.
type Listener struct {
listener net.Listener
config *ServerConfig
}
// Addr returns the listener's network address.
func (l *Listener) Addr() net.Addr {
return l.listener.Addr()
}
// Close closes the listener.
func (l *Listener) Close() error {
return l.listener.Close()
}
// Accept waits for and returns the next incoming SSH connection.
// The receiver should call Handshake() in another goroutine
// to avoid blocking the accepter.
func (l *Listener) Accept() (*ServerConn, error) {
c, err := l.listener.Accept()
if err != nil {
return nil, err
}
return Server(c, l.config), nil
}
// Listen creates an SSH listener accepting connections on
// the given network address using net.Listen.
func Listen(network, addr string, config *ServerConfig) (*Listener, error) {
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &Listener{
l,
config,
}, nil
}
@@ -1,81 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// A Terminal is capable of parsing and generating virtual terminal
// data from an SSH client.
type Terminal interface {
ReadLine() (line string, err error)
SetSize(x, y int)
Write([]byte) (int, error)
}
// ServerTerminal contains the state for running a terminal that is capable of
// reading lines of input.
type ServerTerminal struct {
Term Terminal
Channel Channel
}
// parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (width, height int, ok bool) {
_, s, ok = parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, _, ok := parseUint32(s)
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
}
return
}
func (ss *ServerTerminal) Write(buf []byte) (n int, err error) {
return ss.Term.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *ServerTerminal) ReadLine() (line string, err error) {
for {
if line, err = ss.Term.ReadLine(); err == nil {
return
}
req, ok := err.(ChannelRequest)
if !ok {
return
}
ok = false
switch req.Request {
case "pty-req":
var width, height int
width, height, ok = parsePtyRequest(req.Payload)
ss.Term.SetSize(width, height)
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
if req.WantReply {
ss.Channel.AckRequest(ok)
}
}
panic("unreachable")
}
@@ -1,12 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build linux
package terminal
import "syscall"
const ioctlReadTermios = syscall.TCGETS
const ioctlWriteTermios = syscall.TCSETS
@@ -1,246 +0,0 @@
package test
import (
"reflect"
"strings"
"testing"
"code.google.com/p/go.crypto/ssh"
)
var (
validKey = `AAAAB3NzaC1yc2EAAAADAQABAAABAQDEX/dPu4PmtvgK3La9zioCEDrJ` +
`yUr6xEIK7Pr+rLgydcqWTU/kt7w7gKjOw4vvzgHfjKl09CWyvgb+y5dCiTk` +
`9MxI+erGNhs3pwaoS+EavAbawB7iEqYyTep3YaJK+4RJ4OX7ZlXMAIMrTL+` +
`UVrK89t56hCkFYaAgo3VY+z6rb/b3bDBYtE1Y2tS7C3au73aDgeb9psIrSV` +
`86ucKBTl5X62FnYiyGd++xCnLB6uLximM5OKXfLzJQNS/QyZyk12g3D8y69` +
`Xw1GzCSKX1u1+MQboyf0HJcG2ryUCLHdcDVppApyHx2OLq53hlkQ/yxdflD` +
`qCqAE4j+doagSsIfC1T2T`
authWithOptions = []string{
`# comments to ignore before any keys...`,
``,
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + validKey + ` user@host`,
`# comments to ignore, along with a blank line`,
``,
`env="HOME=/home/root2" ssh-rsa ` + validKey + ` user2@host2`,
``,
`# more comments, plus a invalid entry`,
`ssh-rsa data-that-will-not-parse user@host3`,
}
authOptions = strings.Join(authWithOptions, "\n")
authWithCRLF = strings.Join(authWithOptions, "\r\n")
authInvalid = []byte(`ssh-rsa`)
authWithQuotedCommaInEnv = []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + validKey + ` user@host`)
authWithQuotedSpaceInEnv = []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + validKey + ` user@host`)
authWithQuotedQuoteInEnv = []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + validKey + ` user@host`)
authWithDoubleQuotedQuote = []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + validKey + "\t" + `user@host`)
authWithInvalidSpace = []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + validKey + ` user@host
#more to follow but still no valid keys`)
authWithMissingQuote = []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + validKey + ` user@host
env="HOME=/home/root",shared-control ssh-rsa ` + validKey + ` user@host`)
testClientPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAxF/3T7uD5rb4Cty2vc4qAhA6yclK+sRCCuz6/qy4MnXKlk1P
5Le8O4CozsOL784B34ypdPQlsr4G/suXQok5PTMSPnqxjYbN6cGqEvhGrwG2sAe4
hKmMk3qd2GiSvuESeDl+2ZVzACDK0y/lFayvPbeeoQpBWGgIKN1WPs+q2/292wwW
LRNWNrUuwt2ru92g4Hm/abCK0lfOrnCgU5eV+thZ2IshnfvsQpyweri8YpjOTil3
y8yUDUv0MmcpNdoNw/MuvV8NRswkil9btfjEG6Mn9ByXBtq8lAix3XA1aaQKch8d
ji6ud4ZZEP8sXX5Q6gqgBOI/naGoErCHwtU9kwIDAQABAoIBAFJRKAp0QEZmTHPB
MZk+4r0asIoFpziXLFgIHu7C2DPOzK1Umzj1DCKlPB3wOqi7Ym2jOSWdcnAK2EPW
dAGgJC5TSkKGjAcXixmB5RkumfKidUI0+lQh/puTurcMnvcEwglDkLkEvMBA/sSo
Pw9m486rOgOnmNzGPyViItURmD2+0yDdLl/vOsO/L1p76GCd0q0J3LqnmsQmawi7
Zwj2Stm6BIrggG5GsF204Iet5219TYLo4g1Qb2AlJ9C8P1FtAWhMwJalDxH9Os2/
KCDjnaq5n3bXbIU+3QjskjeVXL/Fnbhjnh4zs1EA7eHzl9dCGbcZ2LOimo2PRo8q
wVQmz4ECgYEA9dhiu74TxRVoaO5N2X+FsMzRO8gZdP3Z9IrV4jVN8WT4Vdp0snoF
gkVkqqbQUNKUb5K6B3Js/qNKfcjLbCNq9fewTcT6WsHQdtPbX/QA6Pa2Z29wrlA2
wrIYaAkmVaHny7wsOmgX01aOnuf2MlUnksK43sjZHdIo/m+sDKwwY1cCgYEAzHx4
mwUDMdRF4qpDKJhthraBNejRextNQQYsHVnNaMwZ4aeQcH5l85Cgjm7VpGlbVyBQ
h4zwFvllImp3D2U3mjVkV8Tm9ID98eWvw2YDzBnS3P3SysajD23Z+BXSG9GNv/8k
oAm+bVlvnJy4haK2AcIMk1YFuDuAOmy73abk7iUCgYEAj4qVM1sq/eKfAM1LJRfg
/jbIX+hYfMePD8pUUWygIra6jJ4tjtvSBZrwyPb3IImjY3W/KoP0AcVjxAeORohz
dkP1a6L8LiuFxSuzpdW5BkyuebxGhXCOWKVVvMDC4jLTPVCUXlHSv3GFemCjjgXM
QlNxT5rjsha4Gr8nLIsJAacCgYA4VA1Q/pd7sXKy1p37X8nD8yAyvnh+Be5I/C9I
woUP2jFC9MqYAmmJJ4ziz2swiAkuPeuQ+2Tjnz2ZtmQnrIUdiJmkh8vrDGFnshKx
q7deELsCPzVCwGcIiAUkDra7DQWUHu9y2lxHePyC0rUNst2aLF8UcvzOXC2danhx
vViQtQKBgCmZ7YavE/GNWww8N3xHBJ6UPmUuhQlnAbgNCcdyz30MevBg/JbyUTs2
slftTH15QusJ1UoITnnZuFJ40LqDvh8UhiK09ffM/IbUx839/m2vUOdFZB/WNn9g
Cy0LzddU4KE8JZ/tlk68+hM5fjLLA0aqSunaql5CKfplwLu8x1hL
-----END RSA PRIVATE KEY-----
`
keys = map[string]string{
"ssh_host_dsa_key": `-----BEGIN DSA PRIVATE KEY-----
MIIBugIBAAKBgQDe2SIKvZdBp+InawtSXH0NotiMPhm3udyu4hh/E+icMz264kDX
v+sV7ddnSQGQWZ/eVU7Jtx29dCMD1VlFpEd7yGKzmdwJIeA+YquNWoqBRQEJsWWS
7Fsfvv83dA/DTNIQfOY3+TIs6Mb9vagbgQMU3JUWEhbLE9LCEU6UwwRlpQIVAL4p
JF83SwpE8Jx6KnDpR89npkl/AoGAAy00TdDnAXvStwrZiAFbjZi8xDmPa9WwpfhJ
Rkno45TthDLrS+WmqY8/LTwlqZdOBtoBAynMJfKkUiZM21lWWpL1hRKYdwBlIBy5
XdR2/6wcPSuZ0tCQhDBTstX0Q3P1j198KGKvzy7q9vILKQwtSRqLS1y4JJERafdO
E+9CnGwCgYBz0WwBe2EZtGhGhBdnelTIBeo7PIsr0PzqxQj+dc8PBl8K9FfhRyOp
U39stUvoUxE9vaIFrY1P5xENjLFnPf+hlcuf40GUWEssW9YWPOaBp8afa9hY5Sxs
pvNR6eZFEFOJnx/ZgcA4g+vbrgGi5cM0W470mbGw2CkfJQUafdoIgAIUF+2I9kZe
2FTBuC9uacqczDlc+0k=
-----END DSA PRIVATE KEY-----`,
"ssh_host_rsa_key": `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAuf76Ue2Wtae9oDtaS6rIJgO7iCFTsZUTW9LBsvx/2nli6jKU
d9tUbBRzgdbnRLJ32UljXhERuB/axlrX8/lBzUZ+oYiM0KkEEOXY1z/bcMxdRxGF
XHuf4uXvyC2XyA4+ZvBeS4j1QFyIHZ62o7gAlKMTjiek3B4AQEJAlCLmhH3jB8wc
K/IYXAOlNGM5G44/ZLQpTi8diOV6DLs7tJ7rtEQedOEJfZng5rwp0USFkqcbfDbe
9/hk0J32jZvOtZNBokYtBb4YEdIiWBzzNtHzU3Dzw61+TKVXaH5HaIvzL9iMrw9f
kJbJyogfZk9BJfemEN+xqP72jlhE8LXNhpTxFQIDAQABAoIBAHbdf+Y5+5XuNF6h
b8xpwW2h9whBnDYiOnP1VfroKWFbMB7R4lZS4joMO+FfkP8zOyqvHwTvza4pFWys
g9SUmDvy8FyVYsC7MzEFYzX0xm3o/Te898ip7P1Zy4rXsGeWysSImwqU5X+TYx3i
33/zyNM1APtZVJ+jwK9QZ+sD/uPuZK2yS03HGSMZq6ebdoOSaYhluKrxXllSLO1J
KJxDiDdy2lEFw0W8HcI3ly1lg6OI+TRqqaCcLVNF4fNJmYIFM+2VEI9BdgynIh0Q
pMZlJKgaEBcSqCymnTK81ohYD1cV4st2B0km3Sw35Rl04Ij5ITeiya3hp8VfE6UY
PljkA6UCgYEA4811FTFj+kzNZ86C4OW1T5sM4NZt8gcz6CSvVnl+bDzbEOMMyzP7
2I9zKsR5ApdodH2m8d+RUw1Oe0bNGW5xig/DH/hn9lLQaO52JAi0we8A94dUUMSq
fUk9jKZEXpP/MlfTdJaPos9mxT7z8jREQxIiqH9AV0rLVDOCfDbSWj8CgYEA0QTE
IAUuki3UUqYKzLQrh/QmhY5KTx5amNW9XZ2VGtJvDPJrtBSBZlPEuXZAc4eBWEc7
U3Y9QwsalzupU6Yi6+gmofaXs8xJnj+jKth1DnJvrbLLGlSmf2Ijnwt22TyFUOtt
UAknpjHutDjQPf7pUGWaCPgwwKFsdB8EBjpJF6sCgYAfXesBQAvEK08dPBJJZVfR
3kenrd71tIgxLtv1zETcIoUHjjv0vvOunhH9kZAYC0EWyTZzl5UrGmn0D4uuNMbt
e74iaNHn2P9Zc3xQ+eHp0j8P1lKFzI6tMaiH9Vz0qOw6wl0bcJ/WizhbcI+migvc
MGMVUHBLlMDqly0gbWwJgQKBgQCgtb9ut01FjANSwORQ3L8Tu3/a9Lrh9n7GQKFn
V4CLrP1BwStavOF5ojMCPo/zxF6JV8ufsqwL3n/FhFP/QyBarpb1tTqTPiHkkR2O
Ffx67TY9IdnUFv4lt3mYEiKBiW0f+MSF42Qe/wmAfKZw5IzUCirTdrFVi0huSGK5
vxrwHQKBgHZ7RoC3I2f6F5fflA2ZAe9oJYC7XT624rY7VeOBwK0W0F47iV3euPi/
pKvLIBLcWL1Lboo+girnmSZtIYg2iLS3b4T9VFcKWg0y4AVwmhMWe9jWIltfWAAX
9l0lNikMRGAx3eXudKXEtbGt3/cUzPVaQUHy5LiBxkxnFxgaJPXs
-----END RSA PRIVATE KEY-----`,
"ssh_host_ecdsa_key": `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
-----END EC PRIVATE KEY-----`,
"authorized_keys": `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDEX/dPu4PmtvgK3La9zioCEDrJyUr6xEIK7Pr+rLgydcqWTU/kt7w7gKjOw4vvzgHfjKl09CWyvgb+y5dCiTk9MxI+erGNhs3pwaoS+EavAbawB7iEqYyTep3YaJK+4RJ4OX7ZlXMAIMrTL+UVrK89t56hCkFYaAgo3VY+z6rb/b3bDBYtE1Y2tS7C3au73aDgeb9psIrSV86ucKBTl5X62FnYiyGd++xCnLB6uLximM5OKXfLzJQNS/QyZyk12g3D8y69Xw1GzCSKX1u1+MQboyf0HJcG2ryUCLHdcDVppApyHx2OLq53hlkQ/yxdflDqCqAE4j+doagSsIfC1T2T user@host`,
}
)
func TestMarshalParsePublicKey(t *testing.T) {
pub := getTestPublicKey(t)
authKeys := ssh.MarshalAuthorizedKey(pub)
actualFields := strings.Fields(string(authKeys))
if len(actualFields) == 0 {
t.Fatalf("failed authKeys: %v", authKeys)
}
// drop the comment
expectedFields := strings.Fields(keys["authorized_keys"])[0:2]
if !reflect.DeepEqual(actualFields, expectedFields) {
t.Errorf("got %v, expected %v", actualFields, expectedFields)
}
actPub, _, _, _, ok := ssh.ParseAuthorizedKey([]byte(keys["authorized_keys"]))
if !ok {
t.Fatalf("cannot parse %v", keys["authorized_keys"])
}
if !reflect.DeepEqual(actPub, pub) {
t.Errorf("got %v, expected %v", actPub, pub)
}
}
type authResult struct {
pubKey interface{} //*rsa.PublicKey
options []string
comments string
rest string
ok bool
}
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
rest := authKeys
var values []authResult
for len(rest) > 0 {
var r authResult
r.pubKey, r.comments, r.options, rest, r.ok = ssh.ParseAuthorizedKey(rest)
r.rest = string(rest)
values = append(values, r)
}
if !reflect.DeepEqual(values, expected) {
t.Errorf("got %q, expected %q", values, expected)
}
}
func getTestPublicKey(t *testing.T) ssh.PublicKey {
priv, err := ssh.ParsePrivateKey([]byte(testClientPrivateKey))
if err != nil {
t.Fatalf("ParsePrivateKey: %v", err)
}
return priv.PublicKey()
}
func TestAuth(t *testing.T) {
pub := getTestPublicKey(t)
rest2 := strings.Join(authWithOptions[3:], "\n")
rest3 := strings.Join(authWithOptions[6:], "\n")
testAuthorizedKeys(t, []byte(authOptions), []authResult{
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
{nil, nil, "", "", false},
})
}
func TestAuthWithCRLF(t *testing.T) {
pub := getTestPublicKey(t)
rest2 := strings.Join(authWithOptions[3:], "\r\n")
rest3 := strings.Join(authWithOptions[6:], "\r\n")
testAuthorizedKeys(t, []byte(authWithCRLF), []authResult{
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
{nil, nil, "", "", false},
})
}
func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
pub := getTestPublicKey(t)
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedCommaInEnv(t *testing.T) {
pub := getTestPublicKey(t)
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
pub := getTestPublicKey(t)
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
})
}
func TestAuthWithInvalidSpace(t *testing.T) {
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
{nil, nil, "", "", false},
})
}
func TestAuthWithMissingQuote(t *testing.T) {
pub := getTestPublicKey(t)
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
})
}
func TestInvalidEntry(t *testing.T) {
_, _, _, _, ok := ssh.ParseAuthorizedKey(authInvalid)
if ok {
t.Errorf("Expected invalid entry, returned valid entry")
}
}
@@ -1,47 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// direct-tcpip functional tests
import (
"net"
"net/http"
"testing"
)
func TestTCPIPHTTP(t *testing.T) {
// google.com will generate at least one redirect, possibly three
// depending on your location.
doTest(t, "http://google.com")
}
func TestTCPIPHTTPS(t *testing.T) {
doTest(t, "https://encrypted.google.com/")
}
func doTest(t *testing.T, url string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
tr := &http.Transport{
Dial: func(n, addr string) (net.Conn, error) {
return conn.Dial(n, addr)
},
}
client := &http.Client{
Transport: tr,
}
resp, err := client.Get(url)
if err != nil {
t.Fatalf("unable to proxy: %s", err)
}
// got a body without error
t.Log(resp)
}
@@ -1,426 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bufio"
"crypto/cipher"
"crypto/subtle"
"encoding/binary"
"errors"
"hash"
"io"
"net"
"sync"
)
const (
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
// indicates implementations SHOULD be able to handle larger packet sizes, but then
// waffles on about reasonable limits.
//
// OpenSSH caps their maxPacket at 256kb so we choose to do the same.
maxPacket = 256 * 1024
)
// packetConn represents a transport that implements packet based
// operations.
type packetConn interface {
// Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error
// Read a packet from the connection
readPacket() ([]byte, error)
// Close closes the write-side of the connection.
Close() error
}
// transport represents the SSH connection to the remote peer.
type transport struct {
reader
writer
net.Conn
// Initial H used for the session ID. Once assigned this does
// not change, even during subsequent key exchanges.
sessionID []byte
}
// reader represents the incoming connection state.
type reader struct {
io.Reader
common
}
// writer represents the outgoing connection state.
type writer struct {
sync.Mutex // protects writer.Writer from concurrent writes
*bufio.Writer
rand io.Reader
common
}
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
t.writer.cipherAlgo = algs.wCipher
t.writer.macAlgo = algs.wMAC
t.writer.compressionAlgo = algs.wCompression
t.reader.cipherAlgo = algs.rCipher
t.reader.macAlgo = algs.rMAC
t.reader.compressionAlgo = algs.rCompression
if t.sessionID == nil {
t.sessionID = kexResult.H
}
kexResult.SessionID = t.sessionID
t.reader.pendingKeyChange <- kexResult
t.writer.pendingKeyChange <- kexResult
return nil
}
// common represents the cipher state needed to process messages in a single
// direction.
type common struct {
seqNum uint32
mac hash.Hash
cipher cipher.Stream
cipherAlgo string
macAlgo string
compressionAlgo string
dir direction
pendingKeyChange chan *kexResult
}
// Read and decrypt a single packet from the remote peer.
func (r *reader) readPacket() ([]byte, error) {
var lengthBytes = make([]byte, 5)
var macSize uint32
if _, err := io.ReadFull(r, lengthBytes); err != nil {
return nil, err
}
r.cipher.XORKeyStream(lengthBytes, lengthBytes)
if r.mac != nil {
r.mac.Reset()
seqNumBytes := []byte{
byte(r.seqNum >> 24),
byte(r.seqNum >> 16),
byte(r.seqNum >> 8),
byte(r.seqNum),
}
r.mac.Write(seqNumBytes)
r.mac.Write(lengthBytes)
macSize = uint32(r.mac.Size())
}
length := binary.BigEndian.Uint32(lengthBytes[0:4])
paddingLength := uint32(lengthBytes[4])
if length <= paddingLength+1 {
return nil, errors.New("ssh: invalid packet length, packet too small")
}
if length > maxPacket {
return nil, errors.New("ssh: invalid packet length, packet too large")
}
packet := make([]byte, length-1+macSize)
if _, err := io.ReadFull(r, packet); err != nil {
return nil, err
}
mac := packet[length-1:]
r.cipher.XORKeyStream(packet, packet[:length-1])
if r.mac != nil {
r.mac.Write(packet[:length-1])
if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 {
return nil, errors.New("ssh: MAC failure")
}
}
r.seqNum++
packet = packet[:length-paddingLength-1]
if len(packet) > 0 && packet[0] == msgNewKeys {
select {
case k := <-r.pendingKeyChange:
if err := r.setupKeys(r.dir, k); err != nil {
return nil, err
}
default:
return nil, errors.New("ssh: got bogus newkeys message.")
}
}
return packet, nil
}
// Read and decrypt next packet discarding debug and noop messages.
func (t *transport) readPacket() ([]byte, error) {
for {
packet, err := t.reader.readPacket()
if err != nil {
return nil, err
}
if len(packet) == 0 {
return nil, errors.New("ssh: zero length packet")
}
if packet[0] != msgIgnore && packet[0] != msgDebug {
return packet, nil
}
}
panic("unreachable")
}
// Encrypt and send a packet of data to the remote peer.
func (w *writer) writePacket(packet []byte) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
if len(packet) > maxPacket {
return errors.New("ssh: packet too large")
}
w.Mutex.Lock()
defer w.Mutex.Unlock()
paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
if paddingLength < 4 {
paddingLength += packetSizeMultiple
}
length := len(packet) + 1 + paddingLength
lengthBytes := []byte{
byte(length >> 24),
byte(length >> 16),
byte(length >> 8),
byte(length),
byte(paddingLength),
}
padding := make([]byte, paddingLength)
_, err := io.ReadFull(w.rand, padding)
if err != nil {
return err
}
if w.mac != nil {
w.mac.Reset()
seqNumBytes := []byte{
byte(w.seqNum >> 24),
byte(w.seqNum >> 16),
byte(w.seqNum >> 8),
byte(w.seqNum),
}
w.mac.Write(seqNumBytes)
w.mac.Write(lengthBytes)
w.mac.Write(packet)
w.mac.Write(padding)
}
// TODO(dfc) lengthBytes, packet and padding should be
// subslices of a single buffer
w.cipher.XORKeyStream(lengthBytes, lengthBytes)
w.cipher.XORKeyStream(packet, packet)
w.cipher.XORKeyStream(padding, padding)
if _, err := w.Write(lengthBytes); err != nil {
return err
}
if _, err := w.Write(packet); err != nil {
return err
}
if _, err := w.Write(padding); err != nil {
return err
}
if w.mac != nil {
if _, err := w.Write(w.mac.Sum(nil)); err != nil {
return err
}
}
w.seqNum++
if err = w.Flush(); err != nil {
return err
}
if changeKeys {
select {
case k := <-w.pendingKeyChange:
err = w.setupKeys(w.dir, k)
default:
panic("ssh: no key material for msgNewKeys")
}
}
return err
}
func newTransport(conn net.Conn, rand io.Reader, isClient bool) *transport {
t := &transport{
reader: reader{
Reader: bufio.NewReader(conn),
common: common{
cipher: noneCipher{},
pendingKeyChange: make(chan *kexResult, 1),
},
},
writer: writer{
Writer: bufio.NewWriter(conn),
rand: rand,
common: common{
cipher: noneCipher{},
pendingKeyChange: make(chan *kexResult, 1),
},
},
Conn: conn,
}
if isClient {
t.reader.dir = serverKeys
t.writer.dir = clientKeys
} else {
t.reader.dir = clientKeys
t.writer.dir = serverKeys
}
return t
}
type direction struct {
ivTag []byte
keyTag []byte
macKeyTag []byte
}
// TODO(dfc) can this be made a constant ?
var (
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func (c *common) setupKeys(d direction, r *kexResult) error {
cipherMode := cipherModes[c.cipherAlgo]
macMode := macModes[c.macAlgo]
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macMode.keySize)
h := r.Hash.New()
generateKeyMaterial(iv, d.ivTag, r.K, r.H, r.SessionID, h)
generateKeyMaterial(key, d.keyTag, r.K, r.H, r.SessionID, h)
generateKeyMaterial(macKey, d.macKeyTag, r.K, r.H, r.SessionID, h)
c.mac = macMode.new(macKey)
var err error
c.cipher, err = cipherMode.createCipher(key, iv)
return err
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
var digestsSoFar []byte
for len(out) > 0 {
h.Reset()
h.Write(K)
h.Write(H)
if len(digestsSoFar) == 0 {
h.Write(tag)
h.Write(sessionId)
} else {
h.Write(digestsSoFar)
}
digest := h.Sum(nil)
n := copy(out, digest)
out = out[n:]
if len(out) > 0 {
digestsSoFar = append(digestsSoFar, digest...)
}
}
}
const packageVersion = "SSH-2.0-Go"
// Sends and receives a version line. The versionLine string should
// be US ASCII, start with "SSH-2.0-", and should not include a
// newline. exchangeVersions returns the other side's version line.
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
// Contrary to the RFC, we do not ignore lines that don't
// start with "SSH-2.0-" to make the library usable with
// nonconforming servers.
for _, c := range versionLine {
// The spec disallows non US-ASCII chars, and
// specifically forbids null chars.
if c < 32 {
return nil, errors.New("ssh: junk character in version line")
}
}
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
return
}
them, err = readVersion(rw)
return them, err
}
// maxVersionStringBytes is the maximum number of bytes that we'll
// accept as a version string. RFC 4253 section 4.2 limits this at 255
// chars
const maxVersionStringBytes = 255
// Read version string as specified by RFC 4253, section 4.2.
func readVersion(r io.Reader) ([]byte, error) {
versionString := make([]byte, 0, 64)
var ok bool
var buf [1]byte
for len(versionString) < maxVersionStringBytes {
_, err := io.ReadFull(r, buf[:])
if err != nil {
return nil, err
}
// The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
ok = true
break
}
// non ASCII chars are disallowed, but we are lenient,
// since Go doesn't use null-terminated strings.
// The RFC allows a comment after a space, however,
// all of it (version and comments) goes into the
// session hash.
versionString = append(versionString, buf[0])
}
if !ok {
return nil, errors.New("ssh: overflow reading version string")
}
// There might be a '\r' on the end which we should remove.
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
versionString = versionString[:len(versionString)-1]
}
return versionString, nil
}
@@ -0,0 +1,53 @@
package pointer
func ToString(s string) *string {
return &s
}
func ToUint(u uint) *uint {
return &u
}
func ToUint8(u uint8) *uint8 {
return &u
}
func ToUint16(u uint16) *uint16 {
return &u
}
func ToUint32(u uint32) *uint32 {
return &u
}
func ToUint64(u uint64) *uint64 {
return &u
}
func ToInt(i int) *int {
return &i
}
func ToInt8(i int8) *int8 {
return &i
}
func ToInt16(i int16) *int16 {
return &i
}
func ToInt32(i int32) *int32 {
return &i
}
func ToInt64(i int64) *int64 {
return &i
}
func ToFloat32(f float32) *float32 {
return &f
}
func ToFloat64(f float64) *float64 {
return &f
}
@@ -57,6 +57,12 @@ bar.ShowTimeLeft = true
// show average speed
bar.ShowSpeed = true
// sets the width of the progress bar
bar.SetWith(80)
// sets the width of the progress bar, but if terminal size smaller will be ignored
bar.SetMaxWith(80)
// convert output to readable format (like KB, MB)
bar.SetUnits(pb.U_BYTES)
@@ -25,13 +25,21 @@ var (
// Create new progress bar object
func New(total int) (pb *ProgressBar) {
return New64(int64(total))
}
// Create new progress bar object uding int64 as total
func New64(total int64) (pb *ProgressBar) {
pb = &ProgressBar{
Total: int64(total),
RefreshRate: DEFAULT_REFRESH_RATE,
ShowPercent: true,
ShowCounters: true,
ShowBar: true,
ShowTimeLeft: true,
Total: total,
RefreshRate: DEFAULT_REFRESH_RATE,
ShowPercent: true,
ShowCounters: true,
ShowBar: true,
ShowTimeLeft: true,
ShowFinalTime: true,
ManualUpdate: false,
currentValue: -1,
}
pb.Format(FORMAT)
return
@@ -59,13 +67,20 @@ type ProgressBar struct {
RefreshRate time.Duration
ShowPercent, ShowCounters bool
ShowSpeed, ShowTimeLeft, ShowBar bool
ShowFinalTime bool
Output io.Writer
Callback Callback
NotPrint bool
Units int
Width int
ForceWidth bool
ManualUpdate bool
isFinish bool
startTime time.Time
isFinish int32
startTime time.Time
currentValue int64
prefix, postfix string
BarStart string
BarEnd string
@@ -82,7 +97,9 @@ func (pb *ProgressBar) Start() {
pb.ShowTimeLeft = false
pb.ShowPercent = false
}
go pb.writer()
if !pb.ManualUpdate {
go pb.writer()
}
}
// Increment current value
@@ -97,7 +114,23 @@ func (pb *ProgressBar) Set(current int) {
// Add to current value
func (pb *ProgressBar) Add(add int) int {
return int(atomic.AddInt64(&pb.current, int64(add)))
return int(pb.Add64(int64(add)))
}
func (pb *ProgressBar) Add64(add int64) int64 {
return atomic.AddInt64(&pb.current, add)
}
// Set prefix string
func (pb *ProgressBar) Prefix(prefix string) (bar *ProgressBar) {
pb.prefix = prefix
return pb
}
// Set postfix string
func (pb *ProgressBar) Postfix(postfix string) (bar *ProgressBar) {
pb.postfix = postfix
return pb
}
// Set custom format for bar
@@ -135,9 +168,25 @@ func (pb *ProgressBar) SetUnits(units int) (bar *ProgressBar) {
return
}
// Set max width, if width is bigger than terminal width, will be ignored
func (pb *ProgressBar) SetMaxWidth(width int) (bar *ProgressBar) {
bar = pb
pb.Width = width
pb.ForceWidth = false
return
}
// Set bar width
func (pb *ProgressBar) SetWidth(width int) (bar *ProgressBar) {
bar = pb
pb.Width = width
pb.ForceWidth = true
return
}
// End print
func (pb *ProgressBar) Finish() {
pb.isFinish = true
atomic.StoreInt32(&pb.isFinish, 1)
pb.write(atomic.LoadInt64(&pb.current))
if !pb.NotPrint {
fmt.Println()
@@ -170,7 +219,8 @@ func (pb *ProgressBar) NewProxyReader(r io.Reader) *Reader {
}
func (pb *ProgressBar) write(current int64) {
width, _ := terminalWidth()
width := pb.getWidth()
var percentBox, countersBox, timeLeftBox, speedBox, barBox, end, out string
// percents
@@ -189,14 +239,17 @@ func (pb *ProgressBar) write(current int64) {
}
// time left
if pb.ShowTimeLeft && current > 0 {
fromStart := time.Now().Sub(pb.startTime)
fromStart := time.Now().Sub(pb.startTime)
if atomic.LoadInt32(&pb.isFinish) != 0 {
if pb.ShowFinalTime {
left := (fromStart / time.Second) * time.Second
timeLeftBox = left.String()
}
} else if pb.ShowTimeLeft && current > 0 {
perEntry := fromStart / time.Duration(current)
left := time.Duration(pb.Total-current) * perEntry
left = (left / time.Second) * time.Second
if left > 0 {
timeLeftBox = left.String()
}
timeLeftBox = left.String()
}
// speed
@@ -208,7 +261,7 @@ func (pb *ProgressBar) write(current int64) {
// bar
if pb.ShowBar {
size := width - len(countersBox+pb.BarStart+pb.BarEnd+percentBox+timeLeftBox+speedBox)
size := width - len(countersBox+pb.BarStart+pb.BarEnd+percentBox+timeLeftBox+speedBox+pb.prefix+pb.postfix)
if size > 0 {
curCount := int(math.Ceil((float64(current) / float64(pb.Total)) * float64(size)))
emptCount := size - curCount
@@ -230,17 +283,15 @@ func (pb *ProgressBar) write(current int64) {
}
// check len
out = countersBox + barBox + percentBox + speedBox + timeLeftBox
out = pb.prefix + countersBox + barBox + percentBox + speedBox + timeLeftBox + pb.postfix
if len(out) < width {
end = strings.Repeat(" ", width-len(out))
}
out = countersBox + barBox + percentBox + speedBox + timeLeftBox
// and print!
switch {
case pb.Output != nil:
fmt.Fprint(pb.Output, out+end)
fmt.Fprint(pb.Output, "\r"+out+end)
case pb.Callback != nil:
pb.Callback(out + end)
case !pb.NotPrint:
@@ -248,18 +299,36 @@ func (pb *ProgressBar) write(current int64) {
}
}
func (pb *ProgressBar) getWidth() int {
if pb.ForceWidth {
return pb.Width
}
width := pb.Width
termWidth, _ := terminalWidth()
if width == 0 || termWidth <= width {
width = termWidth
}
return width
}
// Write the current state of the progressbar
func (pb *ProgressBar) Update() {
c := atomic.LoadInt64(&pb.current)
if c != pb.currentValue {
pb.write(c)
pb.currentValue = c
}
}
// Internal loop for writing progressbar
func (pb *ProgressBar) writer() {
var c, oc int64
oc = -1
for {
if pb.isFinish {
if atomic.LoadInt32(&pb.isFinish) != 0 {
break
}
c = atomic.LoadInt64(&pb.current)
if c != oc {
pb.write(c)
oc = c
}
pb.Update()
time.Sleep(pb.RefreshRate)
}
}
@@ -1,35 +1,7 @@
// +build linux darwin freebsd
// +build linux darwin freebsd netbsd openbsd
package pb
import (
"runtime"
"syscall"
"unsafe"
)
import "syscall"
const (
TIOCGWINSZ = 0x5413
TIOCGWINSZ_OSX = 1074295912
)
func bold(str string) string {
return "\033[1m" + str + "\033[0m"
}
func terminalWidth() (int, error) {
w := new(window)
tio := syscall.TIOCGWINSZ
if runtime.GOOS == "darwin" {
tio = TIOCGWINSZ_OSX
}
res, _, err := syscall.Syscall(syscall.SYS_IOCTL,
uintptr(syscall.Stdin),
uintptr(tio),
uintptr(unsafe.Pointer(w)),
)
if int(res) == -1 {
return 0, err
}
return int(w.Col), nil
}
const sys_ioctl = syscall.SYS_IOCTL
@@ -0,0 +1,5 @@
// +build solaris
package pb
const sys_ioctl = 54
@@ -1,7 +1,6 @@
package pb
import (
"fmt"
"testing"
)
@@ -12,6 +11,20 @@ func Test_IncrementAddsOne(t *testing.T) {
actual := bar.Increment()
if actual != expected {
t.Error(fmt.Sprintf("Expected {%d} was {%d}", expected, actual))
t.Errorf("Expected {%d} was {%d}", expected, actual)
}
}
func Test_Width(t *testing.T) {
count := 5000
bar := New(count)
width := 100
bar.SetWidth(100).Callback = func(out string) {
if len(out) != width {
t.Errorf("Bar width expected {%d} was {%d}", len(out), width)
}
}
bar.Start()
bar.Increment()
bar.Finish()
}
@@ -0,0 +1,46 @@
// +build linux darwin freebsd netbsd openbsd solaris
package pb
import (
"os"
"runtime"
"syscall"
"unsafe"
)
const (
TIOCGWINSZ = 0x5413
TIOCGWINSZ_OSX = 1074295912
)
var tty *os.File
func init() {
var err error
tty, err = os.Open("/dev/tty")
if err != nil {
tty = os.Stdin
}
}
func bold(str string) string {
return "\033[1m" + str + "\033[0m"
}
func terminalWidth() (int, error) {
w := new(window)
tio := syscall.TIOCGWINSZ
if runtime.GOOS == "darwin" {
tio = TIOCGWINSZ_OSX
}
res, _, err := syscall.Syscall(sys_ioctl,
tty.Fd(),
uintptr(tio),
uintptr(unsafe.Pointer(w)),
)
if int(res) == -1 {
return 0, err
}
return int(w.Col), nil
}
@@ -0,0 +1,2 @@
Godeps/*
!Godeps/Godeps.json
@@ -0,0 +1,6 @@
language: go
sudo: false
go:
- 1.3
- 1.4
- tip
@@ -0,0 +1,174 @@
List of all the awesome people working to make Gin the best Web Framework in Go.
##gin 0.x series authors
**Original Developer:** Manu Martinez-Almeida (@manucorporat)
**Long-term Maintainer:** Javier Provecho (@javierprovecho)
People and companies, who have contributed, in alphabetical order.
**@achedeuzot (Klemen Sever)**
- Fix newline debug printing
**@adammck (Adam Mckaig)**
- Add MIT license
**@AlexanderChen1989 (Alexander)**
- Typos in README
**@alexandernyquist (Alexander Nyquist)**
- Using template.Must to fix multiple return issue
- ★ Added support for OPTIONS verb
- ★ Setting response headers before calling WriteHeader
- Improved documentation for model binding
- ★ Added Content.Redirect()
- ★ Added tons of Unit tests
**@austinheap (Austin Heap)**
- Added travis CI integration
**@andredublin (Andre Dublin)**
- Fix typo in comment
**@bredov (Ludwig Valda Vasquez)**
- Fix html templating in debug mode
**@bluele (Jun Kimura)**
- Fixes code examples in README
**@chad-russell**
- ★ Support for serializing gin.H into XML
**@dickeyxxx (Jeff Dickey)**
- Typos in README
- Add example about serving static files
**@dutchcoders (DutchCoders)**
- ★ Fix security bug that allows client to spoof ip
- Fix typo. r.HTMLTemplates -> SetHTMLTemplate
**@fmd (Fareed Dudhia)**
- Fix typo. SetHTTPTemplate -> SetHTMLTemplate
**@jammie-stackhouse (Jamie Stackhouse)**
- Add more shortcuts for router methods
**@jasonrhansen**
- Fix spelling and grammar errors in documentation
**@JasonSoft (Jason Lee)**
- Fix typo in comment
**@joiggama (Ignacio Galindo)**
- Add utf-8 charset header on renders
**@julienschmidt (Julien Schmidt)**
- gofmt the code examples
**@kelcecil (Kel Cecil)**
- Fix readme typo
**@kyledinh (Kyle Dinh)**
- Adds RunTLS()
**@LinusU (Linus Unnebäck)**
- Small fixes in README
**@loongmxbt (Saint Asky)**
- Fix typo in example
**@lucas-clemente (Lucas Clemente)**
- ★ work around path.Join removing trailing slashes from routes
**@mdigger (Dmitry Sedykh)**
- Fixes Form binding when content-type is x-www-form-urlencoded
- No repeat call c.Writer.Status() in gin.Logger
- Fixes Content-Type for json render
**@mirzac (Mirza Ceric)**
- Fix debug printing
**@mopemope (Yutaka Matsubara)**
- ★ Adds Godep support (Dependencies Manager)
- Fix variadic parameter in the flexible render API
- Fix Corrupted plain render
- Add Pluggable View Renderer Example
**@msemenistyi (Mykyta Semenistyi)**
- update Readme.md. Add code to String method
**@msoedov (Sasha Myasoedov)**
- ★ Adds tons of unit tests.
**@ngerakines (Nick Gerakines)**
- ★ Improves API, c.GET() doesn't panic
- Adds MustGet() method
**@r8k (Rajiv Kilaparti)**
- Fix Port usage in README.
**@RobAWilkinson (Robert Wilkinson)**
- Add example of forms and params
**@se77en (Damon Zhao)**
- Improve color logging
**@silasb (Silas Baronda)**
- Fixing quotes in README
**@SkuliOskarsson (Skuli Oskarsson)**
- Fixes some texts in README II
**@slimmy (Jimmy Pettersson)**
- Added messages for required bindings
**@smira (Andrey Smirnov)**
- Add support for ignored/unexported fields in binding
**@superalsrk (SRK.Lyu)**
- Update httprouter godeps
**@yosssi (Keiji Yoshida)**
- Fix link in README
**@yuyabee**
- Fixed README
@@ -0,0 +1,64 @@
#Changelog
###Gin 0.6 (Mar 7, 2015)
###Gin 0.5 (Feb 7, 2015)
- [NEW] Content Negotiation
- [FIX] Solved security bug that allow a client to spoof ip
- [FIX] Fix unexported/ignored fields in binding
###Gin 0.4 (Aug 21, 2014)
- [NEW] Development mode
- [NEW] Unit tests
- [NEW] Add Content.Redirect()
- [FIX] Deferring WriteHeader()
- [FIX] Improved documentation for model binding
###Gin 0.3 (Jul 18, 2014)
- [PERFORMANCE] Normal log and error log are printed in the same call.
- [PERFORMANCE] Improve performance of NoRouter()
- [PERFORMANCE] Improve context's memory locality, reduce CPU cache faults.
- [NEW] Flexible rendering API
- [NEW] Add Context.File()
- [NEW] Add shorcut RunTLS() for http.ListenAndServeTLS
- [FIX] Rename NotFound404() to NoRoute()
- [FIX] Errors in context are purged
- [FIX] Adds HEAD method in Static file serving
- [FIX] Refactors Static() file serving
- [FIX] Using keyed initialization to fix app-engine integration
- [FIX] Can't unmarshal JSON array, #63
- [FIX] Renaming Context.Req to Context.Request
- [FIX] Check application/x-www-form-urlencoded when parsing form
###Gin 0.2b (Jul 08, 2014)
- [PERFORMANCE] Using sync.Pool to allocatio/gc overhead
- [NEW] Travis CI integration
- [NEW] Completely new logger
- [NEW] New API for serving static files. gin.Static()
- [NEW] gin.H() can be serialized into XML
- [NEW] Typed errors. Errors can be typed. Internet/external/custom.
- [NEW] Support for Godeps
- [NEW] Travis/Godocs badges in README
- [NEW] New Bind() and BindWith() methods for parsing request body.
- [NEW] Add Content.Copy()
- [NEW] Add context.LastError()
- [NEW] Add shorcut for OPTIONS HTTP method
- [FIX] Tons of README fixes
- [FIX] Header is written before body
- [FIX] BasicAuth() and changes API a little bit
- [FIX] Recovery() middleware only prints panics
- [FIX] Context.Get() does not panic anymore. Use MustGet() instead.
- [FIX] Multiple http.WriteHeader() in NotFound handlers
- [FIX] Engine.Run() panics if http server can't be setted up
- [FIX] Crash when route path doesn't start with '/'
- [FIX] Do not update header when status code is negative
- [FIX] Setting response headers before calling WriteHeader in context.String()
- [FIX] Add MIT license
- [FIX] Changes behaviour of ErrorLogger() and Logger()
@@ -0,0 +1,10 @@
{
"ImportPath": "github.com/gin-gonic/gin",
"GoVersion": "go1.3",
"Deps": [
{
"ImportPath": "github.com/julienschmidt/httprouter",
"Rev": "00ce1c6a267162792c367acc43b1681a884e1872"
}
]
}
@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Manuel Martínez-Almeida
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
@@ -0,0 +1,536 @@
#Gin Web Framework [![GoDoc](https://godoc.org/github.com/gin-gonic/gin?status.svg)](https://godoc.org/github.com/gin-gonic/gin) [![Build Status](https://travis-ci.org/gin-gonic/gin.svg)](https://travis-ci.org/gin-gonic/gin)
Gin is a web framework written in Golang. It features a martini-like API with much better performance, up to 40 times faster thanks to [httprouter](https://github.com/julienschmidt/httprouter). If you need performance and good productivity, you will love Gin.
![Gin console logger](https://gin-gonic.github.io/gin/other/console.png)
```
$ cat test.go
```
```go
package main
import "github.com/gin-gonic/gin"
func main() {
router := gin.Default()
router.GET("/", func(c *gin.Context) {
c.String(200, "hello world")
})
router.GET("/ping", func(c *gin.Context) {
c.String(200, "pong")
})
router.POST("/submit", func(c *gin.Context) {
c.String(401, "not authorized")
})
router.PUT("/error", func(c *gin.Context) {
c.String(500, "and error hapenned :(")
})
router.Run(":8080")
}
```
##Gin is new, will it be supported?
Yes, Gin is an internal tool of [Manu](https://github.com/manucorporat) and [Javi](https://github.com/javierprovecho) for many of our projects/start-ups. We developed it and we are going to continue using and improve it.
##Roadmap for v1.0
- [ ] Ask our designer for a cool logo
- [ ] Add tons of unit tests
- [ ] Add internal benchmarks suite
- [ ] More powerful validation API
- [ ] Improve documentation
- [ ] Add Swagger support
- [x] Stable API
- [x] Improve logging system
- [x] Improve JSON/XML validation using bindings
- [x] Improve XML support
- [x] Flexible rendering system
- [x] Add more cool middlewares, for example redis caching (this also helps developers to understand the framework).
- [x] Continuous integration
- [x] Performance improments, reduce allocation and garbage collection overhead
- [x] Fix bugs
## Start using it
Obviously, you need to have Git and Go already installed to run Gin.
Run this in your terminal
```
go get github.com/gin-gonic/gin
```
Then import it in your Go code:
```
import "github.com/gin-gonic/gin"
```
##Community
If you'd like to help out with the project, there's a mailing list and IRC channel where Gin discussions normally happen.
* IRC
* [irc.freenode.net #getgin](irc://irc.freenode.net:6667/getgin)
* [Webchat](http://webchat.freenode.net?randomnick=1&channels=%23getgin)
* Mailing List
* Subscribe: [getgin@librelist.org](mailto:getgin@librelist.org)
* [Archives](http://librelist.com/browser/getgin/)
##API Examples
#### Create most basic PING/PONG HTTP endpoint
```go
package main
import "github.com/gin-gonic/gin"
func main() {
r := gin.Default()
r.GET("/ping", func(c *gin.Context) {
c.String(200, "pong")
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Using GET, POST, PUT, PATCH, DELETE and OPTIONS
```go
func main() {
// Creates a gin router + logger and recovery (crash-free) middlewares
r := gin.Default()
r.GET("/someGet", getting)
r.POST("/somePost", posting)
r.PUT("/somePut", putting)
r.DELETE("/someDelete", deleting)
r.PATCH("/somePatch", patching)
r.HEAD("/someHead", head)
r.OPTIONS("/someOptions", options)
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Parameters in path
```go
func main() {
r := gin.Default()
// This handler will match /user/john but will not match neither /user/ or /user
r.GET("/user/:name", func(c *gin.Context) {
name := c.Params.ByName("name")
message := "Hello "+name
c.String(200, message)
})
// However, this one will match /user/john/ and also /user/john/send
// If no other routers match /user/john, it will redirect to /user/join/
r.GET("/user/:name/*action", func(c *gin.Context) {
name := c.Params.ByName("name")
action := c.Params.ByName("action")
message := name + " is " + action
c.String(200, message)
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
###Form parameters
```go
func main() {
r := gin.Default()
// This will respond to urls like search?firstname=Jane&lastname=Doe
r.GET("/search", func(c *gin.Context) {
// You need to call ParseForm() on the request to receive url and form params first
c.Request.ParseForm()
firstname := c.Request.Form.Get("firstname")
lastname := c.Request.Form.get("lastname")
message := "Hello "+ firstname + lastname
c.String(200, message)
})
r.Run(":8080")
}
```
#### Grouping routes
```go
func main() {
r := gin.Default()
// Simple group: v1
v1 := r.Group("/v1")
{
v1.POST("/login", loginEndpoint)
v1.POST("/submit", submitEndpoint)
v1.POST("/read", readEndpoint)
}
// Simple group: v2
v2 := r.Group("/v2")
{
v2.POST("/login", loginEndpoint)
v2.POST("/submit", submitEndpoint)
v2.POST("/read", readEndpoint)
}
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Blank Gin without middlewares by default
Use
```go
r := gin.New()
```
instead of
```go
r := gin.Default()
```
#### Using middlewares
```go
func main() {
// Creates a router without any middleware by default
r := gin.New()
// Global middlewares
r.Use(gin.Logger())
r.Use(gin.Recovery())
// Per route middlewares, you can add as many as you desire.
r.GET("/benchmark", MyBenchLogger(), benchEndpoint)
// Authorization group
// authorized := r.Group("/", AuthRequired())
// exactly the same than:
authorized := r.Group("/")
// per group middlewares! in this case we use the custom created
// AuthRequired() middleware just in the "authorized" group.
authorized.Use(AuthRequired())
{
authorized.POST("/login", loginEndpoint)
authorized.POST("/submit", submitEndpoint)
authorized.POST("/read", readEndpoint)
// nested group
testing := authorized.Group("testing")
testing.GET("/analytics", analyticsEndpoint)
}
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Model binding and validation
To bind a request body into a type, use model binding. We currently support binding of JSON, XML and standard form values (foo=bar&boo=baz).
Note that you need to set the corresponding binding tag on all fields you want to bind. For example, when binding from JSON, set `json:"fieldname"`.
When using the Bind-method, Gin tries to infer the binder depending on the Content-Type header. If you are sure what you are binding, you can use BindWith.
You can also specify that specific fields are required. If a field is decorated with `binding:"required"` and has a empty value when binding, the current request will fail with an error.
```go
// Binding from JSON
type LoginJSON struct {
User string `json:"user" binding:"required"`
Password string `json:"password" binding:"required"`
}
// Binding from form values
type LoginForm struct {
User string `form:"user" binding:"required"`
Password string `form:"password" binding:"required"`
}
func main() {
r := gin.Default()
// Example for binding JSON ({"user": "manu", "password": "123"})
r.POST("/loginJSON", func(c *gin.Context) {
var json LoginJSON
c.Bind(&json) // This will infer what binder to use depending on the content-type header.
if json.User == "manu" && json.Password == "123" {
c.JSON(200, gin.H{"status": "you are logged in"})
} else {
c.JSON(401, gin.H{"status": "unauthorized"})
}
})
// Example for binding a HTML form (user=manu&password=123)
r.POST("/loginHTML", func(c *gin.Context) {
var form LoginForm
c.BindWith(&form, binding.Form) // You can also specify which binder to use. We support binding.Form, binding.JSON and binding.XML.
if form.User == "manu" && form.Password == "123" {
c.JSON(200, gin.H{"status": "you are logged in"})
} else {
c.JSON(401, gin.H{"status": "unauthorized"})
}
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### XML and JSON rendering
```go
func main() {
r := gin.Default()
// gin.H is a shortcut for map[string]interface{}
r.GET("/someJSON", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "hey", "status": 200})
})
r.GET("/moreJSON", func(c *gin.Context) {
// You also can use a struct
var msg struct {
Name string `json:"user"`
Message string
Number int
}
msg.Name = "Lena"
msg.Message = "hey"
msg.Number = 123
// Note that msg.Name becomes "user" in the JSON
// Will output : {"user": "Lena", "Message": "hey", "Number": 123}
c.JSON(200, msg)
})
r.GET("/someXML", func(c *gin.Context) {
c.XML(200, gin.H{"message": "hey", "status": 200})
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
####Serving static files
Use Engine.ServeFiles(path string, root http.FileSystem):
```go
func main() {
r := gin.Default()
r.Static("/assets", "./assets")
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
Note: this will use `httpNotFound` instead of the Router's `NotFound` handler.
####HTML rendering
Using LoadHTMLTemplates()
```go
func main() {
r := gin.Default()
r.LoadHTMLGlob("templates/*")
r.GET("/index", func(c *gin.Context) {
obj := gin.H{"title": "Main website"}
c.HTML(200, "index.tmpl", obj)
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
```html
<h1>
{{ .title }}
</h1>
```
You can also use your own html template render
```go
import "html/template"
func main() {
r := gin.Default()
html := template.Must(template.ParseFiles("file1", "file2"))
r.SetHTMLTemplate(html)
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Redirects
Issuing a HTTP redirect is easy:
```go
r.GET("/test", func(c *gin.Context) {
c.Redirect(301, "http://www.google.com/")
})
```
Both internal and external locations are supported.
#### Custom Middlewares
```go
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
t := time.Now()
// Set example variable
c.Set("example", "12345")
// before request
c.Next()
// after request
latency := time.Since(t)
log.Print(latency)
// access the status we are sending
status := c.Writer.Status()
log.Println(status)
}
}
func main() {
r := gin.New()
r.Use(Logger())
r.GET("/test", func(c *gin.Context) {
example := c.MustGet("example").(string)
// it would print: "12345"
log.Println(example)
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Using BasicAuth() middleware
```go
// similate some private data
var secrets = gin.H{
"foo": gin.H{"email": "foo@bar.com", "phone": "123433"},
"austin": gin.H{"email": "austin@example.com", "phone": "666"},
"lena": gin.H{"email": "lena@guapa.com", "phone": "523443"},
}
func main() {
r := gin.Default()
// Group using gin.BasicAuth() middleware
// gin.Accounts is a shortcut for map[string]string
authorized := r.Group("/admin", gin.BasicAuth(gin.Accounts{
"foo": "bar",
"austin": "1234",
"lena": "hello2",
"manu": "4321",
}))
// /admin/secrets endpoint
// hit "localhost:8080/admin/secrets
authorized.GET("/secrets", func(c *gin.Context) {
// get user, it was setted by the BasicAuth middleware
user := c.MustGet(gin.AuthUserKey).(string)
if secret, ok := secrets[user]; ok {
c.JSON(200, gin.H{"user": user, "secret": secret})
} else {
c.JSON(200, gin.H{"user": user, "secret": "NO SECRET :("})
}
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Goroutines inside a middleware
When starting inside a middleware or handler, you **SHOULD NOT** use the original context inside it, you have to use a read-only copy.
```go
func main() {
r := gin.Default()
r.GET("/long_async", func(c *gin.Context) {
// create copy to be used inside the goroutine
c_cp := c.Copy()
go func() {
// simulate a long task with time.Sleep(). 5 seconds
time.Sleep(5 * time.Second)
// note than you are using the copied context "c_cp", IMPORTANT
log.Println("Done! in path " + c_cp.Request.URL.Path)
}()
})
r.GET("/long_sync", func(c *gin.Context) {
// simulate a long task with time.Sleep(). 5 seconds
time.Sleep(5 * time.Second)
// since we are NOT using a goroutine, we do not have to copy the context
log.Println("Done! in path " + c.Request.URL.Path)
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
```
#### Custom HTTP configuration
Use `http.ListenAndServe()` directly, like this:
```go
func main() {
router := gin.Default()
http.ListenAndServe(":8080", router)
}
```
or
```go
func main() {
router := gin.Default()
s := &http.Server{
Addr: ":8080",
Handler: router,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
}
s.ListenAndServe()
}
```
@@ -0,0 +1,94 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"crypto/subtle"
"encoding/base64"
"errors"
"sort"
)
const (
AuthUserKey = "user"
)
type (
Accounts map[string]string
authPair struct {
Value string
User string
}
authPairs []authPair
)
func (a authPairs) Len() int { return len(a) }
func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }
// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where
// the key is the user name and the value is the password.
func BasicAuth(accounts Accounts) HandlerFunc {
pairs, err := processAccounts(accounts)
if err != nil {
panic(err)
}
return func(c *Context) {
// Search user in the slice of allowed credentials
user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization"))
if !ok {
// Credentials doesn't match, we return 401 Unauthorized and abort request.
c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
c.Fail(401, errors.New("Unauthorized"))
} else {
// user is allowed, set UserId to key "user" in this context, the userId can be read later using
// c.Get(gin.AuthUserKey)
c.Set(AuthUserKey, user)
}
}
}
func processAccounts(accounts Accounts) (authPairs, error) {
if len(accounts) == 0 {
return nil, errors.New("Empty list of authorized credentials")
}
pairs := make(authPairs, 0, len(accounts))
for user, password := range accounts {
if len(user) == 0 {
return nil, errors.New("User can not be empty")
}
base := user + ":" + password
value := "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
pairs = append(pairs, authPair{
Value: value,
User: user,
})
}
// We have to sort the credentials in order to use bsearch later.
sort.Sort(pairs)
return pairs, nil
}
func searchCredential(pairs authPairs, auth string) (string, bool) {
if len(auth) == 0 {
return "", false
}
// Search user in the slice of allowed credentials
r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth })
if r < len(pairs) && secureCompare(pairs[r].Value, auth) {
return pairs[r].User, true
} else {
return "", false
}
}
func secureCompare(given, actual string) bool {
if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
} else {
/* Securely compare actual to itself to keep constant time, but always return false */
return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
}
}
@@ -0,0 +1,61 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
)
func TestBasicAuthSucceed(t *testing.T) {
req, _ := http.NewRequest("GET", "/login", nil)
w := httptest.NewRecorder()
r := New()
accounts := Accounts{"admin": "password"}
r.Use(BasicAuth(accounts))
r.GET("/login", func(c *Context) {
c.String(200, "autorized")
})
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Errorf("Response code should be Ok, was: %s", w.Code)
}
bodyAsString := w.Body.String()
if bodyAsString != "autorized" {
t.Errorf("Response body should be `autorized`, was %s", bodyAsString)
}
}
func TestBasicAuth401(t *testing.T) {
req, _ := http.NewRequest("GET", "/login", nil)
w := httptest.NewRecorder()
r := New()
accounts := Accounts{"foo": "bar"}
r.Use(BasicAuth(accounts))
r.GET("/login", func(c *Context) {
c.String(200, "autorized")
})
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
r.ServeHTTP(w, req)
if w.Code != 401 {
t.Errorf("Response code should be Not autorized, was: %s", w.Code)
}
if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"Authorization Required\"" {
t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
}
}
@@ -0,0 +1,216 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"encoding/json"
"encoding/xml"
"errors"
"net/http"
"reflect"
"strconv"
"strings"
)
type (
Binding interface {
Bind(*http.Request, interface{}) error
}
// JSON binding
jsonBinding struct{}
// XML binding
xmlBinding struct{}
// // form binding
formBinding struct{}
)
var (
JSON = jsonBinding{}
XML = xmlBinding{}
Form = formBinding{} // todo
)
func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error {
decoder := json.NewDecoder(req.Body)
if err := decoder.Decode(obj); err == nil {
return Validate(obj)
} else {
return err
}
}
func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error {
decoder := xml.NewDecoder(req.Body)
if err := decoder.Decode(obj); err == nil {
return Validate(obj)
} else {
return err
}
}
func (_ formBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return Validate(obj)
}
func mapForm(ptr interface{}, form map[string][]string) error {
typ := reflect.TypeOf(ptr).Elem()
formStruct := reflect.ValueOf(ptr).Elem()
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" {
structField := formStruct.Field(i)
if !structField.CanSet() {
continue
}
inputValue, exists := form[inputFieldName]
if !exists {
continue
}
numElems := len(inputValue)
if structField.Kind() == reflect.Slice && numElems > 0 {
sliceOf := structField.Type().Elem().Kind()
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for i := 0; i < numElems; i++ {
if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
return err
}
}
formStruct.Field(i).Set(slice)
} else {
if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
return err
}
}
}
}
return nil
}
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
switch valueKind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if val == "" {
val = "0"
}
intVal, err := strconv.Atoi(val)
if err != nil {
return err
} else {
structField.SetInt(int64(intVal))
}
case reflect.Bool:
if val == "" {
val = "false"
}
boolVal, err := strconv.ParseBool(val)
if err != nil {
return err
} else {
structField.SetBool(boolVal)
}
case reflect.Float32:
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, 32)
if err != nil {
return err
} else {
structField.SetFloat(floatVal)
}
case reflect.Float64:
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, 64)
if err != nil {
return err
} else {
structField.SetFloat(floatVal)
}
case reflect.String:
structField.SetString(val)
}
return nil
}
// Don't pass in pointers to bind to. Can lead to bugs. See:
// https://github.com/codegangsta/martini-contrib/issues/40
// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
func ensureNotPointer(obj interface{}) {
if reflect.TypeOf(obj).Kind() == reflect.Ptr {
panic("Pointers are not accepted as binding models")
}
}
func Validate(obj interface{}, parents ...string) error {
typ := reflect.TypeOf(obj)
val := reflect.ValueOf(obj)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
val = val.Elem()
}
switch typ.Kind() {
case reflect.Struct:
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Allow ignored and unexported fields in the struct
if len(field.PkgPath) > 0 || field.Tag.Get("form") == "-" {
continue
}
fieldValue := val.Field(i).Interface()
zero := reflect.Zero(field.Type).Interface()
if strings.Index(field.Tag.Get("binding"), "required") > -1 {
fieldType := field.Type.Kind()
if fieldType == reflect.Struct {
if reflect.DeepEqual(zero, fieldValue) {
return errors.New("Required " + field.Name)
}
err := Validate(fieldValue, field.Name)
if err != nil {
return err
}
} else if reflect.DeepEqual(zero, fieldValue) {
if len(parents) > 0 {
return errors.New("Required " + field.Name + " on " + parents[0])
} else {
return errors.New("Required " + field.Name)
}
} else if fieldType == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct {
err := Validate(fieldValue)
if err != nil {
return err
}
}
}
}
case reflect.Slice:
for i := 0; i < val.Len(); i++ {
fieldValue := val.Index(i).Interface()
err := Validate(fieldValue)
if err != nil {
return err
}
}
default:
return nil
}
return nil
}
@@ -0,0 +1,434 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin/binding"
"github.com/gin-gonic/gin/render"
"github.com/julienschmidt/httprouter"
"log"
"net"
"net/http"
"strings"
)
const (
ErrorTypeInternal = 1 << iota
ErrorTypeExternal = 1 << iota
ErrorTypeAll = 0xffffffff
)
// Used internally to collect errors that occurred during an http request.
type errorMsg struct {
Err string `json:"error"`
Type uint32 `json:"-"`
Meta interface{} `json:"meta"`
}
type errorMsgs []errorMsg
func (a errorMsgs) ByType(typ uint32) errorMsgs {
if len(a) == 0 {
return a
}
result := make(errorMsgs, 0, len(a))
for _, msg := range a {
if msg.Type&typ > 0 {
result = append(result, msg)
}
}
return result
}
func (a errorMsgs) String() string {
if len(a) == 0 {
return ""
}
var buffer bytes.Buffer
for i, msg := range a {
text := fmt.Sprintf("Error #%02d: %s \n Meta: %v\n", (i + 1), msg.Err, msg.Meta)
buffer.WriteString(text)
}
return buffer.String()
}
// Context is the most important part of gin. It allows us to pass variables between middleware,
// manage the flow, validate the JSON of a request and render a JSON response for example.
type Context struct {
writermem responseWriter
Request *http.Request
Writer ResponseWriter
Keys map[string]interface{}
Errors errorMsgs
Params httprouter.Params
Engine *Engine
handlers []HandlerFunc
index int8
accepted []string
}
/************************************/
/********** CONTEXT CREATION ********/
/************************************/
func (engine *Engine) createContext(w http.ResponseWriter, req *http.Request, params httprouter.Params, handlers []HandlerFunc) *Context {
c := engine.pool.Get().(*Context)
c.writermem.reset(w)
c.Request = req
c.Params = params
c.handlers = handlers
c.Keys = nil
c.index = -1
c.accepted = nil
c.Errors = c.Errors[0:0]
return c
}
func (engine *Engine) reuseContext(c *Context) {
engine.pool.Put(c)
}
func (c *Context) Copy() *Context {
var cp Context = *c
cp.index = AbortIndex
cp.handlers = nil
return &cp
}
/************************************/
/*************** FLOW ***************/
/************************************/
// Next should be used only in the middlewares.
// It executes the pending handlers in the chain inside the calling handler.
// See example in github.
func (c *Context) Next() {
c.index++
s := int8(len(c.handlers))
for ; c.index < s; c.index++ {
c.handlers[c.index](c)
}
}
// Forces the system to do not continue calling the pending handlers in the chain.
func (c *Context) Abort() {
c.index = AbortIndex
}
// Same than AbortWithStatus() but also writes the specified response status code.
// For example, the first handler checks if the request is authorized. If it's not, context.AbortWithStatus(401) should be called.
func (c *Context) AbortWithStatus(code int) {
c.Writer.WriteHeader(code)
c.Abort()
}
/************************************/
/********* ERROR MANAGEMENT *********/
/************************************/
// Fail is the same as Abort plus an error message.
// Calling `context.Fail(500, err)` is equivalent to:
// ```
// context.Error("Operation aborted", err)
// context.AbortWithStatus(500)
// ```
func (c *Context) Fail(code int, err error) {
c.Error(err, "Operation aborted")
c.AbortWithStatus(code)
}
func (c *Context) ErrorTyped(err error, typ uint32, meta interface{}) {
c.Errors = append(c.Errors, errorMsg{
Err: err.Error(),
Type: typ,
Meta: meta,
})
}
// Attaches an error to the current context. The error is pushed to a list of errors.
// It's a good idea to call Error for each error that occurred during the resolution of a request.
// A middleware can be used to collect all the errors and push them to a database together, print a log, or append it in the HTTP response.
func (c *Context) Error(err error, meta interface{}) {
c.ErrorTyped(err, ErrorTypeExternal, meta)
}
func (c *Context) LastError() error {
nuErrors := len(c.Errors)
if nuErrors > 0 {
return errors.New(c.Errors[nuErrors-1].Err)
} else {
return nil
}
}
/************************************/
/******** METADATA MANAGEMENT********/
/************************************/
// Sets a new pair key/value just for the specified context.
// It also lazy initializes the hashmap.
func (c *Context) Set(key string, item interface{}) {
if c.Keys == nil {
c.Keys = make(map[string]interface{})
}
c.Keys[key] = item
}
// Get returns the value for the given key or an error if the key does not exist.
func (c *Context) Get(key string) (interface{}, error) {
if c.Keys != nil {
value, ok := c.Keys[key]
if ok {
return value, nil
}
}
return nil, errors.New("Key does not exist.")
}
// MustGet returns the value for the given key or panics if the value doesn't exist.
func (c *Context) MustGet(key string) interface{} {
value, err := c.Get(key)
if err != nil || value == nil {
log.Panicf("Key %s doesn't exist", value)
}
return value
}
func ipInMasks(ip net.IP, masks []interface{}) bool {
for _, proxy := range masks {
var mask *net.IPNet
var err error
switch t := proxy.(type) {
case string:
if _, mask, err = net.ParseCIDR(t); err != nil {
panic(err)
}
case net.IP:
mask = &net.IPNet{IP: t, Mask: net.CIDRMask(len(t)*8, len(t)*8)}
case net.IPNet:
mask = &t
}
if mask.Contains(ip) {
return true
}
}
return false
}
// the ForwardedFor middleware unwraps the X-Forwarded-For headers, be careful to only use this
// middleware if you've got servers in front of this server. The list with (known) proxies and
// local ips are being filtered out of the forwarded for list, giving the last not local ip being
// the real client ip.
func ForwardedFor(proxies ...interface{}) HandlerFunc {
if len(proxies) == 0 {
// default to local ips
var reservedLocalIps = []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}
proxies = make([]interface{}, len(reservedLocalIps))
for i, v := range reservedLocalIps {
proxies[i] = v
}
}
return func(c *Context) {
// the X-Forwarded-For header contains an array with left most the client ip, then
// comma separated, all proxies the request passed. The last proxy appears
// as the remote address of the request. Returning the client
// ip to comply with default RemoteAddr response.
// check if remoteaddr is local ip or in list of defined proxies
remoteIp := net.ParseIP(strings.Split(c.Request.RemoteAddr, ":")[0])
if !ipInMasks(remoteIp, proxies) {
return
}
if forwardedFor := c.Request.Header.Get("X-Forwarded-For"); forwardedFor != "" {
parts := strings.Split(forwardedFor, ",")
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
ip := net.ParseIP(strings.TrimSpace(part))
if ipInMasks(ip, proxies) {
continue
}
// returning remote addr conform the original remote addr format
c.Request.RemoteAddr = ip.String() + ":0"
// remove forwarded for address
c.Request.Header.Set("X-Forwarded-For", "")
return
}
}
}
}
func (c *Context) ClientIP() string {
return c.Request.RemoteAddr
}
/************************************/
/********* PARSING REQUEST **********/
/************************************/
// This function checks the Content-Type to select a binding engine automatically,
// Depending the "Content-Type" header different bindings are used:
// "application/json" --> JSON binding
// "application/xml" --> XML binding
// else --> returns an error
// if Parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid.
func (c *Context) Bind(obj interface{}) bool {
var b binding.Binding
ctype := filterFlags(c.Request.Header.Get("Content-Type"))
switch {
case c.Request.Method == "GET" || ctype == MIMEPOSTForm:
b = binding.Form
case ctype == MIMEJSON:
b = binding.JSON
case ctype == MIMEXML || ctype == MIMEXML2:
b = binding.XML
default:
c.Fail(400, errors.New("unknown content-type: "+ctype))
return false
}
return c.BindWith(obj, b)
}
func (c *Context) BindWith(obj interface{}, b binding.Binding) bool {
if err := b.Bind(c.Request, obj); err != nil {
c.Fail(400, err)
return false
}
return true
}
/************************************/
/******** RESPONSE RENDERING ********/
/************************************/
func (c *Context) Render(code int, render render.Render, obj ...interface{}) {
if err := render.Render(c.Writer, code, obj...); err != nil {
c.ErrorTyped(err, ErrorTypeInternal, obj)
c.AbortWithStatus(500)
}
}
// Serializes the given struct as JSON into the response body in a fast and efficient way.
// It also sets the Content-Type as "application/json".
func (c *Context) JSON(code int, obj interface{}) {
c.Render(code, render.JSON, obj)
}
// Serializes the given struct as XML into the response body in a fast and efficient way.
// It also sets the Content-Type as "application/xml".
func (c *Context) XML(code int, obj interface{}) {
c.Render(code, render.XML, obj)
}
// Renders the HTTP template specified by its file name.
// It also updates the HTTP code and sets the Content-Type as "text/html".
// See http://golang.org/doc/articles/wiki/
func (c *Context) HTML(code int, name string, obj interface{}) {
c.Render(code, c.Engine.HTMLRender, name, obj)
}
// Writes the given string into the response body and sets the Content-Type to "text/plain".
func (c *Context) String(code int, format string, values ...interface{}) {
c.Render(code, render.Plain, format, values)
}
// Returns a HTTP redirect to the specific location.
func (c *Context) Redirect(code int, location string) {
if code >= 300 && code <= 308 {
c.Render(code, render.Redirect, location)
} else {
panic(fmt.Sprintf("Cannot send a redirect with status code %d", code))
}
}
// Writes some data into the body stream and updates the HTTP code.
func (c *Context) Data(code int, contentType string, data []byte) {
if len(contentType) > 0 {
c.Writer.Header().Set("Content-Type", contentType)
}
c.Writer.WriteHeader(code)
c.Writer.Write(data)
}
// Writes the specified file into the body stream
func (c *Context) File(filepath string) {
http.ServeFile(c.Writer, c.Request, filepath)
}
/************************************/
/******** CONTENT NEGOTIATION *******/
/************************************/
type Negotiate struct {
Offered []string
HTMLPath string
HTMLData interface{}
JSONData interface{}
XMLData interface{}
Data interface{}
}
func (c *Context) Negotiate(code int, config Negotiate) {
switch c.NegotiateFormat(config.Offered...) {
case MIMEJSON:
data := chooseData(config.JSONData, config.Data)
c.JSON(code, data)
case MIMEHTML:
data := chooseData(config.HTMLData, config.Data)
if len(config.HTMLPath) == 0 {
panic("negotiate config is wrong. html path is needed")
}
c.HTML(code, config.HTMLPath, data)
case MIMEXML:
data := chooseData(config.XMLData, config.Data)
c.XML(code, data)
default:
c.Fail(http.StatusNotAcceptable, errors.New("the accepted formats are not offered by the server"))
}
}
func (c *Context) NegotiateFormat(offered ...string) string {
if len(offered) == 0 {
panic("you must provide at least one offer")
}
if c.accepted == nil {
c.accepted = parseAccept(c.Request.Header.Get("Accept"))
}
if len(c.accepted) == 0 {
return offered[0]
} else {
for _, accepted := range c.accepted {
for _, offert := range offered {
if accepted == offert {
return offert
}
}
}
return ""
}
}
func (c *Context) SetAccepted(formats ...string) {
c.accepted = formats
}
@@ -0,0 +1,483 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"bytes"
"errors"
"html/template"
"net/http"
"net/http/httptest"
"testing"
)
// TestContextParamsGet tests that a parameter can be parsed from the URL.
func TestContextParamsByName(t *testing.T) {
req, _ := http.NewRequest("GET", "/test/alexandernyquist", nil)
w := httptest.NewRecorder()
name := ""
r := New()
r.GET("/test/:name", func(c *Context) {
name = c.Params.ByName("name")
})
r.ServeHTTP(w, req)
if name != "alexandernyquist" {
t.Errorf("Url parameter was not correctly parsed. Should be alexandernyquist, was %s.", name)
}
}
// TestContextSetGet tests that a parameter is set correctly on the
// current context and can be retrieved using Get.
func TestContextSetGet(t *testing.T) {
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
r := New()
r.GET("/test", func(c *Context) {
// Key should be lazily created
if c.Keys != nil {
t.Error("Keys should be nil")
}
// Set
c.Set("foo", "bar")
v, err := c.Get("foo")
if err != nil {
t.Errorf("Error on exist key")
}
if v != "bar" {
t.Errorf("Value should be bar, was %s", v)
}
})
r.ServeHTTP(w, req)
}
// TestContextJSON tests that the response is serialized as JSON
// and Content-Type is set to application/json
func TestContextJSON(t *testing.T) {
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
r := New()
r.GET("/test", func(c *Context) {
c.JSON(200, H{"foo": "bar"})
})
r.ServeHTTP(w, req)
if w.Body.String() != "{\"foo\":\"bar\"}\n" {
t.Errorf("Response should be {\"foo\":\"bar\"}, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "application/json; charset=utf-8" {
t.Errorf("Content-Type should be application/json, was %s", w.HeaderMap.Get("Content-Type"))
}
}
// TestContextHTML tests that the response executes the templates
// and responds with Content-Type set to text/html
func TestContextHTML(t *testing.T) {
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
r := New()
templ, _ := template.New("t").Parse(`Hello {{.Name}}`)
r.SetHTMLTemplate(templ)
type TestData struct{ Name string }
r.GET("/test", func(c *Context) {
c.HTML(200, "t", TestData{"alexandernyquist"})
})
r.ServeHTTP(w, req)
if w.Body.String() != "Hello alexandernyquist" {
t.Errorf("Response should be Hello alexandernyquist, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "text/html; charset=utf-8" {
t.Errorf("Content-Type should be text/html, was %s", w.HeaderMap.Get("Content-Type"))
}
}
// TestContextString tests that the response is returned
// with Content-Type set to text/plain
func TestContextString(t *testing.T) {
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
r := New()
r.GET("/test", func(c *Context) {
c.String(200, "test")
})
r.ServeHTTP(w, req)
if w.Body.String() != "test" {
t.Errorf("Response should be test, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Errorf("Content-Type should be text/plain, was %s", w.HeaderMap.Get("Content-Type"))
}
}
// TestContextXML tests that the response is serialized as XML
// and Content-Type is set to application/xml
func TestContextXML(t *testing.T) {
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
r := New()
r.GET("/test", func(c *Context) {
c.XML(200, H{"foo": "bar"})
})
r.ServeHTTP(w, req)
if w.Body.String() != "<map><foo>bar</foo></map>" {
t.Errorf("Response should be <map><foo>bar</foo></map>, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "application/xml; charset=utf-8" {
t.Errorf("Content-Type should be application/xml, was %s", w.HeaderMap.Get("Content-Type"))
}
}
// TestContextData tests that the response can be written from `bytesting`
// with specified MIME type
func TestContextData(t *testing.T) {
req, _ := http.NewRequest("GET", "/test/csv", nil)
w := httptest.NewRecorder()
r := New()
r.GET("/test/csv", func(c *Context) {
c.Data(200, "text/csv", []byte(`foo,bar`))
})
r.ServeHTTP(w, req)
if w.Body.String() != "foo,bar" {
t.Errorf("Response should be foo&bar, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "text/csv" {
t.Errorf("Content-Type should be text/csv, was %s", w.HeaderMap.Get("Content-Type"))
}
}
func TestContextFile(t *testing.T) {
req, _ := http.NewRequest("GET", "/test/file", nil)
w := httptest.NewRecorder()
r := New()
r.GET("/test/file", func(c *Context) {
c.File("./gin.go")
})
r.ServeHTTP(w, req)
bodyAsString := w.Body.String()
if len(bodyAsString) == 0 {
t.Errorf("Got empty body instead of file data")
}
if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Errorf("Content-Type should be text/plain; charset=utf-8, was %s", w.HeaderMap.Get("Content-Type"))
}
}
// TestHandlerFunc - ensure that custom middleware works properly
func TestHandlerFunc(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
r := New()
var stepsPassed int = 0
r.Use(func(context *Context) {
stepsPassed += 1
context.Next()
stepsPassed += 1
})
r.ServeHTTP(w, req)
if w.Code != 404 {
t.Errorf("Response code should be Not found, was: %s", w.Code)
}
if stepsPassed != 2 {
t.Errorf("Falied to switch context in handler function: %s", stepsPassed)
}
}
// TestBadAbortHandlersChain - ensure that Abort after switch context will not interrupt pending handlers
func TestBadAbortHandlersChain(t *testing.T) {
// SETUP
var stepsPassed int = 0
r := New()
r.Use(func(c *Context) {
stepsPassed += 1
c.Next()
stepsPassed += 1
// after check and abort
c.AbortWithStatus(409)
})
r.Use(func(c *Context) {
stepsPassed += 1
c.Next()
stepsPassed += 1
c.AbortWithStatus(403)
})
// RUN
w := PerformRequest(r, "GET", "/")
// TEST
if w.Code != 409 {
t.Errorf("Response code should be Forbiden, was: %d", w.Code)
}
if stepsPassed != 4 {
t.Errorf("Falied to switch context in handler function: %d", stepsPassed)
}
}
// TestAbortHandlersChain - ensure that Abort interrupt used middlewares in fifo order
func TestAbortHandlersChain(t *testing.T) {
// SETUP
var stepsPassed int = 0
r := New()
r.Use(func(context *Context) {
stepsPassed += 1
context.AbortWithStatus(409)
})
r.Use(func(context *Context) {
stepsPassed += 1
context.Next()
stepsPassed += 1
})
// RUN
w := PerformRequest(r, "GET", "/")
// TEST
if w.Code != 409 {
t.Errorf("Response code should be Conflict, was: %d", w.Code)
}
if stepsPassed != 1 {
t.Errorf("Falied to switch context in handler function: %d", stepsPassed)
}
}
// TestFailHandlersChain - ensure that Fail interrupt used middlewares in fifo order as
// as well as Abort
func TestFailHandlersChain(t *testing.T) {
// SETUP
var stepsPassed int = 0
r := New()
r.Use(func(context *Context) {
stepsPassed += 1
context.Fail(500, errors.New("foo"))
})
r.Use(func(context *Context) {
stepsPassed += 1
context.Next()
stepsPassed += 1
})
// RUN
w := PerformRequest(r, "GET", "/")
// TEST
if w.Code != 500 {
t.Errorf("Response code should be Server error, was: %d", w.Code)
}
if stepsPassed != 1 {
t.Errorf("Falied to switch context in handler function: %d", stepsPassed)
}
}
func TestBindingJSON(t *testing.T) {
body := bytes.NewBuffer([]byte("{\"foo\":\"bar\"}"))
r := New()
r.POST("/binding/json", func(c *Context) {
var body struct {
Foo string `json:"foo"`
}
if c.Bind(&body) {
c.JSON(200, H{"parsed": body.Foo})
}
})
req, _ := http.NewRequest("POST", "/binding/json", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Errorf("Response code should be Ok, was: %s", w.Code)
}
if w.Body.String() != "{\"parsed\":\"bar\"}\n" {
t.Errorf("Response should be {\"parsed\":\"bar\"}, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "application/json; charset=utf-8" {
t.Errorf("Content-Type should be application/json, was %s", w.HeaderMap.Get("Content-Type"))
}
}
func TestBindingJSONEncoding(t *testing.T) {
body := bytes.NewBuffer([]byte("{\"foo\":\"嘉\"}"))
r := New()
r.POST("/binding/json", func(c *Context) {
var body struct {
Foo string `json:"foo"`
}
if c.Bind(&body) {
c.JSON(200, H{"parsed": body.Foo})
}
})
req, _ := http.NewRequest("POST", "/binding/json", body)
req.Header.Set("Content-Type", "application/json; charset=utf-8")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Errorf("Response code should be Ok, was: %s", w.Code)
}
if w.Body.String() != "{\"parsed\":\"嘉\"}\n" {
t.Errorf("Response should be {\"parsed\":\"嘉\"}, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "application/json; charset=utf-8" {
t.Errorf("Content-Type should be application/json, was %s", w.HeaderMap.Get("Content-Type"))
}
}
func TestBindingJSONNoContentType(t *testing.T) {
body := bytes.NewBuffer([]byte("{\"foo\":\"bar\"}"))
r := New()
r.POST("/binding/json", func(c *Context) {
var body struct {
Foo string `json:"foo"`
}
if c.Bind(&body) {
c.JSON(200, H{"parsed": body.Foo})
}
})
req, _ := http.NewRequest("POST", "/binding/json", body)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 400 {
t.Errorf("Response code should be Bad request, was: %s", w.Code)
}
if w.Body.String() == "{\"parsed\":\"bar\"}\n" {
t.Errorf("Response should not be {\"parsed\":\"bar\"}, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") == "application/json" {
t.Errorf("Content-Type should not be application/json, was %s", w.HeaderMap.Get("Content-Type"))
}
}
func TestBindingJSONMalformed(t *testing.T) {
body := bytes.NewBuffer([]byte("\"foo\":\"bar\"\n"))
r := New()
r.POST("/binding/json", func(c *Context) {
var body struct {
Foo string `json:"foo"`
}
if c.Bind(&body) {
c.JSON(200, H{"parsed": body.Foo})
}
})
req, _ := http.NewRequest("POST", "/binding/json", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 400 {
t.Errorf("Response code should be Bad request, was: %s", w.Code)
}
if w.Body.String() == "{\"parsed\":\"bar\"}\n" {
t.Errorf("Response should not be {\"parsed\":\"bar\"}, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") == "application/json" {
t.Errorf("Content-Type should not be application/json, was %s", w.HeaderMap.Get("Content-Type"))
}
}
func TestClientIP(t *testing.T) {
r := New()
var clientIP string = ""
r.GET("/", func(c *Context) {
clientIP = c.ClientIP()
})
body := bytes.NewBuffer([]byte(""))
req, _ := http.NewRequest("GET", "/", body)
req.RemoteAddr = "clientip:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if clientIP != "clientip:1234" {
t.Errorf("ClientIP should not be %s, but clientip:1234", clientIP)
}
}
func TestClientIPWithXForwardedForWithProxy(t *testing.T) {
r := New()
r.Use(ForwardedFor())
var clientIP string = ""
r.GET("/", func(c *Context) {
clientIP = c.ClientIP()
})
body := bytes.NewBuffer([]byte(""))
req, _ := http.NewRequest("GET", "/", body)
req.RemoteAddr = "172.16.8.3:1234"
req.Header.Set("X-Real-Ip", "realip")
req.Header.Set("X-Forwarded-For", "1.2.3.4, 10.10.0.4, 192.168.0.43, 172.16.8.4")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if clientIP != "1.2.3.4:0" {
t.Errorf("ClientIP should not be %s, but 1.2.3.4:0", clientIP)
}
}
@@ -0,0 +1,47 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"github.com/gin-gonic/gin/binding"
"net/http"
)
// DEPRECATED, use Bind() instead.
// Like ParseBody() but this method also writes a 400 error if the json is not valid.
func (c *Context) EnsureBody(item interface{}) bool {
return c.Bind(item)
}
// DEPRECATED use bindings directly
// Parses the body content as a JSON input. It decodes the json payload into the struct specified as a pointer.
func (c *Context) ParseBody(item interface{}) error {
return binding.JSON.Bind(c.Request, item)
}
// DEPRECATED use gin.Static() instead
// ServeFiles serves files from the given file system root.
// The path must end with "/*filepath", files are then served from the local
// path /defined/root/dir/*filepath.
// For example if root is "/etc" and *filepath is "passwd", the local file
// "/etc/passwd" would be served.
// Internally a http.FileServer is used, therefore http.NotFound is used instead
// of the Router's NotFound handler.
// To use the operating system's file system implementation,
// use http.Dir:
// router.ServeFiles("/src/*filepath", http.Dir("/var/www"))
func (engine *Engine) ServeFiles(path string, root http.FileSystem) {
engine.router.ServeFiles(path, root)
}
// DEPRECATED use gin.LoadHTMLGlob() or gin.LoadHTMLFiles() instead
func (engine *Engine) LoadHTMLTemplates(pattern string) {
engine.LoadHTMLGlob(pattern)
}
// DEPRECATED. Use NoRoute() instead
func (engine *Engine) NotFound404(handlers ...HandlerFunc) {
engine.NoRoute(handlers...)
}
@@ -0,0 +1,7 @@
# Guide to run Gin under App Engine LOCAL Development Server
1. Download, install and setup Go in your computer. (That includes setting your `$GOPATH`.)
2. Download SDK for your platform from here: `https://developers.google.com/appengine/downloads?hl=es#Google_App_Engine_SDK_for_Go`
3. Download Gin source code using: `$ go get github.com/gin-gonic/gin`
4. Navigate to examples folder: `$ cd $GOPATH/src/github.com/gin-gonic/gin/examples/`
5. Run it: `$ goapp serve app-engine/`
@@ -0,0 +1,8 @@
application: hello
version: 1
runtime: go
api_version: go1
handlers:
- url: /.*
script: _go_app
@@ -0,0 +1,23 @@
package hello
import (
"net/http"
"github.com/gin-gonic/gin"
)
// This function's name is a must. App Engine uses it to drive the requests properly.
func init() {
// Starts a new Gin instance with no middle-ware
r := gin.New()
// Define your handlers
r.GET("/", func(c *gin.Context){
c.String(200, "Hello World!")
})
r.GET("/ping", func(c *gin.Context){
c.String(200, "pong")
})
// Handle all requests using net/http
http.Handle("/", r)
}
@@ -0,0 +1,56 @@
package main
import (
"github.com/gin-gonic/gin"
)
var DB = make(map[string]string)
func main() {
r := gin.Default()
// Ping test
r.GET("/ping", func(c *gin.Context) {
c.String(200, "pong")
})
// Get user value
r.GET("/user/:name", func(c *gin.Context) {
user := c.Params.ByName("name")
value, ok := DB[user]
if ok {
c.JSON(200, gin.H{"user": user, "value": value})
} else {
c.JSON(200, gin.H{"user": user, "status": "no value"})
}
})
// Authorized group (uses gin.BasicAuth() middleware)
// Same than:
// authorized := r.Group("/")
// authorized.Use(gin.BasicAuth(gin.Credentials{
// "foo": "bar",
// "manu": "123",
//}))
authorized := r.Group("/", gin.BasicAuth(gin.Accounts{
"foo": "bar", // user:foo password:bar
"manu": "123", // user:manu password:123
}))
authorized.POST("admin", func(c *gin.Context) {
user := c.MustGet(gin.AuthUserKey).(string)
// Parse JSON
var json struct {
Value string `json:"value" binding:"required"`
}
if c.Bind(&json) {
DB[user] = json.Value
c.JSON(200, gin.H{"status": "ok"})
}
})
// Listen and Server in 0.0.0.0:8080
r.Run(":8080")
}
@@ -0,0 +1,58 @@
package main
import (
"github.com/flosch/pongo2"
"github.com/gin-gonic/gin"
"net/http"
)
type pongoRender struct {
cache map[string]*pongo2.Template
}
func newPongoRender() *pongoRender {
return &pongoRender{map[string]*pongo2.Template{}}
}
func writeHeader(w http.ResponseWriter, code int, contentType string) {
if code >= 0 {
w.Header().Set("Content-Type", contentType)
w.WriteHeader(code)
}
}
func (p *pongoRender) Render(w http.ResponseWriter, code int, data ...interface{}) error {
file := data[0].(string)
ctx := data[1].(pongo2.Context)
var t *pongo2.Template
if tmpl, ok := p.cache[file]; ok {
t = tmpl
} else {
tmpl, err := pongo2.FromFile(file)
if err != nil {
return err
}
p.cache[file] = tmpl
t = tmpl
}
writeHeader(w, code, "text/html")
return t.ExecuteWriter(ctx, w)
}
func main() {
r := gin.Default()
r.HTMLRender = newPongoRender()
r.GET("/index", func(c *gin.Context) {
name := c.Request.FormValue("name")
ctx := pongo2.Context{
"title": "Gin meets pongo2 !",
"name": name,
}
c.HTML(200, "index.html", ctx)
})
// Listen and server on 0.0.0.0:8080
r.Run(":8080")
}
@@ -0,0 +1,12 @@
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="utf-8">
<title>{{ title }}</title>
<meta name="keywords" content="">
<meta name="description" content="">
</head>
<body>
Hello {{ name }} !
</body>
</html>
@@ -0,0 +1,143 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"github.com/gin-gonic/gin/render"
"github.com/julienschmidt/httprouter"
"html/template"
"math"
"net/http"
"sync"
)
const (
AbortIndex = math.MaxInt8 / 2
MIMEJSON = "application/json"
MIMEHTML = "text/html"
MIMEXML = "application/xml"
MIMEXML2 = "text/xml"
MIMEPlain = "text/plain"
MIMEPOSTForm = "application/x-www-form-urlencoded"
)
type (
HandlerFunc func(*Context)
// Represents the web framework, it wraps the blazing fast httprouter multiplexer and a list of global middlewares.
Engine struct {
*RouterGroup
HTMLRender render.Render
Default404Body []byte
pool sync.Pool
allNoRoute []HandlerFunc
noRoute []HandlerFunc
router *httprouter.Router
}
)
// Returns a new blank Engine instance without any middleware attached.
// The most basic configuration
func New() *Engine {
engine := &Engine{}
engine.RouterGroup = &RouterGroup{
Handlers: nil,
absolutePath: "/",
engine: engine,
}
engine.router = httprouter.New()
engine.Default404Body = []byte("404 page not found")
engine.router.NotFound = engine.handle404
engine.pool.New = func() interface{} {
c := &Context{Engine: engine}
c.Writer = &c.writermem
return c
}
return engine
}
// Returns a Engine instance with the Logger and Recovery already attached.
func Default() *Engine {
engine := New()
engine.Use(Recovery(), Logger())
return engine
}
func (engine *Engine) LoadHTMLGlob(pattern string) {
if IsDebugging() {
render.HTMLDebug.AddGlob(pattern)
engine.HTMLRender = render.HTMLDebug
} else {
templ := template.Must(template.ParseGlob(pattern))
engine.SetHTMLTemplate(templ)
}
}
func (engine *Engine) LoadHTMLFiles(files ...string) {
if IsDebugging() {
render.HTMLDebug.AddFiles(files...)
engine.HTMLRender = render.HTMLDebug
} else {
templ := template.Must(template.ParseFiles(files...))
engine.SetHTMLTemplate(templ)
}
}
func (engine *Engine) SetHTMLTemplate(templ *template.Template) {
engine.HTMLRender = render.HTMLRender{
Template: templ,
}
}
// Adds handlers for NoRoute. It return a 404 code by default.
func (engine *Engine) NoRoute(handlers ...HandlerFunc) {
engine.noRoute = handlers
engine.rebuild404Handlers()
}
func (engine *Engine) Use(middlewares ...HandlerFunc) {
engine.RouterGroup.Use(middlewares...)
engine.rebuild404Handlers()
}
func (engine *Engine) rebuild404Handlers() {
engine.allNoRoute = engine.combineHandlers(engine.noRoute)
}
func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
c := engine.createContext(w, req, nil, engine.allNoRoute)
// set 404 by default, useful for logging
c.Writer.WriteHeader(404)
c.Next()
if !c.Writer.Written() {
if c.Writer.Status() == 404 {
c.Data(-1, MIMEPlain, engine.Default404Body)
} else {
c.Writer.WriteHeaderNow()
}
}
engine.reuseContext(c)
}
// ServeHTTP makes the router implement the http.Handler interface.
func (engine *Engine) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
engine.router.ServeHTTP(writer, request)
}
func (engine *Engine) Run(addr string) error {
debugPrint("Listening and serving HTTP on %s\n", addr)
if err := http.ListenAndServe(addr, engine); err != nil {
return err
}
return nil
}
func (engine *Engine) RunTLS(addr string, cert string, key string) error {
debugPrint("Listening and serving HTTPS on %s\n", addr)
if err := http.ListenAndServeTLS(addr, cert, key, engine); err != nil {
return err
}
return nil
}
@@ -0,0 +1,206 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path"
"strings"
"testing"
)
func init() {
SetMode(TestMode)
}
func PerformRequest(r http.Handler, method, path string) *httptest.ResponseRecorder {
req, _ := http.NewRequest(method, path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
return w
}
// TestSingleRouteOK tests that POST route is correctly invoked.
func testRouteOK(method string, t *testing.T) {
// SETUP
passed := false
r := New()
r.Handle(method, "/test", []HandlerFunc{func(c *Context) {
passed = true
}})
// RUN
w := PerformRequest(r, method, "/test")
// TEST
if passed == false {
t.Errorf(method + " route handler was not invoked.")
}
if w.Code != http.StatusOK {
t.Errorf("Status code should be %v, was %d", http.StatusOK, w.Code)
}
}
func TestRouterGroupRouteOK(t *testing.T) {
testRouteOK("POST", t)
testRouteOK("DELETE", t)
testRouteOK("PATCH", t)
testRouteOK("PUT", t)
testRouteOK("OPTIONS", t)
testRouteOK("HEAD", t)
}
// TestSingleRouteOK tests that POST route is correctly invoked.
func testRouteNotOK(method string, t *testing.T) {
// SETUP
passed := false
r := New()
r.Handle(method, "/test_2", []HandlerFunc{func(c *Context) {
passed = true
}})
// RUN
w := PerformRequest(r, method, "/test")
// TEST
if passed == true {
t.Errorf(method + " route handler was invoked, when it should not")
}
if w.Code != http.StatusNotFound {
// If this fails, it's because httprouter needs to be updated to at least f78f58a0db
t.Errorf("Status code should be %v, was %d. Location: %s", http.StatusNotFound, w.Code, w.HeaderMap.Get("Location"))
}
}
// TestSingleRouteOK tests that POST route is correctly invoked.
func TestRouteNotOK(t *testing.T) {
testRouteNotOK("POST", t)
testRouteNotOK("DELETE", t)
testRouteNotOK("PATCH", t)
testRouteNotOK("PUT", t)
testRouteNotOK("OPTIONS", t)
testRouteNotOK("HEAD", t)
}
// TestSingleRouteOK tests that POST route is correctly invoked.
func testRouteNotOK2(method string, t *testing.T) {
// SETUP
passed := false
r := New()
var methodRoute string
if method == "POST" {
methodRoute = "GET"
} else {
methodRoute = "POST"
}
r.Handle(methodRoute, "/test", []HandlerFunc{func(c *Context) {
passed = true
}})
// RUN
w := PerformRequest(r, method, "/test")
// TEST
if passed == true {
t.Errorf(method + " route handler was invoked, when it should not")
}
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("Status code should be %v, was %d. Location: %s", http.StatusMethodNotAllowed, w.Code, w.HeaderMap.Get("Location"))
}
}
// TestSingleRouteOK tests that POST route is correctly invoked.
func TestRouteNotOK2(t *testing.T) {
testRouteNotOK2("POST", t)
testRouteNotOK2("DELETE", t)
testRouteNotOK2("PATCH", t)
testRouteNotOK2("PUT", t)
testRouteNotOK2("OPTIONS", t)
testRouteNotOK2("HEAD", t)
}
// TestHandleStaticFile - ensure the static file handles properly
func TestHandleStaticFile(t *testing.T) {
// SETUP file
testRoot, _ := os.Getwd()
f, err := ioutil.TempFile(testRoot, "")
if err != nil {
t.Error(err)
}
defer os.Remove(f.Name())
filePath := path.Join("/", path.Base(f.Name()))
f.WriteString("Gin Web Framework")
f.Close()
// SETUP gin
r := New()
r.Static("./", testRoot)
// RUN
w := PerformRequest(r, "GET", filePath)
// TEST
if w.Code != 200 {
t.Errorf("Response code should be 200, was: %d", w.Code)
}
if w.Body.String() != "Gin Web Framework" {
t.Errorf("Response should be test, was: %s", w.Body.String())
}
if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Errorf("Content-Type should be text/plain, was %s", w.HeaderMap.Get("Content-Type"))
}
}
// TestHandleStaticDir - ensure the root/sub dir handles properly
func TestHandleStaticDir(t *testing.T) {
// SETUP
r := New()
r.Static("/", "./")
// RUN
w := PerformRequest(r, "GET", "/")
// TEST
bodyAsString := w.Body.String()
if w.Code != 200 {
t.Errorf("Response code should be 200, was: %d", w.Code)
}
if len(bodyAsString) == 0 {
t.Errorf("Got empty body instead of file tree")
}
if !strings.Contains(bodyAsString, "gin.go") {
t.Errorf("Can't find:`gin.go` in file tree: %s", bodyAsString)
}
if w.HeaderMap.Get("Content-Type") != "text/html; charset=utf-8" {
t.Errorf("Content-Type should be text/plain, was %s", w.HeaderMap.Get("Content-Type"))
}
}
// TestHandleHeadToDir - ensure the root/sub dir handles properly
func TestHandleHeadToDir(t *testing.T) {
// SETUP
r := New()
r.Static("/", "./")
// RUN
w := PerformRequest(r, "HEAD", "/")
// TEST
bodyAsString := w.Body.String()
if w.Code != 200 {
t.Errorf("Response code should be Ok, was: %s", w.Code)
}
if len(bodyAsString) == 0 {
t.Errorf("Got empty body instead of file tree")
}
if !strings.Contains(bodyAsString, "gin.go") {
t.Errorf("Can't find:`gin.go` in file tree: %s", bodyAsString)
}
if w.HeaderMap.Get("Content-Type") != "text/html; charset=utf-8" {
t.Errorf("Content-Type should be text/plain, was %s", w.HeaderMap.Get("Content-Type"))
}
}
@@ -0,0 +1,105 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"log"
"os"
"time"
)
var (
green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
reset = string([]byte{27, 91, 48, 109})
)
func ErrorLogger() HandlerFunc {
return ErrorLoggerT(ErrorTypeAll)
}
func ErrorLoggerT(typ uint32) HandlerFunc {
return func(c *Context) {
c.Next()
errs := c.Errors.ByType(typ)
if len(errs) > 0 {
// -1 status code = do not change current one
c.JSON(-1, c.Errors)
}
}
}
func Logger() HandlerFunc {
stdlogger := log.New(os.Stdout, "", 0)
//errlogger := log.New(os.Stderr, "", 0)
return func(c *Context) {
// Start timer
start := time.Now()
// Process request
c.Next()
// Stop timer
end := time.Now()
latency := end.Sub(start)
clientIP := c.ClientIP()
method := c.Request.Method
statusCode := c.Writer.Status()
statusColor := colorForStatus(statusCode)
methodColor := colorForMethod(method)
stdlogger.Printf("[GIN] %v |%s %3d %s| %12v | %s |%s %s %-7s %s\n%s",
end.Format("2006/01/02 - 15:04:05"),
statusColor, statusCode, reset,
latency,
clientIP,
methodColor, reset, method,
c.Request.URL.Path,
c.Errors.String(),
)
}
}
func colorForStatus(code int) string {
switch {
case code >= 200 && code <= 299:
return green
case code >= 300 && code <= 399:
return white
case code >= 400 && code <= 499:
return yellow
default:
return red
}
}
func colorForMethod(method string) string {
switch {
case method == "GET":
return blue
case method == "POST":
return cyan
case method == "PUT":
return yellow
case method == "DELETE":
return red
case method == "PATCH":
return green
case method == "HEAD":
return magenta
case method == "OPTIONS":
return white
default:
return reset
}
}
@@ -0,0 +1,63 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"fmt"
"os"
)
const GIN_MODE = "GIN_MODE"
const (
DebugMode string = "debug"
ReleaseMode string = "release"
TestMode string = "test"
)
const (
debugCode = iota
releaseCode = iota
testCode = iota
)
var gin_mode int = debugCode
var mode_name string = DebugMode
func init() {
value := os.Getenv(GIN_MODE)
if len(value) == 0 {
SetMode(DebugMode)
} else {
SetMode(value)
}
}
func SetMode(value string) {
switch value {
case DebugMode:
gin_mode = debugCode
case ReleaseMode:
gin_mode = releaseCode
case TestMode:
gin_mode = testCode
default:
panic("gin mode unknown: " + value)
}
mode_name = value
}
func Mode() string {
return mode_name
}
func IsDebugging() bool {
return gin_mode == debugCode
}
func debugPrint(format string, values ...interface{}) {
if IsDebugging() {
fmt.Printf("[GIN-debug] "+format, values...)
}
}
@@ -0,0 +1,98 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"net/http"
"runtime"
)
var (
dunno = []byte("???")
centerDot = []byte("·")
dot = []byte(".")
slash = []byte("/")
)
// stack returns a nicely formated stack frame, skipping skip frames
func stack(skip int) []byte {
buf := new(bytes.Buffer) // the returned data
// As we loop, we open files and read them. These variables record the currently
// loaded file.
var lines [][]byte
var lastFile string
for i := skip; ; i++ { // Skip the expected number of frames
pc, file, line, ok := runtime.Caller(i)
if !ok {
break
}
// Print this much at least. If we can't find the source, it won't show.
fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
if file != lastFile {
data, err := ioutil.ReadFile(file)
if err != nil {
continue
}
lines = bytes.Split(data, []byte{'\n'})
lastFile = file
}
fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
}
return buf.Bytes()
}
// source returns a space-trimmed slice of the n'th line.
func source(lines [][]byte, n int) []byte {
n-- // in stack trace, lines are 1-indexed but our array is 0-indexed
if n < 0 || n >= len(lines) {
return dunno
}
return bytes.TrimSpace(lines[n])
}
// function returns, if possible, the name of the function containing the PC.
func function(pc uintptr) []byte {
fn := runtime.FuncForPC(pc)
if fn == nil {
return dunno
}
name := []byte(fn.Name())
// The name includes the path name to the package, which is unnecessary
// since the file name is already included. Plus, it has center dots.
// That is, we see
// runtime/debug.*T·ptrmethod
// and want
// *T.ptrmethod
// Also the package path might contains dot (e.g. code.google.com/...),
// so first eliminate the path prefix
if lastslash := bytes.LastIndex(name, slash); lastslash >= 0 {
name = name[lastslash+1:]
}
if period := bytes.Index(name, dot); period >= 0 {
name = name[period+1:]
}
name = bytes.Replace(name, centerDot, dot, -1)
return name
}
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
// While Martini is in development mode, Recovery will also output the panic as HTML.
func Recovery() HandlerFunc {
return func(c *Context) {
defer func() {
if err := recover(); err != nil {
stack := stack(3)
log.Printf("PANIC: %s\n%s", err, stack)
c.Writer.WriteHeader(http.StatusInternalServerError)
}
}()
c.Next()
}
}
@@ -0,0 +1,56 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"bytes"
"log"
"os"
"testing"
)
// TestPanicInHandler assert that panic has been recovered.
func TestPanicInHandler(t *testing.T) {
// SETUP
log.SetOutput(bytes.NewBuffer(nil)) // Disable panic logs for testing
r := New()
r.Use(Recovery())
r.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := PerformRequest(r, "GET", "/recovery")
// restore logging
log.SetOutput(os.Stderr)
if w.Code != 500 {
t.Errorf("Response code should be Internal Server Error, was: %s", w.Code)
}
}
// TestPanicWithAbort assert that panic has been recovered even if context.Abort was used.
func TestPanicWithAbort(t *testing.T) {
// SETUP
log.SetOutput(bytes.NewBuffer(nil))
r := New()
r.Use(Recovery())
r.GET("/recovery", func(c *Context) {
c.AbortWithStatus(400)
panic("Oupps, Houston, we have a problem")
})
// RUN
w := PerformRequest(r, "GET", "/recovery")
// restore logging
log.SetOutput(os.Stderr)
// TEST
if w.Code != 500 {
t.Errorf("Response code should be Bad request, was: %s", w.Code)
}
}
@@ -0,0 +1,123 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import (
"encoding/json"
"encoding/xml"
"fmt"
"html/template"
"net/http"
)
type (
Render interface {
Render(http.ResponseWriter, int, ...interface{}) error
}
// JSON binding
jsonRender struct{}
// XML binding
xmlRender struct{}
// Plain text
plainRender struct{}
// Redirects
redirectRender struct{}
// Redirects
htmlDebugRender struct {
files []string
globs []string
}
// form binding
HTMLRender struct {
Template *template.Template
}
)
var (
JSON = jsonRender{}
XML = xmlRender{}
Plain = plainRender{}
Redirect = redirectRender{}
HTMLDebug = &htmlDebugRender{}
)
func writeHeader(w http.ResponseWriter, code int, contentType string) {
w.Header().Set("Content-Type", contentType+"; charset=utf-8")
w.WriteHeader(code)
}
func (_ jsonRender) Render(w http.ResponseWriter, code int, data ...interface{}) error {
writeHeader(w, code, "application/json")
encoder := json.NewEncoder(w)
return encoder.Encode(data[0])
}
func (_ redirectRender) Render(w http.ResponseWriter, code int, data ...interface{}) error {
w.Header().Set("Location", data[0].(string))
w.WriteHeader(code)
return nil
}
func (_ xmlRender) Render(w http.ResponseWriter, code int, data ...interface{}) error {
writeHeader(w, code, "application/xml")
encoder := xml.NewEncoder(w)
return encoder.Encode(data[0])
}
func (_ plainRender) Render(w http.ResponseWriter, code int, data ...interface{}) error {
writeHeader(w, code, "text/plain")
format := data[0].(string)
args := data[1].([]interface{})
var err error
if len(args) > 0 {
_, err = w.Write([]byte(fmt.Sprintf(format, args...)))
} else {
_, err = w.Write([]byte(format))
}
return err
}
func (r *htmlDebugRender) AddGlob(pattern string) {
r.globs = append(r.globs, pattern)
}
func (r *htmlDebugRender) AddFiles(files ...string) {
r.files = append(r.files, files...)
}
func (r *htmlDebugRender) Render(w http.ResponseWriter, code int, data ...interface{}) error {
writeHeader(w, code, "text/html")
file := data[0].(string)
obj := data[1]
t := template.New("")
if len(r.files) > 0 {
if _, err := t.ParseFiles(r.files...); err != nil {
return err
}
}
for _, glob := range r.globs {
if _, err := t.ParseGlob(glob); err != nil {
return err
}
}
return t.ExecuteTemplate(w, file, obj)
}
func (html HTMLRender) Render(w http.ResponseWriter, code int, data ...interface{}) error {
writeHeader(w, code, "text/html")
file := data[0].(string)
obj := data[1]
return html.Template.ExecuteTemplate(w, file, obj)
}
@@ -0,0 +1,100 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"bufio"
"errors"
"log"
"net"
"net/http"
)
const (
NoWritten = -1
)
type (
ResponseWriter interface {
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
Status() int
Size() int
Written() bool
WriteHeaderNow()
}
responseWriter struct {
http.ResponseWriter
status int
size int
}
)
func (w *responseWriter) reset(writer http.ResponseWriter) {
w.ResponseWriter = writer
w.status = 200
w.size = NoWritten
}
func (w *responseWriter) WriteHeader(code int) {
if code > 0 {
w.status = code
if w.Written() {
log.Println("[GIN] WARNING. Headers were already written!")
}
}
}
func (w *responseWriter) WriteHeaderNow() {
if !w.Written() {
w.size = 0
w.ResponseWriter.WriteHeader(w.status)
}
}
func (w *responseWriter) Write(data []byte) (n int, err error) {
w.WriteHeaderNow()
n, err = w.ResponseWriter.Write(data)
w.size += n
return
}
func (w *responseWriter) Status() int {
return w.status
}
func (w *responseWriter) Size() int {
return w.size
}
func (w *responseWriter) Written() bool {
return w.size != NoWritten
}
// Implements the http.Hijacker interface
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("the ResponseWriter doesn't support the Hijacker interface")
}
return hijacker.Hijack()
}
// Implements the http.CloseNotify interface
func (w *responseWriter) CloseNotify() <-chan bool {
return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Implements the http.Flush interface
func (w *responseWriter) Flush() {
flusher, ok := w.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
@@ -0,0 +1,148 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"github.com/julienschmidt/httprouter"
"net/http"
"path"
)
// Used internally to configure router, a RouterGroup is associated with a prefix
// and an array of handlers (middlewares)
type RouterGroup struct {
Handlers []HandlerFunc
absolutePath string
engine *Engine
}
// Adds middlewares to the group, see example code in github.
func (group *RouterGroup) Use(middlewares ...HandlerFunc) {
group.Handlers = append(group.Handlers, middlewares...)
}
// Creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
// For example, all the routes that use a common middlware for authorization could be grouped.
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {
return &RouterGroup{
Handlers: group.combineHandlers(handlers),
absolutePath: group.calculateAbsolutePath(relativePath),
engine: group.engine,
}
}
// Handle registers a new request handle and middlewares with the given path and method.
// The last handler should be the real handler, the other ones should be middlewares that can and should be shared among different routes.
// See the example code in github.
//
// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut
// functions can be used.
//
// This function is intended for bulk loading and to allow the usage of less
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers []HandlerFunc) {
absolutePath := group.calculateAbsolutePath(relativePath)
handlers = group.combineHandlers(handlers)
if IsDebugging() {
nuHandlers := len(handlers)
handlerName := nameOfFunction(handlers[nuHandlers-1])
debugPrint("%-5s %-25s --> %s (%d handlers)\n", httpMethod, absolutePath, handlerName, nuHandlers)
}
group.engine.router.Handle(httpMethod, absolutePath, func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
context := group.engine.createContext(w, req, params, handlers)
context.Next()
context.Writer.WriteHeaderNow()
group.engine.reuseContext(context)
})
}
// POST is a shortcut for router.Handle("POST", path, handle)
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) {
group.Handle("POST", relativePath, handlers)
}
// GET is a shortcut for router.Handle("GET", path, handle)
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) {
group.Handle("GET", relativePath, handlers)
}
// DELETE is a shortcut for router.Handle("DELETE", path, handle)
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) {
group.Handle("DELETE", relativePath, handlers)
}
// PATCH is a shortcut for router.Handle("PATCH", path, handle)
func (group *RouterGroup) PATCH(relativePath string, handlers ...HandlerFunc) {
group.Handle("PATCH", relativePath, handlers)
}
// PUT is a shortcut for router.Handle("PUT", path, handle)
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) {
group.Handle("PUT", relativePath, handlers)
}
// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle)
func (group *RouterGroup) OPTIONS(relativePath string, handlers ...HandlerFunc) {
group.Handle("OPTIONS", relativePath, handlers)
}
// HEAD is a shortcut for router.Handle("HEAD", path, handle)
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) {
group.Handle("HEAD", relativePath, handlers)
}
// LINK is a shortcut for router.Handle("LINK", path, handle)
func (group *RouterGroup) LINK(relativePath string, handlers ...HandlerFunc) {
group.Handle("LINK", relativePath, handlers)
}
// UNLINK is a shortcut for router.Handle("UNLINK", path, handle)
func (group *RouterGroup) UNLINK(relativePath string, handlers ...HandlerFunc) {
group.Handle("UNLINK", relativePath, handlers)
}
// Static serves files from the given file system root.
// Internally a http.FileServer is used, therefore http.NotFound is used instead
// of the Router's NotFound handler.
// To use the operating system's file system implementation,
// use :
// router.Static("/static", "/var/www")
func (group *RouterGroup) Static(relativePath, root string) {
absolutePath := group.calculateAbsolutePath(relativePath)
handler := group.createStaticHandler(absolutePath, root)
absolutePath = path.Join(absolutePath, "/*filepath")
// Register GET and HEAD handlers
group.GET(absolutePath, handler)
group.HEAD(absolutePath, handler)
}
func (group *RouterGroup) createStaticHandler(absolutePath, root string) func(*Context) {
fileServer := http.StripPrefix(absolutePath, http.FileServer(http.Dir(root)))
return func(c *Context) {
fileServer.ServeHTTP(c.Writer, c.Request)
}
}
func (group *RouterGroup) combineHandlers(handlers []HandlerFunc) []HandlerFunc {
finalSize := len(group.Handlers) + len(handlers)
mergedHandlers := make([]HandlerFunc, 0, finalSize)
mergedHandlers = append(mergedHandlers, group.Handlers...)
return append(mergedHandlers, handlers...)
}
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {
if len(relativePath) == 0 {
return group.absolutePath
}
absolutePath := path.Join(group.absolutePath, relativePath)
appendSlash := lastChar(relativePath) == '/' && lastChar(absolutePath) != '/'
if appendSlash {
return absolutePath + "/"
}
return absolutePath
}
@@ -0,0 +1,82 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"encoding/xml"
"reflect"
"runtime"
"strings"
)
type H map[string]interface{}
// Allows type H to be used with xml.Marshal
func (h H) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
start.Name = xml.Name{
Space: "",
Local: "map",
}
if err := e.EncodeToken(start); err != nil {
return err
}
for key, value := range h {
elem := xml.StartElement{
Name: xml.Name{Space: "", Local: key},
Attr: []xml.Attr{},
}
if err := e.EncodeElement(value, elem); err != nil {
return err
}
}
if err := e.EncodeToken(xml.EndElement{Name: start.Name}); err != nil {
return err
}
return nil
}
func filterFlags(content string) string {
for i, char := range content {
if char == ' ' || char == ';' {
return content[:i]
}
}
return content
}
func chooseData(custom, wildcard interface{}) interface{} {
if custom == nil {
if wildcard == nil {
panic("negotiation config is invalid")
}
return wildcard
}
return custom
}
func parseAccept(accept string) []string {
parts := strings.Split(accept, ",")
for i, part := range parts {
index := strings.IndexByte(part, ';')
if index >= 0 {
part = part[0:index]
}
part = strings.TrimSpace(part)
parts[i] = part
}
return parts
}
func lastChar(str string) uint8 {
size := len(str)
if size == 0 {
panic("The length of the string can't be 0")
}
return str[size-1]
}
func nameOfFunction(f interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
}
@@ -0,0 +1,6 @@
language: go
go:
- 1.1
- 1.2
- 1.3
- tip
@@ -0,0 +1,24 @@
Copyright (c) 2013 Julien Schmidt. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* The names of the contributors may not be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL JULIEN SCHMIDT BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,234 @@
# HttpRouter [![Build Status](https://travis-ci.org/julienschmidt/httprouter.png?branch=master)](https://travis-ci.org/julienschmidt/httprouter) [![GoDoc](http://godoc.org/github.com/julienschmidt/httprouter?status.png)](http://godoc.org/github.com/julienschmidt/httprouter)
HttpRouter is a lightweight high performance HTTP request router
(also called *multiplexer* or just *mux* for short) for [Go](http://golang.org/).
In contrast to the default mux of Go's net/http package, this router supports
variables in the routing pattern and matches against the request method.
It also scales better.
The router is optimized for best performance and a small memory footprint.
It scales well even with very long paths and a large number of routes.
A compressing dynamic trie (radix tree) structure is used for efficient matching.
## Features
**Zero Garbage:** The matching and dispatching process generates zero bytes of
garbage. In fact, the only heap allocations that are made, is by building the
slice of the key-value pairs for path parameters. If the request path contains
no parameters, not a single heap allocation is necessary.
**Best Performance:** [Benchmarks speak for themselves](https://github.com/julienschmidt/go-http-routing-benchmark).
See below for technical details of the implementation.
**Parameters in your routing pattern:** Stop parsing the requested URL path,
just give the path segment a name and the router delivers the dynamic value to
you. Because of the design of the router, path parameters are very cheap.
**Only explicit matches:** With other routers, like [http.ServeMux](http://golang.org/pkg/net/http/#ServeMux),
a requested URL path could match multiple patterns. Therefore they have some
awkward pattern priority rules, like *longest match* or *first registered,
first matched*. By design of this router, a request can only match exactly one
or no route. As a result, there are also no unintended matches, which makes it
great for SEO and improves the user experience.
**Stop caring about trailing slashes:** Choose the URL style you like, the
router automatically redirects the client if a trailing slash is missing or if
there is one extra. Of course it only does so, if the new path has a handler.
If you don't like it, you can turn off this behavior.
**Path auto-correction:** Besides detecting the missing or additional trailing
slash at no extra cost, the router can also fix wrong cases and remove
superfluous path elements (like `../` or `//`).
Is [CAPTAIN CAPS LOCK](http://www.urbandictionary.com/define.php?term=Captain+Caps+Lock) one of your users?
HttpRouter can help him by making a case-insensitive look-up and redirecting him
to the correct URL.
**No more server crashes:** You can set a PanicHandler to deal with panics
occurring during handling a HTTP request. The router then recovers and lets the
PanicHandler log what happened and deliver a nice error page.
Of course you can also set a **custom NotFound handler** and **serve static files**.
## Usage
This is just a quick introduction, view the [GoDoc](http://godoc.org/github.com/julienschmidt/httprouter) for details.
Let's start with a trivial example:
```go
package main
import (
"fmt"
"github.com/julienschmidt/httprouter"
"net/http"
"log"
)
func Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
fmt.Fprint(w, "Welcome!\n")
}
func Hello(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
fmt.Fprintf(w, "hello, %s!\n", ps.ByName("name"))
}
func main() {
router := httprouter.New()
router.GET("/", Index)
router.GET("/hello/:name", Hello)
log.Fatal(http.ListenAndServe(":8080", router))
}
```
### Named parameters
As you can see, `:name` is a *named parameter*.
The values are accessible via `httprouter.Params`, which is just a slice of `httprouter.Param`s.
You can get the value of a parameter either by its index in the slice, or by using the `ByName(name)` method:
`:name` can be retrived by `ByName("name")`.
Named parameters only match a single path segment:
```
Pattern: /user/:user
/user/gordon match
/user/you match
/user/gordon/profile no match
/user/ no match
```
**Note:** Since this router has only explicit matches, you can not register static routes and parameters for the same path segment. For example you can not register the patterns `/user/new` and `/user/:user` for the same request method at the same time. The routing of different request methods is independent from each other.
### Catch-All parameters
The second type are *catch-all* parameters and have the form `*name`.
Like the name suggests, they match everything.
Therefore they must always be at the **end** of the pattern:
```
Pattern: /src/*filepath
/src/ match
/src/somefile.go match
/src/subdir/somefile.go match
```
## How does it work?
The router relies on a tree structure which makes heavy use of *common prefixes*,
it is basically a *compact* [*prefix tree*](http://en.wikipedia.org/wiki/Trie)
(or just [*Radix tree*](http://en.wikipedia.org/wiki/Radix_tree)).
Nodes with a common prefix also share a common parent. Here is a short example
what the routing tree for the `GET` request method could look like:
```
Priority Path Handle
9 \ *<1>
3 ├s nil
2 |├earch\ *<2>
1 |└upport\ *<3>
2 ├blog\ *<4>
1 | └:post nil
1 | └\ *<5>
2 ├about-us\ *<6>
1 | └team\ *<7>
1 └contact\ *<8>
```
Every `*<num>` represents the memory address of a handler function (a pointer).
If you follow a path trough the tree from the root to the leaf, you get the
complete route path, e.g `\blog\:post\`, where `:post` is just a placeholder
([*parameter*](#named-parameters)) for an actual post name. Unlike hash-maps, a
tree structure also allows us to use dynamic parts like the `:post` parameter,
since we actually match against the routing patterns instead of just comparing
hashes. [As benchmarks show](https://github.com/julienschmidt/go-http-routing-benchmark),
this works very well and efficient.
Since URL paths have a hierarchical structure and make use only of a limited set
of characters (byte values), it is very likely that there are a lot of common
prefixes. This allows us to easily reduce the routing into ever smaller problems.
Moreover the router manages a separate tree for every request method.
For one thing it is more space efficient than holding a method->handle map in
every single node, for another thing is also allows us to greatly reduce the
routing problem before even starting the look-up in the prefix-tree.
For even better scalability, the child nodes on each tree level are ordered by
priority, where the priority is just the number of handles registered in sub
nodes (children, grandchildren, and so on..).
This helps in two ways:
1. Nodes which are part of the most routing paths are evaluated first. This
helps to make as much routes as possible to be reachable as fast as possible.
2. It is some sort of cost compensation. The longest reachable path (highest
cost) can always be evaluated first. The following scheme visualizes the tree
structure. Nodes are evaluated from top to bottom and from left to right.
```
├------------
├---------
├-----
├----
├--
├--
└-
```
## Why doesn't this work with http.Handler?
**It does!** The router itself implements the http.Handler interface.
Moreover the router provides convenient [adapters for http.Handler](http://godoc.org/github.com/julienschmidt/httprouter#Router.Handler)s and [http.HandlerFunc](http://godoc.org/github.com/julienschmidt/httprouter#Router.HandlerFunc)s
which allows them to be used as a [httprouter.Handle](http://godoc.org/github.com/julienschmidt/httprouter#Router.Handle) when registering a route.
The only disadvantage is, that no parameter values can be retrieved when a
http.Handler or http.HandlerFunc is used, since there is no efficient way to
pass the values with the existing function parameters.
Therefore [httprouter.Handle](http://godoc.org/github.com/julienschmidt/httprouter#Router.Handle) has a third function parameter.
Just try it out for yourself, the usage of HttpRouter is very straightforward. The package is compact and minimalistic, but also probably one of the easiest routers to set up.
## Where can I find Middleware *X*?
This package just provides a very efficient request router with a few extra
features. The router is just a [http.Handler](http://golang.org/pkg/net/http/#Handler),
you can chain any http.Handler compatible middleware before the router,
for example the [Gorilla handlers](http://www.gorillatoolkit.org/pkg/handlers).
Or you could [just write your own](http://justinas.org/writing-http-middleware-in-go/),
it's very easy!
Alternatively, you could try [a framework building upon HttpRouter](#web-frameworks-building-upon-httprouter).
Here is a quick example: Does your server serve multiple domains / hosts?
You want to use sub-domains?
Define a router per host!
```go
// We need an object that implements the http.Handler interface.
// Therefore we need a type for which we implement the ServeHTTP method.
// We just use a map here, in which we map host names (with port) to http.Handlers
type HostSwitch map[string]http.Handler
// Implement the ServerHTTP method on our new type
func (hs HostSwitch) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check if a http.Handler is registered for the given host.
// If yes, use it to handle the request.
if handler := hs[r.Host]; handler != nil {
handler.ServeHTTP(w, r)
} else {
// Handle host names for wich no handler is registered
http.Error(w, "Forbidden", 403) // Or Redirect?
}
}
func main() {
// Initialize a router as usual
router := httprouter.New()
router.GET("/", Index)
router.GET("/hello/:name", Hello)
// Make a new HostSwitch and insert the router (our http handler)
// for example.com and port 12345
hs := make(HostSwitch)
hs["example.com:12345"] = router
// Use the HostSwitch to listen and serve on port 12345
log.Fatal(http.ListenAndServe(":12345", hs))
}
```
## Web Frameworks building upon HttpRouter
If the HttpRouter is a bit too minimalistic for you, you might try one of the following more high-level 3rd-party web frameworks building upon the HttpRouter package:
* [Gin](https://github.com/gin-gonic/gin): Features a martini-like API with much better performance
* [Hikaru](https://github.com/najeira/hikaru): Supports standalone and Google AppEngine
@@ -0,0 +1,123 @@
// Copyright 2013 Julien Schmidt. All rights reserved.
// Based on the path package, Copyright 2009 The Go Authors.
// Use of this source code is governed by a BSD-style license that can be found
// in the LICENSE file.
package httprouter
// CleanPath is the URL version of path.Clean, it returns a canonical URL path
// for p, eliminating . and .. elements.
//
// The following rules are applied iteratively until no further processing can
// be done:
// 1. Replace multiple slashes with a single slash.
// 2. Eliminate each . path name element (the current directory).
// 3. Eliminate each inner .. path name element (the parent directory)
// along with the non-.. element that precedes it.
// 4. Eliminate .. elements that begin a rooted path:
// that is, replace "/.." by "/" at the beginning of a path.
//
// If the result of this process is an empty string, "/" is returned
func CleanPath(p string) string {
// Turn empty string into "/"
if p == "" {
return "/"
}
n := len(p)
var buf []byte
// Invariants:
// reading from path; r is index of next byte to process.
// writing to buf; w is index of next byte to write.
// path must start with '/'
r := 1
w := 1
if p[0] != '/' {
r = 0
buf = make([]byte, n+1)
buf[0] = '/'
}
trailing := n > 2 && p[n-1] == '/'
// A bit more clunky without a 'lazybuf' like the path package, but the loop
// gets completely inlined (bufApp). So in contrast to the path package this
// loop has no expensive function calls (except 1x make)
for r < n {
switch {
case p[r] == '/':
// empty path element, trailing slash is added after the end
r++
case p[r] == '.' && r+1 == n:
trailing = true
r++
case p[r] == '.' && p[r+1] == '/':
// . element
r++
case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'):
// .. element: remove to last /
r += 2
if w > 1 {
// can backtrack
w--
if buf == nil {
for w > 1 && p[w] != '/' {
w--
}
} else {
for w > 1 && buf[w] != '/' {
w--
}
}
}
default:
// real path element.
// add slash if needed
if w > 1 {
bufApp(&buf, p, w, '/')
w++
}
// copy element
for r < n && p[r] != '/' {
bufApp(&buf, p, w, p[r])
w++
r++
}
}
}
// re-append trailing slash
if trailing && w > 1 {
bufApp(&buf, p, w, '/')
w++
}
if buf == nil {
return p[:w]
}
return string(buf[:w])
}
// internal helper to lazily create a buffer if necessary
func bufApp(buf *[]byte, s string, w int, c byte) {
if *buf == nil {
if s[w] == c {
return
}
*buf = make([]byte, len(s))
copy(*buf, s[:w])
}
(*buf)[w] = c
}
@@ -0,0 +1,92 @@
// Copyright 2013 Julien Schmidt. All rights reserved.
// Based on the path package, Copyright 2009 The Go Authors.
// Use of this source code is governed by a BSD-style license that can be found
// in the LICENSE file.
package httprouter
import (
"runtime"
"testing"
)
var cleanTests = []struct {
path, result string
}{
// Already clean
{"/", "/"},
{"/abc", "/abc"},
{"/a/b/c", "/a/b/c"},
{"/abc/", "/abc/"},
{"/a/b/c/", "/a/b/c/"},
// missing root
{"", "/"},
{"abc", "/abc"},
{"abc/def", "/abc/def"},
{"a/b/c", "/a/b/c"},
// Remove doubled slash
{"//", "/"},
{"/abc//", "/abc/"},
{"/abc/def//", "/abc/def/"},
{"/a/b/c//", "/a/b/c/"},
{"/abc//def//ghi", "/abc/def/ghi"},
{"//abc", "/abc"},
{"///abc", "/abc"},
{"//abc//", "/abc/"},
// Remove . elements
{".", "/"},
{"./", "/"},
{"/abc/./def", "/abc/def"},
{"/./abc/def", "/abc/def"},
{"/abc/.", "/abc/"},
// Remove .. elements
{"..", "/"},
{"../", "/"},
{"../../", "/"},
{"../..", "/"},
{"../../abc", "/abc"},
{"/abc/def/ghi/../jkl", "/abc/def/jkl"},
{"/abc/def/../ghi/../jkl", "/abc/jkl"},
{"/abc/def/..", "/abc"},
{"/abc/def/../..", "/"},
{"/abc/def/../../..", "/"},
{"/abc/def/../../..", "/"},
{"/abc/def/../../../ghi/jkl/../../../mno", "/mno"},
// Combinations
{"abc/./../def", "/def"},
{"abc//./../def", "/def"},
{"abc/../../././../def", "/def"},
}
func TestPathClean(t *testing.T) {
for _, test := range cleanTests {
if s := CleanPath(test.path); s != test.result {
t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result)
}
if s := CleanPath(test.result); s != test.result {
t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result)
}
}
}
func TestPathCleanMallocs(t *testing.T) {
if testing.Short() {
t.Skip("skipping malloc count in short mode")
}
if runtime.GOMAXPROCS(0) > 1 {
t.Log("skipping AllocsPerRun checks; GOMAXPROCS>1")
return
}
for _, test := range cleanTests {
allocs := testing.AllocsPerRun(100, func() { CleanPath(test.result) })
if allocs > 0 {
t.Errorf("CleanPath(%q): %v allocs, want zero", test.result, allocs)
}
}
}
@@ -0,0 +1,317 @@
// Copyright 2013 Julien Schmidt. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be found
// in the LICENSE file.
// Package httprouter is a trie based high performance HTTP request router.
//
// A trivial example is:
//
// package main
//
// import (
// "fmt"
// "github.com/julienschmidt/httprouter"
// "net/http"
// "log"
// )
//
// func Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// fmt.Fprint(w, "Welcome!\n")
// }
//
// func Hello(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// fmt.Fprintf(w, "hello, %s!\n", ps.ByName("name"))
// }
//
// func main() {
// router := httprouter.New()
// router.GET("/", Index)
// router.GET("/hello/:name", Hello)
//
// log.Fatal(http.ListenAndServe(":8080", router))
// }
//
// The router matches incoming requests by the request method and the path.
// If a handle is registered for this path and method, the router delegates the
// request to that function.
// For the methods GET, POST, PUT, PATCH and DELETE shortcut functions exist to
// register handles, for all other methods router.Handle can be used.
//
// The registered path, against which the router matches incoming requests, can
// contain two types of parameters:
// Syntax Type
// :name named parameter
// *name catch-all parameter
//
// Named parameters are dynamic path segments. They match anything until the
// next '/' or the path end:
// Path: /blog/:category/:post
//
// Requests:
// /blog/go/request-routers match: category="go", post="request-routers"
// /blog/go/request-routers/ no match, but the router would redirect
// /blog/go/ no match
// /blog/go/request-routers/comments no match
//
// Catch-all parameters match anything until the path end, including the
// directory index (the '/' before the catch-all). Since they match anything
// until the end, catch-all paramerters must always be the final path element.
// Path: /files/*filepath
//
// Requests:
// /files/ match: filepath="/"
// /files/LICENSE match: filepath="/LICENSE"
// /files/templates/article.html match: filepath="/templates/article.html"
// /files no match, but the router would redirect
//
// The value of parameters is saved as a slice of the Param struct, consisting
// each of a key and a value. The slice is passed to the Handle func as a third
// parameter.
// There are two ways to retrieve the value of a parameter:
// // by the name of the parameter
// user := ps.ByName("user") // defined by :user or *user
//
// // by the index of the parameter. This way you can also get the name (key)
// thirdKey := ps[2].Key // the name of the 3rd parameter
// thirdValue := ps[2].Value // the value of the 3rd parameter
package httprouter
import (
"net/http"
)
// Handle is a function that can be registered to a route to handle HTTP
// requests. Like http.HandlerFunc, but has a third parameter for the values of
// wildcards (variables).
type Handle func(http.ResponseWriter, *http.Request, Params)
// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
Key string
Value string
}
// Params is a Param-slice, as returned by the router.
// The slice is ordered, the first URL parameter is also the first slice value.
// It is therefore safe to read values by the index.
type Params []Param
// ByName returns the value of the first Param which key matches the given name.
// If no matching Param is found, an empty string is returned.
func (ps Params) ByName(name string) string {
for i := range ps {
if ps[i].Key == name {
return ps[i].Value
}
}
return ""
}
// Router is a http.Handler which can be used to dispatch requests to different
// handler functions via configurable routes
type Router struct {
trees map[string]*node
// Enables automatic redirection if the current route can't be matched but a
// handler for the path with (without) the trailing slash exists.
// For example if /foo/ is requested but a route only exists for /foo, the
// client is redirected to /foo with http status code 301 for GET requests
// and 307 for all other request methods.
RedirectTrailingSlash bool
// If enabled, the router tries to fix the current request path, if no
// handle is registered for it.
// First superfluous path elements like ../ or // are removed.
// Afterwards the router does a case-insensitive lookup of the cleaned path.
// If a handle can be found for this route, the router makes a redirection
// to the corrected path with status code 301 for GET requests and 307 for
// all other request methods.
// For example /FOO and /..//Foo could be redirected to /foo.
// RedirectTrailingSlash is independent of this option.
RedirectFixedPath bool
// Configurable http.HandlerFunc which is called when no matching route is
// found. If it is not set, http.NotFound is used.
NotFound http.HandlerFunc
// Function to handle panics recovered from http handlers.
// It should be used to generate a error page and return the http error code
// 500 (Internal Server Error).
// The handler can be used to keep your server from crashing because of
// unrecovered panics.
PanicHandler func(http.ResponseWriter, *http.Request, interface{})
}
// Make sure the Router conforms with the http.Handler interface
var _ http.Handler = New()
// New returns a new initialized Router.
// Path auto-correction, including trailing slashes, is enabled by default.
func New() *Router {
return &Router{
RedirectTrailingSlash: true,
RedirectFixedPath: true,
}
}
// GET is a shortcut for router.Handle("GET", path, handle)
func (r *Router) GET(path string, handle Handle) {
r.Handle("GET", path, handle)
}
// POST is a shortcut for router.Handle("POST", path, handle)
func (r *Router) POST(path string, handle Handle) {
r.Handle("POST", path, handle)
}
// PUT is a shortcut for router.Handle("PUT", path, handle)
func (r *Router) PUT(path string, handle Handle) {
r.Handle("PUT", path, handle)
}
// PATCH is a shortcut for router.Handle("PATCH", path, handle)
func (r *Router) PATCH(path string, handle Handle) {
r.Handle("PATCH", path, handle)
}
// DELETE is a shortcut for router.Handle("DELETE", path, handle)
func (r *Router) DELETE(path string, handle Handle) {
r.Handle("DELETE", path, handle)
}
// Handle registers a new request handle with the given path and method.
//
// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut
// functions can be used.
//
// This function is intended for bulk loading and to allow the usage of less
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (r *Router) Handle(method, path string, handle Handle) {
if path[0] != '/' {
panic("path must begin with '/'")
}
if r.trees == nil {
r.trees = make(map[string]*node)
}
root := r.trees[method]
if root == nil {
root = new(node)
r.trees[method] = root
}
root.addRoute(path, handle)
}
// Handler is an adapter which allows the usage of an http.Handler as a
// request handle.
func (r *Router) Handler(method, path string, handler http.Handler) {
r.Handle(method, path,
func(w http.ResponseWriter, req *http.Request, _ Params) {
handler.ServeHTTP(w, req)
},
)
}
// HandlerFunc is an adapter which allows the usage of an http.HandlerFunc as a
// request handle.
func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) {
r.Handle(method, path,
func(w http.ResponseWriter, req *http.Request, _ Params) {
handler(w, req)
},
)
}
// ServeFiles serves files from the given file system root.
// The path must end with "/*filepath", files are then served from the local
// path /defined/root/dir/*filepath.
// For example if root is "/etc" and *filepath is "passwd", the local file
// "/etc/passwd" would be served.
// Internally a http.FileServer is used, therefore http.NotFound is used instead
// of the Router's NotFound handler.
// To use the operating system's file system implementation,
// use http.Dir:
// router.ServeFiles("/src/*filepath", http.Dir("/var/www"))
func (r *Router) ServeFiles(path string, root http.FileSystem) {
if len(path) < 10 || path[len(path)-10:] != "/*filepath" {
panic("path must end with /*filepath")
}
fileServer := http.FileServer(root)
r.GET(path, func(w http.ResponseWriter, req *http.Request, ps Params) {
req.URL.Path = ps.ByName("filepath")
fileServer.ServeHTTP(w, req)
})
}
func (r *Router) recv(w http.ResponseWriter, req *http.Request) {
if rcv := recover(); rcv != nil {
r.PanicHandler(w, req, rcv)
}
}
// Lookup allows the manual lookup of a method + path combo.
// This is e.g. useful to build a framework around this router.
func (r *Router) Lookup(method, path string) (Handle, Params, bool) {
if root := r.trees[method]; root != nil {
return root.getValue(path)
}
return nil, nil, false
}
// ServeHTTP makes the router implement the http.Handler interface.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.PanicHandler != nil {
defer r.recv(w, req)
}
if root := r.trees[req.Method]; root != nil {
path := req.URL.Path
if handle, ps, tsr := root.getValue(path); handle != nil {
handle(w, req, ps)
return
} else if req.Method != "CONNECT" && path != "/" {
code := 301 // Permanent redirect, request with GET method
if req.Method != "GET" {
// Temporary redirect, request with same method
// As of Go 1.3, Go does not support status code 308.
code = 307
}
if tsr && r.RedirectTrailingSlash {
if path[len(path)-1] == '/' {
req.URL.Path = path[:len(path)-1]
} else {
req.URL.Path = path + "/"
}
http.Redirect(w, req, req.URL.String(), code)
return
}
// Try to fix the request path
if r.RedirectFixedPath {
fixedPath, found := root.findCaseInsensitivePath(
CleanPath(path),
r.RedirectTrailingSlash,
)
if found {
req.URL.Path = string(fixedPath)
http.Redirect(w, req, req.URL.String(), code)
return
}
}
}
}
// Handle 404
if r.NotFound != nil {
r.NotFound(w, req)
} else {
http.NotFound(w, req)
}
}
@@ -0,0 +1,329 @@
// Copyright 2013 Julien Schmidt. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be found
// in the LICENSE file.
package httprouter
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
type mockResponseWriter struct{}
func (m *mockResponseWriter) Header() (h http.Header) {
return http.Header{}
}
func (m *mockResponseWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (m *mockResponseWriter) WriteString(s string) (n int, err error) {
return len(s), nil
}
func (m *mockResponseWriter) WriteHeader(int) {}
func TestParams(t *testing.T) {
ps := Params{
Param{"param1", "value1"},
Param{"param2", "value2"},
Param{"param3", "value3"},
}
for i := range ps {
if val := ps.ByName(ps[i].Key); val != ps[i].Value {
t.Errorf("Wrong value for %s: Got %s; Want %s", ps[i].Key, val, ps[i].Value)
}
}
if val := ps.ByName("noKey"); val != "" {
t.Errorf("Expected empty string for not found key; got: %s", val)
}
}
func TestRouter(t *testing.T) {
router := New()
routed := false
router.Handle("GET", "/user/:name", func(w http.ResponseWriter, r *http.Request, ps Params) {
routed = true
want := Params{Param{"name", "gopher"}}
if !reflect.DeepEqual(ps, want) {
t.Fatalf("wrong wildcard values: want %v, got %v", want, ps)
}
})
w := new(mockResponseWriter)
req, _ := http.NewRequest("GET", "/user/gopher", nil)
router.ServeHTTP(w, req)
if !routed {
t.Fatal("routing failed")
}
}
type handlerStruct struct {
handeled *bool
}
func (h handlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
*h.handeled = true
}
func TestRouterAPI(t *testing.T) {
var get, post, put, patch, delete, handler, handlerFunc bool
httpHandler := handlerStruct{&handler}
router := New()
router.GET("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
get = true
})
router.POST("/POST", func(w http.ResponseWriter, r *http.Request, _ Params) {
post = true
})
router.PUT("/PUT", func(w http.ResponseWriter, r *http.Request, _ Params) {
put = true
})
router.PATCH("/PATCH", func(w http.ResponseWriter, r *http.Request, _ Params) {
patch = true
})
router.DELETE("/DELETE", func(w http.ResponseWriter, r *http.Request, _ Params) {
delete = true
})
router.Handler("GET", "/Handler", httpHandler)
router.HandlerFunc("GET", "/HandlerFunc", func(w http.ResponseWriter, r *http.Request) {
handlerFunc = true
})
w := new(mockResponseWriter)
r, _ := http.NewRequest("GET", "/GET", nil)
router.ServeHTTP(w, r)
if !get {
t.Error("routing GET failed")
}
r, _ = http.NewRequest("POST", "/POST", nil)
router.ServeHTTP(w, r)
if !post {
t.Error("routing POST failed")
}
r, _ = http.NewRequest("PUT", "/PUT", nil)
router.ServeHTTP(w, r)
if !put {
t.Error("routing PUT failed")
}
r, _ = http.NewRequest("PATCH", "/PATCH", nil)
router.ServeHTTP(w, r)
if !patch {
t.Error("routing PATCH failed")
}
r, _ = http.NewRequest("DELETE", "/DELETE", nil)
router.ServeHTTP(w, r)
if !delete {
t.Error("routing DELETE failed")
}
r, _ = http.NewRequest("GET", "/Handler", nil)
router.ServeHTTP(w, r)
if !handler {
t.Error("routing Handler failed")
}
r, _ = http.NewRequest("GET", "/HandlerFunc", nil)
router.ServeHTTP(w, r)
if !handlerFunc {
t.Error("routing HandlerFunc failed")
}
}
func TestRouterRoot(t *testing.T) {
router := New()
recv := catchPanic(func() {
router.GET("noSlashRoot", nil)
})
if recv == nil {
t.Fatal("registering path not beginning with '/' did not panic")
}
}
func TestRouterNotFound(t *testing.T) {
handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
router := New()
router.GET("/path", handlerFunc)
router.GET("/dir/", handlerFunc)
testRoutes := []struct {
route string
code int
header string
}{
{"/path/", 301, "map[Location:[/path]]"}, // TSR -/
{"/dir", 301, "map[Location:[/dir/]]"}, // TSR +/
{"/PATH", 301, "map[Location:[/path]]"}, // Fixed Case
{"/DIR/", 301, "map[Location:[/dir/]]"}, // Fixed Case
{"/PATH/", 301, "map[Location:[/path]]"}, // Fixed Case -/
{"/DIR", 301, "map[Location:[/dir/]]"}, // Fixed Case +/
{"/../path", 301, "map[Location:[/path]]"}, // CleanPath
{"/nope", 404, ""}, // NotFound
}
for _, tr := range testRoutes {
r, _ := http.NewRequest("GET", tr.route, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == tr.code && (w.Code == 404 || fmt.Sprint(w.Header()) == tr.header)) {
t.Errorf("NotFound handling route %s failed: Code=%d, Header=%v", tr.route, w.Code, w.Header())
}
}
// Test custom not found handler
var notFound bool
router.NotFound = func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(404)
notFound = true
}
r, _ := http.NewRequest("GET", "/nope", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == 404 && notFound == true) {
t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", w.Code, w.Header())
}
// Test other method than GET (want 307 instead of 301)
router.PATCH("/path", handlerFunc)
r, _ = http.NewRequest("PATCH", "/path/", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == 307 && fmt.Sprint(w.Header()) == "map[Location:[/path]]") {
t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", w.Code, w.Header())
}
// Test special case where no node for the prefix "/" exists
router = New()
router.GET("/a", handlerFunc)
r, _ = http.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == 404) {
t.Errorf("NotFound handling route / failed: Code=%d", w.Code)
}
}
func TestRouterPanicHandler(t *testing.T) {
router := New()
panicHandled := false
router.PanicHandler = func(rw http.ResponseWriter, r *http.Request, p interface{}) {
panicHandled = true
}
router.Handle("PUT", "/user/:name", func(_ http.ResponseWriter, _ *http.Request, _ Params) {
panic("oops!")
})
w := new(mockResponseWriter)
req, _ := http.NewRequest("PUT", "/user/gopher", nil)
defer func() {
if rcv := recover(); rcv != nil {
t.Fatal("handling panic failed")
}
}()
router.ServeHTTP(w, req)
if !panicHandled {
t.Fatal("simulating failed")
}
}
func TestRouterLookup(t *testing.T) {
routed := false
wantHandle := func(_ http.ResponseWriter, _ *http.Request, _ Params) {
routed = true
}
wantParams := Params{Param{"name", "gopher"}}
router := New()
// try empty router first
handle, _, tsr := router.Lookup("GET", "/nope")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
}
if tsr {
t.Error("Got wrong TSR recommendation!")
}
// insert route and try again
router.GET("/user/:name", wantHandle)
handle, params, tsr := router.Lookup("GET", "/user/gopher")
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
}
}
if !reflect.DeepEqual(params, wantParams) {
t.Fatalf("Wrong parameter values: want %v, got %v", wantParams, params)
}
handle, _, tsr = router.Lookup("GET", "/user/gopher/")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
}
if !tsr {
t.Error("Got no TSR recommendation!")
}
handle, _, tsr = router.Lookup("GET", "/nope")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
}
if tsr {
t.Error("Got wrong TSR recommendation!")
}
}
type mockFileSystem struct {
opened bool
}
func (mfs *mockFileSystem) Open(name string) (http.File, error) {
mfs.opened = true
return nil, errors.New("this is just a mock")
}
func TestRouterServeFiles(t *testing.T) {
router := New()
mfs := &mockFileSystem{}
recv := catchPanic(func() {
router.ServeFiles("/noFilepath", mfs)
})
if recv == nil {
t.Fatal("registering path not ending with '*filepath' did not panic")
}
router.ServeFiles("/*filepath", mfs)
w := new(mockResponseWriter)
r, _ := http.NewRequest("GET", "/favicon.ico", nil)
router.ServeHTTP(w, r)
if !mfs.opened {
t.Error("serving file failed")
}
}
@@ -0,0 +1,534 @@
// Copyright 2013 Julien Schmidt. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be found
// in the LICENSE file.
package httprouter
import (
"strings"
"unicode"
)
func min(a, b int) int {
if a <= b {
return a
}
return b
}
func countParams(path string) uint8 {
var n uint
for i := 0; i < len(path); i++ {
if path[i] != ':' && path[i] != '*' {
continue
}
n++
}
if n >= 255 {
return 255
}
return uint8(n)
}
type nodeType uint8
const (
static nodeType = 0
param nodeType = 1
catchAll nodeType = 2
)
type node struct {
path string
wildChild bool
nType nodeType
maxParams uint8
indices []byte
children []*node
handle Handle
priority uint32
}
// increments priority of the given child and reorders if necessary
func (n *node) incrementChildPrio(i int) int {
n.children[i].priority++
prio := n.children[i].priority
// adjust position (move to front)
for j := i - 1; j >= 0 && n.children[j].priority < prio; j-- {
// swap node positions
tmpN := n.children[j]
n.children[j] = n.children[i]
n.children[i] = tmpN
tmpI := n.indices[j]
n.indices[j] = n.indices[i]
n.indices[i] = tmpI
i--
}
return i
}
// addRoute adds a node with the given handle to the path.
// Not concurrency-safe!
func (n *node) addRoute(path string, handle Handle) {
n.priority++
numParams := countParams(path)
// non-empty tree
if len(n.path) > 0 || len(n.children) > 0 {
WALK:
for {
// Update maxParams of the current node
if numParams > n.maxParams {
n.maxParams = numParams
}
// Find the longest common prefix.
// This also implies that the commom prefix contains no ':' or '*'
// since the existing key can't contain this chars.
i := 0
for max := min(len(path), len(n.path)); i < max && path[i] == n.path[i]; i++ {
}
// Split edge
if i < len(n.path) {
child := node{
path: n.path[i:],
wildChild: n.wildChild,
indices: n.indices,
children: n.children,
handle: n.handle,
priority: n.priority - 1,
}
// Update maxParams (max of all children)
for i := range child.children {
if child.children[i].maxParams > child.maxParams {
child.maxParams = child.children[i].maxParams
}
}
n.children = []*node{&child}
n.indices = []byte{n.path[i]}
n.path = path[:i]
n.handle = nil
n.wildChild = false
}
// Make new node a child of this node
if i < len(path) {
path = path[i:]
if n.wildChild {
n = n.children[0]
n.priority++
// Update maxParams of the child node
if numParams > n.maxParams {
n.maxParams = numParams
}
numParams--
// Check if the wildcard matches
if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
// check for longer wildcard, e.g. :name and :names
if len(n.path) >= len(path) || path[len(n.path)] == '/' {
continue WALK
}
}
panic("conflict with wildcard route")
}
c := path[0]
// slash after param
if n.nType == param && c == '/' && len(n.children) == 1 {
n = n.children[0]
n.priority++
continue WALK
}
// Check if a child with the next path byte exists
for i, index := range n.indices {
if c == index {
i = n.incrementChildPrio(i)
n = n.children[i]
continue WALK
}
}
// Otherwise insert it
if c != ':' && c != '*' {
n.indices = append(n.indices, c)
child := &node{
maxParams: numParams,
}
n.children = append(n.children, child)
n.incrementChildPrio(len(n.indices) - 1)
n = child
}
n.insertChild(numParams, path, handle)
return
} else if i == len(path) { // Make node a (in-path) leaf
if n.handle != nil {
panic("a Handle is already registered for this path")
}
n.handle = handle
}
return
}
} else { // Empty tree
n.insertChild(numParams, path, handle)
}
}
func (n *node) insertChild(numParams uint8, path string, handle Handle) {
var offset int
// find prefix until first wildcard (beginning with ':'' or '*'')
for i, max := 0, len(path); numParams > 0; i++ {
c := path[i]
if c != ':' && c != '*' {
continue
}
// Check if this Node existing children which would be
// unreachable if we insert the wildcard here
if len(n.children) > 0 {
panic("wildcard route conflicts with existing children")
}
// find wildcard end (either '/' or path end)
end := i + 1
for end < max && path[end] != '/' {
end++
}
if end-i < 2 {
panic("wildcards must be named with a non-empty name")
}
if c == ':' { // param
// split path at the beginning of the wildcard
if i > 0 {
n.path = path[offset:i]
offset = i
}
child := &node{
nType: param,
maxParams: numParams,
}
n.children = []*node{child}
n.wildChild = true
n = child
n.priority++
numParams--
// if the path doesn't end with the wildcard, then there
// will be another non-wildcard subpath starting with '/'
if end < max {
n.path = path[offset:end]
offset = end
child := &node{
maxParams: numParams,
priority: 1,
}
n.children = []*node{child}
n = child
}
} else { // catchAll
if end != max || numParams > 1 {
panic("catch-all routes are only allowed at the end of the path")
}
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
panic("catch-all conflicts with existing handle for the path segment root")
}
// currently fixed width 1 for '/'
i--
if path[i] != '/' {
panic("no / before catch-all")
}
n.path = path[offset:i]
// first node: catchAll node with empty path
child := &node{
wildChild: true,
nType: catchAll,
maxParams: 1,
}
n.children = []*node{child}
n.indices = []byte{path[i]}
n = child
n.priority++
// second node: node holding the variable
child = &node{
path: path[i:],
nType: catchAll,
maxParams: 1,
handle: handle,
priority: 1,
}
n.children = []*node{child}
return
}
}
// insert remaining path part and handle to the leaf
n.path = path[offset:]
n.handle = handle
}
// Returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
// made if a handle exists with an extra (without the) trailing slash for the
// given path.
func (n *node) getValue(path string) (handle Handle, p Params, tsr bool) {
walk: // Outer loop for walking the tree
for {
if len(path) > len(n.path) {
if path[:len(n.path)] == n.path {
path = path[len(n.path):]
// If this node does not have a wildcard (param or catchAll)
// child, we can just look up the next child node and continue
// to walk down the tree
if !n.wildChild {
c := path[0]
for i, index := range n.indices {
if c == index {
n = n.children[i]
continue walk
}
}
// Nothing found.
// We can recommend to redirect to the same URL without a
// trailing slash if a leaf exists for that path.
tsr = (path == "/" && n.handle != nil)
return
}
// handle wildcard child
n = n.children[0]
switch n.nType {
case param:
// find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// save param value
if p == nil {
// lazy allocation
p = make(Params, 0, n.maxParams)
}
i := len(p)
p = p[:i+1] // expand slice within preallocated capacity
p[i].Key = n.path[1:]
p[i].Value = path[:end]
// we need to go deeper!
if end < len(path) {
if len(n.children) > 0 {
path = path[end:]
n = n.children[0]
continue walk
}
// ... but we can't
tsr = (len(path) == end+1)
return
}
if handle = n.handle; handle != nil {
return
} else if len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists for TSR recommendation
n = n.children[0]
tsr = (n.path == "/" && n.handle != nil)
}
return
case catchAll:
// save param value
if p == nil {
// lazy allocation
p = make(Params, 0, n.maxParams)
}
i := len(p)
p = p[:i+1] // expand slice within preallocated capacity
p[i].Key = n.path[2:]
p[i].Value = path
handle = n.handle
return
default:
panic("Unknown node type")
}
}
} else if path == n.path {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if handle = n.handle; handle != nil {
return
}
// No handle found. Check if a handle for this path + a
// trailing slash exists for trailing slash recommendation
for i, index := range n.indices {
if index == '/' {
n = n.children[i]
tsr = (n.path == "/" && n.handle != nil) ||
(n.nType == catchAll && n.children[0].handle != nil)
return
}
}
return
}
// Nothing found. We can recommend to redirect to the same URL with an
// extra trailing slash if a leaf exists for that path
tsr = (path == "/") ||
(len(n.path) == len(path)+1 && n.path[len(path)] == '/' &&
path == n.path[:len(n.path)-1] && n.handle != nil)
return
}
}
// Makes a case-insensitive lookup of the given path and tries to find a handler.
// It can optionally also fix trailing slashes.
// It returns the case-corrected path and a bool indicating wether the lookup
// was successful.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) {
ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory
// Outer loop for walking the tree
for len(path) >= len(n.path) && strings.ToLower(path[:len(n.path)]) == strings.ToLower(n.path) {
path = path[len(n.path):]
ciPath = append(ciPath, n.path...)
if len(path) > 0 {
// If this node does not have a wildcard (param or catchAll) child,
// we can just look up the next child node and continue to walk down
// the tree
if !n.wildChild {
r := unicode.ToLower(rune(path[0]))
for i, index := range n.indices {
// must use recursive approach since both index and
// ToLower(index) could exist. We must check both.
if r == unicode.ToLower(rune(index)) {
out, found := n.children[i].findCaseInsensitivePath(path, fixTrailingSlash)
if found {
return append(ciPath, out...), true
}
}
}
// Nothing found. We can recommend to redirect to the same URL
// without a trailing slash if a leaf exists for that path
found = (fixTrailingSlash && path == "/" && n.handle != nil)
return
} else {
n = n.children[0]
switch n.nType {
case param:
// find param end (either '/' or path end)
k := 0
for k < len(path) && path[k] != '/' {
k++
}
// add param value to case insensitive path
ciPath = append(ciPath, path[:k]...)
// we need to go deeper!
if k < len(path) {
if len(n.children) > 0 {
path = path[k:]
n = n.children[0]
continue
} else { // ... but we can't
if fixTrailingSlash && len(path) == k+1 {
return ciPath, true
}
return
}
}
if n.handle != nil {
return ciPath, true
} else if fixTrailingSlash && len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists
n = n.children[0]
if n.path == "/" && n.handle != nil {
return append(ciPath, '/'), true
}
}
return
case catchAll:
return append(ciPath, path...), true
default:
panic("Unknown node type")
}
}
} else {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if n.handle != nil {
return ciPath, true
}
// No handle found.
// Try to fix the path by adding a trailing slash
if fixTrailingSlash {
for i, index := range n.indices {
if index == '/' {
n = n.children[i]
if (n.path == "/" && n.handle != nil) ||
(n.nType == catchAll && n.children[0].handle != nil) {
return append(ciPath, '/'), true
}
return
}
}
}
return
}
}
// Nothing found.
// Try to fix the path by adding / removing a trailing slash
if fixTrailingSlash {
if path == "/" {
return ciPath, true
}
if len(path)+1 == len(n.path) && n.path[len(path)] == '/' &&
strings.ToLower(path) == strings.ToLower(n.path[:len(path)]) &&
n.handle != nil {
return append(ciPath, n.path...), true
}
}
return
}
@@ -0,0 +1,559 @@
// Copyright 2013 Julien Schmidt. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be found
// in the LICENSE file.
package httprouter
import (
"fmt"
"net/http"
"reflect"
"strings"
"testing"
)
func printChildren(n *node, prefix string) {
fmt.Printf(" %02d:%02d %s%s[%d] %v %t %d \r\n", n.priority, n.maxParams, prefix, n.path, len(n.children), n.handle, n.wildChild, n.nType)
for l := len(n.path); l > 0; l-- {
prefix += " "
}
for _, child := range n.children {
printChildren(child, prefix)
}
}
// Used as a workaround since we can't compare functions or their adresses
var fakeHandlerValue string
func fakeHandler(val string) Handle {
return func(http.ResponseWriter, *http.Request, Params) {
fakeHandlerValue = val
}
}
type testRequests []struct {
path string
nilHandler bool
route string
ps Params
}
func checkRequests(t *testing.T, tree *node, requests testRequests) {
for _, request := range requests {
handler, ps, _ := tree.getValue(request.path)
if handler == nil {
if !request.nilHandler {
t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path)
}
} else if request.nilHandler {
t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path)
} else {
handler(nil, nil, nil)
if fakeHandlerValue != request.route {
t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route)
}
}
if !reflect.DeepEqual(ps, request.ps) {
t.Errorf("Params mismatch for route '%s'", request.path)
}
}
}
func checkPriorities(t *testing.T, n *node) uint32 {
var prio uint32
for i := range n.children {
prio += checkPriorities(t, n.children[i])
}
if n.handle != nil {
prio++
}
if n.priority != prio {
t.Errorf(
"priority mismatch for node '%s': is %d, should be %d",
n.path, n.priority, prio,
)
}
return prio
}
func checkMaxParams(t *testing.T, n *node) uint8 {
var maxParams uint8
for i := range n.children {
params := checkMaxParams(t, n.children[i])
if params > maxParams {
maxParams = params
}
}
if n.nType != static && !n.wildChild {
maxParams++
}
if n.maxParams != maxParams {
t.Errorf(
"maxParams mismatch for node '%s': is %d, should be %d",
n.path, n.maxParams, maxParams,
)
}
return maxParams
}
func TestCountParams(t *testing.T) {
if countParams("/path/:param1/static/*catch-all") != 2 {
t.Fail()
}
if countParams(strings.Repeat("/:param", 256)) != 255 {
t.Fail()
}
}
func TestTreeAddAndGet(t *testing.T) {
tree := &node{}
routes := [...]string{
"/hi",
"/contact",
"/co",
"/c",
"/a",
"/ab",
"/doc/",
"/doc/go_faq.html",
"/doc/go1.html",
}
for _, route := range routes {
tree.addRoute(route, fakeHandler(route))
}
//printChildren(tree, "")
checkRequests(t, tree, testRequests{
{"/a", false, "/a", nil},
{"/", true, "", nil},
{"/hi", false, "/hi", nil},
{"/contact", false, "/contact", nil},
{"/co", false, "/co", nil},
{"/con", true, "", nil}, // key mismatch
{"/cona", true, "", nil}, // key mismatch
{"/no", true, "", nil}, // no matching child
{"/ab", false, "/ab", nil},
})
checkPriorities(t, tree)
checkMaxParams(t, tree)
}
func TestTreeWildcard(t *testing.T) {
tree := &node{}
routes := [...]string{
"/",
"/cmd/:tool/:sub",
"/cmd/:tool/",
"/src/*filepath",
"/search/",
"/search/:query",
"/user_:name",
"/user_:name/about",
"/files/:dir/*filepath",
"/doc/",
"/doc/go_faq.html",
"/doc/go1.html",
"/info/:user/public",
"/info/:user/project/:project",
}
for _, route := range routes {
tree.addRoute(route, fakeHandler(route))
}
//printChildren(tree, "")
checkRequests(t, tree, testRequests{
{"/", false, "/", nil},
{"/cmd/test/", false, "/cmd/:tool/", Params{Param{"tool", "test"}}},
{"/cmd/test", true, "", Params{Param{"tool", "test"}}},
{"/cmd/test/3", false, "/cmd/:tool/:sub", Params{Param{"tool", "test"}, Param{"sub", "3"}}},
{"/src/", false, "/src/*filepath", Params{Param{"filepath", "/"}}},
{"/src/some/file.png", false, "/src/*filepath", Params{Param{"filepath", "/some/file.png"}}},
{"/search/", false, "/search/", nil},
{"/search/someth!ng+in+ünìcodé", false, "/search/:query", Params{Param{"query", "someth!ng+in+ünìcodé"}}},
{"/search/someth!ng+in+ünìcodé/", true, "", Params{Param{"query", "someth!ng+in+ünìcodé"}}},
{"/user_gopher", false, "/user_:name", Params{Param{"name", "gopher"}}},
{"/user_gopher/about", false, "/user_:name/about", Params{Param{"name", "gopher"}}},
{"/files/js/inc/framework.js", false, "/files/:dir/*filepath", Params{Param{"dir", "js"}, Param{"filepath", "/inc/framework.js"}}},
{"/info/gordon/public", false, "/info/:user/public", Params{Param{"user", "gordon"}}},
{"/info/gordon/project/go", false, "/info/:user/project/:project", Params{Param{"user", "gordon"}, Param{"project", "go"}}},
})
checkPriorities(t, tree)
checkMaxParams(t, tree)
}
func catchPanic(testFunc func()) (recv interface{}) {
defer func() {
recv = recover()
}()
testFunc()
return
}
type testRoute struct {
path string
conflict bool
}
func testRoutes(t *testing.T, routes []testRoute) {
tree := &node{}
for _, route := range routes {
recv := catchPanic(func() {
tree.addRoute(route.path, nil)
})
if route.conflict {
if recv == nil {
t.Errorf("no panic for conflicting route '%s'", route.path)
}
} else if recv != nil {
t.Errorf("unexpected panic for route '%s': %v", route.path, recv)
}
}
//printChildren(tree, "")
}
func TestTreeWildcardConflict(t *testing.T) {
routes := []testRoute{
{"/cmd/:tool/:sub", false},
{"/cmd/vet", true},
{"/src/*filepath", false},
{"/src/*filepathx", true},
{"/src/", true},
{"/src1/", false},
{"/src1/*filepath", true},
{"/src2*filepath", true},
{"/search/:query", false},
{"/search/invalid", true},
{"/user_:name", false},
{"/user_x", true},
{"/user_:name", false},
{"/id:id", false},
{"/id/:id", true},
}
testRoutes(t, routes)
}
func TestTreeChildConflict(t *testing.T) {
routes := []testRoute{
{"/cmd/vet", false},
{"/cmd/:tool/:sub", true},
{"/src/AUTHORS", false},
{"/src/*filepath", true},
{"/user_x", false},
{"/user_:name", true},
{"/id/:id", false},
{"/id:id", true},
{"/:id", true},
{"/*filepath", true},
}
testRoutes(t, routes)
}
func TestTreeDupliatePath(t *testing.T) {
tree := &node{}
routes := [...]string{
"/",
"/doc/",
"/src/*filepath",
"/search/:query",
"/user_:name",
}
for _, route := range routes {
recv := catchPanic(func() {
tree.addRoute(route, fakeHandler(route))
})
if recv != nil {
t.Fatalf("panic inserting route '%s': %v", route, recv)
}
// Add again
recv = catchPanic(func() {
tree.addRoute(route, nil)
})
if recv == nil {
t.Fatalf("no panic while inserting duplicate route '%s", route)
}
}
//printChildren(tree, "")
checkRequests(t, tree, testRequests{
{"/", false, "/", nil},
{"/doc/", false, "/doc/", nil},
{"/src/some/file.png", false, "/src/*filepath", Params{Param{"filepath", "/some/file.png"}}},
{"/search/someth!ng+in+ünìcodé", false, "/search/:query", Params{Param{"query", "someth!ng+in+ünìcodé"}}},
{"/user_gopher", false, "/user_:name", Params{Param{"name", "gopher"}}},
})
}
func TestEmptyWildcardName(t *testing.T) {
tree := &node{}
routes := [...]string{
"/user:",
"/user:/",
"/cmd/:/",
"/src/*",
}
for _, route := range routes {
recv := catchPanic(func() {
tree.addRoute(route, nil)
})
if recv == nil {
t.Fatalf("no panic while inserting route with empty wildcard name '%s", route)
}
}
}
func TestTreeCatchAllConflict(t *testing.T) {
routes := []testRoute{
{"/src/*filepath/x", true},
{"/src2/", false},
{"/src2/*filepath/x", true},
}
testRoutes(t, routes)
}
func TestTreeCatchAllConflictRoot(t *testing.T) {
routes := []testRoute{
{"/", false},
{"/*filepath", true},
}
testRoutes(t, routes)
}
/*func TestTreeDuplicateWildcard(t *testing.T) {
tree := &node{}
routes := [...]string{
"/:id/:name/:id",
}
for _, route := range routes {
...
}
}*/
func TestTreeTrailingSlashRedirect(t *testing.T) {
tree := &node{}
routes := [...]string{
"/hi",
"/b/",
"/search/:query",
"/cmd/:tool/",
"/src/*filepath",
"/x",
"/x/y",
"/y/",
"/y/z",
"/0/:id",
"/0/:id/1",
"/1/:id/",
"/1/:id/2",
"/aa",
"/a/",
"/doc",
"/doc/go_faq.html",
"/doc/go1.html",
"/no/a",
"/no/b",
"/api/hello/:name",
}
for _, route := range routes {
recv := catchPanic(func() {
tree.addRoute(route, fakeHandler(route))
})
if recv != nil {
t.Fatalf("panic inserting route '%s': %v", route, recv)
}
}
//printChildren(tree, "")
tsrRoutes := [...]string{
"/hi/",
"/b",
"/search/gopher/",
"/cmd/vet",
"/src",
"/x/",
"/y",
"/0/go/",
"/1/go",
"/a",
"/doc/",
}
for _, route := range tsrRoutes {
handler, _, tsr := tree.getValue(route)
if handler != nil {
t.Fatalf("non-nil handler for TSR route '%s", route)
} else if !tsr {
t.Errorf("expected TSR recommendation for route '%s'", route)
}
}
noTsrRoutes := [...]string{
"/",
"/no",
"/no/",
"/_",
"/_/",
"/api/world/abc",
}
for _, route := range noTsrRoutes {
handler, _, tsr := tree.getValue(route)
if handler != nil {
t.Fatalf("non-nil handler for No-TSR route '%s", route)
} else if tsr {
t.Errorf("expected no TSR recommendation for route '%s'", route)
}
}
}
func TestTreeFindCaseInsensitivePath(t *testing.T) {
tree := &node{}
routes := [...]string{
"/hi",
"/b/",
"/ABC/",
"/search/:query",
"/cmd/:tool/",
"/src/*filepath",
"/x",
"/x/y",
"/y/",
"/y/z",
"/0/:id",
"/0/:id/1",
"/1/:id/",
"/1/:id/2",
"/aa",
"/a/",
"/doc",
"/doc/go_faq.html",
"/doc/go1.html",
"/doc/go/away",
"/no/a",
"/no/b",
}
for _, route := range routes {
recv := catchPanic(func() {
tree.addRoute(route, fakeHandler(route))
})
if recv != nil {
t.Fatalf("panic inserting route '%s': %v", route, recv)
}
}
// Check out == in for all registered routes
// With fixTrailingSlash = true
for _, route := range routes {
out, found := tree.findCaseInsensitivePath(route, true)
if !found {
t.Errorf("Route '%s' not found!", route)
} else if string(out) != route {
t.Errorf("Wrong result for route '%s': %s", route, string(out))
}
}
// With fixTrailingSlash = false
for _, route := range routes {
out, found := tree.findCaseInsensitivePath(route, false)
if !found {
t.Errorf("Route '%s' not found!", route)
} else if string(out) != route {
t.Errorf("Wrong result for route '%s': %s", route, string(out))
}
}
tests := []struct {
in string
out string
found bool
slash bool
}{
{"/HI", "/hi", true, false},
{"/HI/", "/hi", true, true},
{"/B", "/b/", true, true},
{"/B/", "/b/", true, false},
{"/abc", "/ABC/", true, true},
{"/abc/", "/ABC/", true, false},
{"/aBc", "/ABC/", true, true},
{"/aBc/", "/ABC/", true, false},
{"/abC", "/ABC/", true, true},
{"/abC/", "/ABC/", true, false},
{"/SEARCH/QUERY", "/search/QUERY", true, false},
{"/SEARCH/QUERY/", "/search/QUERY", true, true},
{"/CMD/TOOL/", "/cmd/TOOL/", true, false},
{"/CMD/TOOL", "/cmd/TOOL/", true, true},
{"/SRC/FILE/PATH", "/src/FILE/PATH", true, false},
{"/x/Y", "/x/y", true, false},
{"/x/Y/", "/x/y", true, true},
{"/X/y", "/x/y", true, false},
{"/X/y/", "/x/y", true, true},
{"/X/Y", "/x/y", true, false},
{"/X/Y/", "/x/y", true, true},
{"/Y/", "/y/", true, false},
{"/Y", "/y/", true, true},
{"/Y/z", "/y/z", true, false},
{"/Y/z/", "/y/z", true, true},
{"/Y/Z", "/y/z", true, false},
{"/Y/Z/", "/y/z", true, true},
{"/y/Z", "/y/z", true, false},
{"/y/Z/", "/y/z", true, true},
{"/Aa", "/aa", true, false},
{"/Aa/", "/aa", true, true},
{"/AA", "/aa", true, false},
{"/AA/", "/aa", true, true},
{"/aA", "/aa", true, false},
{"/aA/", "/aa", true, true},
{"/A/", "/a/", true, false},
{"/A", "/a/", true, true},
{"/DOC", "/doc", true, false},
{"/DOC/", "/doc", true, true},
{"/NO", "", false, true},
{"/DOC/GO", "", false, true},
}
// With fixTrailingSlash = true
for _, test := range tests {
out, found := tree.findCaseInsensitivePath(test.in, true)
if found != test.found || (found && (string(out) != test.out)) {
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
test.in, string(out), found, test.out, test.found)
return
}
}
// With fixTrailingSlash = false
for _, test := range tests {
out, found := tree.findCaseInsensitivePath(test.in, false)
if test.slash {
if found { // test needs a trailingSlash fix. It must not be found!
t.Errorf("Found without fixTrailingSlash: %s; got %s", test.in, string(out))
}
} else {
if found != test.found || (found && (string(out) != test.out)) {
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
test.in, string(out), found, test.out, test.found)
return
}
}
}
}
@@ -0,0 +1,42 @@
# go-colorable
Colorable writer for windows.
For example, most of logger packages doesn't show colors on windows. (I know we can do it with ansicon. But I don't want.)
This package is possible to handle escape sequence for ansi color on windows.
## Too Bad!
![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/bad.png)
## So Good!
![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/good.png)
## Usage
```go
logrus.SetOutput(colorable.NewColorableStdout())
logrus.Info("succeeded")
logrus.Warn("not correct")
logrus.Error("something error")
logrus.Fatal("panic")
```
You can compile above code on non-windows OSs.
## Installation
```
$ go get github.com/mattn/go-colorable
```
# License
MIT
# Author
Yasuhiro Matsumoto (a.k.a mattn)
@@ -0,0 +1,15 @@
package main
import (
"github.com/mattn/go-colorable"
"github.com/Sirupsen/logrus"
)
func main() {
logrus.SetOutput(colorable.NewColorableStdout())
logrus.Info("succeeded")
logrus.Warn("not correct")
logrus.Error("something error")
logrus.Fatal("panic")
}
@@ -0,0 +1,16 @@
// +build !windows
package colorable
import (
"io"
"os"
)
func NewColorableStdout() io.Writer {
return os.Stdout
}
func NewColorableStderr() io.Writer {
return os.Stderr
}
@@ -0,0 +1,594 @@
package colorable
import (
"bytes"
"fmt"
"io"
"os"
"strconv"
"strings"
"syscall"
"unsafe"
"github.com/mattn/go-isatty"
)
const (
foregroundBlue = 0x1
foregroundGreen = 0x2
foregroundRed = 0x4
foregroundIntensity = 0x8
foregroundMask = (foregroundRed | foregroundBlue | foregroundGreen | foregroundIntensity)
backgroundBlue = 0x10
backgroundGreen = 0x20
backgroundRed = 0x40
backgroundIntensity = 0x80
backgroundMask = (backgroundRed | backgroundBlue | backgroundGreen | backgroundIntensity)
)
type wchar uint16
type short int16
type dword uint32
type word uint16
type coord struct {
x short
y short
}
type smallRect struct {
left short
top short
right short
bottom short
}
type consoleScreenBufferInfo struct {
size coord
cursorPosition coord
attributes word
window smallRect
maximumWindowSize coord
}
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute")
)
type Writer struct {
out io.Writer
handle syscall.Handle
lastbuf bytes.Buffer
oldattr word
}
func NewColorableStdout() io.Writer {
var csbi consoleScreenBufferInfo
out := os.Stdout
if !isatty.IsTerminal(out.Fd()) {
return out
}
handle := syscall.Handle(out.Fd())
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
return &Writer{out: out, handle: handle, oldattr: csbi.attributes}
}
func NewColorableStderr() io.Writer {
var csbi consoleScreenBufferInfo
out := os.Stderr
if !isatty.IsTerminal(out.Fd()) {
return out
}
handle := syscall.Handle(out.Fd())
procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi)))
return &Writer{out: out, handle: handle, oldattr: csbi.attributes}
}
var color256 = map[int]int{
0: 0x000000,
1: 0x800000,
2: 0x008000,
3: 0x808000,
4: 0x000080,
5: 0x800080,
6: 0x008080,
7: 0xc0c0c0,
8: 0x808080,
9: 0xff0000,
10: 0x00ff00,
11: 0xffff00,
12: 0x0000ff,
13: 0xff00ff,
14: 0x00ffff,
15: 0xffffff,
16: 0x000000,
17: 0x00005f,
18: 0x000087,
19: 0x0000af,
20: 0x0000d7,
21: 0x0000ff,
22: 0x005f00,
23: 0x005f5f,
24: 0x005f87,
25: 0x005faf,
26: 0x005fd7,
27: 0x005fff,
28: 0x008700,
29: 0x00875f,
30: 0x008787,
31: 0x0087af,
32: 0x0087d7,
33: 0x0087ff,
34: 0x00af00,
35: 0x00af5f,
36: 0x00af87,
37: 0x00afaf,
38: 0x00afd7,
39: 0x00afff,
40: 0x00d700,
41: 0x00d75f,
42: 0x00d787,
43: 0x00d7af,
44: 0x00d7d7,
45: 0x00d7ff,
46: 0x00ff00,
47: 0x00ff5f,
48: 0x00ff87,
49: 0x00ffaf,
50: 0x00ffd7,
51: 0x00ffff,
52: 0x5f0000,
53: 0x5f005f,
54: 0x5f0087,
55: 0x5f00af,
56: 0x5f00d7,
57: 0x5f00ff,
58: 0x5f5f00,
59: 0x5f5f5f,
60: 0x5f5f87,
61: 0x5f5faf,
62: 0x5f5fd7,
63: 0x5f5fff,
64: 0x5f8700,
65: 0x5f875f,
66: 0x5f8787,
67: 0x5f87af,
68: 0x5f87d7,
69: 0x5f87ff,
70: 0x5faf00,
71: 0x5faf5f,
72: 0x5faf87,
73: 0x5fafaf,
74: 0x5fafd7,
75: 0x5fafff,
76: 0x5fd700,
77: 0x5fd75f,
78: 0x5fd787,
79: 0x5fd7af,
80: 0x5fd7d7,
81: 0x5fd7ff,
82: 0x5fff00,
83: 0x5fff5f,
84: 0x5fff87,
85: 0x5fffaf,
86: 0x5fffd7,
87: 0x5fffff,
88: 0x870000,
89: 0x87005f,
90: 0x870087,
91: 0x8700af,
92: 0x8700d7,
93: 0x8700ff,
94: 0x875f00,
95: 0x875f5f,
96: 0x875f87,
97: 0x875faf,
98: 0x875fd7,
99: 0x875fff,
100: 0x878700,
101: 0x87875f,
102: 0x878787,
103: 0x8787af,
104: 0x8787d7,
105: 0x8787ff,
106: 0x87af00,
107: 0x87af5f,
108: 0x87af87,
109: 0x87afaf,
110: 0x87afd7,
111: 0x87afff,
112: 0x87d700,
113: 0x87d75f,
114: 0x87d787,
115: 0x87d7af,
116: 0x87d7d7,
117: 0x87d7ff,
118: 0x87ff00,
119: 0x87ff5f,
120: 0x87ff87,
121: 0x87ffaf,
122: 0x87ffd7,
123: 0x87ffff,
124: 0xaf0000,
125: 0xaf005f,
126: 0xaf0087,
127: 0xaf00af,
128: 0xaf00d7,
129: 0xaf00ff,
130: 0xaf5f00,
131: 0xaf5f5f,
132: 0xaf5f87,
133: 0xaf5faf,
134: 0xaf5fd7,
135: 0xaf5fff,
136: 0xaf8700,
137: 0xaf875f,
138: 0xaf8787,
139: 0xaf87af,
140: 0xaf87d7,
141: 0xaf87ff,
142: 0xafaf00,
143: 0xafaf5f,
144: 0xafaf87,
145: 0xafafaf,
146: 0xafafd7,
147: 0xafafff,
148: 0xafd700,
149: 0xafd75f,
150: 0xafd787,
151: 0xafd7af,
152: 0xafd7d7,
153: 0xafd7ff,
154: 0xafff00,
155: 0xafff5f,
156: 0xafff87,
157: 0xafffaf,
158: 0xafffd7,
159: 0xafffff,
160: 0xd70000,
161: 0xd7005f,
162: 0xd70087,
163: 0xd700af,
164: 0xd700d7,
165: 0xd700ff,
166: 0xd75f00,
167: 0xd75f5f,
168: 0xd75f87,
169: 0xd75faf,
170: 0xd75fd7,
171: 0xd75fff,
172: 0xd78700,
173: 0xd7875f,
174: 0xd78787,
175: 0xd787af,
176: 0xd787d7,
177: 0xd787ff,
178: 0xd7af00,
179: 0xd7af5f,
180: 0xd7af87,
181: 0xd7afaf,
182: 0xd7afd7,
183: 0xd7afff,
184: 0xd7d700,
185: 0xd7d75f,
186: 0xd7d787,
187: 0xd7d7af,
188: 0xd7d7d7,
189: 0xd7d7ff,
190: 0xd7ff00,
191: 0xd7ff5f,
192: 0xd7ff87,
193: 0xd7ffaf,
194: 0xd7ffd7,
195: 0xd7ffff,
196: 0xff0000,
197: 0xff005f,
198: 0xff0087,
199: 0xff00af,
200: 0xff00d7,
201: 0xff00ff,
202: 0xff5f00,
203: 0xff5f5f,
204: 0xff5f87,
205: 0xff5faf,
206: 0xff5fd7,
207: 0xff5fff,
208: 0xff8700,
209: 0xff875f,
210: 0xff8787,
211: 0xff87af,
212: 0xff87d7,
213: 0xff87ff,
214: 0xffaf00,
215: 0xffaf5f,
216: 0xffaf87,
217: 0xffafaf,
218: 0xffafd7,
219: 0xffafff,
220: 0xffd700,
221: 0xffd75f,
222: 0xffd787,
223: 0xffd7af,
224: 0xffd7d7,
225: 0xffd7ff,
226: 0xffff00,
227: 0xffff5f,
228: 0xffff87,
229: 0xffffaf,
230: 0xffffd7,
231: 0xffffff,
232: 0x080808,
233: 0x121212,
234: 0x1c1c1c,
235: 0x262626,
236: 0x303030,
237: 0x3a3a3a,
238: 0x444444,
239: 0x4e4e4e,
240: 0x585858,
241: 0x626262,
242: 0x6c6c6c,
243: 0x767676,
244: 0x808080,
245: 0x8a8a8a,
246: 0x949494,
247: 0x9e9e9e,
248: 0xa8a8a8,
249: 0xb2b2b2,
250: 0xbcbcbc,
251: 0xc6c6c6,
252: 0xd0d0d0,
253: 0xdadada,
254: 0xe4e4e4,
255: 0xeeeeee,
}
func (w *Writer) Write(data []byte) (n int, err error) {
var csbi consoleScreenBufferInfo
procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
er := bytes.NewBuffer(data)
loop:
for {
r1, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi)))
if r1 == 0 {
break loop
}
c1, _, err := er.ReadRune()
if err != nil {
break loop
}
if c1 != 0x1b {
fmt.Fprint(w.out, string(c1))
continue
}
c2, _, err := er.ReadRune()
if err != nil {
w.lastbuf.WriteRune(c1)
break loop
}
if c2 != 0x5b {
w.lastbuf.WriteRune(c1)
w.lastbuf.WriteRune(c2)
continue
}
var buf bytes.Buffer
var m rune
for {
c, _, err := er.ReadRune()
if err != nil {
w.lastbuf.WriteRune(c1)
w.lastbuf.WriteRune(c2)
w.lastbuf.Write(buf.Bytes())
break loop
}
if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' {
m = c
break
}
buf.Write([]byte(string(c)))
}
switch m {
case 'm':
attr := csbi.attributes
cs := buf.String()
if cs == "" {
procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(w.oldattr))
continue
}
token := strings.Split(cs, ";")
for i, ns := range token {
if n, err = strconv.Atoi(ns); err == nil {
switch {
case n == 0 || n == 100:
attr = w.oldattr
case 1 <= n && n <= 5:
attr |= foregroundIntensity
case n == 7:
attr = ((attr & foregroundMask) << 4) | ((attr & backgroundMask) >> 4)
case 22 == n || n == 25 || n == 25:
attr |= foregroundIntensity
case n == 27:
attr = ((attr & foregroundMask) << 4) | ((attr & backgroundMask) >> 4)
case 30 <= n && n <= 37:
attr = (attr & backgroundMask)
if (n-30)&1 != 0 {
attr |= foregroundRed
}
if (n-30)&2 != 0 {
attr |= foregroundGreen
}
if (n-30)&4 != 0 {
attr |= foregroundBlue
}
case n == 38: // set foreground color.
if i < len(token)-2 && token[i+1] == "5" {
if n256, err := strconv.Atoi(token[i+2]); err == nil {
if n256foreAttr == nil {
n256setup()
}
attr &= backgroundMask
attr |= n256foreAttr[n256]
i += 2
}
} else {
attr = attr & (w.oldattr & backgroundMask)
}
case n == 39: // reset foreground color.
attr &= backgroundMask
attr |= w.oldattr & foregroundMask
case 40 <= n && n <= 47:
attr = (attr & foregroundMask)
if (n-40)&1 != 0 {
attr |= backgroundRed
}
if (n-40)&2 != 0 {
attr |= backgroundGreen
}
if (n-40)&4 != 0 {
attr |= backgroundBlue
}
case n == 48: // set background color.
if i < len(token)-2 && token[i+1] == "5" {
if n256, err := strconv.Atoi(token[i+2]); err == nil {
if n256backAttr == nil {
n256setup()
}
attr &= foregroundMask
attr |= n256backAttr[n256]
i += 2
}
} else {
attr = attr & (w.oldattr & foregroundMask)
}
case n == 49: // reset foreground color.
attr &= foregroundMask
attr |= w.oldattr & backgroundMask
}
procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(attr))
}
}
}
}
return len(data) - w.lastbuf.Len(), nil
}
type consoleColor struct {
red bool
green bool
blue bool
intensity bool
}
func minmax3(a, b, c int) (min, max int) {
if a < b {
if b < c {
return a, c
} else if a < c {
return a, b
} else {
return c, b
}
} else {
if a < c {
return b, c
} else if b < c {
return b, a
} else {
return c, a
}
}
}
func toConsoleColor(rgb int) (c consoleColor) {
r, g, b := (rgb&0xFF0000)>>16, (rgb&0x00FF00)>>8, rgb&0x0000FF
min, max := minmax3(r, g, b)
a := (min + max) / 2
if r < 128 && g < 128 && b < 128 {
if r >= a {
c.red = true
}
if g >= a {
c.green = true
}
if b >= a {
c.blue = true
}
// non-intensed white is lighter than intensed black, so swap those.
if c.red && c.green && c.blue {
c.red, c.green, c.blue = false, false, false
c.intensity = true
}
} else {
if min < 128 {
min = 128
a = (min + max) / 2
}
if r >= a {
c.red = true
}
if g >= a {
c.green = true
}
if b >= a {
c.blue = true
}
c.intensity = true
// intensed black is darker than non-intensed white, so swap those.
if !c.red && !c.green && !c.blue {
c.red, c.green, c.blue = true, true, true
c.intensity = false
}
}
return c
}
func (c consoleColor) foregroundAttr() (attr word) {
if c.red {
attr |= foregroundRed
}
if c.green {
attr |= foregroundGreen
}
if c.blue {
attr |= foregroundBlue
}
if c.intensity {
attr |= foregroundIntensity
}
return
}
func (c consoleColor) backgroundAttr() (attr word) {
if c.red {
attr |= backgroundRed
}
if c.green {
attr |= backgroundGreen
}
if c.blue {
attr |= backgroundBlue
}
if c.intensity {
attr |= backgroundIntensity
}
return
}
var n256foreAttr []word
var n256backAttr []word
func n256setup() {
n256foreAttr = make([]word, 256)
n256backAttr = make([]word, 256)
for i, rgb := range color256 {
c := toConsoleColor(rgb)
n256foreAttr[i] = c.foregroundAttr()
n256backAttr[i] = c.backgroundAttr()
}
}
@@ -0,0 +1,4 @@
*~
*.pyc
test-env*
junk/
@@ -0,0 +1,10 @@
language: go
go:
- 1.1.2
- 1.2.2
- 1.3
- tip
script:
- go test
@@ -0,0 +1,20 @@
Copyright (C) 2012 by Nick Craig-Wood http://www.craig-wood.com/nick/
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
@@ -0,0 +1,105 @@
Swift
=====
This package provides an easy to use library for interfacing with
Swift / Openstack Object Storage / Rackspace cloud files from the Go
Language
See here for package docs
http://godoc.org/github.com/ncw/swift
[![Build Status](https://travis-ci.org/ncw/swift.png)](https://travis-ci.org/ncw/swift)
Install
-------
Use go to install the library
go get github.com/ncw/swift
Usage
-----
See here for full package docs
- http://godoc.org/github.com/ncw/swift
Here is a short example from the docs
import "github.com/ncw/swift"
// Create a connection
c := swift.Connection{
UserName: "user",
ApiKey: "key",
AuthUrl: "auth_url",
}
// Authenticate
err := c.Authenticate()
if err != nil {
panic(err)
}
// List all the containers
containers, err := c.ContainerNames(nil)
fmt.Println(containers)
// etc...
Additions
---------
The `rs` sub project contains a wrapper for the Rackspace specific CDN Management interface.
Testing
-------
To run the tests you can either use an embedded fake Swift server
either use a real Openstack Swift server or a Rackspace Cloud files account.
When using a real Swift server, you need to set these environment variables
before running the tests
export SWIFT_API_USER='user'
export SWIFT_API_KEY='key'
export SWIFT_AUTH_URL='https://url.of.auth.server/v1.0'
And optionally these if using v2 authentication
export SWIFT_TENANT='TenantName'
export SWIFT_TENANT_ID='TenantId'
Then run the tests with `go test`
License
-------
This is free software under the terms of MIT license (check COPYING file
included in this package).
Contact and support
-------------------
The project website is at:
- https://github.com/ncw/swift
There you can file bug reports, ask for help or contribute patches.
Authors
-------
- Nick Craig-Wood <nick@craig-wood.com>
Contributors
------------
- Brian "bojo" Jones <mojobojo@gmail.com>
- Janika Liiv <janika@toggl.com>
- Yamamoto, Hirotaka <ymmt2005@gmail.com>
- Stephen <yo@groks.org>
- platformpurple <stephen@platformpurple.com>
- Paul Querna <pquerna@apache.org>
- Livio Soares <liviobs@gmail.com>
- thesyncim <thesyncim@gmail.com>
- lsowen <lsowen@s1network.com>
- Sylvain Baubeau <sbaubeau@redhat.com>
@@ -0,0 +1,279 @@
package swift
import (
"bytes"
"encoding/json"
"net/http"
"net/url"
"strings"
)
// Auth defines the operations needed to authenticate with swift
//
// This encapsulates the different authentication schemes in use
type Authenticator interface {
Request(*Connection) (*http.Request, error)
Response(resp *http.Response) error
// The public storage URL - set Internal to true to read
// internal/service net URL
StorageUrl(Internal bool) string
// The access token
Token() string
// The CDN url if available
CdnUrl() string
}
// newAuth - create a new Authenticator from the AuthUrl
//
// A hint for AuthVersion can be provided
func newAuth(c *Connection) (Authenticator, error) {
AuthVersion := c.AuthVersion
if AuthVersion == 0 {
if strings.Contains(c.AuthUrl, "v2") {
AuthVersion = 2
} else if strings.Contains(c.AuthUrl, "v1") {
AuthVersion = 1
} else {
return nil, newErrorf(500, "Can't find AuthVersion in AuthUrl - set explicitly")
}
}
switch AuthVersion {
case 1:
return &v1Auth{}, nil
case 2:
return &v2Auth{
// Guess as to whether using API key or
// password it will try both eventually so
// this is just an optimization.
useApiKey: len(c.ApiKey) >= 32,
}, nil
}
return nil, newErrorf(500, "Auth Version %d not supported", AuthVersion)
}
// ------------------------------------------------------------
// v1 auth
type v1Auth struct {
Headers http.Header // V1 auth: the authentication headers so extensions can access them
}
// v1 Authentication - make request
func (auth *v1Auth) Request(c *Connection) (*http.Request, error) {
req, err := http.NewRequest("GET", c.AuthUrl, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", c.UserAgent)
req.Header.Set("X-Auth-Key", c.ApiKey)
req.Header.Set("X-Auth-User", c.UserName)
return req, nil
}
// v1 Authentication - read response
func (auth *v1Auth) Response(resp *http.Response) error {
auth.Headers = resp.Header
return nil
}
// v1 Authentication - read storage url
func (auth *v1Auth) StorageUrl(Internal bool) string {
storageUrl := auth.Headers.Get("X-Storage-Url")
if Internal {
newUrl, err := url.Parse(storageUrl)
if err != nil {
return storageUrl
}
newUrl.Host = "snet-" + newUrl.Host
storageUrl = newUrl.String()
}
return storageUrl
}
// v1 Authentication - read auth token
func (auth *v1Auth) Token() string {
return auth.Headers.Get("X-Auth-Token")
}
// v1 Authentication - read cdn url
func (auth *v1Auth) CdnUrl() string {
return auth.Headers.Get("X-CDN-Management-Url")
}
// ------------------------------------------------------------
// v2 Authentication
type v2Auth struct {
Auth *v2AuthResponse
Region string
useApiKey bool // if set will use API key not Password
useApiKeyOk bool // if set won't change useApiKey any more
notFirst bool // set after first run
}
// v2 Authentication - make request
func (auth *v2Auth) Request(c *Connection) (*http.Request, error) {
auth.Region = c.Region
// Toggle useApiKey if not first run and not OK yet
if auth.notFirst && !auth.useApiKeyOk {
auth.useApiKey = !auth.useApiKey
}
auth.notFirst = true
// Create a V2 auth request for the body of the connection
var v2i interface{}
if !auth.useApiKey {
// Normal swift authentication
v2 := v2AuthRequest{}
v2.Auth.PasswordCredentials.UserName = c.UserName
v2.Auth.PasswordCredentials.Password = c.ApiKey
v2.Auth.Tenant = c.Tenant
v2.Auth.TenantId = c.TenantId
v2i = v2
} else {
// Rackspace special with API Key
v2 := v2AuthRequestRackspace{}
v2.Auth.ApiKeyCredentials.UserName = c.UserName
v2.Auth.ApiKeyCredentials.ApiKey = c.ApiKey
v2.Auth.Tenant = c.Tenant
v2.Auth.TenantId = c.TenantId
v2i = v2
}
body, err := json.Marshal(v2i)
if err != nil {
return nil, err
}
url := c.AuthUrl
if !strings.HasSuffix(url, "/") {
url += "/"
}
url += "tokens"
req, err := http.NewRequest("POST", url, bytes.NewBuffer(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return req, nil
}
// v2 Authentication - read response
func (auth *v2Auth) Response(resp *http.Response) error {
auth.Auth = new(v2AuthResponse)
err := readJson(resp, auth.Auth)
// If successfully read Auth then no need to toggle useApiKey any more
if err == nil {
auth.useApiKeyOk = true
}
return err
}
// Finds the Endpoint Url of "type" from the v2AuthResponse using the
// Region if set or defaulting to the first one if not
//
// Returns "" if not found
func (auth *v2Auth) endpointUrl(Type string, Internal bool) string {
for _, catalog := range auth.Auth.Access.ServiceCatalog {
if catalog.Type == Type {
for _, endpoint := range catalog.Endpoints {
if auth.Region == "" || (auth.Region == endpoint.Region) {
if Internal {
return endpoint.InternalUrl
} else {
return endpoint.PublicUrl
}
}
}
}
}
return ""
}
// v2 Authentication - read storage url
//
// If Internal is true then it reads the private (internal / service
// net) URL.
func (auth *v2Auth) StorageUrl(Internal bool) string {
return auth.endpointUrl("object-store", Internal)
}
// v2 Authentication - read auth token
func (auth *v2Auth) Token() string {
return auth.Auth.Access.Token.Id
}
// v2 Authentication - read cdn url
func (auth *v2Auth) CdnUrl() string {
return auth.endpointUrl("rax:object-cdn", false)
}
// ------------------------------------------------------------
// V2 Authentication request
//
// http://docs.openstack.org/developer/keystone/api_curl_examples.html
// http://docs.rackspace.com/servers/api/v2/cs-gettingstarted/content/curl_auth.html
// http://docs.openstack.org/api/openstack-identity-service/2.0/content/POST_authenticate_v2.0_tokens_.html
type v2AuthRequest struct {
Auth struct {
PasswordCredentials struct {
UserName string `json:"username"`
Password string `json:"password"`
} `json:"passwordCredentials"`
Tenant string `json:"tenantName,omitempty"`
TenantId string `json:"tenantId,omitempty"`
} `json:"auth"`
}
// V2 Authentication request - Rackspace variant
//
// http://docs.openstack.org/developer/keystone/api_curl_examples.html
// http://docs.rackspace.com/servers/api/v2/cs-gettingstarted/content/curl_auth.html
// http://docs.openstack.org/api/openstack-identity-service/2.0/content/POST_authenticate_v2.0_tokens_.html
type v2AuthRequestRackspace struct {
Auth struct {
ApiKeyCredentials struct {
UserName string `json:"username"`
ApiKey string `json:"apiKey"`
} `json:"RAX-KSKEY:apiKeyCredentials"`
Tenant string `json:"tenantName,omitempty"`
TenantId string `json:"tenantId,omitempty"`
} `json:"auth"`
}
// V2 Authentication reply
//
// http://docs.openstack.org/developer/keystone/api_curl_examples.html
// http://docs.rackspace.com/servers/api/v2/cs-gettingstarted/content/curl_auth.html
// http://docs.openstack.org/api/openstack-identity-service/2.0/content/POST_authenticate_v2.0_tokens_.html
type v2AuthResponse struct {
Access struct {
ServiceCatalog []struct {
Endpoints []struct {
InternalUrl string
PublicUrl string
Region string
TenantId string
}
Name string
Type string
}
Token struct {
Expires string
Id string
Tenant struct {
Id string
Name string
}
}
User struct {
DefaultRegion string `json:"RAX-AUTH:defaultRegion"`
Id string
Name string
Roles []struct {
Description string
Id string
Name string
TenantId string
}
}
}
}
@@ -0,0 +1,28 @@
// Go 1.0 compatibility functions
// +build !go1.1
package swift
import (
"log"
"net/http"
"time"
)
// Cancel the request - doesn't work under < go 1.1
func cancelRequest(transport http.RoundTripper, req *http.Request) {
log.Printf("Tried to cancel a request but couldn't - recompile with go 1.1")
}
// Reset a timer - Doesn't work properly < go 1.1
//
// This is quite hard to do properly under go < 1.1 so we do a crude
// approximation and hope that everyone upgrades to go 1.1 quickly
func resetTimer(t *time.Timer, d time.Duration) {
t.Stop()
// Very likely this doesn't actually work if we are already
// selecting on t.C. However we've stopped the original timer
// so won't break transfers but may not time them out :-(
*t = *time.NewTimer(d)
}
@@ -0,0 +1,24 @@
// Go 1.1 and later compatibility functions
//
// +build go1.1
package swift
import (
"net/http"
"time"
)
// Cancel the request
func cancelRequest(transport http.RoundTripper, req *http.Request) {
if tr, ok := transport.(interface {
CancelRequest(*http.Request)
}); ok {
tr.CancelRequest(req)
}
}
// Reset a timer
func resetTimer(t *time.Timer, d time.Duration) {
t.Reset(d)
}
@@ -0,0 +1,19 @@
/*
Package swift provides an easy to use interface to Swift / Openstack Object Storage / Rackspace Cloud Files
Standard Usage
Most of the work is done through the Container*() and Object*() methods.
All methods are safe to use concurrently in multiple go routines.
Object Versioning
As defined by http://docs.openstack.org/api/openstack-object-storage/1.0/content/Object_Versioning-e1e3230.html#d6e983 one can create a container which allows for version control of files. The suggested method is to create a version container for holding all non-current files, and a current container for holding the latest version that the file points to. The container and objects inside it can be used in the standard manner, however, pushing a file multiple times will result in it being copied to the version container and the new file put in it's place. If the current file is deleted, the previous file in the version container will replace it. This means that if a file is updated 5 times, it must be deleted 5 times to be completely removed from the system.
Rackspace Sub Module
This module specifically allows the enabling/disabling of Rackspace Cloud File CDN management on a container. This is specific to the Rackspace API and not Swift/Openstack, therefore it has been placed in a submodule. One can easily create a RsConnection and use it like the standard Connection to access and manipulate containers and objects.
*/
package swift
@@ -0,0 +1,97 @@
// Copyright...
// This example demonstrates opening a Connection and doing some basic operations.
package swift_test
import (
"fmt"
"github.com/ncw/swift"
)
func ExampleConnection() {
// Create a v1 auth connection
c := swift.Connection{
// This should be your username
UserName: "user",
// This should be your api key
ApiKey: "key",
// This should be a v1 auth url, eg
// Rackspace US https://auth.api.rackspacecloud.com/v1.0
// Rackspace UK https://lon.auth.api.rackspacecloud.com/v1.0
// Memset Memstore UK https://auth.storage.memset.com/v1.0
AuthUrl: "auth_url",
}
// Authenticate
err := c.Authenticate()
if err != nil {
panic(err)
}
// List all the containers
containers, err := c.ContainerNames(nil)
fmt.Println(containers)
// etc...
// ------ or alternatively create a v2 connection ------
// Create a v2 auth connection
c = swift.Connection{
// This is the sub user for the storage - eg "admin"
UserName: "user",
// This should be your api key
ApiKey: "key",
// This should be a version2 auth url, eg
// Rackspace v2 https://identity.api.rackspacecloud.com/v2.0
// Memset Memstore v2 https://auth.storage.memset.com/v2.0
AuthUrl: "v2_auth_url",
// Region to use - default is use first region if unset
Region: "LON",
// Name of the tenant - this is likely your username
Tenant: "jim",
}
// as above...
}
var container string
func ExampleConnection_ObjectsWalk() {
objects := make([]string, 0)
err := c.ObjectsWalk(container, nil, func(opts *swift.ObjectsOpts) (interface{}, error) {
newObjects, err := c.ObjectNames(container, opts)
if err == nil {
objects = append(objects, newObjects...)
}
return newObjects, err
})
fmt.Println("Found all the objects", objects, err)
}
func ExampleConnection_VersionContainerCreate() {
// Use the helper method to create the current and versions container.
if err := c.VersionContainerCreate("cds", "cd-versions"); err != nil {
fmt.Print(err.Error())
}
}
func ExampleConnection_VersionEnable() {
// Build the containers manually and enable them.
if err := c.ContainerCreate("movie-versions", nil); err != nil {
fmt.Print(err.Error())
}
if err := c.ContainerCreate("movies", nil); err != nil {
fmt.Print(err.Error())
}
if err := c.VersionEnable("movies", "movie-versions"); err != nil {
fmt.Print(err.Error())
}
// Access the primary container as usual with ObjectCreate(), ObjectPut(), etc.
// etc...
}
func ExampleConnection_VersionDisable() {
// Disable versioning on a container. Note that this does not delete the versioning container.
c.VersionDisable("movies")
}
@@ -0,0 +1,174 @@
// Metadata manipulation in and out of Headers
package swift
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
// Metadata stores account, container or object metadata.
type Metadata map[string]string
// Metadata gets the Metadata starting with the metaPrefix out of the Headers.
//
// The keys in the Metadata will be converted to lower case
func (h Headers) Metadata(metaPrefix string) Metadata {
m := Metadata{}
metaPrefix = http.CanonicalHeaderKey(metaPrefix)
for key, value := range h {
if strings.HasPrefix(key, metaPrefix) {
metaKey := strings.ToLower(key[len(metaPrefix):])
m[metaKey] = value
}
}
return m
}
// AccountMetadata converts Headers from account to a Metadata.
//
// The keys in the Metadata will be converted to lower case.
func (h Headers) AccountMetadata() Metadata {
return h.Metadata("X-Account-Meta-")
}
// ContainerMetadata converts Headers from container to a Metadata.
//
// The keys in the Metadata will be converted to lower case.
func (h Headers) ContainerMetadata() Metadata {
return h.Metadata("X-Container-Meta-")
}
// ObjectMetadata converts Headers from object to a Metadata.
//
// The keys in the Metadata will be converted to lower case.
func (h Headers) ObjectMetadata() Metadata {
return h.Metadata("X-Object-Meta-")
}
// Headers convert the Metadata starting with the metaPrefix into a
// Headers.
//
// The keys in the Metadata will be converted from lower case to http
// Canonical (see http.CanonicalHeaderKey).
func (m Metadata) Headers(metaPrefix string) Headers {
h := Headers{}
for key, value := range m {
key = http.CanonicalHeaderKey(metaPrefix + key)
h[key] = value
}
return h
}
// AccountHeaders converts the Metadata for the account.
func (m Metadata) AccountHeaders() Headers {
return m.Headers("X-Account-Meta-")
}
// ContainerHeaders converts the Metadata for the container.
func (m Metadata) ContainerHeaders() Headers {
return m.Headers("X-Container-Meta-")
}
// ObjectHeaders converts the Metadata for the object.
func (m Metadata) ObjectHeaders() Headers {
return m.Headers("X-Object-Meta-")
}
// Turns a number of ns into a floating point string in seconds
//
// Trims trailing zeros and guaranteed to be perfectly accurate
func nsToFloatString(ns int64) string {
if ns < 0 {
return "-" + nsToFloatString(-ns)
}
result := fmt.Sprintf("%010d", ns)
split := len(result) - 9
result, decimals := result[:split], result[split:]
decimals = strings.TrimRight(decimals, "0")
if decimals != "" {
result += "."
result += decimals
}
return result
}
// Turns a floating point string in seconds into a ns integer
//
// Guaranteed to be perfectly accurate
func floatStringToNs(s string) (int64, error) {
const zeros = "000000000"
if point := strings.IndexRune(s, '.'); point >= 0 {
tail := s[point+1:]
if fill := 9 - len(tail); fill < 0 {
tail = tail[:9]
} else {
tail += zeros[:fill]
}
s = s[:point] + tail
} else if len(s) > 0 { // Make sure empty string produces an error
s += zeros
}
return strconv.ParseInt(s, 10, 64)
}
// FloatStringToTime converts a floating point number string to a time.Time
//
// The string is floating point number of seconds since the epoch
// (Unix time). The number should be in fixed point format (not
// exponential), eg "1354040105.123456789" which represents the time
// "2012-11-27T18:15:05.123456789Z"
//
// Some care is taken to preserve all the accuracy in the time.Time
// (which wouldn't happen with a naive conversion through float64) so
// a round trip conversion won't change the data.
//
// If an error is returned then time will be returned as the zero time.
func FloatStringToTime(s string) (t time.Time, err error) {
ns, err := floatStringToNs(s)
if err != nil {
return
}
t = time.Unix(0, ns)
return
}
// TimeToFloatString converts a time.Time object to a floating point string
//
// The string is floating point number of seconds since the epoch
// (Unix time). The number is in fixed point format (not
// exponential), eg "1354040105.123456789" which represents the time
// "2012-11-27T18:15:05.123456789Z". Trailing zeros will be dropped
// from the output.
//
// Some care is taken to preserve all the accuracy in the time.Time
// (which wouldn't happen with a naive conversion through float64) so
// a round trip conversion won't change the data.
func TimeToFloatString(t time.Time) string {
return nsToFloatString(t.UnixNano())
}
// Read a modification time (mtime) from a Metadata object
//
// This is a defacto standard (used in the official python-swiftclient
// amongst others) for storing the modification time (as read using
// os.Stat) for an object. It is stored using the key 'mtime', which
// for example when written to an object will be 'X-Object-Meta-Mtime'.
//
// If an error is returned then time will be returned as the zero time.
func (m Metadata) GetModTime() (t time.Time, err error) {
return FloatStringToTime(m["mtime"])
}
// Write an modification time (mtime) to a Metadata object
//
// This is a defacto standard (used in the official python-swiftclient
// amongst others) for storing the modification time (as read using
// os.Stat) for an object. It is stored using the key 'mtime', which
// for example when written to an object will be 'X-Object-Meta-Mtime'.
func (m Metadata) SetModTime(t time.Time) {
m["mtime"] = TimeToFloatString(t)
}
@@ -0,0 +1,213 @@
// Tests for swift metadata
package swift
import (
"testing"
"time"
)
func TestHeadersToMetadata(t *testing.T) {
}
func TestHeadersToAccountMetadata(t *testing.T) {
}
func TestHeadersToContainerMetadata(t *testing.T) {
}
func TestHeadersToObjectMetadata(t *testing.T) {
}
func TestMetadataToHeaders(t *testing.T) {
}
func TestMetadataToAccountHeaders(t *testing.T) {
}
func TestMetadataToContainerHeaders(t *testing.T) {
}
func TestMetadataToObjectHeaders(t *testing.T) {
}
func TestNsToFloatString(t *testing.T) {
for _, d := range []struct {
ns int64
fs string
}{
{0, "0"},
{1, "0.000000001"},
{1000, "0.000001"},
{1000000, "0.001"},
{100000000, "0.1"},
{1000000000, "1"},
{10000000000, "10"},
{12345678912, "12.345678912"},
{12345678910, "12.34567891"},
{12345678900, "12.3456789"},
{12345678000, "12.345678"},
{12345670000, "12.34567"},
{12345600000, "12.3456"},
{12345000000, "12.345"},
{12340000000, "12.34"},
{12300000000, "12.3"},
{12000000000, "12"},
{10000000000, "10"},
{1347717491123123123, "1347717491.123123123"},
} {
if nsToFloatString(d.ns) != d.fs {
t.Error("Failed", d.ns, "!=", d.fs)
}
if d.ns > 0 && nsToFloatString(-d.ns) != "-"+d.fs {
t.Error("Failed on negative", d.ns, "!=", d.fs)
}
}
}
func TestFloatStringToNs(t *testing.T) {
for _, d := range []struct {
ns int64
fs string
}{
{0, "0"},
{0, "0."},
{0, ".0"},
{0, "0.0"},
{0, "0.0000000001"},
{1, "0.000000001"},
{1000, "0.000001"},
{1000000, "0.001"},
{100000000, "0.1"},
{100000000, "0.10"},
{100000000, "0.1000000001"},
{1000000000, "1"},
{1000000000, "1."},
{1000000000, "1.0"},
{10000000000, "10"},
{12345678912, "12.345678912"},
{12345678912, "12.3456789129"},
{12345678912, "12.34567891299"},
{12345678910, "12.34567891"},
{12345678900, "12.3456789"},
{12345678000, "12.345678"},
{12345670000, "12.34567"},
{12345600000, "12.3456"},
{12345000000, "12.345"},
{12340000000, "12.34"},
{12300000000, "12.3"},
{12000000000, "12"},
{10000000000, "10"},
// This is a typical value which has more bits in than a float64
{1347717491123123123, "1347717491.123123123"},
} {
ns, err := floatStringToNs(d.fs)
if err != nil {
t.Error("Failed conversion", err)
}
if ns != d.ns {
t.Error("Failed", d.fs, "!=", d.ns, "was", ns)
}
if d.ns > 0 {
ns, err := floatStringToNs("-" + d.fs)
if err != nil {
t.Error("Failed conversion", err)
}
if ns != -d.ns {
t.Error("Failed on negative", -d.ns, "!=", "-"+d.fs)
}
}
}
// These are expected to produce errors
for _, fs := range []string{
"",
" 1",
"- 1",
"- 1",
"1.-1",
"1.0.0",
"1x0",
} {
ns, err := floatStringToNs(fs)
if err == nil {
t.Error("Didn't produce expected error", fs, ns)
}
}
}
func TestGetModTime(t *testing.T) {
for _, d := range []struct {
ns string
t string
}{
{"1354040105", "2012-11-27T18:15:05Z"},
{"1354040105.", "2012-11-27T18:15:05Z"},
{"1354040105.0", "2012-11-27T18:15:05Z"},
{"1354040105.000000000000", "2012-11-27T18:15:05Z"},
{"1354040105.123", "2012-11-27T18:15:05.123Z"},
{"1354040105.123456", "2012-11-27T18:15:05.123456Z"},
{"1354040105.123456789", "2012-11-27T18:15:05.123456789Z"},
{"1354040105.123456789123", "2012-11-27T18:15:05.123456789Z"},
{"0", "1970-01-01T00:00:00.000000000Z"},
} {
expected, err := time.Parse(time.RFC3339, d.t)
if err != nil {
t.Error("Bad test", err)
}
m := Metadata{"mtime": d.ns}
actual, err := m.GetModTime()
if err != nil {
t.Error("Parse error", err)
}
if !actual.Equal(expected) {
t.Error("Expecting", expected, expected.UnixNano(), "got", actual, actual.UnixNano())
}
}
for _, ns := range []string{
"EMPTY",
"",
" 1",
"- 1",
"- 1",
"1.-1",
"1.0.0",
"1x0",
} {
m := Metadata{}
if ns != "EMPTY" {
m["mtime"] = ns
}
actual, err := m.GetModTime()
if err == nil {
t.Error("Expected error not produced")
}
if !actual.IsZero() {
t.Error("Expected output to be zero")
}
}
}
func TestSetModTime(t *testing.T) {
for _, d := range []struct {
ns string
t string
}{
{"1354040105", "2012-11-27T18:15:05Z"},
{"1354040105", "2012-11-27T18:15:05.000000Z"},
{"1354040105.123", "2012-11-27T18:15:05.123Z"},
{"1354040105.123456", "2012-11-27T18:15:05.123456Z"},
{"1354040105.123456789", "2012-11-27T18:15:05.123456789Z"},
{"0", "1970-01-01T00:00:00.000000000Z"},
} {
time, err := time.Parse(time.RFC3339, d.t)
if err != nil {
t.Error("Bad test", err)
}
m := Metadata{}
m.SetModTime(time)
if m["mtime"] != d.ns {
t.Error("mtime wrong", m, "should be", d.ns)
}
}
}
@@ -0,0 +1,55 @@
Notes on Go Swift
=================
Make a builder style interface like the Google Go APIs? Advantages
are that it is easy to add named methods to the service object to do
specific things. Slightly less efficient. Not sure about how to
return extra stuff though - in an object?
Make a container struct so these could be methods on it?
Make noResponse check for 204?
Make storage public so it can be extended easily?
Rename to go-swift to match user agent string?
Reconnect on auth error - 401 when token expires isn't tested
Make more api compatible with python cloudfiles?
Retry operations on timeout / network errors?
- also 408 error
- GET requests only?
Make Connection thread safe - whenever it is changed take a write lock whenever it is read from a read lock
Add extra headers field to Connection (for via etc)
Make errors use an error heirachy then can catch them with a type assertion
Error(...)
ObjectCorrupted{ Error }
Make a Debug flag in connection for logging stuff
Object If-Match, If-None-Match, If-Modified-Since, If-Unmodified-Since etc
Object range
Object create, update with X-Delete-At or X-Delete-After
Large object support
- check uploads are less than 5GB in normal mode?
Access control CORS?
Swift client retries and backs off for all types of errors
Implement net error interface?
type Error interface {
error
Timeout() bool // Is the error a timeout?
Temporary() bool // Is the error temporary?
}
@@ -0,0 +1,83 @@
package rs
import (
"errors"
"net/http"
"strconv"
"github.com/ncw/swift"
)
// RsConnection is a RackSpace specific wrapper to the core swift library which
// exposes the RackSpace CDN commands via the CDN Management URL interface.
type RsConnection struct {
swift.Connection
cdnUrl string
}
// manage is similar to the swift storage method, but uses the CDN Management URL for CDN specific calls.
func (c *RsConnection) manage(p swift.RequestOpts) (resp *http.Response, headers swift.Headers, err error) {
p.OnReAuth = func() (string, error) {
if c.cdnUrl == "" {
c.cdnUrl = c.Auth.CdnUrl()
}
if c.cdnUrl == "" {
return "", errors.New("The X-CDN-Management-Url does not exist on the authenticated platform")
}
return c.cdnUrl, nil
}
if c.Authenticated() {
_, err = p.OnReAuth()
if err != nil {
return nil, nil, err
}
}
return c.Connection.Call(c.cdnUrl, p)
}
// ContainerCDNEnable enables a container for public CDN usage.
//
// Change the default TTL of 259200 seconds (72 hours) by passing in an integer value.
//
// This method can be called again to change the TTL.
func (c *RsConnection) ContainerCDNEnable(container string, ttl int) (swift.Headers, error) {
h := swift.Headers{"X-CDN-Enabled": "true"}
if ttl > 0 {
h["X-TTL"] = strconv.Itoa(ttl)
}
_, headers, err := c.manage(swift.RequestOpts{
Container: container,
Operation: "PUT",
ErrorMap: swift.ContainerErrorMap,
NoResponse: true,
Headers: h,
})
return headers, err
}
// ContainerCDNDisable disables CDN access to a container.
func (c *RsConnection) ContainerCDNDisable(container string) error {
h := swift.Headers{"X-CDN-Enabled": "false"}
_, _, err := c.manage(swift.RequestOpts{
Container: container,
Operation: "PUT",
ErrorMap: swift.ContainerErrorMap,
NoResponse: true,
Headers: h,
})
return err
}
// ContainerCDNMeta returns the CDN metadata for a container.
func (c *RsConnection) ContainerCDNMeta(container string) (swift.Headers, error) {
_, headers, err := c.manage(swift.RequestOpts{
Container: container,
Operation: "HEAD",
ErrorMap: swift.ContainerErrorMap,
NoResponse: true,
Headers: swift.Headers{},
})
return headers, err
}
@@ -0,0 +1,96 @@
// See swift_test.go for requirements to run this test.
package rs_test
import (
"os"
"testing"
"github.com/ncw/swift/rs"
)
var (
c rs.RsConnection
)
const (
CONTAINER = "GoSwiftUnitTest"
OBJECT = "test_object"
CONTENTS = "12345"
CONTENT_SIZE = int64(len(CONTENTS))
CONTENT_MD5 = "827ccb0eea8a706c4c34a16891f84e7b"
)
// Test functions are run in order - this one must be first!
func TestAuthenticate(t *testing.T) {
UserName := os.Getenv("SWIFT_API_USER")
ApiKey := os.Getenv("SWIFT_API_KEY")
AuthUrl := os.Getenv("SWIFT_AUTH_URL")
if UserName == "" || ApiKey == "" || AuthUrl == "" {
t.Fatal("SWIFT_API_USER, SWIFT_API_KEY and SWIFT_AUTH_URL not all set")
}
c = rs.RsConnection{}
c.UserName = UserName
c.ApiKey = ApiKey
c.AuthUrl = AuthUrl
err := c.Authenticate()
if err != nil {
t.Fatal("Auth failed", err)
}
if !c.Authenticated() {
t.Fatal("Not authenticated")
}
}
// Setup
func TestContainerCreate(t *testing.T) {
err := c.ContainerCreate(CONTAINER, nil)
if err != nil {
t.Fatal(err)
}
}
func TestCDNEnable(t *testing.T) {
headers, err := c.ContainerCDNEnable(CONTAINER, 0)
if err != nil {
t.Error(err)
}
if _, ok := headers["X-Cdn-Uri"]; !ok {
t.Error("Failed to enable CDN for container")
}
}
func TestOnReAuth(t *testing.T) {
c2 := rs.RsConnection{}
c2.UserName = c.UserName
c2.ApiKey = c.ApiKey
c2.AuthUrl = c.AuthUrl
_, err := c2.ContainerCDNEnable(CONTAINER, 0)
if err != nil {
t.Fatalf("Failed to reauthenticate: %v", err)
}
}
func TestCDNMeta(t *testing.T) {
headers, err := c.ContainerCDNMeta(CONTAINER)
if err != nil {
t.Error(err)
}
if _, ok := headers["X-Cdn-Uri"]; !ok {
t.Error("CDN is not enabled")
}
}
func TestCDNDisable(t *testing.T) {
err := c.ContainerCDNDisable(CONTAINER) // files stick in CDN until TTL expires
if err != nil {
t.Error(err)
}
}
// Teardown
func TestContainerDelete(t *testing.T) {
err := c.ContainerDelete(CONTAINER)
if err != nil {
t.Fatal(err)
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,409 @@
// This tests the swift package internals
//
// It does not require access to a swift server
//
// FIXME need to add more tests and to check URLs and parameters
package swift
import (
"fmt"
"io"
"net"
"net/http"
"testing"
// "net/http/httputil"
// "os"
)
const (
TEST_ADDRESS = "localhost:5324"
AUTH_URL = "http://" + TEST_ADDRESS + "/v1.0"
PROXY_URL = "http://" + TEST_ADDRESS + "/proxy"
USERNAME = "test"
APIKEY = "apikey"
AUTH_TOKEN = "token"
)
// Globals
var (
server *SwiftServer
c *Connection
)
// SwiftServer implements a test swift server
type SwiftServer struct {
t *testing.T
checks []*Check
}
// Used to check and reply to http transactions
type Check struct {
in Headers
out Headers
rx *string
tx *string
err *Error
url *string
}
// Add a in check
func (check *Check) In(in Headers) *Check {
check.in = in
return check
}
// Add an out check
func (check *Check) Out(out Headers) *Check {
check.out = out
return check
}
// Add an Error check
func (check *Check) Error(StatusCode int, Text string) *Check {
check.err = newError(StatusCode, Text)
return check
}
// Add a rx check
func (check *Check) Rx(rx string) *Check {
check.rx = &rx
return check
}
// Add an tx check
func (check *Check) Tx(tx string) *Check {
check.tx = &tx
return check
}
// Add an URL check
func (check *Check) Url(url string) *Check {
check.url = &url
return check
}
// Add a check
func (s *SwiftServer) AddCheck(t *testing.T) *Check {
server.t = t
check := &Check{
in: Headers{},
out: Headers{},
err: nil,
}
s.checks = append(s.checks, check)
return check
}
// Responds to a request
func (s *SwiftServer) Respond(w http.ResponseWriter, r *http.Request) {
if len(s.checks) < 1 {
s.t.Fatal("Unexpected http transaction")
}
check := s.checks[0]
s.checks = s.checks[1:]
// Check URL
if check.url != nil && *check.url != r.URL.String() {
s.t.Errorf("Expecting URL %q but got %q", *check.url, r.URL)
}
// Check headers
for k, v := range check.in {
actual := r.Header.Get(k)
if actual != v {
s.t.Errorf("Expecting header %q=%q but got %q", k, v, actual)
}
}
// Write output headers
h := w.Header()
for k, v := range check.out {
h.Set(k, v)
}
// Return an error if required
if check.err != nil {
http.Error(w, check.err.Text, check.err.StatusCode)
} else {
if check.tx != nil {
_, err := w.Write([]byte(*check.tx))
if err != nil {
s.t.Error("Write failed", err)
}
}
}
}
// Checks to see all responses are used up
func (s *SwiftServer) Finished() {
if len(s.checks) > 0 {
s.t.Error("Unused checks", s.checks)
}
}
func handle(w http.ResponseWriter, r *http.Request) {
// out, _ := httputil.DumpRequest(r, true)
// os.Stdout.Write(out)
server.Respond(w, r)
}
func NewSwiftServer() *SwiftServer {
server := &SwiftServer{}
http.HandleFunc("/", handle)
go http.ListenAndServe(TEST_ADDRESS, nil)
fmt.Print("Waiting for server to start ")
for {
fmt.Print(".")
conn, err := net.Dial("tcp", TEST_ADDRESS)
if err == nil {
conn.Close()
fmt.Println(" Started")
break
}
}
return server
}
func init() {
server = NewSwiftServer()
c = &Connection{
UserName: USERNAME,
ApiKey: APIKEY,
AuthUrl: AUTH_URL,
}
}
// Check the error is a swift error
func checkError(t *testing.T, err error, StatusCode int, Text string) {
if err == nil {
t.Fatal("No error returned")
}
err2, ok := err.(*Error)
if !ok {
t.Fatal("Bad error type")
}
if err2.StatusCode != StatusCode {
t.Fatalf("Bad status code, expecting %d got %d", StatusCode, err2.StatusCode)
}
if err2.Text != Text {
t.Fatalf("Bad error string, expecting %q got %q", Text, err2.Text)
}
}
// FIXME copied from swift_test.go
func compareMaps(t *testing.T, a, b map[string]string) {
if len(a) != len(b) {
t.Error("Maps different sizes", a, b)
}
for ka, va := range a {
if vb, ok := b[ka]; !ok || va != vb {
t.Error("Difference in key", ka, va, b[ka])
}
}
for kb, vb := range b {
if va, ok := a[kb]; !ok || vb != va {
t.Error("Difference in key", kb, vb, a[kb])
}
}
}
func TestInternalError(t *testing.T) {
e := newError(404, "Not Found!")
if e.StatusCode != 404 || e.Text != "Not Found!" {
t.Fatal("Bad error")
}
if e.Error() != "Not Found!" {
t.Fatal("Bad error")
}
}
func testCheckClose(c io.Closer, e error) (err error) {
err = e
defer checkClose(c, &err)
return
}
// Make a closer which returns the error of our choice
type myCloser struct {
err error
}
func (c *myCloser) Close() error {
return c.err
}
func TestInternalCheckClose(t *testing.T) {
if testCheckClose(&myCloser{nil}, nil) != nil {
t.Fatal("bad 1")
}
if testCheckClose(&myCloser{nil}, ObjectCorrupted) != ObjectCorrupted {
t.Fatal("bad 2")
}
if testCheckClose(&myCloser{ObjectNotFound}, nil) != ObjectNotFound {
t.Fatal("bad 3")
}
if testCheckClose(&myCloser{ObjectNotFound}, ObjectCorrupted) != ObjectCorrupted {
t.Fatal("bad 4")
}
}
func TestInternalParseHeaders(t *testing.T) {
resp := &http.Response{StatusCode: 200}
if c.parseHeaders(resp, nil) != nil {
t.Error("Bad 1")
}
if c.parseHeaders(resp, authErrorMap) != nil {
t.Error("Bad 1")
}
resp = &http.Response{StatusCode: 299}
if c.parseHeaders(resp, nil) != nil {
t.Error("Bad 1")
}
resp = &http.Response{StatusCode: 199, Status: "BOOM"}
checkError(t, c.parseHeaders(resp, nil), 199, "HTTP Error: 199: BOOM")
resp = &http.Response{StatusCode: 300, Status: "BOOM"}
checkError(t, c.parseHeaders(resp, nil), 300, "HTTP Error: 300: BOOM")
resp = &http.Response{StatusCode: 404, Status: "BOOM"}
checkError(t, c.parseHeaders(resp, nil), 404, "HTTP Error: 404: BOOM")
if c.parseHeaders(resp, ContainerErrorMap) != ContainerNotFound {
t.Error("Bad 1")
}
if c.parseHeaders(resp, objectErrorMap) != ObjectNotFound {
t.Error("Bad 1")
}
}
func TestInternalReadHeaders(t *testing.T) {
resp := &http.Response{Header: http.Header{}}
compareMaps(t, readHeaders(resp), Headers{})
resp = &http.Response{Header: http.Header{
"one": []string{"1"},
"two": []string{"2"},
}}
compareMaps(t, readHeaders(resp), Headers{"one": "1", "two": "2"})
// FIXME this outputs a log which we should test and check
resp = &http.Response{Header: http.Header{
"one": []string{"1", "11", "111"},
"two": []string{"2"},
}}
compareMaps(t, readHeaders(resp), Headers{"one": "1", "two": "2"})
}
func TestInternalStorage(t *testing.T) {
// FIXME
}
// ------------------------------------------------------------
func TestInternalAuthenticate(t *testing.T) {
server.AddCheck(t).In(Headers{
"User-Agent": DefaultUserAgent,
"X-Auth-Key": APIKEY,
"X-Auth-User": USERNAME,
}).Out(Headers{
"X-Storage-Url": PROXY_URL,
"X-Auth-Token": AUTH_TOKEN,
}).Url("/v1.0")
defer server.Finished()
err := c.Authenticate()
if err != nil {
t.Fatal(err)
}
if c.StorageUrl != PROXY_URL {
t.Error("Bad storage url")
}
if c.AuthToken != AUTH_TOKEN {
t.Error("Bad auth token")
}
if !c.Authenticated() {
t.Error("Didn't authenticate")
}
}
func TestInternalAuthenticateDenied(t *testing.T) {
server.AddCheck(t).Error(400, "Bad request")
server.AddCheck(t).Error(401, "DENIED")
defer server.Finished()
c.UnAuthenticate()
err := c.Authenticate()
if err != AuthorizationFailed {
t.Fatal("Expecting AuthorizationFailed", err)
}
// FIXME
// if c.Authenticated() {
// t.Fatal("Expecting not authenticated")
// }
}
func TestInternalAuthenticateBad(t *testing.T) {
server.AddCheck(t).Out(Headers{
"X-Storage-Url": PROXY_URL,
})
defer server.Finished()
err := c.Authenticate()
checkError(t, err, 0, "Response didn't have storage url and auth token")
if c.Authenticated() {
t.Fatal("Expecting not authenticated")
}
server.AddCheck(t).Out(Headers{
"X-Auth-Token": AUTH_TOKEN,
})
err = c.Authenticate()
checkError(t, err, 0, "Response didn't have storage url and auth token")
if c.Authenticated() {
t.Fatal("Expecting not authenticated")
}
server.AddCheck(t)
err = c.Authenticate()
checkError(t, err, 0, "Response didn't have storage url and auth token")
if c.Authenticated() {
t.Fatal("Expecting not authenticated")
}
server.AddCheck(t).Out(Headers{
"X-Storage-Url": PROXY_URL,
"X-Auth-Token": AUTH_TOKEN,
})
err = c.Authenticate()
if err != nil {
t.Fatal(err)
}
if !c.Authenticated() {
t.Fatal("Expecting authenticated")
}
}
func testContainerNames(t *testing.T, rx string, expected []string) {
server.AddCheck(t).In(Headers{
"User-Agent": DefaultUserAgent,
"X-Auth-Token": AUTH_TOKEN,
}).Tx(rx).Url("/proxy")
containers, err := c.ContainerNames(nil)
if err != nil {
t.Fatal(err)
}
if len(containers) != len(expected) {
t.Fatal("Wrong number of containers", len(containers), rx, len(expected), expected)
}
for i := range containers {
if containers[i] != expected[i] {
t.Error("Bad container", containers[i], expected[i])
}
}
}
func TestInternalContainerNames(t *testing.T) {
defer server.Finished()
testContainerNames(t, "", []string{})
testContainerNames(t, "one", []string{"one"})
testContainerNames(t, "one\n", []string{"one"})
testContainerNames(t, "one\ntwo\nthree\n", []string{"one", "two", "three"})
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,806 @@
// This implements a very basic Swift server
// Everything is stored in memory
//
// This comes from the https://github.com/mitchellh/goamz
// and was adapted for Swift
//
package swifttest
import (
"bytes"
"crypto/md5"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"mime"
"net"
"net/http"
"net/url"
"path"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/ncw/swift"
)
const (
DEBUG = false
)
type SwiftServer struct {
t *testing.T
listener net.Listener
url string
reqId int
mu sync.Mutex
containers map[string]*container
accounts map[string]*account
sessions map[string]*session
}
// The Folder type represents a container stored in an account
type Folder struct {
Count int `json:"count"`
Bytes int `json:"bytes"`
Name string `json:"name"`
}
// The Key type represents an item stored in an container.
type Key struct {
Key string `json:"name"`
LastModified string `json:"last_modified"`
Size int64 `json:"bytes"`
// ETag gives the hex-encoded MD5 sum of the contents,
// surrounded with double-quotes.
ETag string `json:"hash"`
ContentType string `json:"content_type"`
// Owner Owner
}
type Subdir struct {
Subdir string `json:"subdir"`
}
type swiftError struct {
statusCode int
Code string
Message string
}
type action struct {
srv *SwiftServer
w http.ResponseWriter
req *http.Request
reqId string
user *account
}
type session struct {
username string
}
type metadata struct {
meta http.Header // metadata to return with requests.
}
type account struct {
swift.Account
metadata
password string
}
type object struct {
metadata
name string
mtime time.Time
checksum []byte // also held as ETag in meta.
data []byte
content_type string
}
type container struct {
metadata
name string
ctime time.Time
objects map[string]*object
bytes int
}
// A resource encapsulates the subject of an HTTP request.
// The resource referred to may or may not exist
// when the request is made.
type resource interface {
put(a *action) interface{}
get(a *action) interface{}
post(a *action) interface{}
delete(a *action) interface{}
copy(a *action) interface{}
}
type objectResource struct {
name string
version string
container *container // always non-nil.
object *object // may be nil.
}
type containerResource struct {
name string
container *container // non-nil if the container already exists.
}
var responseParams = map[string]bool{
"content-type": true,
"content-language": true,
"expires": true,
"cache-control": true,
"content-disposition": true,
"content-encoding": true,
}
func fatalf(code int, codeStr string, errf string, a ...interface{}) {
panic(&swiftError{
statusCode: code,
Code: codeStr,
Message: fmt.Sprintf(errf, a...),
})
}
func (m metadata) setMetadata(a *action, resource string) {
for key, values := range a.req.Header {
key = http.CanonicalHeaderKey(key)
if metaHeaders[key] || strings.HasPrefix(key, "X-"+strings.Title(resource)+"-Meta-") {
if values[0] != "" || resource == "object" {
m.meta[key] = values
} else {
m.meta.Del(key)
}
}
}
}
func (m metadata) getMetadata(a *action) {
h := a.w.Header()
for name, d := range m.meta {
h[name] = d
}
}
// GET on a bucket lists the objects in the bucket.
func (r containerResource) get(a *action) interface{} {
if r.container == nil {
fatalf(404, "NoSuchContainer", "The specified container does not exist")
}
delimiter := a.req.Form.Get("delimiter")
marker := a.req.Form.Get("marker")
prefix := a.req.Form.Get("prefix")
format := a.req.URL.Query().Get("format")
parent := a.req.Form.Get("path")
a.w.Header().Set("X-Container-Bytes-Used", strconv.Itoa(r.container.bytes))
a.w.Header().Set("X-Container-Object-Count", strconv.Itoa(len(r.container.objects)))
r.container.getMetadata(a)
if a.req.Method == "HEAD" {
return nil
}
var tmp orderedObjects
var resp []interface{} = make([]interface{}, 0)
// first get all matching objects and arrange them in alphabetical order.
for _, obj := range r.container.objects {
if strings.HasPrefix(obj.name, prefix) {
tmp = append(tmp, obj)
}
}
sort.Sort(tmp)
var prefixes []string
for _, obj := range tmp {
if !strings.HasPrefix(obj.name, prefix) {
continue
}
isPrefix := false
name := obj.name
if parent != "" {
if path.Dir(obj.name) != path.Clean(parent) {
continue
}
} else if delimiter != "" {
if i := strings.Index(obj.name[len(prefix):], delimiter); i >= 0 {
name = obj.name[:len(prefix)+i+len(delimiter)]
if prefixes != nil && prefixes[len(prefixes)-1] == name {
continue
}
isPrefix = true
}
}
if name <= marker {
continue
}
if isPrefix {
prefixes = append(prefixes, name)
resp = append(resp, Subdir{
Subdir: name,
})
} else {
if format == "json" {
resp = append(resp, obj.Key())
} else {
a.w.Write([]byte(obj.name + "\n"))
}
}
}
if format == "json" {
return resp
} else {
return nil
}
}
// orderedContainers holds a slice of containers that can be sorted
// by name.
type orderedContainers []*container
func (s orderedContainers) Len() int {
return len(s)
}
func (s orderedContainers) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (s orderedContainers) Less(i, j int) bool {
return s[i].name < s[j].name
}
func (r containerResource) delete(a *action) interface{} {
b := r.container
if b == nil {
fatalf(404, "NoSuchContainer", "The specified container does not exist")
}
if len(b.objects) > 0 {
fatalf(409, "Conflict", "The container you tried to delete is not empty")
}
delete(a.srv.containers, b.name)
a.user.Containers--
return nil
}
func (r containerResource) put(a *action) interface{} {
if a.req.URL.Query().Get("extract-archive") != "" {
fatalf(403, "Operation forbidden", "Bulk upload is not supported")
}
var created bool
if r.container == nil {
if !validContainerName(r.name) {
fatalf(400, "InvalidContainerName", "The specified container is not valid")
}
r.container = &container{
name: r.name,
objects: make(map[string]*object),
metadata: metadata{
meta: make(http.Header),
},
}
r.container.setMetadata(a, "container")
a.srv.containers[r.name] = r.container
a.user.Containers++
created = true
}
if !created {
fatalf(409, "ContainerAlreadyOwnedByYou", "Your previous request to create the named container succeeded and you already own it.")
}
return nil
}
func (r containerResource) post(a *action) interface{} {
if r.container == nil {
fatalf(400, "Method", "The resource could not be found.")
} else {
r.container.setMetadata(a, "container")
a.w.WriteHeader(201)
jsonMarshal(a.w, Folder{
Count: len(r.container.objects),
Bytes: r.container.bytes,
Name: r.container.name,
})
}
return nil
}
func (containerResource) copy(a *action) interface{} { return notAllowed() }
// validContainerName returns whether name is a valid bucket name.
// Here are the rules, from:
// http://docs.openstack.org/api/openstack-object-storage/1.0/content/ch_object-storage-dev-api-storage.html
//
// Container names cannot exceed 256 bytes and cannot contain the / character.
//
func validContainerName(name string) bool {
if len(name) == 0 || len(name) > 256 {
return false
}
for _, r := range name {
switch {
case r == '/':
return false
default:
}
}
return true
}
// orderedObjects holds a slice of objects that can be sorted
// by name.
type orderedObjects []*object
func (s orderedObjects) Len() int {
return len(s)
}
func (s orderedObjects) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (s orderedObjects) Less(i, j int) bool {
return s[i].name < s[j].name
}
func (obj *object) Key() Key {
return Key{
Key: obj.name,
LastModified: obj.mtime.Format("2006-01-02T15:04:05"),
Size: int64(len(obj.data)),
ETag: fmt.Sprintf("%x", obj.checksum),
ContentType: obj.content_type,
}
}
var metaHeaders = map[string]bool{
"Content-Type": true,
"Content-Encoding": true,
"Content-Disposition": true,
}
var rangeRegexp = regexp.MustCompile("([0-9])*-([0-9])*")
// GET on an object gets the contents of the object.
func (objr objectResource) get(a *action) interface{} {
obj := objr.object
if obj == nil {
fatalf(404, "Not Found", "The resource could not be found.")
}
h := a.w.Header()
// add metadata
obj.getMetadata(a)
var (
start int
end int = len(obj.data)
)
if r := a.req.Header.Get("Range"); r != "" {
m := rangeRegexp.FindStringSubmatch(r)
if m[1] != "" {
start, _ = strconv.Atoi(m[1])
}
if m[2] != "" {
end, _ = strconv.Atoi(m[2])
}
}
h.Set("Content-Length", fmt.Sprint(len(obj.data)))
h.Set("ETag", hex.EncodeToString(obj.checksum))
h.Set("Last-Modified", obj.mtime.Format(http.TimeFormat))
if a.req.Method == "HEAD" {
return nil
}
// TODO avoid holding the lock when writing data.
_, err := a.w.Write(obj.data[start:end])
if err != nil {
// we can't do much except just log the fact.
log.Printf("error writing data: %v", err)
}
return nil
}
// PUT on an object creates the object.
func (objr objectResource) put(a *action) interface{} {
var expectHash []byte
if c := a.req.Header.Get("ETag"); c != "" {
var err error
expectHash, err = hex.DecodeString(c)
if err != nil || len(expectHash) != md5.Size {
fatalf(400, "InvalidDigest", "The ETag you specified was invalid")
}
}
sum := md5.New()
// TODO avoid holding lock while reading data.
data, err := ioutil.ReadAll(io.TeeReader(a.req.Body, sum))
if err != nil {
fatalf(400, "TODO", "read error")
}
gotHash := sum.Sum(nil)
if expectHash != nil && bytes.Compare(gotHash, expectHash) != 0 {
fatalf(422, "Bad ETag", "The ETag you specified did not match what we received")
}
if a.req.ContentLength >= 0 && int64(len(data)) != a.req.ContentLength {
fatalf(400, "IncompleteBody", "You did not provide the number of bytes specified by the Content-Length HTTP header")
}
// TODO is this correct, or should we erase all previous metadata?
obj := objr.object
if obj == nil {
obj = &object{
name: objr.name,
metadata: metadata{
meta: make(http.Header),
},
}
a.user.Objects++
} else {
objr.container.bytes -= len(obj.data)
a.user.BytesUsed -= int64(len(obj.data))
}
var content_type string
if content_type = a.req.Header.Get("Content-Type"); content_type == "" {
content_type := mime.TypeByExtension(obj.name)
if content_type == "" {
content_type = "application/octet-stream"
}
}
// PUT request has been successful - save data and metadata
obj.setMetadata(a, "object")
obj.content_type = content_type
obj.data = data
obj.checksum = gotHash
obj.mtime = time.Now().UTC()
objr.container.objects[objr.name] = obj
objr.container.bytes += len(data)
a.user.BytesUsed += int64(len(data))
h := a.w.Header()
h.Set("ETag", hex.EncodeToString(obj.checksum))
return nil
}
func (objr objectResource) delete(a *action) interface{} {
if objr.object == nil {
fatalf(404, "NoSuchKey", "The specified key does not exist.")
}
objr.container.bytes -= len(objr.object.data)
a.user.BytesUsed -= int64(len(objr.object.data))
delete(objr.container.objects, objr.name)
a.user.Objects--
return nil
}
func (objr objectResource) post(a *action) interface{} {
obj := objr.object
obj.setMetadata(a, "object")
return nil
}
func (objr objectResource) copy(a *action) interface{} {
obj := objr.object
destination := a.req.Header.Get("Destination")
if destination == "" {
fatalf(400, "Bad Request", "You must provide a Destination header")
}
var (
obj2 *object
objr2 objectResource
)
destURL, _ := url.Parse("/v1/AUTH_tk/" + destination)
r := a.srv.resourceForURL(destURL)
switch t := r.(type) {
case objectResource:
objr2 = t
if objr2.object == nil {
obj2 = &object{
name: objr2.name,
metadata: metadata{
meta: make(http.Header),
},
}
a.user.Objects++
} else {
obj2 = objr2.object
objr2.container.bytes -= len(obj2.data)
a.user.BytesUsed -= int64(len(obj2.data))
}
default:
fatalf(400, "Bad Request", "Destination must point to a valid object path")
}
obj2.content_type = obj.content_type
obj2.data = obj.data
obj2.checksum = obj.checksum
obj2.mtime = time.Now()
objr2.container.objects[objr2.name] = obj2
objr2.container.bytes += len(obj.data)
a.user.BytesUsed += int64(len(obj.data))
for key, values := range obj.metadata.meta {
obj2.metadata.meta[key] = values
}
obj2.setMetadata(a, "object")
return nil
}
func (s *SwiftServer) serveHTTP(w http.ResponseWriter, req *http.Request) {
// ignore error from ParseForm as it's usually spurious.
req.ParseForm()
s.mu.Lock()
defer s.mu.Unlock()
if DEBUG {
log.Printf("swifttest %q %q", req.Method, req.URL)
}
a := &action{
srv: s,
w: w,
req: req,
reqId: fmt.Sprintf("%09X", s.reqId),
}
s.reqId++
var r resource
defer func() {
switch err := recover().(type) {
case *swiftError:
w.Header().Set("Content-Type", `text/plain; charset=utf-8`)
http.Error(w, err.Message, err.statusCode)
case nil:
default:
panic(err)
}
}()
var resp interface{}
if req.URL.String() == "/v1.0" {
username := req.Header.Get("x-auth-user")
key := req.Header.Get("x-auth-key")
if acct, ok := s.accounts[username]; ok {
if acct.password == key {
r := make([]byte, 16)
_, _ = rand.Read(r)
id := fmt.Sprintf("%X", r)
w.Header().Set("X-Storage-Url", s.url+"/v1/AUTH_"+username)
w.Header().Set("X-Auth-Token", "AUTH_tk"+string(id))
w.Header().Set("X-Storage-Token", "AUTH_tk"+string(id))
s.sessions[id] = &session{
username: username,
}
return
}
}
panic(notAuthorized())
}
key := req.Header.Get("x-auth-token")
session, ok := s.sessions[key[7:]]
if !ok {
panic(notAuthorized())
}
a.user = s.accounts[session.username]
r = s.resourceForURL(req.URL)
switch req.Method {
case "PUT":
resp = r.put(a)
case "GET", "HEAD":
resp = r.get(a)
case "DELETE":
resp = r.delete(a)
case "POST":
resp = r.post(a)
case "COPY":
resp = r.copy(a)
default:
fatalf(400, "MethodNotAllowed", "unknown http request method %q", req.Method)
}
content_type := req.Header.Get("Content-Type")
if resp != nil && req.Method != "HEAD" {
if strings.HasPrefix(content_type, "application/json") ||
req.URL.Query().Get("format") == "json" {
jsonMarshal(w, resp)
} else {
switch r := resp.(type) {
case string:
w.Write([]byte(r))
default:
w.Write(resp.([]byte))
}
}
}
}
func jsonMarshal(w io.Writer, x interface{}) {
if err := json.NewEncoder(w).Encode(x); err != nil {
panic(fmt.Errorf("error marshalling %#v: %v", x, err))
}
}
var pathRegexp = regexp.MustCompile("/v1/AUTH_[a-zA-Z0-9]+(/([^/]+)(/(.*))?)?")
// resourceForURL returns a resource object for the given URL.
func (srv *SwiftServer) resourceForURL(u *url.URL) (r resource) {
m := pathRegexp.FindStringSubmatch(u.Path)
if m == nil {
fatalf(404, "InvalidURI", "Couldn't parse the specified URI")
}
containerName := m[2]
objectName := m[4]
if containerName == "" {
return rootResource{}
}
b := containerResource{
name: containerName,
container: srv.containers[containerName],
}
if objectName == "" {
return b
}
if b.container == nil {
fatalf(404, "NoSuchContainer", "The specified container does not exist")
}
objr := objectResource{
name: objectName,
version: u.Query().Get("versionId"),
container: b.container,
}
if obj := objr.container.objects[objr.name]; obj != nil {
objr.object = obj
}
return objr
}
// nullResource has error stubs for all resource methods.
type nullResource struct{}
func notAllowed() interface{} {
fatalf(400, "MethodNotAllowed", "The specified method is not allowed against this resource")
return nil
}
func notAuthorized() interface{} {
fatalf(401, "Unauthorized", "This server could not verify that you are authorized to access the document you requested.")
return nil
}
func (nullResource) put(a *action) interface{} { return notAllowed() }
func (nullResource) get(a *action) interface{} { return notAllowed() }
func (nullResource) post(a *action) interface{} { return notAllowed() }
func (nullResource) delete(a *action) interface{} { return notAllowed() }
func (nullResource) copy(a *action) interface{} { return notAllowed() }
type rootResource struct{}
func (rootResource) put(a *action) interface{} { return notAllowed() }
func (rootResource) get(a *action) interface{} {
marker := a.req.Form.Get("marker")
prefix := a.req.Form.Get("prefix")
format := a.req.URL.Query().Get("format")
h := a.w.Header()
h.Set("X-Account-Bytes-Used", strconv.Itoa(int(a.user.BytesUsed)))
h.Set("X-Account-Container-Count", strconv.Itoa(int(a.user.Containers)))
h.Set("X-Account-Object-Count", strconv.Itoa(int(a.user.Objects)))
// add metadata
a.user.metadata.getMetadata(a)
if a.req.Method == "HEAD" {
return nil
}
var tmp orderedContainers
// first get all matching objects and arrange them in alphabetical order.
for _, container := range a.srv.containers {
if strings.HasPrefix(container.name, prefix) {
tmp = append(tmp, container)
}
}
sort.Sort(tmp)
resp := make([]Folder, 0)
for _, container := range tmp {
if container.name <= marker {
continue
}
if format == "json" {
resp = append(resp, Folder{
Count: len(container.objects),
Bytes: container.bytes,
Name: container.name,
})
} else {
a.w.Write([]byte(container.name + "\n"))
}
}
if format == "json" {
return resp
} else {
return nil
}
}
func (r rootResource) post(a *action) interface{} {
a.user.metadata.setMetadata(a, "account")
return nil
}
func (rootResource) delete(a *action) interface{} {
if a.req.URL.Query().Get("bulk-delete") == "1" {
fatalf(403, "Operation forbidden", "Bulk delete is not supported")
}
return notAllowed()
}
func (rootResource) copy(a *action) interface{} { return notAllowed() }
func NewSwiftServer(address string) (*SwiftServer, error) {
l, err := net.Listen("tcp", address)
if err != nil {
return nil, fmt.Errorf("cannot listen on %s: %v", address, err)
}
server := &SwiftServer{
listener: l,
url: "http://" + l.Addr().String(),
containers: make(map[string]*container),
accounts: make(map[string]*account),
sessions: make(map[string]*session),
}
server.accounts["swifttest"] = &account{
password: "swifttest",
metadata: metadata{
meta: make(http.Header),
},
}
go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
server.serveHTTP(w, req)
}))
return server, nil
}
func (srv SwiftServer) Close() {
srv.listener.Close()
}
@@ -0,0 +1,57 @@
package swift
import (
"io"
"time"
)
// An io.ReadCloser which obeys an idle timeout
type timeoutReader struct {
reader io.ReadCloser
timeout time.Duration
cancel func()
}
// Returns a wrapper around the reader which obeys an idle
// timeout. The cancel function is called if the timeout happens
func newTimeoutReader(reader io.ReadCloser, timeout time.Duration, cancel func()) *timeoutReader {
return &timeoutReader{
reader: reader,
timeout: timeout,
cancel: cancel,
}
}
// Read reads up to len(p) bytes into p
//
// Waits at most for timeout for the read to complete otherwise returns a timeout
func (t *timeoutReader) Read(p []byte) (int, error) {
// FIXME limit the amount of data read in one chunk so as to not exceed the timeout?
// Do the read in the background
type result struct {
n int
err error
}
done := make(chan result, 1)
go func() {
n, err := t.reader.Read(p)
done <- result{n, err}
}()
// Wait for the read or the timeout
select {
case r := <-done:
return r.n, r.err
case <-time.After(t.timeout):
t.cancel()
return 0, TimeoutError
}
panic("unreachable") // for Go 1.0
}
// Close the channel
func (t *timeoutReader) Close() error {
return t.reader.Close()
}
// Check it satisfies the interface
var _ io.ReadCloser = &timeoutReader{}
@@ -0,0 +1,107 @@
// This tests TimeoutReader
package swift
import (
"io"
"io/ioutil"
"sync"
"testing"
"time"
)
// An io.ReadCloser for testing
type testReader struct {
sync.Mutex
n int
delay time.Duration
closed bool
}
// Returns n bytes with at time.Duration delay
func newTestReader(n int, delay time.Duration) *testReader {
return &testReader{
n: n,
delay: delay,
}
}
// Returns 1 byte at a time after delay
func (t *testReader) Read(p []byte) (n int, err error) {
if t.n <= 0 {
return 0, io.EOF
}
time.Sleep(t.delay)
p[0] = 'A'
t.Lock()
t.n--
t.Unlock()
return 1, nil
}
// Close the channel
func (t *testReader) Close() error {
t.Lock()
t.closed = true
t.Unlock()
return nil
}
func TestTimeoutReaderNoTimeout(t *testing.T) {
test := newTestReader(3, 10*time.Millisecond)
cancelled := false
cancel := func() {
cancelled = true
}
tr := newTimeoutReader(test, 100*time.Millisecond, cancel)
b, err := ioutil.ReadAll(tr)
if err != nil || string(b) != "AAA" {
t.Fatalf("Bad read %s %s", err, b)
}
if cancelled {
t.Fatal("Cancelled when shouldn't have been")
}
if test.n != 0 {
t.Fatal("Didn't read all")
}
if test.closed {
t.Fatal("Shouldn't be closed")
}
tr.Close()
if !test.closed {
t.Fatal("Should be closed")
}
}
func TestTimeoutReaderTimeout(t *testing.T) {
// Return those bytes slowly so we get an idle timeout
test := newTestReader(3, 100*time.Millisecond)
cancelled := false
cancel := func() {
cancelled = true
}
tr := newTimeoutReader(test, 10*time.Millisecond, cancel)
_, err := ioutil.ReadAll(tr)
if err != TimeoutError {
t.Fatal("Expecting TimeoutError, got", err)
}
if !cancelled {
t.Fatal("Not cancelled when should have been")
}
test.Lock()
n := test.n
test.Unlock()
if n == 0 {
t.Fatal("Read all")
}
if n != 3 {
t.Fatal("Didn't read any")
}
if test.closed {
t.Fatal("Shouldn't be closed")
}
tr.Close()
if !test.closed {
t.Fatal("Should be closed")
}
}
@@ -0,0 +1,34 @@
package swift
import (
"io"
"time"
)
// An io.Reader which resets a watchdog timer whenever data is read
type watchdogReader struct {
timeout time.Duration
reader io.Reader
timer *time.Timer
}
// Returns a new reader which will kick the watchdog timer whenever data is read
func newWatchdogReader(reader io.Reader, timeout time.Duration, timer *time.Timer) *watchdogReader {
return &watchdogReader{
timeout: timeout,
reader: reader,
timer: timer,
}
}
// Read reads up to len(p) bytes into p
func (t *watchdogReader) Read(p []byte) (n int, err error) {
// FIXME limit the amount of data read in one chunk so as to not exceed the timeout?
resetTimer(t.timer, t.timeout)
n, err = t.reader.Read(p)
resetTimer(t.timer, t.timeout)
return
}
// Check it satisfies the interface
var _ io.Reader = &watchdogReader{}
@@ -0,0 +1,61 @@
// This tests WatchdogReader
package swift
import (
"io/ioutil"
"testing"
"time"
)
// Uses testReader from timeout_reader_test.go
func testWatchdogReaderTimeout(t *testing.T, initialTimeout, watchdogTimeout time.Duration, expectedTimeout bool) {
test := newTestReader(3, 10*time.Millisecond)
timer := time.NewTimer(initialTimeout)
firedChan := make(chan bool)
started := make(chan bool)
go func() {
started <- true
select {
case <-timer.C:
firedChan <- true
}
}()
<-started
wr := newWatchdogReader(test, watchdogTimeout, timer)
b, err := ioutil.ReadAll(wr)
if err != nil || string(b) != "AAA" {
t.Fatalf("Bad read %s %s", err, b)
}
fired := false
select {
case fired = <-firedChan:
default:
}
if expectedTimeout {
if !fired {
t.Fatal("Timer should have fired")
}
} else {
if fired {
t.Fatal("Timer should not have fired")
}
}
}
func TestWatchdogReaderNoTimeout(t *testing.T) {
testWatchdogReaderTimeout(t, 100*time.Millisecond, 100*time.Millisecond, false)
}
func TestWatchdogReaderTimeout(t *testing.T) {
testWatchdogReaderTimeout(t, 5*time.Millisecond, 5*time.Millisecond, true)
}
func TestWatchdogReaderNoTimeoutShortInitial(t *testing.T) {
testWatchdogReaderTimeout(t, 5*time.Millisecond, 100*time.Millisecond, false)
}
func TestWatchdogReaderTimeoutLongInitial(t *testing.T) {
testWatchdogReaderTimeout(t, 100*time.Millisecond, 5*time.Millisecond, true)
}

Some files were not shown because too many files have changed in this diff Show More