diff --git a/cmd/api_serve.go b/cmd/api_serve.go index 22af04a8..d3ccdf77 100644 --- a/cmd/api_serve.go +++ b/cmd/api_serve.go @@ -65,6 +65,8 @@ func aptlyAPIServe(cmd *commander.Command, args []string) error { signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM) go (func() { if _, ok := <-sigchan; ok { + fmt.Printf("\nShutdown signal received, waiting for background tasks...\n") + context.TaskList().Wait() server.Shutdown(stdcontext.Background()) } })() diff --git a/context/context.go b/context/context.go index 30e84d42..7e468b0f 100644 --- a/context/context.go +++ b/context/context.go @@ -13,6 +13,7 @@ import ( "runtime/pprof" "strings" "sync" + "syscall" "time" "github.com/aptly-dev/aptly/aptly" @@ -359,10 +360,7 @@ func (context *AptlyContext) ReOpenDatabase() error { // NewCollectionFactory builds factory producing all kinds of collections func (context *AptlyContext) NewCollectionFactory() *deb.CollectionFactory { - context.Lock() - defer context.Unlock() - - db, err := context._database() + db, err := context.Database() if err != nil { Fatal(err) } @@ -560,7 +558,7 @@ func (context *AptlyContext) GoContextHandleSignals() { // Catch ^C sigch := make(chan os.Signal, 1) - signal.Notify(sigch, os.Interrupt) + signal.Notify(sigch, syscall.SIGINT, syscall.SIGTERM) var cancel gocontext.CancelFunc diff --git a/http/download.go b/http/download.go index 5bb9d050..be7b2c78 100644 --- a/http/download.go +++ b/http/download.go @@ -133,6 +133,8 @@ func retryableError(err error) bool { } switch err { + case context.Canceled: + return false case io.EOF: return true case io.ErrUnexpectedEOF: diff --git a/http/download_test.go b/http/download_test.go index 9cde0716..d0feccf3 100644 --- a/http/download_test.go +++ b/http/download_test.go @@ -154,3 +154,13 @@ func (s *DownloaderSuite) TestGetLengthConnectError(c *C) { c.Assert(err, ErrorMatches, ".*no such host") } + +func (s *DownloaderSuite) TestContextCancel(c *C) { + ctx, cancel := context.WithCancel(s.ctx) + s.ctx = ctx + + cancel() + _, err := s.d.GetLength(s.ctx, "http://nosuch.host.invalid./") + + c.Assert(err, ErrorMatches, ".*context canceled.*") +}