From f24bc7aed59a8fb4b89afa06b82eae1aad067983 Mon Sep 17 00:00:00 2001 From: Mike Frysinger Date: Wed, 18 Mar 2026 11:17:12 -0400 Subject: [PATCH] tests: switch some test modules to pytest Change-Id: I524b5ff2d77f8232f94e21921b00ba4027d2ac4f Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/563081 Tested-by: Mike Frysinger Reviewed-by: Gavin Mak Commit-Queue: Mike Frysinger --- tests/test_color.py | 86 ++++---- tests/test_editor.py | 47 ++--- tests/test_error.py | 65 +++--- tests/test_git_config.py | 361 ++++++++++++++++++---------------- tests/test_hooks.py | 69 ++++--- tests/test_platform_utils.py | 50 +++-- tests/test_repo_logging.py | 148 +++++++------- tests/test_repo_trace.py | 57 +++--- tests/test_ssh.py | 124 ++++++------ tests/test_subcmds.py | 246 +++++++++++------------ tests/test_subcmds_init.py | 51 ++--- tests/test_update_manpages.py | 13 +- 12 files changed, 667 insertions(+), 650 deletions(-) diff --git a/tests/test_color.py b/tests/test_color.py index 91a1bf250..923f7e355 100644 --- a/tests/test_color.py +++ b/tests/test_color.py @@ -15,60 +15,66 @@ """Unittests for the color.py module.""" import os -import unittest + +import pytest import color import git_config -def fixture(*paths): +def fixture(*paths: str) -> str: """Return a path relative to test/fixtures.""" return os.path.join(os.path.dirname(__file__), "fixtures", *paths) -class ColoringTests(unittest.TestCase): - """tests of the Coloring class.""" +@pytest.fixture +def coloring() -> color.Coloring: + """Create a Coloring object for testing.""" + config_fixture = fixture("test.gitconfig") + config = git_config.GitConfig(config_fixture) + color.SetDefaultColoring("true") + return color.Coloring(config, "status") - def setUp(self): - """Create a GitConfig object using the test.gitconfig fixture.""" - config_fixture = fixture("test.gitconfig") - self.config = git_config.GitConfig(config_fixture) - color.SetDefaultColoring("true") - self.color = color.Coloring(self.config, "status") - def test_Color_Parse_all_params_none(self): - """all params are None""" - val = self.color._parse(None, None, None, None) - self.assertEqual("", val) +def test_Color_Parse_all_params_none(coloring: color.Coloring) -> None: + """all params are None""" + val = coloring._parse(None, None, None, None) + assert val == "" - def test_Color_Parse_first_parameter_none(self): - """check fg & bg & attr""" - val = self.color._parse(None, "black", "red", "ul") - self.assertEqual("\x1b[4;30;41m", val) - def test_Color_Parse_one_entry(self): - """check fg""" - val = self.color._parse("one", None, None, None) - self.assertEqual("\033[33m", val) +def test_Color_Parse_first_parameter_none(coloring: color.Coloring) -> None: + """check fg & bg & attr""" + val = coloring._parse(None, "black", "red", "ul") + assert val == "\x1b[4;30;41m" - def test_Color_Parse_two_entry(self): - """check fg & bg""" - val = self.color._parse("two", None, None, None) - self.assertEqual("\033[35;46m", val) - def test_Color_Parse_three_entry(self): - """check fg & bg & attr""" - val = self.color._parse("three", None, None, None) - self.assertEqual("\033[4;30;41m", val) +def test_Color_Parse_one_entry(coloring: color.Coloring) -> None: + """check fg""" + val = coloring._parse("one", None, None, None) + assert val == "\033[33m" - def test_Color_Parse_reset_entry(self): - """check reset entry""" - val = self.color._parse("reset", None, None, None) - self.assertEqual("\033[m", val) - def test_Color_Parse_empty_entry(self): - """check empty entry""" - val = self.color._parse("none", "blue", "white", "dim") - self.assertEqual("\033[2;34;47m", val) - val = self.color._parse("empty", "green", "white", "bold") - self.assertEqual("\033[1;32;47m", val) +def test_Color_Parse_two_entry(coloring: color.Coloring) -> None: + """check fg & bg""" + val = coloring._parse("two", None, None, None) + assert val == "\033[35;46m" + + +def test_Color_Parse_three_entry(coloring: color.Coloring) -> None: + """check fg & bg & attr""" + val = coloring._parse("three", None, None, None) + assert val == "\033[4;30;41m" + + +def test_Color_Parse_reset_entry(coloring: color.Coloring) -> None: + """check reset entry""" + val = coloring._parse("reset", None, None, None) + assert val == "\033[m" + + +def test_Color_Parse_empty_entry(coloring: color.Coloring) -> None: + """check empty entry""" + val = coloring._parse("none", "blue", "white", "dim") + assert val == "\033[2;34;47m" + val = coloring._parse("empty", "green", "white", "bold") + assert val == "\033[1;32;47m" diff --git a/tests/test_editor.py b/tests/test_editor.py index 8f5d160e2..d51765dd1 100644 --- a/tests/test_editor.py +++ b/tests/test_editor.py @@ -14,43 +14,32 @@ """Unittests for the editor.py module.""" -import unittest +import pytest from editor import Editor -class EditorTestCase(unittest.TestCase): +@pytest.fixture(autouse=True) +def reset_editor() -> None: """Take care of resetting Editor state across tests.""" - - def setUp(self): - self.setEditor(None) - - def tearDown(self): - self.setEditor(None) - - @staticmethod - def setEditor(editor): - Editor._editor = editor + Editor._editor = None + yield + Editor._editor = None -class GetEditor(EditorTestCase): - """Check GetEditor behavior.""" - - def test_basic(self): - """Basic checking of _GetEditor.""" - self.setEditor(":") - self.assertEqual(":", Editor._GetEditor()) +def test_basic() -> None: + """Basic checking of _GetEditor.""" + Editor._editor = ":" + assert Editor._GetEditor() == ":" -class EditString(EditorTestCase): - """Check EditString behavior.""" +def test_no_editor() -> None: + """Check behavior when no editor is available.""" + Editor._editor = ":" + assert Editor.EditString("foo") == "foo" - def test_no_editor(self): - """Check behavior when no editor is available.""" - self.setEditor(":") - self.assertEqual("foo", Editor.EditString("foo")) - def test_cat_editor(self): - """Check behavior when editor is `cat`.""" - self.setEditor("cat") - self.assertEqual("foo", Editor.EditString("foo")) +def test_cat_editor() -> None: + """Check behavior when editor is `cat`.""" + Editor._editor = "cat" + assert Editor.EditString("foo") == "foo" diff --git a/tests/test_error.py b/tests/test_error.py index a0ff32c31..3eb82835e 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -16,7 +16,9 @@ import inspect import pickle -import unittest +from typing import Iterator, Type + +import pytest import command import error @@ -26,7 +28,7 @@ import project from subcmds import all_modules -imports = all_modules + [ +_IMPORTS = all_modules + [ error, project, git_command, @@ -35,36 +37,35 @@ imports = all_modules + [ ] -class PickleTests(unittest.TestCase): - """Make sure all our custom exceptions can be pickled.""" +def get_exceptions() -> Iterator[Type[Exception]]: + """Return all our custom exceptions.""" + for entry in _IMPORTS: + for name in dir(entry): + cls = getattr(entry, name) + if isinstance(cls, type) and issubclass(cls, Exception): + yield cls - def getExceptions(self): - """Return all our custom exceptions.""" - for entry in imports: - for name in dir(entry): - cls = getattr(entry, name) - if isinstance(cls, type) and issubclass(cls, Exception): - yield cls - def testExceptionLookup(self): - """Make sure our introspection logic works.""" - classes = list(self.getExceptions()) - self.assertIn(error.HookError, classes) - # Don't assert the exact number to avoid being a change-detector test. - self.assertGreater(len(classes), 10) +def test_exception_lookup() -> None: + """Make sure our introspection logic works.""" + classes = list(get_exceptions()) + assert error.HookError in classes + # Don't assert the exact number to avoid being a change-detector test. + assert len(classes) > 10 - def testPickle(self): - """Try to pickle all the exceptions.""" - for cls in self.getExceptions(): - args = inspect.getfullargspec(cls.__init__).args[1:] - obj = cls(*args) - p = pickle.dumps(obj) - try: - newobj = pickle.loads(p) - except Exception as e: # pylint: disable=broad-except - self.fail( - "Class %s is unable to be pickled: %s\n" - "Incomplete super().__init__(...) call?" % (cls, e) - ) - self.assertIsInstance(newobj, cls) - self.assertEqual(str(obj), str(newobj)) + +@pytest.mark.parametrize("cls", get_exceptions()) +def test_pickle(cls: Type[Exception]) -> None: + """Try to pickle all the exceptions.""" + args = inspect.getfullargspec(cls.__init__).args[1:] + obj = cls(*args) + p = pickle.dumps(obj) + try: + newobj = pickle.loads(p) + except Exception as e: + pytest.fail( + f"Class {cls} is unable to be pickled: {e}\n" + "Incomplete super().__init__(...) call?" + ) + assert isinstance(newobj, cls) + assert str(obj) == str(newobj) diff --git a/tests/test_git_config.py b/tests/test_git_config.py index e1604bd76..496d97141 100644 --- a/tests/test_git_config.py +++ b/tests/test_git_config.py @@ -15,200 +15,219 @@ """Unittests for the git_config.py module.""" import os -import tempfile -import unittest +from pathlib import Path +from typing import Any + +import pytest import git_config -def fixture(*paths): +def fixture_path(*paths: str) -> str: """Return a path relative to test/fixtures.""" return os.path.join(os.path.dirname(__file__), "fixtures", *paths) -class GitConfigReadOnlyTests(unittest.TestCase): - """Read-only tests of the GitConfig class.""" - - def setUp(self): - """Create a GitConfig object using the test.gitconfig fixture.""" - config_fixture = fixture("test.gitconfig") - self.config = git_config.GitConfig(config_fixture) - - def test_GetString_with_empty_config_values(self): - """ - Test config entries with no value. - - [section] - empty - - """ - val = self.config.GetString("section.empty") - self.assertEqual(val, None) - - def test_GetString_with_true_value(self): - """ - Test config entries with a string value. - - [section] - nonempty = true - - """ - val = self.config.GetString("section.nonempty") - self.assertEqual(val, "true") - - def test_GetString_from_missing_file(self): - """ - Test missing config file - """ - config_fixture = fixture("not.present.gitconfig") - config = git_config.GitConfig(config_fixture) - val = config.GetString("empty") - self.assertEqual(val, None) - - def test_GetBoolean_undefined(self): - """Test GetBoolean on key that doesn't exist.""" - self.assertIsNone(self.config.GetBoolean("section.missing")) - - def test_GetBoolean_invalid(self): - """Test GetBoolean on invalid boolean value.""" - self.assertIsNone(self.config.GetBoolean("section.boolinvalid")) - - def test_GetBoolean_true(self): - """Test GetBoolean on valid true boolean.""" - self.assertTrue(self.config.GetBoolean("section.booltrue")) - - def test_GetBoolean_false(self): - """Test GetBoolean on valid false boolean.""" - self.assertFalse(self.config.GetBoolean("section.boolfalse")) - - def test_GetInt_undefined(self): - """Test GetInt on key that doesn't exist.""" - self.assertIsNone(self.config.GetInt("section.missing")) - - def test_GetInt_invalid(self): - """Test GetInt on invalid integer value.""" - self.assertIsNone(self.config.GetBoolean("section.intinvalid")) - - def test_GetInt_valid(self): - """Test GetInt on valid integers.""" - TESTS = ( - ("inthex", 16), - ("inthexk", 16384), - ("int", 10), - ("intk", 10240), - ("intm", 10485760), - ("intg", 10737418240), - ) - for key, value in TESTS: - self.assertEqual(value, self.config.GetInt(f"section.{key}")) +@pytest.fixture +def readonly_config() -> git_config.GitConfig: + """Create a GitConfig object using the test.gitconfig fixture.""" + config_fixture = fixture_path("test.gitconfig") + return git_config.GitConfig(config_fixture) -class GitConfigReadWriteTests(unittest.TestCase): - """Read/write tests of the GitConfig class.""" +def test_get_string_with_empty_config_values( + readonly_config: git_config.GitConfig, +) -> None: + """Test config entries with no value. - def setUp(self): - self.tmpfile = tempfile.NamedTemporaryFile() - self.config = self.get_config() + [section] + empty - def get_config(self): - """Get a new GitConfig instance.""" - return git_config.GitConfig(self.tmpfile.name) + """ + val = readonly_config.GetString("section.empty") + assert val is None - def test_SetString(self): - """Test SetString behavior.""" - # Set a value. - self.assertIsNone(self.config.GetString("foo.bar")) - self.config.SetString("foo.bar", "val") - self.assertEqual("val", self.config.GetString("foo.bar")) - # Make sure the value was actually written out. - config = self.get_config() - self.assertEqual("val", config.GetString("foo.bar")) +def test_get_string_with_true_value( + readonly_config: git_config.GitConfig, +) -> None: + """Test config entries with a string value. - # Update the value. - self.config.SetString("foo.bar", "valll") - self.assertEqual("valll", self.config.GetString("foo.bar")) - config = self.get_config() - self.assertEqual("valll", config.GetString("foo.bar")) + [section] + nonempty = true - # Delete the value. - self.config.SetString("foo.bar", None) - self.assertIsNone(self.config.GetString("foo.bar")) - config = self.get_config() - self.assertIsNone(config.GetString("foo.bar")) + """ + val = readonly_config.GetString("section.nonempty") + assert val == "true" - def test_SetBoolean(self): - """Test SetBoolean behavior.""" - # Set a true value. - self.assertIsNone(self.config.GetBoolean("foo.bar")) - for val in (True, 1): - self.config.SetBoolean("foo.bar", val) - self.assertTrue(self.config.GetBoolean("foo.bar")) - # Make sure the value was actually written out. - config = self.get_config() - self.assertTrue(config.GetBoolean("foo.bar")) - self.assertEqual("true", config.GetString("foo.bar")) +def test_get_string_from_missing_file() -> None: + """Test missing config file.""" + config_fixture = fixture_path("not.present.gitconfig") + config = git_config.GitConfig(config_fixture) + val = config.GetString("empty") + assert val is None - # Set a false value. - for val in (False, 0): - self.config.SetBoolean("foo.bar", val) - self.assertFalse(self.config.GetBoolean("foo.bar")) - # Make sure the value was actually written out. - config = self.get_config() - self.assertFalse(config.GetBoolean("foo.bar")) - self.assertEqual("false", config.GetString("foo.bar")) +def test_get_boolean_undefined(readonly_config: git_config.GitConfig) -> None: + """Test GetBoolean on key that doesn't exist.""" + assert readonly_config.GetBoolean("section.missing") is None - # Delete the value. - self.config.SetBoolean("foo.bar", None) - self.assertIsNone(self.config.GetBoolean("foo.bar")) - config = self.get_config() - self.assertIsNone(config.GetBoolean("foo.bar")) - def test_SetInt(self): - """Test SetInt behavior.""" - # Set a value. - self.assertIsNone(self.config.GetInt("foo.bar")) - self.config.SetInt("foo.bar", 10) - self.assertEqual(10, self.config.GetInt("foo.bar")) +def test_get_boolean_invalid(readonly_config: git_config.GitConfig) -> None: + """Test GetBoolean on invalid boolean value.""" + assert readonly_config.GetBoolean("section.boolinvalid") is None - # Make sure the value was actually written out. - config = self.get_config() - self.assertEqual(10, config.GetInt("foo.bar")) - self.assertEqual("10", config.GetString("foo.bar")) - # Update the value. - self.config.SetInt("foo.bar", 20) - self.assertEqual(20, self.config.GetInt("foo.bar")) - config = self.get_config() - self.assertEqual(20, config.GetInt("foo.bar")) +def test_get_boolean_true(readonly_config: git_config.GitConfig) -> None: + """Test GetBoolean on valid true boolean.""" + assert readonly_config.GetBoolean("section.booltrue") is True - # Delete the value. - self.config.SetInt("foo.bar", None) - self.assertIsNone(self.config.GetInt("foo.bar")) - config = self.get_config() - self.assertIsNone(config.GetInt("foo.bar")) - def test_GetSyncAnalysisStateData(self): - """Test config entries with a sync state analysis data.""" - superproject_logging_data = {} - superproject_logging_data["test"] = False - options = type("options", (object,), {})() - options.verbose = "true" - options.mp_update = "false" - TESTS = ( - ("superproject.test", "false"), - ("options.verbose", "true"), - ("options.mpupdate", "false"), - ("main.version", "1"), - ) - self.config.UpdateSyncAnalysisState(options, superproject_logging_data) - sync_data = self.config.GetSyncAnalysisStateData() - for key, value in TESTS: - self.assertEqual( - sync_data[f"{git_config.SYNC_STATE_PREFIX}{key}"], value - ) - self.assertTrue( - sync_data[f"{git_config.SYNC_STATE_PREFIX}main.synctime"] - ) +def test_get_boolean_false(readonly_config: git_config.GitConfig) -> None: + """Test GetBoolean on valid false boolean.""" + assert readonly_config.GetBoolean("section.boolfalse") is False + + +def test_get_int_undefined(readonly_config: git_config.GitConfig) -> None: + """Test GetInt on key that doesn't exist.""" + assert readonly_config.GetInt("section.missing") is None + + +def test_get_int_invalid(readonly_config: git_config.GitConfig) -> None: + """Test GetInt on invalid integer value.""" + assert readonly_config.GetInt("section.intinvalid") is None + + +@pytest.mark.parametrize( + "key, expected", + ( + ("inthex", 16), + ("inthexk", 16384), + ("int", 10), + ("intk", 10240), + ("intm", 10485760), + ("intg", 10737418240), + ), +) +def test_get_int_valid( + readonly_config: git_config.GitConfig, key: str, expected: int +) -> None: + """Test GetInt on valid integers.""" + assert readonly_config.GetInt(f"section.{key}") == expected + + +@pytest.fixture +def rw_config_file(tmp_path: Path) -> Path: + """Return a path to a temporary config file.""" + return tmp_path / "config" + + +def test_set_string(rw_config_file: Path) -> None: + """Test SetString behavior.""" + config = git_config.GitConfig(str(rw_config_file)) + + # Set a value. + assert config.GetString("foo.bar") is None + config.SetString("foo.bar", "val") + assert config.GetString("foo.bar") == "val" + + # Make sure the value was actually written out. + config2 = git_config.GitConfig(str(rw_config_file)) + assert config2.GetString("foo.bar") == "val" + + # Update the value. + config.SetString("foo.bar", "valll") + assert config.GetString("foo.bar") == "valll" + config3 = git_config.GitConfig(str(rw_config_file)) + assert config3.GetString("foo.bar") == "valll" + + # Delete the value. + config.SetString("foo.bar", None) + assert config.GetString("foo.bar") is None + config4 = git_config.GitConfig(str(rw_config_file)) + assert config4.GetString("foo.bar") is None + + +def test_set_boolean(rw_config_file: Path) -> None: + """Test SetBoolean behavior.""" + config = git_config.GitConfig(str(rw_config_file)) + + # Set a true value. + assert config.GetBoolean("foo.bar") is None + for val in (True, 1): + config.SetBoolean("foo.bar", val) + assert config.GetBoolean("foo.bar") is True + + # Make sure the value was actually written out. + config2 = git_config.GitConfig(str(rw_config_file)) + assert config2.GetBoolean("foo.bar") is True + assert config2.GetString("foo.bar") == "true" + + # Set a false value. + for val in (False, 0): + config.SetBoolean("foo.bar", val) + assert config.GetBoolean("foo.bar") is False + + # Make sure the value was actually written out. + config3 = git_config.GitConfig(str(rw_config_file)) + assert config3.GetBoolean("foo.bar") is False + assert config3.GetString("foo.bar") == "false" + + # Delete the value. + config.SetBoolean("foo.bar", None) + assert config.GetBoolean("foo.bar") is None + config4 = git_config.GitConfig(str(rw_config_file)) + assert config4.GetBoolean("foo.bar") is None + + +def test_set_int(rw_config_file: Path) -> None: + """Test SetInt behavior.""" + config = git_config.GitConfig(str(rw_config_file)) + + # Set a value. + assert config.GetInt("foo.bar") is None + config.SetInt("foo.bar", 10) + assert config.GetInt("foo.bar") == 10 + + # Make sure the value was actually written out. + config2 = git_config.GitConfig(str(rw_config_file)) + assert config2.GetInt("foo.bar") == 10 + assert config2.GetString("foo.bar") == "10" + + # Update the value. + config.SetInt("foo.bar", 20) + assert config.GetInt("foo.bar") == 20 + config3 = git_config.GitConfig(str(rw_config_file)) + assert config3.GetInt("foo.bar") == 20 + + # Delete the value. + config.SetInt("foo.bar", None) + assert config.GetInt("foo.bar") is None + config4 = git_config.GitConfig(str(rw_config_file)) + assert config4.GetInt("foo.bar") is None + + +def test_get_sync_analysis_state_data(rw_config_file: Path) -> None: + """Test config entries with a sync state analysis data.""" + config = git_config.GitConfig(str(rw_config_file)) + superproject_logging_data: dict[str, Any] = {"test": False} + + class Options: + """Container for testing.""" + + options = Options() + options.verbose = "true" + options.mp_update = "false" + + TESTS = ( + ("superproject.test", "false"), + ("options.verbose", "true"), + ("options.mpupdate", "false"), + ("main.version", "1"), + ) + config.UpdateSyncAnalysisState(options, superproject_logging_data) + sync_data = config.GetSyncAnalysisStateData() + for key, value in TESTS: + assert sync_data[f"{git_config.SYNC_STATE_PREFIX}{key}"] == value + assert sync_data[f"{git_config.SYNC_STATE_PREFIX}main.synctime"] diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 76e928f7a..9d52d1849 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -14,42 +14,47 @@ """Unittests for the hooks.py module.""" -import unittest +import pytest import hooks -class RepoHookShebang(unittest.TestCase): - """Check shebang parsing in RepoHook.""" +@pytest.mark.parametrize( + "data", + ( + "", + "#\n# foo\n", + "# Bad shebang in script\n#!/foo\n", + ), +) +def test_no_shebang(data: str) -> None: + """Lines w/out shebangs should be rejected.""" + assert hooks.RepoHook._ExtractInterpFromShebang(data) is None - def test_no_shebang(self): - """Lines w/out shebangs should be rejected.""" - DATA = ("", "#\n# foo\n", "# Bad shebang in script\n#!/foo\n") - for data in DATA: - self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data)) - def test_direct_interp(self): - """Lines whose shebang points directly to the interpreter.""" - DATA = ( - ("#!/foo", "/foo"), - ("#! /foo", "/foo"), - ("#!/bin/foo ", "/bin/foo"), - ("#! /usr/foo ", "/usr/foo"), - ("#! /usr/foo -args", "/usr/foo"), - ) - for shebang, interp in DATA: - self.assertEqual( - hooks.RepoHook._ExtractInterpFromShebang(shebang), interp - ) +@pytest.mark.parametrize( + "shebang, interp", + ( + ("#!/foo", "/foo"), + ("#! /foo", "/foo"), + ("#!/bin/foo ", "/bin/foo"), + ("#! /usr/foo ", "/usr/foo"), + ("#! /usr/foo -args", "/usr/foo"), + ), +) +def test_direct_interp(shebang: str, interp: str) -> None: + """Lines whose shebang points directly to the interpreter.""" + assert hooks.RepoHook._ExtractInterpFromShebang(shebang) == interp - def test_env_interp(self): - """Lines whose shebang launches through `env`.""" - DATA = ( - ("#!/usr/bin/env foo", "foo"), - ("#!/bin/env foo", "foo"), - ("#! /bin/env /bin/foo ", "/bin/foo"), - ) - for shebang, interp in DATA: - self.assertEqual( - hooks.RepoHook._ExtractInterpFromShebang(shebang), interp - ) + +@pytest.mark.parametrize( + "shebang, interp", + ( + ("#!/usr/bin/env foo", "foo"), + ("#!/bin/env foo", "foo"), + ("#! /bin/env /bin/foo ", "/bin/foo"), + ), +) +def test_env_interp(shebang: str, interp: str) -> None: + """Lines whose shebang launches through `env`.""" + assert hooks.RepoHook._ExtractInterpFromShebang(shebang) == interp diff --git a/tests/test_platform_utils.py b/tests/test_platform_utils.py index aab8a52ab..e89965391 100644 --- a/tests/test_platform_utils.py +++ b/tests/test_platform_utils.py @@ -14,39 +14,35 @@ """Unittests for the platform_utils.py module.""" -import os -import tempfile -import unittest +from pathlib import Path + +import pytest import platform_utils -class RemoveTests(unittest.TestCase): - """Check remove() helper.""" +def test_remove_missing_ok(tmp_path: Path) -> None: + """Check missing_ok handling.""" + path = tmp_path / "test" - def testMissingOk(self): - """Check missing_ok handling.""" - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "test") + # Should not fail. + platform_utils.remove(path, missing_ok=True) - # Should not fail. - platform_utils.remove(path, missing_ok=True) + # Should fail. + with pytest.raises(OSError): + platform_utils.remove(path) + with pytest.raises(OSError): + platform_utils.remove(path, missing_ok=False) - # Should fail. - self.assertRaises(OSError, platform_utils.remove, path) - self.assertRaises( - OSError, platform_utils.remove, path, missing_ok=False - ) + # Should not fail if it exists. + path.touch() + platform_utils.remove(path, missing_ok=True) + assert not path.exists() - # Should not fail if it exists. - open(path, "w").close() - platform_utils.remove(path, missing_ok=True) - self.assertFalse(os.path.exists(path)) + path.touch() + platform_utils.remove(path) + assert not path.exists() - open(path, "w").close() - platform_utils.remove(path) - self.assertFalse(os.path.exists(path)) - - open(path, "w").close() - platform_utils.remove(path, missing_ok=False) - self.assertFalse(os.path.exists(path)) + path.touch() + platform_utils.remove(path, missing_ok=False) + assert not path.exists() diff --git a/tests/test_repo_logging.py b/tests/test_repo_logging.py index e072039ed..c5bba6b8e 100644 --- a/tests/test_repo_logging.py +++ b/tests/test_repo_logging.py @@ -12,90 +12,96 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit test for repo_logging module.""" +"""Unittests for the repo_logging.py module.""" import contextlib import io import logging -import unittest +import re from unittest import mock +import pytest + from color import SetDefaultColoring from error import RepoExitError from repo_logging import RepoLogger -class TestRepoLogger(unittest.TestCase): - @mock.patch.object(RepoLogger, "error") - def test_log_aggregated_errors_logs_aggregated_errors(self, mock_error): - """Test if log_aggregated_errors logs a list of aggregated errors.""" - logger = RepoLogger(__name__) - logger.log_aggregated_errors( - RepoExitError( - aggregate_errors=[ - Exception("foo"), - Exception("bar"), - Exception("baz"), - Exception("hello"), - Exception("world"), - Exception("test"), - ] - ) - ) - - mock_error.assert_has_calls( - [ - mock.call("=" * 80), - mock.call( - "Repo command failed due to the following `%s` errors:", - "RepoExitError", - ), - mock.call("foo\nbar\nbaz\nhello\nworld"), - mock.call("+%d additional errors...", 1), +@mock.patch.object(RepoLogger, "error") +def test_log_aggregated_errors_logs_aggregated_errors(mock_error) -> None: + """Test if log_aggregated_errors logs a list of aggregated errors.""" + logger = RepoLogger(__name__) + logger.log_aggregated_errors( + RepoExitError( + aggregate_errors=[ + Exception("foo"), + Exception("bar"), + Exception("baz"), + Exception("hello"), + Exception("world"), + Exception("test"), ] ) + ) - @mock.patch.object(RepoLogger, "error") - def test_log_aggregated_errors_logs_single_error(self, mock_error): - """Test if log_aggregated_errors logs empty aggregated_errors.""" + mock_error.assert_has_calls( + [ + mock.call("=" * 80), + mock.call( + "Repo command failed due to the following `%s` errors:", + "RepoExitError", + ), + mock.call("foo\nbar\nbaz\nhello\nworld"), + mock.call("+%d additional errors...", 1), + ] + ) + + +@mock.patch.object(RepoLogger, "error") +def test_log_aggregated_errors_logs_single_error(mock_error) -> None: + """Test if log_aggregated_errors logs empty aggregated_errors.""" + logger = RepoLogger(__name__) + logger.log_aggregated_errors(RepoExitError()) + + mock_error.assert_has_calls( + [ + mock.call("=" * 80), + mock.call("Repo command failed: %s", "RepoExitError"), + ] + ) + + +@pytest.mark.parametrize( + "level", + ( + logging.INFO, + logging.WARN, + logging.ERROR, + ), +) +def test_log_with_format_string(level: int) -> None: + """Test different log levels with format strings.""" + name = logging.getLevelName(level) + + # Set color output to "always" for consistent test results. + # This ensures the logger's behavior is uniform across different + # environments and git configurations. + SetDefaultColoring("always") + + # Regex pattern to match optional ANSI color codes. + # \033 - Escape character + # \[ - Opening square bracket + # [0-9;]* - Zero or more digits or semicolons + # m - Ending 'm' character + # ? - Makes the entire group optional + opt_color = r"(\033\[[0-9;]*m)?" + + output = io.StringIO() + + with contextlib.redirect_stderr(output): logger = RepoLogger(__name__) - logger.log_aggregated_errors(RepoExitError()) + logger.log(level, "%s", "100% pass") - mock_error.assert_has_calls( - [ - mock.call("=" * 80), - mock.call("Repo command failed: %s", "RepoExitError"), - ] - ) - - def test_log_with_format_string(self): - """Test different log levels with format strings.""" - - # Set color output to "always" for consistent test results. - # This ensures the logger's behavior is uniform across different - # environments and git configurations. - SetDefaultColoring("always") - - # Regex pattern to match optional ANSI color codes. - # \033 - Escape character - # \[ - Opening square bracket - # [0-9;]* - Zero or more digits or semicolons - # m - Ending 'm' character - # ? - Makes the entire group optional - opt_color = r"(\033\[[0-9;]*m)?" - - for level in (logging.INFO, logging.WARN, logging.ERROR): - name = logging.getLevelName(level) - - with self.subTest(level=level, name=name): - output = io.StringIO() - - with contextlib.redirect_stderr(output): - logger = RepoLogger(__name__) - logger.log(level, "%s", "100% pass") - - self.assertRegex( - output.getvalue().strip(), - f"^{opt_color}100% pass{opt_color}$", - f"failed for level {name}", - ) + assert re.search( + f"^{opt_color}100% pass{opt_color}$", output.getvalue().strip() + ), f"failed for level {name}" diff --git a/tests/test_repo_trace.py b/tests/test_repo_trace.py index 2ea341025..3ec540b25 100644 --- a/tests/test_repo_trace.py +++ b/tests/test_repo_trace.py @@ -15,46 +15,37 @@ """Unittests for the repo_trace.py module.""" import os -import unittest -from unittest import mock + +import pytest import repo_trace -class TraceTests(unittest.TestCase): +def test_trace_max_size_enforced(monkeypatch: pytest.MonkeyPatch) -> None: """Check Trace behavior.""" + content = "git chicken" - def testTrace_MaxSizeEnforced(self): - content = "git chicken" + with repo_trace.Trace(content, first_trace=True): + pass + first_trace_size = os.path.getsize(repo_trace._TRACE_FILE) - with repo_trace.Trace(content, first_trace=True): - pass - first_trace_size = os.path.getsize(repo_trace._TRACE_FILE) + with repo_trace.Trace(content): + pass + assert os.path.getsize(repo_trace._TRACE_FILE) > first_trace_size - with repo_trace.Trace(content): - pass - self.assertGreater( - os.path.getsize(repo_trace._TRACE_FILE), first_trace_size - ) + # Check we clear everything if the last chunk is larger than _MAX_SIZE. + monkeypatch.setattr(repo_trace, "_MAX_SIZE", 0) + with repo_trace.Trace(content, first_trace=True): + pass + assert os.path.getsize(repo_trace._TRACE_FILE) == first_trace_size - # Check we clear everything is the last chunk is larger than _MAX_SIZE. - with mock.patch("repo_trace._MAX_SIZE", 0): - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual( - first_trace_size, os.path.getsize(repo_trace._TRACE_FILE) - ) + # Check we only clear the chunks we need to. + new_max = (first_trace_size + 1) / (1024 * 1024) + monkeypatch.setattr(repo_trace, "_MAX_SIZE", new_max) + with repo_trace.Trace(content, first_trace=True): + pass + assert os.path.getsize(repo_trace._TRACE_FILE) == first_trace_size * 2 - # Check we only clear the chunks we need to. - repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024) - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual( - first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE) - ) - - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual( - first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE) - ) + with repo_trace.Trace(content, first_trace=True): + pass + assert os.path.getsize(repo_trace._TRACE_FILE) == first_trace_size * 2 diff --git a/tests/test_ssh.py b/tests/test_ssh.py index ce65fac45..6bb2c2f4e 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -16,72 +16,84 @@ import multiprocessing import subprocess -import unittest +from typing import Tuple from unittest import mock +import pytest + import ssh -class SshTests(unittest.TestCase): - """Tests the ssh functions.""" +@pytest.fixture(autouse=True) +def clear_ssh_version_cache() -> None: + """Clear the ssh version cache before each test.""" + ssh.version.cache_clear() - def setUp(self) -> None: - super().setUp() - ssh.version.cache_clear() - def test_parse_ssh_version(self): - """Check _parse_ssh_version() handling.""" - ver = ssh._parse_ssh_version("Unknown\n") - self.assertEqual(ver, ()) - ver = ssh._parse_ssh_version("OpenSSH_1.0\n") - self.assertEqual(ver, (1, 0)) - ver = ssh._parse_ssh_version( - "OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n" - ) - self.assertEqual(ver, (6, 6, 1)) - ver = ssh._parse_ssh_version( - "OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n" - ) - self.assertEqual(ver, (7, 6)) - ver = ssh._parse_ssh_version("OpenSSH_9.0p1, LibreSSL 3.3.6\n") - self.assertEqual(ver, (9, 0)) +@pytest.mark.parametrize( + "input_str, expected", + ( + ("Unknown\n", ()), + ("OpenSSH_1.0\n", (1, 0)), + ( + "OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n", + (6, 6, 1), + ), + ( + "OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n", + (7, 6), + ), + ("OpenSSH_9.0p1, LibreSSL 3.3.6\n", (9, 0)), + ), +) +def test_parse_ssh_version(input_str: str, expected: Tuple[int, ...]) -> None: + """Check _parse_ssh_version() handling.""" + assert ssh._parse_ssh_version(input_str) == expected - def test_version(self): - """Check version() handling.""" - with mock.patch("ssh._run_ssh_version", return_value="OpenSSH_1.2\n"): - self.assertEqual(ssh.version(), (1, 2)) - def test_context_manager_empty(self): - """Verify context manager with no clients works correctly.""" - with multiprocessing.Manager() as manager: - with ssh.ProxyManager(manager): - pass +def test_version() -> None: + """Check version() handling.""" + with mock.patch("ssh._run_ssh_version", return_value="OpenSSH_1.2\n"): + assert ssh.version() == (1, 2) - def test_context_manager_child_cleanup(self): - """Verify orphaned clients & masters get cleaned up.""" - with multiprocessing.Manager() as manager: - with mock.patch("ssh.version", return_value=(1, 2)): - with ssh.ProxyManager(manager) as ssh_proxy: - client = subprocess.Popen(["sleep", "964853320"]) - ssh_proxy.add_client(client) - master = subprocess.Popen(["sleep", "964853321"]) - ssh_proxy.add_master(master) - # If the process still exists, these will throw timeout errors. - client.wait(0) - master.wait(0) - def test_ssh_sock(self): - """Check sock() function.""" - manager = multiprocessing.Manager() +def test_context_manager_empty() -> None: + """Verify context manager with no clients works correctly.""" + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager): + pass + + +def test_context_manager_child_cleanup() -> None: + """Verify orphaned clients & masters get cleaned up.""" + with multiprocessing.Manager() as manager: + with mock.patch("ssh.version", return_value=(1, 2)): + with ssh.ProxyManager(manager) as ssh_proxy: + client = subprocess.Popen(["sleep", "964853320"]) + ssh_proxy.add_client(client) + master = subprocess.Popen(["sleep", "964853321"]) + ssh_proxy.add_master(master) + # If the process still exists, these will throw timeout errors. + client.wait(0) + master.wait(0) + + +def test_ssh_sock(monkeypatch: pytest.MonkeyPatch) -> None: + """Check sock() function.""" + with multiprocessing.Manager() as manager: proxy = ssh.ProxyManager(manager) - with mock.patch("tempfile.mkdtemp", return_value="/tmp/foo"): - # Old ssh version uses port. - with mock.patch("ssh.version", return_value=(6, 6)): - with proxy as ssh_proxy: - self.assertTrue(ssh_proxy.sock().endswith("%p")) + monkeypatch.setattr( + "tempfile.mkdtemp", lambda *args, **kwargs: "/tmp/foo" + ) - proxy._sock_path = None - # New ssh version uses hash. - with mock.patch("ssh.version", return_value=(6, 7)): - with proxy as ssh_proxy: - self.assertTrue(ssh_proxy.sock().endswith("%C")) + # Old ssh version uses port. + with mock.patch("ssh.version", return_value=(6, 6)): + with proxy as ssh_proxy: + assert ssh_proxy.sock().endswith("%p") + + proxy._sock_path = None + # New ssh version uses hash. + with mock.patch("ssh.version", return_value=(6, 7)): + with proxy as ssh_proxy: + assert ssh_proxy.sock().endswith("%C") + proxy._sock_path = None diff --git a/tests/test_subcmds.py b/tests/test_subcmds.py index 3382d822d..12bbb2700 100644 --- a/tests/test_subcmds.py +++ b/tests/test_subcmds.py @@ -15,170 +15,164 @@ """Unittests for the subcmds module (mostly __init__.py than subcommands).""" import optparse -import unittest +from typing import Type +import pytest + +from command import Command import subcmds -class AllCommands(unittest.TestCase): - """Check registered all_commands.""" +# NB: We don't test all subcommands as we want to avoid "change detection" +# tests, so we just look for the most common/important ones here that are +# unlikely to ever change. +@pytest.mark.parametrize( + "cmd", ("cherry-pick", "help", "init", "start", "sync", "upload") +) +def test_required_basic(cmd: str) -> None: + """Basic checking of registered commands.""" + assert cmd in subcmds.all_commands - def test_required_basic(self): - """Basic checking of registered commands.""" - # NB: We don't test all subcommands as we want to avoid "change - # detection" tests, so we just look for the most common/important ones - # here that are unlikely to ever change. - for cmd in {"cherry-pick", "help", "init", "start", "sync", "upload"}: - self.assertIn(cmd, subcmds.all_commands) - def test_naming(self): - """Verify we don't add things that we shouldn't.""" - for cmd in subcmds.all_commands: - # Reject filename suffixes like "help.py". - self.assertNotIn(".", cmd) +@pytest.mark.parametrize("name", subcmds.all_commands.keys()) +def test_naming(name: str) -> None: + """Verify we don't add things that we shouldn't.""" + # Reject filename suffixes like "help.py". + assert "." not in name - # Make sure all '_' were converted to '-'. - self.assertNotIn("_", cmd) + # Make sure all '_' were converted to '-'. + assert "_" not in name - # Reject internal python paths like "__init__". - self.assertFalse(cmd.startswith("__")) + # Reject internal python paths like "__init__". + assert not name.startswith("__") - def test_help_desc_style(self): - """Force some consistency in option descriptions. - Python's optparse & argparse has a few default options like --help. - Their option description text uses lowercase sentence fragments, so - enforce our options follow the same style so UI is consistent. +@pytest.mark.parametrize("name, cls", subcmds.all_commands.items()) +def test_help_desc_style(name: str, cls: Type[Command]) -> None: + """Force some consistency in option descriptions. - We enforce: - * Text starts with lowercase. - * Text doesn't end with period. - """ - for name, cls in subcmds.all_commands.items(): - cmd = cls() - parser = cmd.OptionParser - for option in parser.option_list: - if option.help == optparse.SUPPRESS_HELP: - continue + Python's optparse & argparse has a few default options like --help. + Their option description text uses lowercase sentence fragments, so + enforce our options follow the same style so UI is consistent. - c = option.help[0] - self.assertEqual( - c.lower(), - c, - msg=f"subcmds/{name}.py: {option.get_opt_string()}: " - f'help text should start with lowercase: "{option.help}"', - ) + We enforce: + * Text starts with lowercase. + * Text doesn't end with period. + """ + cmd = cls() + parser = cmd.OptionParser + for option in parser.option_list: + if option.help == optparse.SUPPRESS_HELP or not option.help: + continue - self.assertNotEqual( - option.help[-1], - ".", - msg=f"subcmds/{name}.py: {option.get_opt_string()}: " - f'help text should not end in a period: "{option.help}"', - ) + c = option.help[0] + assert c.lower() == c, ( + f"subcmds/{name}.py: {option.get_opt_string()}: " + f'help text should start with lowercase: "{option.help}"' + ) - def test_cli_option_style(self): - """Force some consistency in option flags.""" - for name, cls in subcmds.all_commands.items(): - cmd = cls() - parser = cmd.OptionParser - for option in parser.option_list: - for opt in option._long_opts: - self.assertNotIn( - "_", - opt, - msg=f"subcmds/{name}.py: {opt}: only use dashes in " - "options, not underscores", - ) + assert option.help[-1] != ".", ( + f"subcmds/{name}.py: {option.get_opt_string()}: " + f'help text should not end in a period: "{option.help}"' + ) - def test_cli_option_dest(self): - """Block redundant dest= arguments.""" - def _check_dest(opt): - """Check the dest= setting.""" - # If the destination is not set, nothing to check. - # If long options are not set, then there's no implicit destination. - # If callback is used, then a destination might be needed because - # optparse cannot assume a value is always stored. - if opt.dest is None or not opt._long_opts or opt.callback: - return +@pytest.mark.parametrize("name, cls", subcmds.all_commands.items()) +def test_cli_option_style(name: str, cls: Type[Command]) -> None: + """Force some consistency in option flags.""" + cmd = cls() + parser = cmd.OptionParser + for option in parser.option_list: + for opt in option._long_opts: + assert "_" not in opt, ( + f"subcmds/{name}.py: {opt}: only use dashes in " + "options, not underscores" + ) - long = opt._long_opts[0] - assert long.startswith("--") - # This matches optparse's behavior. - implicit_dest = long[2:].replace("-", "_") - if implicit_dest == opt.dest: - bad_opts.append((str(opt), opt.dest)) - # Hook the option check list. - optparse.Option.CHECK_METHODS.insert(0, _check_dest) +def test_cli_option_dest() -> None: + """Block redundant dest= arguments.""" + bad_opts: list[tuple[str, str]] = [] + def _check_dest(opt: optparse.Option) -> None: + """Check the dest= setting.""" + # If the destination is not set, nothing to check. + # If long options are not set, then there's no implicit destination. + # If callback is used, then a destination might be needed because + # optparse cannot assume a value is always stored. + if opt.dest is None or not opt._long_opts or opt.callback: + return + + long = opt._long_opts[0] + assert long.startswith("--") + # This matches optparse's behavior. + implicit_dest = long[2:].replace("-", "_") + if implicit_dest == opt.dest: + bad_opts.append((str(opt), opt.dest)) + + # Hook the option check list. + optparse.Option.CHECK_METHODS.insert(0, _check_dest) + + try: # Gather all the bad options up front so people can see all bad options # instead of failing at the first one. - all_bad_opts = {} + all_bad_opts: dict[str, list[tuple[str, str]]] = {} for name, cls in subcmds.all_commands.items(): - bad_opts = all_bad_opts[name] = [] + bad_opts = [] cmd = cls() # Trigger construction of parser. - cmd.OptionParser + _ = cmd.OptionParser + all_bad_opts[name] = bad_opts - errmsg = None - for name, bad_opts in sorted(all_bad_opts.items()): - if bad_opts: + errmsg = "" + for name, bad_opts_list in sorted(all_bad_opts.items()): + if bad_opts_list: if not errmsg: errmsg = "Omit redundant dest= when defining options.\n" errmsg += f"\nSubcommand {name} (subcmds/{name}.py):\n" errmsg += "".join( - f" {opt}: dest='{dest}'\n" for opt, dest in bad_opts + f" {opt}: dest='{dest}'\n" for opt, dest in bad_opts_list ) if errmsg: - self.fail(errmsg) - + pytest.fail(errmsg) + finally: # Make sure we aren't popping the wrong stuff. assert optparse.Option.CHECK_METHODS.pop(0) is _check_dest - def test_common_validate_options(self): - """Verify CommonValidateOptions sets up expected fields.""" - for name, cls in subcmds.all_commands.items(): - cmd = cls() - opts, args = cmd.OptionParser.parse_args([]) - # Verify the fields don't exist yet. - self.assertFalse( - hasattr(opts, "verbose"), - msg=f"{name}: has verbose before validation", - ) - self.assertFalse( - hasattr(opts, "quiet"), - msg=f"{name}: has quiet before validation", - ) +@pytest.mark.parametrize("name, cls", subcmds.all_commands.items()) +def test_common_validate_options(name: str, cls: Type[Command]) -> None: + """Verify CommonValidateOptions sets up expected fields.""" + cmd = cls() + opts, args = cmd.OptionParser.parse_args([]) - cmd.CommonValidateOptions(opts, args) + # Verify the fields don't exist yet. + assert not hasattr( + opts, "verbose" + ), f"{name}: has verbose before validation" + assert not hasattr(opts, "quiet"), f"{name}: has quiet before validation" - # Verify the fields exist now. - self.assertTrue( - hasattr(opts, "verbose"), - msg=f"{name}: missing verbose after validation", - ) - self.assertTrue( - hasattr(opts, "quiet"), - msg=f"{name}: missing quiet after validation", - ) - self.assertTrue( - hasattr(opts, "outer_manifest"), - msg=f"{name}: missing outer_manifest after validation", - ) + cmd.CommonValidateOptions(opts, args) - def test_attribute_error_repro(self): - """Confirm that accessing verbose before CommonValidateOptions fails.""" - from subcmds.sync import Sync + # Verify the fields exist now. + assert hasattr(opts, "verbose"), f"{name}: missing verbose after validation" + assert hasattr(opts, "quiet"), f"{name}: missing quiet after validation" + assert hasattr( + opts, "outer_manifest" + ), f"{name}: missing outer_manifest after validation" - cmd = Sync() - opts, args = cmd.OptionParser.parse_args([]) - # This confirms that without the fix in main.py, an AttributeError - # would be raised because CommonValidateOptions hasn't been called yet. - with self.assertRaises(AttributeError): - _ = opts.verbose +def test_attribute_error_repro() -> None: + """Confirm that accessing verbose before CommonValidateOptions fails.""" + from subcmds.sync import Sync - cmd.CommonValidateOptions(opts, args) - self.assertTrue(hasattr(opts, "verbose")) + cmd = Sync() + opts, args = cmd.OptionParser.parse_args([]) + + # This confirms that without the fix in main.py, an AttributeError + # would be raised because CommonValidateOptions hasn't been called yet. + with pytest.raises(AttributeError): + _ = opts.verbose + + cmd.CommonValidateOptions(opts, args) + assert hasattr(opts, "verbose") diff --git a/tests/test_subcmds_init.py b/tests/test_subcmds_init.py index 25e5be567..960fb93fb 100644 --- a/tests/test_subcmds_init.py +++ b/tests/test_subcmds_init.py @@ -14,33 +14,36 @@ """Unittests for the subcmds/init.py module.""" -import unittest +from typing import List + +import pytest from subcmds import init -class InitCommand(unittest.TestCase): - """Check registered all_commands.""" +@pytest.mark.parametrize( + "argv", + ([],), +) +def test_cli_parser_good(argv: List[str]) -> None: + """Check valid command line options.""" + cmd = init.Init() + opts, args = cmd.OptionParser.parse_args(argv) + cmd.ValidateOptions(opts, args) - def setUp(self): - self.cmd = init.Init() - def test_cli_parser_good(self): - """Check valid command line options.""" - ARGV = ([],) - for argv in ARGV: - opts, args = self.cmd.OptionParser.parse_args(argv) - self.cmd.ValidateOptions(opts, args) - - def test_cli_parser_bad(self): - """Check invalid command line options.""" - ARGV = ( - # Too many arguments. - ["url", "asdf"], - # Conflicting options. - ["--mirror", "--archive"], - ) - for argv in ARGV: - opts, args = self.cmd.OptionParser.parse_args(argv) - with self.assertRaises(SystemExit): - self.cmd.ValidateOptions(opts, args) +@pytest.mark.parametrize( + "argv", + ( + # Too many arguments. + ["url", "asdf"], + # Conflicting options. + ["--mirror", "--archive"], + ), +) +def test_cli_parser_bad(argv: List[str]) -> None: + """Check invalid command line options.""" + cmd = init.Init() + opts, args = cmd.OptionParser.parse_args(argv) + with pytest.raises(SystemExit): + cmd.ValidateOptions(opts, args) diff --git a/tests/test_update_manpages.py b/tests/test_update_manpages.py index 06e5d3f5f..f52f8575f 100644 --- a/tests/test_update_manpages.py +++ b/tests/test_update_manpages.py @@ -14,15 +14,10 @@ """Unittests for the update_manpages module.""" -import unittest - from release import update_manpages -class UpdateManpagesTest(unittest.TestCase): - """Tests the update-manpages code.""" - - def test_replace_regex(self): - """Check that replace_regex works.""" - data = "\n\033[1mSummary\033[m\n" - self.assertEqual(update_manpages.replace_regex(data), "\nSummary\n") +def test_replace_regex() -> None: + """Check that replace_regex works.""" + data = "\n\033[1mSummary\033[m\n" + assert update_manpages.replace_regex(data) == "\nSummary\n"