From 0c0270763858fae79e7825904c64908311a40f3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Roth?= Date: Mon, 25 May 2026 12:13:41 +0000 Subject: [PATCH] snapshot: fix race conditions * perform collection.LoadComplete inside maybeRunTaskInBackground * have tasks use a fresh copy of taskCollectionFactory, taskCollection * fix locking for snapshots of snapshots by locking SourceSnapshots * use uuids, since names can be renamed --- api/snapshot.go | 223 ++++++++++++++++++++++++++++++++++-------------- deb/snapshot.go | 6 -- 2 files changed, 161 insertions(+), 68 deletions(-) diff --git a/api/snapshot.go b/api/snapshot.go index 2c50566f..7b6422f0 100644 --- a/api/snapshot.go +++ b/api/snapshot.go @@ -83,26 +83,33 @@ func apiSnapshotsCreateFromMirror(c *gin.Context) { } collectionFactory := context.NewCollectionFactory() - collection := collectionFactory.RemoteRepoCollection() - snapshotCollection := collectionFactory.SnapshotCollection() name := c.Params.ByName("name") - repo, err = collection.ByName(name) + repo, err = collectionFactory.RemoteRepoCollection().ByName(name) if err != nil { AbortWithJSONError(c, 404, err) return } // including snapshot resource key - resources := []string{string(repo.Key()), "S" + b.Name} + resources := []string{string(repo.Key())} taskName := fmt.Sprintf("Create snapshot of mirror %s", name) maybeRunTaskInBackground(c, taskName, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { - err := repo.CheckLock() + taskCollectionFactory := context.NewCollectionFactory() + taskMirrorCollection := taskCollectionFactory.RemoteRepoCollection() + taskSnapshotCollection := taskCollectionFactory.SnapshotCollection() + + repo, err := taskMirrorCollection.ByName(name) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err + } + + err = repo.CheckLock() if err != nil { return &task.ProcessReturnValue{Code: http.StatusConflict, Value: nil}, err } - err = collection.LoadComplete(repo) + err = taskMirrorCollection.LoadComplete(repo) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } @@ -116,7 +123,7 @@ func apiSnapshotsCreateFromMirror(c *gin.Context) { snapshot.Description = b.Description } - err = snapshotCollection.Add(snapshot) + err = taskSnapshotCollection.Add(snapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusBadRequest, Value: nil}, err } @@ -165,6 +172,7 @@ func apiSnapshotsCreate(c *gin.Context) { } } + // Phase 1: Pre-task validation (shallow load for 404 checks only) collectionFactory := context.NewCollectionFactory() snapshotCollection := collectionFactory.SnapshotCollection() var resources []string @@ -178,37 +186,62 @@ func apiSnapshotsCreate(c *gin.Context) { return } - resources = append(resources, string(sources[i].ResourceKey())) + resources = append(resources, string(sources[i].Key())) } maybeRunTaskInBackground(c, "Create snapshot "+b.Name, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { - for i := range sources { - err = snapshotCollection.LoadComplete(sources[i]) + // Phase 2: Inside task lock - create fresh factory + taskCollectionFactory := context.NewCollectionFactory() + taskSnapshotCollection := taskCollectionFactory.SnapshotCollection() + taskPackageCollection := taskCollectionFactory.PackageCollection() + + // Fresh load of all sources after lock acquired + freshSources := make([]*deb.Snapshot, len(b.SourceSnapshots)) + for i := range b.SourceSnapshots { + freshSources[i], err = taskSnapshotCollection.ByName(b.SourceSnapshots[i]) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err + } + // LoadComplete on fresh copy + err = taskSnapshotCollection.LoadComplete(freshSources[i]) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } } - list := deb.NewPackageList() + // Merge packages from all source snapshots + var refList *deb.PackageRefList + if len(freshSources) > 0 { + refList = freshSources[0].RefList() + for i := 1; i < len(freshSources); i++ { + refList = refList.Merge(freshSources[i].RefList(), true, false) + } + } else { + refList = deb.NewPackageRefList() + } - // verify package refs and build package list - for _, ref := range b.PackageRefs { - p, err := collectionFactory.PackageCollection().ByKey([]byte(ref)) - if err != nil { - if err == database.ErrNotFound { - return &task.ProcessReturnValue{Code: http.StatusNotFound, Value: nil}, fmt.Errorf("package %s: %s", ref, err) + // Add any explicitly specified package refs on top + if len(b.PackageRefs) > 0 { + list := deb.NewPackageList() + for _, ref := range b.PackageRefs { + p, err := taskPackageCollection.ByKey([]byte(ref)) + if err != nil { + if err == database.ErrNotFound { + return &task.ProcessReturnValue{Code: http.StatusNotFound, Value: nil}, fmt.Errorf("package %s: %s", ref, err) + } + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err + } + err = list.Add(p) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusBadRequest, Value: nil}, err } - return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err - } - err = list.Add(p) - if err != nil { - return &task.ProcessReturnValue{Code: http.StatusBadRequest, Value: nil}, err } + refList = refList.Merge(deb.NewPackageRefListFromPackageList(list), true, false) } - snapshot = deb.NewSnapshotFromRefList(b.Name, sources, deb.NewPackageRefListFromPackageList(list), b.Description) + snapshot = deb.NewSnapshotFromRefList(b.Name, freshSources, refList, b.Description) - err = snapshotCollection.Add(snapshot) + err = taskSnapshotCollection.Add(snapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusBadRequest, Value: nil}, err } @@ -249,21 +282,28 @@ func apiSnapshotsCreateFromRepository(c *gin.Context) { } collectionFactory := context.NewCollectionFactory() - collection := collectionFactory.LocalRepoCollection() - snapshotCollection := collectionFactory.SnapshotCollection() name := c.Params.ByName("name") - repo, err = collection.ByName(name) + repo, err = collectionFactory.LocalRepoCollection().ByName(name) if err != nil { AbortWithJSONError(c, 404, err) return } // including snapshot resource key - resources := []string{string(repo.Key()), "S" + b.Name} + resources := []string{string(repo.Key())} taskName := fmt.Sprintf("Create snapshot of repo %s", name) maybeRunTaskInBackground(c, taskName, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { - err := collection.LoadComplete(repo) + taskCollectionFactory := context.NewCollectionFactory() + taskRepoCollection := taskCollectionFactory.LocalRepoCollection() + taskSnapshotCollection := taskCollectionFactory.SnapshotCollection() + + repo, err := taskRepoCollection.ByName(name) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err + } + + err = taskRepoCollection.LoadComplete(repo) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } @@ -277,7 +317,7 @@ func apiSnapshotsCreateFromRepository(c *gin.Context) { snapshot.Description = b.Description } - err = snapshotCollection.Add(snapshot) + err = taskSnapshotCollection.Add(snapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusBadRequest, Value: nil}, err } @@ -315,6 +355,7 @@ func apiSnapshotsUpdate(c *gin.Context) { return } + // Phase 1: Pre-task validation (shallow load for 404 check only) collectionFactory := context.NewCollectionFactory() collection := collectionFactory.SnapshotCollection() name := c.Params.ByName("name") @@ -325,14 +366,38 @@ func apiSnapshotsUpdate(c *gin.Context) { return } - resources := []string{string(snapshot.ResourceKey()), "S" + b.Name} - taskName := fmt.Sprintf("Update snapshot %s", name) - maybeRunTaskInBackground(c, taskName, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { - _, err := collection.ByName(b.Name) + // Pre-task validation of new name if provided (skip if renaming to same name) + if b.Name != "" && b.Name != name { + _, err = collection.ByName(b.Name) if err == nil { - return &task.ProcessReturnValue{Code: http.StatusConflict, Value: nil}, fmt.Errorf("unable to rename: snapshot %s already exists", b.Name) + AbortWithJSONError(c, 409, fmt.Errorf("unable to rename: snapshot %s already exists", b.Name)) + return + } + } + + resources := []string{string(snapshot.Key())} + taskName := fmt.Sprintf("Update snapshot %s", name) + + maybeRunTaskInBackground(c, taskName, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { + // Phase 2: Inside task lock - create fresh factory + taskCollectionFactory := context.NewCollectionFactory() + taskCollection := taskCollectionFactory.SnapshotCollection() + + // Fresh load after lock acquired + snapshot, err = taskCollection.ByName(name) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } + // Fresh duplicate check inside lock + if b.Name != "" { + _, err := taskCollection.ByName(b.Name) + if err == nil { + return &task.ProcessReturnValue{Code: http.StatusConflict, Value: nil}, fmt.Errorf("unable to rename: snapshot %s already exists", b.Name) + } + } + + // Update fresh copy if b.Name != "" { snapshot.Name = b.Name } @@ -341,7 +406,7 @@ func apiSnapshotsUpdate(c *gin.Context) { snapshot.Description = b.Description } - err = collectionFactory.SnapshotCollection().Update(snapshot) + err = taskCollection.Update(snapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } @@ -395,9 +460,9 @@ func apiSnapshotsDrop(c *gin.Context) { name := c.Params.ByName("name") force := c.Request.URL.Query().Get("force") == "1" + // Phase 1: Pre-task validation (shallow load for 404 check only) collectionFactory := context.NewCollectionFactory() snapshotCollection := collectionFactory.SnapshotCollection() - publishedCollection := collectionFactory.PublishedRepoCollection() snapshot, err := snapshotCollection.ByName(name) if err != nil { @@ -405,23 +470,37 @@ func apiSnapshotsDrop(c *gin.Context) { return } - resources := []string{string(snapshot.ResourceKey())} + resources := []string{string(snapshot.Key())} taskName := fmt.Sprintf("Delete snapshot %s", name) + maybeRunTaskInBackground(c, taskName, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { - published := publishedCollection.BySnapshot(snapshot) + // Phase 2: Inside task lock - create fresh collections + taskCollectionFactory := context.NewCollectionFactory() + taskSnapshotCollection := taskCollectionFactory.SnapshotCollection() + taskPublishedCollection := taskCollectionFactory.PublishedRepoCollection() + + // Fresh load after lock acquired + snapshot, err := taskSnapshotCollection.ByName(name) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err + } + + // Fresh checks with current collections + published := taskPublishedCollection.BySnapshot(snapshot) if len(published) > 0 { return &task.ProcessReturnValue{Code: http.StatusConflict, Value: nil}, fmt.Errorf("unable to drop: snapshot is published") } if !force { - snapshots := snapshotCollection.BySnapshotSource(snapshot) + // Using fresh collection for dependency check + snapshots := taskSnapshotCollection.BySnapshotSource(snapshot) if len(snapshots) > 0 { return &task.ProcessReturnValue{Code: http.StatusConflict, Value: nil}, fmt.Errorf("won't delete snapshot that was used as source for other snapshots, use ?force=1 to override") } } - err = snapshotCollection.Drop(snapshot) + err = taskSnapshotCollection.Drop(snapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } @@ -576,6 +655,7 @@ func apiSnapshotsMerge(c *gin.Context) { return } + // Phase 1: Pre-task validation (shallow load for 404 checks only) collectionFactory := context.NewCollectionFactory() snapshotCollection := collectionFactory.SnapshotCollection() @@ -588,36 +668,47 @@ func apiSnapshotsMerge(c *gin.Context) { return } - resources[i] = string(sources[i].ResourceKey()) + resources[i] = string(sources[i].Key()) } maybeRunTaskInBackground(c, "Merge snapshot "+name, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { - err = snapshotCollection.LoadComplete(sources[0]) - if err != nil { - return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err - } - result := sources[0].RefList() - for i := 1; i < len(sources); i++ { - err = snapshotCollection.LoadComplete(sources[i]) + // Phase 2: Inside task lock - create fresh factory + taskCollectionFactory := context.NewCollectionFactory() + taskSnapshotCollection := taskCollectionFactory.SnapshotCollection() + + // Fresh load of all sources inside task + freshSources := make([]*deb.Snapshot, len(body.Sources)) + for i := range body.Sources { + freshSources[i], err = taskSnapshotCollection.ByName(body.Sources[i]) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } - result = result.Merge(sources[i].RefList(), overrideMatching, false) + // LoadComplete on fresh copy + err = taskSnapshotCollection.LoadComplete(freshSources[i]) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err + } + } + + // Merge using fresh sources + result := freshSources[0].RefList() + for i := 1; i < len(freshSources); i++ { + result = result.Merge(freshSources[i].RefList(), overrideMatching, false) } if latest { result.FilterLatestRefs() } - sourceDescription := make([]string, len(sources)) - for i, s := range sources { + sourceDescription := make([]string, len(freshSources)) + for i, s := range freshSources { sourceDescription[i] = fmt.Sprintf("'%s'", s.Name) } - snapshot = deb.NewSnapshotFromRefList(name, sources, result, + snapshot = deb.NewSnapshotFromRefList(name, freshSources, result, fmt.Sprintf("Merged from sources: %s", strings.Join(sourceDescription, ", "))) - err = collectionFactory.SnapshotCollection().Add(snapshot) + err = taskCollectionFactory.SnapshotCollection().Add(snapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, fmt.Errorf("unable to create snapshot: %s", err) } @@ -698,24 +789,32 @@ func apiSnapshotsPull(c *gin.Context) { return } - resources := []string{string(sourceSnapshot.ResourceKey()), string(toSnapshot.ResourceKey())} + resources := []string{string(sourceSnapshot.Key()), string(toSnapshot.Key())} taskName := fmt.Sprintf("Pull snapshot %s into %s and save as %s", body.Source, name, body.Destination) maybeRunTaskInBackground(c, taskName, resources, func(_ aptly.Progress, _ *task.Detail) (*task.ProcessReturnValue, error) { - err = collectionFactory.SnapshotCollection().LoadComplete(toSnapshot) + // Phase 2: Inside task lock - create fresh factory + taskCollectionFactory := context.NewCollectionFactory() + + // Fresh load of snapshots after lock acquired + freshToSnapshot, err := taskCollectionFactory.SnapshotCollection().ByName(name) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } - err = collectionFactory.SnapshotCollection().LoadComplete(sourceSnapshot) + freshSourceSnapshot, err := taskCollectionFactory.SnapshotCollection().ByName(body.Source) + if err != nil { + return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err + } + err = taskCollectionFactory.SnapshotCollection().LoadComplete(freshSourceSnapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } // convert snapshots to package list - toPackageList, err := deb.NewPackageListFromRefList(toSnapshot.RefList(), collectionFactory.PackageCollection(), context.Progress()) + toPackageList, err := deb.NewPackageListFromRefList(freshToSnapshot.RefList(), taskCollectionFactory.PackageCollection(), context.Progress()) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } - sourcePackageList, err := deb.NewPackageListFromRefList(sourceSnapshot.RefList(), collectionFactory.PackageCollection(), context.Progress()) + sourcePackageList, err := deb.NewPackageListFromRefList(freshSourceSnapshot.RefList(), taskCollectionFactory.PackageCollection(), context.Progress()) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } @@ -812,10 +911,10 @@ func apiSnapshotsPull(c *gin.Context) { } // Create snapshot - destinationSnapshot = deb.NewSnapshotFromPackageList(body.Destination, []*deb.Snapshot{toSnapshot, sourceSnapshot}, toPackageList, - fmt.Sprintf("Pulled into '%s' with '%s' as source, pull request was: '%s'", toSnapshot.Name, sourceSnapshot.Name, strings.Join(body.Queries, ", "))) + destinationSnapshot = deb.NewSnapshotFromPackageList(body.Destination, []*deb.Snapshot{freshToSnapshot, freshSourceSnapshot}, toPackageList, + fmt.Sprintf("Pulled into '%s' with '%s' as source, pull request was: '%s'", freshToSnapshot.Name, freshSourceSnapshot.Name, strings.Join(body.Queries, ", "))) - err = collectionFactory.SnapshotCollection().Add(destinationSnapshot) + err = taskCollectionFactory.SnapshotCollection().Add(destinationSnapshot) if err != nil { return &task.ProcessReturnValue{Code: http.StatusInternalServerError, Value: nil}, err } diff --git a/deb/snapshot.go b/deb/snapshot.go index f79a50d2..c308403c 100644 --- a/deb/snapshot.go +++ b/deb/snapshot.go @@ -143,12 +143,6 @@ func (s *Snapshot) Key() []byte { return []byte("S" + s.UUID) } -// ResourceKey is a unique identifier of the resource -// this snapshot uses. Instead of uuid it uses name -// which needs to be unique as well. -func (s *Snapshot) ResourceKey() []byte { - return []byte("S" + s.Name) -} // RefKey is a unique id for package reference list func (s *Snapshot) RefKey() []byte {