From 3608c137a05ca29ca161133b0317cb22030fb2b8 Mon Sep 17 00:00:00 2001 From: Agustin Henze Date: Tue, 19 Aug 2025 14:37:52 +0200 Subject: [PATCH] Add mutex on LinkFromPool to fix #1449 This fixes the race condition that happens when you call publish concurrently. It adds a valuable test that reproduces the error almost deterministically, it's hard to say always but I have run this in loop 100 times and it reproduces the error consistently without the patch and after the patch it works consistently. --- AUTHORS | 3 +- files/linkfrompool_concurrency_test.go | 283 +++++++++++++++++++++++++ files/public.go | 40 +++- 3 files changed, 318 insertions(+), 8 deletions(-) create mode 100644 files/linkfrompool_concurrency_test.go diff --git a/AUTHORS b/AUTHORS index 8eef529d..e2b318c8 100644 --- a/AUTHORS +++ b/AUTHORS @@ -69,4 +69,5 @@ List of contributors, in chronological order: * Leigh London (https://github.com/leighlondon) * Gordian Schoenherr (https://github.com/schoenherrg) * Silke Hofstra (https://github.com/silkeh) -* Itay Porezky (https://github.com/itayporezky) \ No newline at end of file +* Itay Porezky (https://github.com/itayporezky) +* Agustin Henze (https://github.com/agustinhenze) diff --git a/files/linkfrompool_concurrency_test.go b/files/linkfrompool_concurrency_test.go new file mode 100644 index 00000000..0acbab3f --- /dev/null +++ b/files/linkfrompool_concurrency_test.go @@ -0,0 +1,283 @@ +package files + +import ( + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/aptly-dev/aptly/aptly" + "github.com/aptly-dev/aptly/utils" + + . "gopkg.in/check.v1" +) + +type LinkFromPoolConcurrencySuite struct { + root string + poolDir string + storage *PublishedStorage + pool *PackagePool + cs aptly.ChecksumStorage + testFile string + testContent []byte + testChecksums utils.ChecksumInfo + srcPoolPath string +} + +var _ = Suite(&LinkFromPoolConcurrencySuite{}) + +func (s *LinkFromPoolConcurrencySuite) SetUpTest(c *C) { + s.root = c.MkDir() + s.poolDir = filepath.Join(s.root, "pool") + publishDir := filepath.Join(s.root, "public") + + // Create package pool and published storage + s.pool = NewPackagePool(s.poolDir, true) + s.storage = NewPublishedStorage(publishDir, "copy", "checksum") + s.cs = NewMockChecksumStorage() + + // Create test file content + s.testContent = []byte("test package content for concurrency testing") + s.testFile = filepath.Join(s.root, "test-package.deb") + + err := os.WriteFile(s.testFile, s.testContent, 0644) + c.Assert(err, IsNil) + + // Calculate checksums + md5sum, err := utils.MD5ChecksumForFile(s.testFile) + c.Assert(err, IsNil) + + s.testChecksums = utils.ChecksumInfo{ + Size: int64(len(s.testContent)), + MD5: md5sum, + } + + // Import the test file into the pool + s.srcPoolPath, err = s.pool.Import(s.testFile, "test-package.deb", &s.testChecksums, false, s.cs) + c.Assert(err, IsNil) +} + +func (s *LinkFromPoolConcurrencySuite) TestLinkFromPoolConcurrency(c *C) { + // Test concurrent LinkFromPool operations to ensure no race conditions + concurrency := 5000 + iterations := 10 + + for iter := 0; iter < iterations; iter++ { + c.Logf("Iteration %d: Testing concurrent LinkFromPool with %d goroutines", iter+1, concurrency) + + destPath := fmt.Sprintf("main/t/test%d", iter) + + var wg sync.WaitGroup + errors := make(chan error, concurrency) + successes := make(chan struct{}, concurrency) + + start := time.Now() + + // Launch concurrent LinkFromPool operations + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Use force=true to test the most vulnerable code path (remove-then-create) + err := s.storage.LinkFromPool( + "", // publishedPrefix + destPath, // publishedRelPath + "test-package.deb", // fileName + s.pool, // sourcePool + s.srcPoolPath, // sourcePath + s.testChecksums, // sourceChecksums + true, // force - this triggers vulnerable remove-then-create pattern + ) + + if err != nil { + errors <- fmt.Errorf("goroutine %d failed: %v", id, err) + } else { + successes <- struct{}{} + } + }(i) + } + + // Wait for completion + wg.Wait() + duration := time.Since(start) + + close(errors) + close(successes) + + // Count results + errorCount := 0 + successCount := 0 + var firstError error + + for err := range errors { + errorCount++ + if firstError == nil { + firstError = err + } + c.Logf("Race condition error: %v", err) + } + + for range successes { + successCount++ + } + + c.Logf("Results: %d successes, %d errors, took %v", successCount, errorCount, duration) + + // Assert no race conditions occurred + if errorCount > 0 { + c.Fatalf("Race condition detected in iteration %d! "+ + "Errors: %d out of %d operations (%.1f%% failure rate). "+ + "First error: %v. "+ + "This indicates the fix is not working properly.", + iter+1, errorCount, concurrency, + float64(errorCount)/float64(concurrency)*100, firstError) + } + + // Verify the final file exists and has correct content + finalFile := filepath.Join(s.storage.rootPath, destPath, "test-package.deb") + _, err := os.Stat(finalFile) + c.Assert(err, IsNil, Commentf("Final file should exist after concurrent operations")) + + content, err := os.ReadFile(finalFile) + c.Assert(err, IsNil, Commentf("Should be able to read final file")) + c.Assert(content, DeepEquals, s.testContent, Commentf("File content should be intact after concurrent operations")) + + c.Logf("✓ Iteration %d: No race conditions detected", iter+1) + } + + c.Logf("SUCCESS: Handled %d total concurrent operations across %d iterations with no race conditions", + concurrency*iterations, iterations) +} + +func (s *LinkFromPoolConcurrencySuite) TestLinkFromPoolConcurrencyDifferentFiles(c *C) { + // Test concurrent operations on different files to ensure no blocking + concurrency := 10 + + var wg sync.WaitGroup + errors := make(chan error, concurrency) + + start := time.Now() + + // Launch concurrent operations on different destination files + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + destPath := fmt.Sprintf("main/t/test-file-%d", id) + + err := s.storage.LinkFromPool( + "", // publishedPrefix + destPath, // publishedRelPath + "test-package.deb", // fileName + s.pool, // sourcePool + s.srcPoolPath, // sourcePath + s.testChecksums, // sourceChecksums + false, // force + ) + + if err != nil { + errors <- fmt.Errorf("goroutine %d failed: %v", id, err) + } + }(i) + } + + // Wait for completion + wg.Wait() + duration := time.Since(start) + + close(errors) + + // Count errors + errorCount := 0 + for err := range errors { + errorCount++ + c.Logf("Error: %v", err) + } + + c.Assert(errorCount, Equals, 0, Commentf("No errors should occur when linking to different files")) + c.Logf("SUCCESS: %d concurrent operations on different files completed in %v", concurrency, duration) + + // Verify all files were created correctly + for i := 0; i < concurrency; i++ { + finalFile := filepath.Join(s.storage.rootPath, fmt.Sprintf("main/t/test-file-%d", i), "test-package.deb") + _, err := os.Stat(finalFile) + c.Assert(err, IsNil, Commentf("File %d should exist", i)) + + content, err := os.ReadFile(finalFile) + c.Assert(err, IsNil, Commentf("Should be able to read file %d", i)) + c.Assert(content, DeepEquals, s.testContent, Commentf("File %d content should be correct", i)) + } +} + +func (s *LinkFromPoolConcurrencySuite) TestLinkFromPoolWithoutForceNoConcurrencyIssues(c *C) { + // Test that when force=false, concurrent operations fail gracefully without corruption + concurrency := 20 + destPath := "main/t/single-dest" + + var wg sync.WaitGroup + errors := make(chan error, concurrency) + successes := make(chan struct{}, concurrency) + + // First, create the file so subsequent operations will conflict + err := s.storage.LinkFromPool("", destPath, "test-package.deb", s.pool, s.srcPoolPath, s.testChecksums, false) + c.Assert(err, IsNil) + + start := time.Now() + + // Launch concurrent operations that should mostly fail + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + err := s.storage.LinkFromPool( + "", // publishedPrefix + destPath, // publishedRelPath + "test-package.deb", // fileName + s.pool, // sourcePool + s.srcPoolPath, // sourcePath + s.testChecksums, // sourceChecksums + false, // force=false - should fail if file exists and is same + ) + + if err != nil { + errors <- err + } else { + successes <- struct{}{} + } + }(i) + } + + // Wait for completion + wg.Wait() + duration := time.Since(start) + + close(errors) + close(successes) + + errorCount := 0 + successCount := 0 + + for range errors { + errorCount++ + } + + for range successes { + successCount++ + } + + c.Logf("Results with force=false: %d successes, %d errors, took %v", successCount, errorCount, duration) + + // With force=false and identical files, operations should succeed (file already exists with same content) + // No race conditions should cause crashes or corruption + c.Assert(errorCount, Equals, 0, Commentf("With identical files and force=false, operations should succeed")) + + // Verify the file still exists and has correct content + finalFile := filepath.Join(s.storage.rootPath, destPath, "test-package.deb") + content, err := os.ReadFile(finalFile) + c.Assert(err, IsNil) + c.Assert(content, DeepEquals, s.testContent, Commentf("File should not be corrupted by concurrent access")) +} diff --git a/files/public.go b/files/public.go index f3756aeb..7cf36ec8 100644 --- a/files/public.go +++ b/files/public.go @@ -22,6 +22,26 @@ type PublishedStorage struct { verifyMethod uint } +// Global mutex map to prevent concurrent access to the same destinationPath in LinkFromPool +var ( + fileLockMutex sync.Mutex + fileLocks = make(map[string]*sync.Mutex) +) + +// getFileLock returns a mutex for a specific file path to prevent concurrent modifications +func getFileLock(filePath string) *sync.Mutex { + fileLockMutex.Lock() + defer fileLockMutex.Unlock() + + if mutex, exists := fileLocks[filePath]; exists { + return mutex + } + + mutex := &sync.Mutex{} + fileLocks[filePath] = mutex + return mutex +} + // Check interfaces var ( _ aptly.PublishedStorage = (*PublishedStorage)(nil) @@ -136,6 +156,12 @@ func (storage *PublishedStorage) LinkFromPool(publishedPrefix, publishedRelPath, baseName := filepath.Base(fileName) poolPath := filepath.Join(storage.rootPath, publishedPrefix, publishedRelPath, filepath.Dir(fileName)) + destinationPath := filepath.Join(poolPath, baseName) + + // Acquire file-specific lock to prevent concurrent access to the same file + fileLock := getFileLock(destinationPath) + fileLock.Lock() + defer fileLock.Unlock() var localSourcePool aptly.LocalPackagePool if storage.linkMethod != LinkMethodCopy { @@ -154,7 +180,7 @@ func (storage *PublishedStorage) LinkFromPool(publishedPrefix, publishedRelPath, var dstStat os.FileInfo - dstStat, err = os.Stat(filepath.Join(poolPath, baseName)) + dstStat, err = os.Stat(destinationPath) if err == nil { // already exists, check source file @@ -173,7 +199,7 @@ func (storage *PublishedStorage) LinkFromPool(publishedPrefix, publishedRelPath, } else { // if source and destination have the same checksums, no need to copy var dstMD5 string - dstMD5, err = utils.MD5ChecksumForFile(filepath.Join(poolPath, baseName)) + dstMD5, err = utils.MD5ChecksumForFile(destinationPath) if err != nil { return err @@ -204,11 +230,11 @@ func (storage *PublishedStorage) LinkFromPool(publishedPrefix, publishedRelPath, // source and destination have different inodes, if !forced, this is fatal error if !force { - return fmt.Errorf("error linking file to %s: file already exists and is different", filepath.Join(poolPath, baseName)) + return fmt.Errorf("error linking file to %s: file already exists and is different", destinationPath) } // forced, so remove destination - err = os.Remove(filepath.Join(poolPath, baseName)) + err = os.Remove(destinationPath) if err != nil { return err } @@ -223,7 +249,7 @@ func (storage *PublishedStorage) LinkFromPool(publishedPrefix, publishedRelPath, } var dst *os.File - dst, err = os.Create(filepath.Join(poolPath, baseName)) + dst, err = os.Create(destinationPath) if err != nil { _ = r.Close() return err @@ -244,9 +270,9 @@ func (storage *PublishedStorage) LinkFromPool(publishedPrefix, publishedRelPath, err = dst.Close() } else if storage.linkMethod == LinkMethodSymLink { - err = localSourcePool.Symlink(sourcePath, filepath.Join(poolPath, baseName)) + err = localSourcePool.Symlink(sourcePath, destinationPath) } else { - err = localSourcePool.Link(sourcePath, filepath.Join(poolPath, baseName)) + err = localSourcePool.Link(sourcePath, destinationPath) } return err