diff --git a/command.py b/command.py index e74a94fa3..ee1f68c14 100644 --- a/command.py +++ b/command.py @@ -101,6 +101,11 @@ class Command: def WantPager(self, _opt): return False + @staticmethod + def is_multiprocessing_active() -> bool: + """Whether the current process is a worker in a pool.""" + return multiprocessing.current_process().name != "MainProcess" + def ReadEnvironmentOptions(self, opts): """Set options from environment variables.""" diff --git a/subcmds/sync.py b/subcmds/sync.py index b1c30867e..2517694e7 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py @@ -841,6 +841,8 @@ later is required to fix a server side protocol bug. ) except KeyboardInterrupt: logger.error("Keyboard interrupt while processing %s", project.name) + if not cls.is_multiprocessing_active(): + raise except GitError as e: logger.error("error.GitError: Cannot fetch %s", e) errors.append(e) @@ -1104,6 +1106,8 @@ later is required to fix a server side protocol bug. errors.extend(syncbuf.errors) except KeyboardInterrupt: logger.error("Keyboard interrupt while processing %s", project.name) + if not cls.is_multiprocessing_active(): + raise except GitError as e: logger.error( "error.GitError: Cannot checkout %s: %s", project.name, e @@ -2410,6 +2414,8 @@ later is required to fix a server side protocol bug. logger.error( "Keyboard interrupt while processing %s", project.name ) + if not cls.is_multiprocessing_active(): + raise except GitError as e: fetch_errors.append(e) logger.error("error.GitError: Cannot fetch %s", e) @@ -2460,6 +2466,8 @@ later is required to fix a server side protocol bug. logger.error( "Keyboard interrupt while processing %s", project.name ) + if not cls.is_multiprocessing_active(): + raise except GitError as e: checkout_errors.append(e) logger.error( diff --git a/tests/test_subcmds_sync.py b/tests/test_subcmds_sync.py index 0a46bbdf3..669027c98 100644 --- a/tests/test_subcmds_sync.py +++ b/tests/test_subcmds_sync.py @@ -477,6 +477,56 @@ class GetPreciousObjectsState(unittest.TestCase): ) +class KeyboardInterruptTest(unittest.TestCase): + """Tests for KeyboardInterrupt handling in Sync operations.""" + + def setUp(self): + self.project = mock.MagicMock(name="project") + self.project.name = "project" + self.project.relpath = "proj" + self.project.manifest.IsArchive = False + self.opt = mock.Mock() + self.opt.quiet = True + self.opt.verbose = False + self.opt.tags = False + + self.sync_dict = {} + + self.get_parallel_context_mock = { + "projects": [self.project], + "sync_dict": self.sync_dict, + "ssh_proxy": None, + } + + @mock.patch("subcmds.sync.Sync.is_multiprocessing_active") + def test_fetch_one_keyboard_interrupt_main_process(self, mock_is_active): + """Test that _FetchOne re-raises KeyboardInterrupt if not worker.""" + mock_is_active.return_value = False + self.project.Sync_NetworkHalf.side_effect = KeyboardInterrupt() + + with mock.patch.object( + sync.Sync, + "get_parallel_context", + return_value=self.get_parallel_context_mock, + ): + with self.assertRaises(KeyboardInterrupt): + sync.Sync._FetchOne(self.opt, 0) + + @mock.patch("subcmds.sync.Sync.is_multiprocessing_active") + def test_fetch_one_keyboard_interrupt_worker_process(self, mock_is_active): + """Test that _FetchOne suppresses KeyboardInterrupt in workers.""" + mock_is_active.return_value = True + self.project.Sync_NetworkHalf.side_effect = KeyboardInterrupt() + + with mock.patch.object( + sync.Sync, + "get_parallel_context", + return_value=self.get_parallel_context_mock, + ): + result = sync.Sync._FetchOne(self.opt, 0) + self.assertFalse(result.success) + + class CheckForBloatedProjects(unittest.TestCase): """Tests for Sync._CheckForBloatedProjects."""