Merge branch 'master' into dpkg-compliant-version-compare

This commit is contained in:
André Roth
2025-12-26 16:55:52 +01:00
committed by GitHub
17 changed files with 336 additions and 21 deletions

View File

@@ -35,7 +35,7 @@ jobs:
- name: Install and initialize swagger
run: |
go install github.com/swaggo/swag/cmd/swag@latest
swag init -q --markdownFiles docs
swag init -q --propertyStrategy pascalcase --markdownFiles docs
shell: sh
- name: golangci-lint

View File

@@ -75,3 +75,5 @@ List of contributors, in chronological order:
* Agustin Henze (https://github.com/agustinhenze)
* Tobias Assarsson (https://github.com/daedaluz)
* Yaksh Bariya (https://github.com/thunder-coding)
* Juan Calderon-Perez (https://github.com/gaby)
* Ato Araki (https://github.com/atotto)

View File

@@ -77,7 +77,7 @@ azurite-stop:
swagger: swagger-install
# Generate swagger docs
@PATH=$(BINPATH)/:$(PATH) swag init --parseDependency --parseInternal --markdownFiles docs --generalInfo docs/swagger.conf
@PATH=$(BINPATH)/:$(PATH) swag init --propertyStrategy pascalcase --parseDependency --parseInternal --markdownFiles docs --generalInfo docs/swagger.conf
etcd-install:
# Install etcd
@@ -131,7 +131,7 @@ serve: prepare swagger-install ## Run development server (auto recompiling)
test -f $(BINPATH)/air || go install github.com/air-verse/air@v1.52.3
cp debian/aptly.conf ~/.aptly.conf
sed -i /enable_swagger_endpoint/s/false/true/ ~/.aptly.conf
PATH=$(BINPATH):$$PATH air -build.pre_cmd 'swag init -q --markdownFiles docs --generalInfo docs/swagger.conf' -build.exclude_dir docs,system,debian,pgp/keyrings,pgp/test-bins,completion.d,man,deb/testdata,console,_man,systemd,obj-x86_64-linux-gnu -- api serve -listen 0.0.0.0:3142
PATH=$(BINPATH):$$PATH air -build.pre_cmd 'swag init -q --propertyStrategy pascalcase --markdownFiles docs --generalInfo docs/swagger.conf' -build.exclude_dir docs,system,debian,pgp/keyrings,pgp/test-bins,completion.d,man,deb/testdata,console,_man,systemd,obj-x86_64-linux-gnu -- api serve -listen 0.0.0.0:3142
dpkg: prepare swagger ## Build debian packages
@test -n "$(DEBARCH)" || (echo "please define DEBARCH"; exit 1)

View File

@@ -343,6 +343,8 @@ type mirrorUpdateParams struct {
ForceUpdate bool ` json:"ForceUpdate"`
// Set "true" to skip downloading already downloaded packages
SkipExistingPackages bool ` json:"SkipExistingPackages"`
// Set "true" to download only the latest version per package/architecture
LatestOnly bool ` json:"LatestOnly"`
}
// @Summary Update Mirror
@@ -434,7 +436,7 @@ func apiMirrorsUpdate(c *gin.Context) {
}
queue, downloadSize, err := remote.BuildDownloadQueue(context.PackagePool(), collectionFactory.PackageCollection(),
collectionFactory.ChecksumCollection(nil), b.SkipExistingPackages)
collectionFactory.ChecksumCollection(nil), b.SkipExistingPackages, b.LatestOnly)
if err != nil {
return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, fmt.Errorf("unable to update: %s", err)
}

View File

@@ -87,10 +87,11 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
)
skipExistingPackages := context.Flags().Lookup("skip-existing-packages").Value.Get().(bool)
latestOnly := context.Flags().Lookup("latest").Value.Get().(bool)
context.Progress().Printf("Building download queue...\n")
queue, downloadSize, err = repo.BuildDownloadQueue(context.PackagePool(), collectionFactory.PackageCollection(),
collectionFactory.ChecksumCollection(nil), skipExistingPackages)
collectionFactory.ChecksumCollection(nil), skipExistingPackages, latestOnly)
if err != nil {
return fmt.Errorf("unable to update: %s", err)
@@ -292,6 +293,7 @@ Example:
cmd.Flag.Bool("ignore-checksums", false, "ignore checksum mismatches while downloading package files and metadata")
cmd.Flag.Bool("ignore-signatures", false, "disable verification of Release file signatures")
cmd.Flag.Bool("skip-existing-packages", false, "do not check file existence for packages listed in the internal database of the mirror")
cmd.Flag.Bool("latest", false, "download only latest version of each package (per architecture)")
cmd.Flag.Int64("download-limit", 0, "limit download speed (kbytes/sec)")
cmd.Flag.String("downloader", "default", "downloader to use (e.g. grab)")
cmd.Flag.Int("max-tries", 1, "max download tries till process fails with download error")

View File

@@ -211,6 +211,7 @@ local keyring="*-keyring=[gpg keyring to use when verifying Release file (could
$keyring \
"-max-tries=[max download tries till process fails with download error]:number: " \
"-skip-existing-packages=[do not check file existence for packages listed in the internal database of the mirror]:$bool" \
"-latest=[download only latest version of each package (per architecture)]:$bool" \
"(-)2:mirror name:$mirrors"
;;
rename)

View File

@@ -263,7 +263,7 @@ _aptly()
"update")
if [[ $numargs -eq 0 ]]; then
if [[ "$cur" == -* ]]; then
COMPREPLY=($(compgen -W "-force -download-limit= -downloader= -ignore-checksums -ignore-signatures -keyring= -skip-existing-packages" -- ${cur}))
COMPREPLY=($(compgen -W "-force -download-limit= -downloader= -ignore-checksums -ignore-signatures -keyring= -skip-existing-packages -latest" -- ${cur}))
else
COMPREPLY=($(compgen -W "$(__aptly_mirror_list)" -- ${cur}))
fi

View File

@@ -172,6 +172,39 @@ func (l *PackageList) ForEach(handler func(*Package) error) error {
return err
}
// FilterLatest creates a copy of the package list containing only the
// latest version for each package name/architecture pair.
func (l *PackageList) FilterLatest() (*PackageList, error) {
if l == nil {
return nil, fmt.Errorf("package list is nil")
}
filtered := make(map[string]*Package, l.Len())
err := l.ForEach(func(p *Package) error {
key := p.Architecture + "|" + p.Name
if existing, found := filtered[key]; !found || CompareVersions(p.Version, existing.Version) > 0 {
filtered[key] = p
}
return nil
})
if err != nil {
return nil, err
}
result := NewPackageListWithDuplicates(l.duplicatesAllowed, len(filtered))
for _, pkg := range filtered {
if err = result.Add(pkg); err != nil {
return nil, err
}
}
return result, nil
}
// ForEachIndexed calls handler for each package in list in indexed order
func (l *PackageList) ForEachIndexed(handler func(*Package) error) error {
if !l.indexed {

View File

@@ -503,3 +503,52 @@ func (s *PackageListSuite) TestArchitectures(c *C) {
sort.Strings(archs)
c.Check(archs, DeepEquals, []string{"amd64", "arm", "i386", "s390"})
}
func (s *PackageListSuite) TestFilterLatest(c *C) {
list := NewPackageList()
older := packageStanza.Copy()
older["Version"] = "1.0"
olderPkg := NewPackageFromControlFile(older)
_ = list.Add(olderPkg)
newer := packageStanza.Copy()
newer["Version"] = "2.0"
newerPkg := NewPackageFromControlFile(newer)
_ = list.Add(newerPkg)
shared := packageStanza.Copy()
shared["Architecture"] = ArchitectureAll
shared["Version"] = "3.0"
shared["Package"] = "shared"
sharedPkg := NewPackageFromControlFile(shared)
_ = list.Add(sharedPkg)
filtered, err := list.FilterLatest()
c.Assert(err, IsNil)
c.Assert(filtered.Len(), Equals, 2)
c.Check(filtered.Has(newerPkg), Equals, true)
c.Check(filtered.Has(sharedPkg), Equals, true)
}
func (s *PackageListSuite) TestFilterLatestPreservesDuplicatesFlag(c *C) {
list := NewPackageListWithDuplicates(true, 2)
_ = list.Add(NewPackageFromControlFile(packageStanza.Copy()))
another := packageStanza.Copy()
another["Version"] = "7.41-1"
_ = list.Add(NewPackageFromControlFile(another))
filtered, err := list.FilterLatest()
c.Assert(err, IsNil)
c.Assert(filtered.duplicatesAllowed, Equals, true)
}
func (s *PackageListSuite) TestFilterLatestNil(c *C) {
var list *PackageList
filtered, err := list.FilterLatest()
c.Assert(err, ErrorMatches, "package list is nil")
c.Assert(filtered, IsNil)
}

View File

@@ -612,7 +612,19 @@ func (repo *RemoteRepo) ApplyFilter(dependencyOptions int, filterQuery PackageQu
}
// BuildDownloadQueue builds queue, discards current PackageList
func (repo *RemoteRepo) BuildDownloadQueue(packagePool aptly.PackagePool, packageCollection *PackageCollection, checksumStorage aptly.ChecksumStorage, skipExistingPackages bool) (queue []PackageDownloadTask, downloadSize int64, err error) {
func (repo *RemoteRepo) BuildDownloadQueue(packagePool aptly.PackagePool, packageCollection *PackageCollection, checksumStorage aptly.ChecksumStorage, skipExistingPackages, latestOnly bool) (queue []PackageDownloadTask, downloadSize int64, err error) {
if repo.packageList == nil {
err = fmt.Errorf("package list is empty, please (re)download package indexes")
return
}
if latestOnly {
repo.packageList, err = repo.packageList.FilterLatest()
if err != nil {
return
}
}
queue = make([]PackageDownloadTask, 0, repo.packageList.Len())
seen := make(map[string]int, repo.packageList.Len())

View File

@@ -281,7 +281,7 @@ func (s *RemoteRepoSuite) TestDownload(c *C) {
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
queue, size, err := s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err := s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(3))
c.Check(queue, HasLen, 1)
@@ -308,7 +308,7 @@ func (s *RemoteRepoSuite) TestDownload(c *C) {
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(0))
c.Check(queue, HasLen, 0)
@@ -329,7 +329,7 @@ func (s *RemoteRepoSuite) TestDownload(c *C) {
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(3))
c.Check(queue, HasLen, 1)
@@ -356,7 +356,7 @@ func (s *RemoteRepoSuite) TestDownloadWithInstaller(c *C) {
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
queue, size, err := s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err := s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(3)+int64(len(exampleInstallerManifestFile)))
c.Check(queue, HasLen, 2)
@@ -382,6 +382,35 @@ func (s *RemoteRepoSuite) TestDownloadWithInstaller(c *C) {
c.Check(pkg.Name, Equals, "installer")
}
func (s *RemoteRepoSuite) TestBuildDownloadQueueLatestOnly(c *C) {
s.repo.Architectures = []string{"i386"}
err := s.repo.Fetch(s.downloader, nil, true)
c.Assert(err, IsNil)
s.downloader.ExpectError("http://mirror.yandex.ru/debian/dists/squeeze/main/binary-i386/Packages.bz2", &http.Error{Code: 404})
s.downloader.ExpectError("http://mirror.yandex.ru/debian/dists/squeeze/main/binary-i386/Packages.gz", &http.Error{Code: 404})
s.downloader.ExpectResponse("http://mirror.yandex.ru/debian/dists/squeeze/main/binary-i386/Packages", examplePackagesFile)
err = s.repo.DownloadPackageIndexes(s.progress, s.downloader, nil, s.collectionFactory, true, false)
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
stanza := packageStanza.Copy()
stanza["Package"] = "amanda-client"
stanza["Version"] = "1:3.4.0-1"
stanza["Filename"] = "pool/main/a/amanda/amanda-client_3.4.0-1_i386.deb"
newest := NewPackageFromControlFile(stanza)
_ = s.repo.packageList.Add(newest)
queue, size, err := s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, true)
c.Assert(err, IsNil)
c.Check(queue, HasLen, 1)
c.Check(queue[0].File.DownloadURL(), Equals, "pool/main/a/amanda/amanda-client_3.4.0-1_i386.deb")
c.Check(size, Equals, int64(187518))
}
func (s *RemoteRepoSuite) TestDownloadWithSources(c *C) {
s.repo.Architectures = []string{"i386"}
s.repo.DownloadSources = true
@@ -400,7 +429,7 @@ func (s *RemoteRepoSuite) TestDownloadWithSources(c *C) {
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
queue, size, err := s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err := s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(15))
c.Check(queue, HasLen, 4)
@@ -444,7 +473,7 @@ func (s *RemoteRepoSuite) TestDownloadWithSources(c *C) {
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(0))
c.Check(queue, HasLen, 0)
@@ -469,7 +498,7 @@ func (s *RemoteRepoSuite) TestDownloadWithSources(c *C) {
c.Assert(err, IsNil)
c.Assert(s.downloader.Empty(), Equals, true)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err = s.repo.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(15))
c.Check(queue, HasLen, 4)
@@ -493,7 +522,7 @@ func (s *RemoteRepoSuite) TestDownloadFlat(c *C) {
c.Assert(err, IsNil)
c.Assert(downloader.Empty(), Equals, true)
queue, size, err := s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err := s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(3))
c.Check(queue, HasLen, 1)
@@ -521,7 +550,7 @@ func (s *RemoteRepoSuite) TestDownloadFlat(c *C) {
c.Assert(err, IsNil)
c.Assert(downloader.Empty(), Equals, true)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(0))
c.Check(queue, HasLen, 0)
@@ -543,7 +572,7 @@ func (s *RemoteRepoSuite) TestDownloadFlat(c *C) {
c.Assert(err, IsNil)
c.Assert(downloader.Empty(), Equals, true)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(3))
c.Check(queue, HasLen, 1)
@@ -574,7 +603,7 @@ func (s *RemoteRepoSuite) TestDownloadWithSourcesFlat(c *C) {
c.Assert(err, IsNil)
c.Assert(downloader.Empty(), Equals, true)
queue, size, err := s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err := s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(15))
c.Check(queue, HasLen, 4)
@@ -620,7 +649,7 @@ func (s *RemoteRepoSuite) TestDownloadWithSourcesFlat(c *C) {
c.Assert(err, IsNil)
c.Assert(downloader.Empty(), Equals, true)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, true, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(0))
c.Check(queue, HasLen, 0)
@@ -646,7 +675,7 @@ func (s *RemoteRepoSuite) TestDownloadWithSourcesFlat(c *C) {
c.Assert(err, IsNil)
c.Assert(downloader.Empty(), Equals, true)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false)
queue, size, err = s.flat.BuildDownloadQueue(s.packagePool, s.collectionFactory.PackageCollection(), s.cs, false, false)
c.Assert(err, IsNil)
c.Check(size, Equals, int64(15))
c.Check(queue, HasLen, 4)

4
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/aptly-dev/aptly
go 1.24
go 1.24.0
require (
github.com/AlekSi/pointer v1.1.0
@@ -41,6 +41,7 @@ require (
)
require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/PuerkitoBio/purell v1.1.1 // indirect
@@ -128,5 +129,6 @@ require (
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.3
go.etcd.io/etcd/client/v3 v3.5.15
golang.org/x/oauth2 v0.33.0
gopkg.in/yaml.v3 v3.0.1
)

4
go.sum
View File

@@ -1,3 +1,5 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/AlekSi/pointer v1.1.0 h1:SSDMPcXD9jSl8FPy9cRzoRaMJtm9g9ggGTxecRUbQoI=
github.com/AlekSi/pointer v1.1.0/go.mod h1:y7BvfRI3wXPWKXEBhU71nbnIEEZX0QTSB2Bj48UJIZE=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8=
@@ -348,6 +350,8 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo=
golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -44,6 +44,7 @@ func NewDownloader(downLimit int64, maxTries int, progress aptly.Progress) aptly
transport.DisableCompression = true
initTransport(&transport)
transport.RegisterProtocol("ftp", &protocol.FTPRoundTripper{})
transport.RegisterProtocol("ar+https", NewGCPRoundTripper(&transport))
downloader := &downloaderImpl{
progress: progress,

64
http/gcp_auth.go Normal file
View File

@@ -0,0 +1,64 @@
package http
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// gcpRoundTripper wraps http.RoundTripper to add Google Cloud authentication.
// It delays GCP authentication initialization until the first actual request is made.
// This avoids unnecessary credential loading when ar+https protocol is not actually used.
//
// It uses Application Default Credentials (ADC) which checks:
// 1. GOOGLE_APPLICATION_CREDENTIALS environment variable
// 2. gcloud auth application-default credentials
// 3. GCE/GKE metadata server
// See https://cloud.google.com/docs/authentication/application-default-credentials for usage details.
type gcpRoundTripper struct {
base http.RoundTripper
initOnce sync.Once
tokenSrc oauth2.TokenSource
initErr error
}
func (t *gcpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Lazy initialization: only initialize GCP credentials on first request
t.initOnce.Do(func() {
creds, err := google.FindDefaultCredentials(context.Background(),
"https://www.googleapis.com/auth/cloud-platform")
if err != nil {
t.initErr = fmt.Errorf("failed to find default credentials: %w", err)
return
}
t.tokenSrc = creds.TokenSource
})
reqCopy := req.Clone(req.Context())
reqCopy.URL.Scheme = strings.TrimPrefix(reqCopy.URL.Scheme, "ar+")
// Fall back to base transport if GCP auth initialization failed
if t.initErr != nil {
return t.base.RoundTrip(reqCopy)
}
token, err := t.tokenSrc.Token()
if err != nil {
return nil, fmt.Errorf("failed to get OAuth2 token: %w", err)
}
token.SetAuthHeader(reqCopy)
return t.base.RoundTrip(reqCopy)
}
// NewGCPRoundTripper creates a new RoundTripper that handles GCP authentication for ar+https protocol.
func NewGCPRoundTripper(base http.RoundTripper) http.RoundTripper {
return &gcpRoundTripper{
base: base,
}
}

110
http/gcp_auth_test.go Normal file
View File

@@ -0,0 +1,110 @@
package http
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"golang.org/x/oauth2"
)
func TestGCPAuthTransport_RoundTrip(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth == "" {
t.Error("Expected Authorization header, got none")
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
transport := NewGCPRoundTripper(http.DefaultTransport)
if os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") == "" {
t.Skip("Skipping test: GOOGLE_APPLICATION_CREDENTIALS not set")
}
client := &http.Client{Transport: transport}
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
}
func TestGCPAuthTransport_RoundTrip_with_dummy_tokenSource(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "Bearer dummy-token" {
t.Errorf("Expected Authorization header 'Bearer dummy-token', got '%s'", auth)
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
// Use a dummy token source for testing
transport := &gcpRoundTripper{
base: http.DefaultTransport,
tokenSrc: oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "dummy-token",
}),
}
transport.initOnce.Do(func() {}) // Mark as initialized for testing
client := &http.Client{Transport: transport}
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
}
func TestGCPAuthTransport_RoundTrip_with_InvalidCredentials(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer ts.Close()
// Create a temporary invalid credentials file
tmpFile, err := os.CreateTemp("", "invalid_credentials.json")
if err != nil {
t.Fatalf("Failed to create temp file: %s", err)
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.WriteString(`{"invalid": "data"}`); err != nil {
t.Fatalf("Failed to write to temp file: %s", err)
}
tmpFile.Close()
defaultEnv := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS")
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", tmpFile.Name())
defer os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", defaultEnv)
transport := &gcpRoundTripper{
base: http.DefaultTransport,
}
client := &http.Client{Transport: transport}
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatalf("Failed to make request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected status 403, got %d", resp.StatusCode)
}
if transport.initErr == nil {
t.Error("Expected init error due to invalid credentials, got none")
}
}

View File

@@ -738,6 +738,10 @@ max download tries till process fails with download error
\-\fBskip\-existing\-packages\fR
do not check file existence for packages listed in the internal database of the mirror
.
.TP
\-\fBlatest\fR
download only latest version of each package (per architecture)
.
.SH "RENAMES MIRROR"
\fBaptly\fR \fBmirror\fR \fBrename\fR \fIold\-name\fR \fInew\-name\fR
.