diff --git a/ssh.py b/ssh.py index ffa0d6c02..6de8b89e8 100644 --- a/ssh.py +++ b/ssh.py @@ -52,12 +52,12 @@ def _parse_ssh_version(ver_str=None): @functools.lru_cache(maxsize=None) def version(): - """return ssh version as a tuple""" + """Return ssh version as a tuple. + + If ssh is not available, a FileNotFoundError will be raised. + """ try: return _parse_ssh_version() - except FileNotFoundError: - print("fatal: ssh not installed", file=sys.stderr) - sys.exit(1) except subprocess.CalledProcessError as e: print( "fatal: unable to detect ssh version" @@ -102,9 +102,18 @@ class ProxyManager: self._clients = manager.list() # Path to directory for holding master sockets. self._sock_path = None + # See if ssh is usable. + self._ssh_installed = False def __enter__(self): """Enter a new context.""" + # Check which version of ssh is available. + try: + version() + self._ssh_installed = True + except FileNotFoundError: + self._ssh_installed = False + return self def __exit__(self, exc_type, exc_value, traceback): @@ -282,6 +291,9 @@ class ProxyManager: def preconnect(self, url): """If |uri| will create a ssh connection, setup the ssh master for it.""" # noqa: E501 + if not self._ssh_installed: + return False + m = URI_ALL.match(url) if m: scheme = m.group(1) @@ -306,6 +318,9 @@ class ProxyManager: This has all the master sockets so clients can talk to them. """ + if not self._ssh_installed: + return None + if self._sock_path is None: if not create: return None diff --git a/tests/test_ssh.py b/tests/test_ssh.py index a15c0b5fe..ce65fac45 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -25,6 +25,10 @@ import ssh class SshTests(unittest.TestCase): """Tests the ssh functions.""" + def setUp(self) -> None: + super().setUp() + ssh.version.cache_clear() + def test_parse_ssh_version(self): """Check _parse_ssh_version() handling.""" ver = ssh._parse_ssh_version("Unknown\n") @@ -56,11 +60,12 @@ class SshTests(unittest.TestCase): def test_context_manager_child_cleanup(self): """Verify orphaned clients & masters get cleaned up.""" with multiprocessing.Manager() as manager: - with ssh.ProxyManager(manager) as ssh_proxy: - client = subprocess.Popen(["sleep", "964853320"]) - ssh_proxy.add_client(client) - master = subprocess.Popen(["sleep", "964853321"]) - ssh_proxy.add_master(master) + with mock.patch("ssh.version", return_value=(1, 2)): + with ssh.ProxyManager(manager) as ssh_proxy: + client = subprocess.Popen(["sleep", "964853320"]) + ssh_proxy.add_client(client) + master = subprocess.Popen(["sleep", "964853321"]) + ssh_proxy.add_master(master) # If the process still exists, these will throw timeout errors. client.wait(0) master.wait(0) @@ -72,9 +77,11 @@ class SshTests(unittest.TestCase): with mock.patch("tempfile.mkdtemp", return_value="/tmp/foo"): # Old ssh version uses port. with mock.patch("ssh.version", return_value=(6, 6)): - self.assertTrue(proxy.sock().endswith("%p")) + with proxy as ssh_proxy: + self.assertTrue(ssh_proxy.sock().endswith("%p")) proxy._sock_path = None # New ssh version uses hash. with mock.patch("ssh.version", return_value=(6, 7)): - self.assertTrue(proxy.sock().endswith("%C")) + with proxy as ssh_proxy: + self.assertTrue(ssh_proxy.sock().endswith("%C"))