tests: switch some test modules to pytest

Change-Id: I524b5ff2d77f8232f94e21921b00ba4027d2ac4f
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/563081
Tested-by: Mike Frysinger <vapier@google.com>
Reviewed-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Mike Frysinger <vapier@google.com>
This commit is contained in:
Mike Frysinger
2026-03-18 11:17:12 -04:00
committed by LUCI
parent 83b8ebdbbe
commit f24bc7aed5
12 changed files with 667 additions and 650 deletions
+46 -40
View File
@@ -15,60 +15,66 @@
"""Unittests for the color.py module."""
import os
import unittest
import pytest
import color
import git_config
def fixture(*paths):
def fixture(*paths: str) -> str:
"""Return a path relative to test/fixtures."""
return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
class ColoringTests(unittest.TestCase):
"""tests of the Coloring class."""
@pytest.fixture
def coloring() -> color.Coloring:
"""Create a Coloring object for testing."""
config_fixture = fixture("test.gitconfig")
config = git_config.GitConfig(config_fixture)
color.SetDefaultColoring("true")
return color.Coloring(config, "status")
def setUp(self):
"""Create a GitConfig object using the test.gitconfig fixture."""
config_fixture = fixture("test.gitconfig")
self.config = git_config.GitConfig(config_fixture)
color.SetDefaultColoring("true")
self.color = color.Coloring(self.config, "status")
def test_Color_Parse_all_params_none(self):
"""all params are None"""
val = self.color._parse(None, None, None, None)
self.assertEqual("", val)
def test_Color_Parse_all_params_none(coloring: color.Coloring) -> None:
"""all params are None"""
val = coloring._parse(None, None, None, None)
assert val == ""
def test_Color_Parse_first_parameter_none(self):
"""check fg & bg & attr"""
val = self.color._parse(None, "black", "red", "ul")
self.assertEqual("\x1b[4;30;41m", val)
def test_Color_Parse_one_entry(self):
"""check fg"""
val = self.color._parse("one", None, None, None)
self.assertEqual("\033[33m", val)
def test_Color_Parse_first_parameter_none(coloring: color.Coloring) -> None:
"""check fg & bg & attr"""
val = coloring._parse(None, "black", "red", "ul")
assert val == "\x1b[4;30;41m"
def test_Color_Parse_two_entry(self):
"""check fg & bg"""
val = self.color._parse("two", None, None, None)
self.assertEqual("\033[35;46m", val)
def test_Color_Parse_three_entry(self):
"""check fg & bg & attr"""
val = self.color._parse("three", None, None, None)
self.assertEqual("\033[4;30;41m", val)
def test_Color_Parse_one_entry(coloring: color.Coloring) -> None:
"""check fg"""
val = coloring._parse("one", None, None, None)
assert val == "\033[33m"
def test_Color_Parse_reset_entry(self):
"""check reset entry"""
val = self.color._parse("reset", None, None, None)
self.assertEqual("\033[m", val)
def test_Color_Parse_empty_entry(self):
"""check empty entry"""
val = self.color._parse("none", "blue", "white", "dim")
self.assertEqual("\033[2;34;47m", val)
val = self.color._parse("empty", "green", "white", "bold")
self.assertEqual("\033[1;32;47m", val)
def test_Color_Parse_two_entry(coloring: color.Coloring) -> None:
"""check fg & bg"""
val = coloring._parse("two", None, None, None)
assert val == "\033[35;46m"
def test_Color_Parse_three_entry(coloring: color.Coloring) -> None:
"""check fg & bg & attr"""
val = coloring._parse("three", None, None, None)
assert val == "\033[4;30;41m"
def test_Color_Parse_reset_entry(coloring: color.Coloring) -> None:
"""check reset entry"""
val = coloring._parse("reset", None, None, None)
assert val == "\033[m"
def test_Color_Parse_empty_entry(coloring: color.Coloring) -> None:
"""check empty entry"""
val = coloring._parse("none", "blue", "white", "dim")
assert val == "\033[2;34;47m"
val = coloring._parse("empty", "green", "white", "bold")
assert val == "\033[1;32;47m"
+18 -29
View File
@@ -14,43 +14,32 @@
"""Unittests for the editor.py module."""
import unittest
import pytest
from editor import Editor
class EditorTestCase(unittest.TestCase):
@pytest.fixture(autouse=True)
def reset_editor() -> None:
"""Take care of resetting Editor state across tests."""
def setUp(self):
self.setEditor(None)
def tearDown(self):
self.setEditor(None)
@staticmethod
def setEditor(editor):
Editor._editor = editor
Editor._editor = None
yield
Editor._editor = None
class GetEditor(EditorTestCase):
"""Check GetEditor behavior."""
def test_basic(self):
"""Basic checking of _GetEditor."""
self.setEditor(":")
self.assertEqual(":", Editor._GetEditor())
def test_basic() -> None:
"""Basic checking of _GetEditor."""
Editor._editor = ":"
assert Editor._GetEditor() == ":"
class EditString(EditorTestCase):
"""Check EditString behavior."""
def test_no_editor() -> None:
"""Check behavior when no editor is available."""
Editor._editor = ":"
assert Editor.EditString("foo") == "foo"
def test_no_editor(self):
"""Check behavior when no editor is available."""
self.setEditor(":")
self.assertEqual("foo", Editor.EditString("foo"))
def test_cat_editor(self):
"""Check behavior when editor is `cat`."""
self.setEditor("cat")
self.assertEqual("foo", Editor.EditString("foo"))
def test_cat_editor() -> None:
"""Check behavior when editor is `cat`."""
Editor._editor = "cat"
assert Editor.EditString("foo") == "foo"
+33 -32
View File
@@ -16,7 +16,9 @@
import inspect
import pickle
import unittest
from typing import Iterator, Type
import pytest
import command
import error
@@ -26,7 +28,7 @@ import project
from subcmds import all_modules
imports = all_modules + [
_IMPORTS = all_modules + [
error,
project,
git_command,
@@ -35,36 +37,35 @@ imports = all_modules + [
]
class PickleTests(unittest.TestCase):
"""Make sure all our custom exceptions can be pickled."""
def get_exceptions() -> Iterator[Type[Exception]]:
"""Return all our custom exceptions."""
for entry in _IMPORTS:
for name in dir(entry):
cls = getattr(entry, name)
if isinstance(cls, type) and issubclass(cls, Exception):
yield cls
def getExceptions(self):
"""Return all our custom exceptions."""
for entry in imports:
for name in dir(entry):
cls = getattr(entry, name)
if isinstance(cls, type) and issubclass(cls, Exception):
yield cls
def testExceptionLookup(self):
"""Make sure our introspection logic works."""
classes = list(self.getExceptions())
self.assertIn(error.HookError, classes)
# Don't assert the exact number to avoid being a change-detector test.
self.assertGreater(len(classes), 10)
def test_exception_lookup() -> None:
"""Make sure our introspection logic works."""
classes = list(get_exceptions())
assert error.HookError in classes
# Don't assert the exact number to avoid being a change-detector test.
assert len(classes) > 10
def testPickle(self):
"""Try to pickle all the exceptions."""
for cls in self.getExceptions():
args = inspect.getfullargspec(cls.__init__).args[1:]
obj = cls(*args)
p = pickle.dumps(obj)
try:
newobj = pickle.loads(p)
except Exception as e: # pylint: disable=broad-except
self.fail(
"Class %s is unable to be pickled: %s\n"
"Incomplete super().__init__(...) call?" % (cls, e)
)
self.assertIsInstance(newobj, cls)
self.assertEqual(str(obj), str(newobj))
@pytest.mark.parametrize("cls", get_exceptions())
def test_pickle(cls: Type[Exception]) -> None:
"""Try to pickle all the exceptions."""
args = inspect.getfullargspec(cls.__init__).args[1:]
obj = cls(*args)
p = pickle.dumps(obj)
try:
newobj = pickle.loads(p)
except Exception as e:
pytest.fail(
f"Class {cls} is unable to be pickled: {e}\n"
"Incomplete super().__init__(...) call?"
)
assert isinstance(newobj, cls)
assert str(obj) == str(newobj)
+190 -171
View File
@@ -15,200 +15,219 @@
"""Unittests for the git_config.py module."""
import os
import tempfile
import unittest
from pathlib import Path
from typing import Any
import pytest
import git_config
def fixture(*paths):
def fixture_path(*paths: str) -> str:
"""Return a path relative to test/fixtures."""
return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
class GitConfigReadOnlyTests(unittest.TestCase):
"""Read-only tests of the GitConfig class."""
def setUp(self):
"""Create a GitConfig object using the test.gitconfig fixture."""
config_fixture = fixture("test.gitconfig")
self.config = git_config.GitConfig(config_fixture)
def test_GetString_with_empty_config_values(self):
"""
Test config entries with no value.
[section]
empty
"""
val = self.config.GetString("section.empty")
self.assertEqual(val, None)
def test_GetString_with_true_value(self):
"""
Test config entries with a string value.
[section]
nonempty = true
"""
val = self.config.GetString("section.nonempty")
self.assertEqual(val, "true")
def test_GetString_from_missing_file(self):
"""
Test missing config file
"""
config_fixture = fixture("not.present.gitconfig")
config = git_config.GitConfig(config_fixture)
val = config.GetString("empty")
self.assertEqual(val, None)
def test_GetBoolean_undefined(self):
"""Test GetBoolean on key that doesn't exist."""
self.assertIsNone(self.config.GetBoolean("section.missing"))
def test_GetBoolean_invalid(self):
"""Test GetBoolean on invalid boolean value."""
self.assertIsNone(self.config.GetBoolean("section.boolinvalid"))
def test_GetBoolean_true(self):
"""Test GetBoolean on valid true boolean."""
self.assertTrue(self.config.GetBoolean("section.booltrue"))
def test_GetBoolean_false(self):
"""Test GetBoolean on valid false boolean."""
self.assertFalse(self.config.GetBoolean("section.boolfalse"))
def test_GetInt_undefined(self):
"""Test GetInt on key that doesn't exist."""
self.assertIsNone(self.config.GetInt("section.missing"))
def test_GetInt_invalid(self):
"""Test GetInt on invalid integer value."""
self.assertIsNone(self.config.GetBoolean("section.intinvalid"))
def test_GetInt_valid(self):
"""Test GetInt on valid integers."""
TESTS = (
("inthex", 16),
("inthexk", 16384),
("int", 10),
("intk", 10240),
("intm", 10485760),
("intg", 10737418240),
)
for key, value in TESTS:
self.assertEqual(value, self.config.GetInt(f"section.{key}"))
@pytest.fixture
def readonly_config() -> git_config.GitConfig:
"""Create a GitConfig object using the test.gitconfig fixture."""
config_fixture = fixture_path("test.gitconfig")
return git_config.GitConfig(config_fixture)
class GitConfigReadWriteTests(unittest.TestCase):
"""Read/write tests of the GitConfig class."""
def test_get_string_with_empty_config_values(
readonly_config: git_config.GitConfig,
) -> None:
"""Test config entries with no value.
def setUp(self):
self.tmpfile = tempfile.NamedTemporaryFile()
self.config = self.get_config()
[section]
empty
def get_config(self):
"""Get a new GitConfig instance."""
return git_config.GitConfig(self.tmpfile.name)
"""
val = readonly_config.GetString("section.empty")
assert val is None
def test_SetString(self):
"""Test SetString behavior."""
# Set a value.
self.assertIsNone(self.config.GetString("foo.bar"))
self.config.SetString("foo.bar", "val")
self.assertEqual("val", self.config.GetString("foo.bar"))
# Make sure the value was actually written out.
config = self.get_config()
self.assertEqual("val", config.GetString("foo.bar"))
def test_get_string_with_true_value(
readonly_config: git_config.GitConfig,
) -> None:
"""Test config entries with a string value.
# Update the value.
self.config.SetString("foo.bar", "valll")
self.assertEqual("valll", self.config.GetString("foo.bar"))
config = self.get_config()
self.assertEqual("valll", config.GetString("foo.bar"))
[section]
nonempty = true
# Delete the value.
self.config.SetString("foo.bar", None)
self.assertIsNone(self.config.GetString("foo.bar"))
config = self.get_config()
self.assertIsNone(config.GetString("foo.bar"))
"""
val = readonly_config.GetString("section.nonempty")
assert val == "true"
def test_SetBoolean(self):
"""Test SetBoolean behavior."""
# Set a true value.
self.assertIsNone(self.config.GetBoolean("foo.bar"))
for val in (True, 1):
self.config.SetBoolean("foo.bar", val)
self.assertTrue(self.config.GetBoolean("foo.bar"))
# Make sure the value was actually written out.
config = self.get_config()
self.assertTrue(config.GetBoolean("foo.bar"))
self.assertEqual("true", config.GetString("foo.bar"))
def test_get_string_from_missing_file() -> None:
"""Test missing config file."""
config_fixture = fixture_path("not.present.gitconfig")
config = git_config.GitConfig(config_fixture)
val = config.GetString("empty")
assert val is None
# Set a false value.
for val in (False, 0):
self.config.SetBoolean("foo.bar", val)
self.assertFalse(self.config.GetBoolean("foo.bar"))
# Make sure the value was actually written out.
config = self.get_config()
self.assertFalse(config.GetBoolean("foo.bar"))
self.assertEqual("false", config.GetString("foo.bar"))
def test_get_boolean_undefined(readonly_config: git_config.GitConfig) -> None:
"""Test GetBoolean on key that doesn't exist."""
assert readonly_config.GetBoolean("section.missing") is None
# Delete the value.
self.config.SetBoolean("foo.bar", None)
self.assertIsNone(self.config.GetBoolean("foo.bar"))
config = self.get_config()
self.assertIsNone(config.GetBoolean("foo.bar"))
def test_SetInt(self):
"""Test SetInt behavior."""
# Set a value.
self.assertIsNone(self.config.GetInt("foo.bar"))
self.config.SetInt("foo.bar", 10)
self.assertEqual(10, self.config.GetInt("foo.bar"))
def test_get_boolean_invalid(readonly_config: git_config.GitConfig) -> None:
"""Test GetBoolean on invalid boolean value."""
assert readonly_config.GetBoolean("section.boolinvalid") is None
# Make sure the value was actually written out.
config = self.get_config()
self.assertEqual(10, config.GetInt("foo.bar"))
self.assertEqual("10", config.GetString("foo.bar"))
# Update the value.
self.config.SetInt("foo.bar", 20)
self.assertEqual(20, self.config.GetInt("foo.bar"))
config = self.get_config()
self.assertEqual(20, config.GetInt("foo.bar"))
def test_get_boolean_true(readonly_config: git_config.GitConfig) -> None:
"""Test GetBoolean on valid true boolean."""
assert readonly_config.GetBoolean("section.booltrue") is True
# Delete the value.
self.config.SetInt("foo.bar", None)
self.assertIsNone(self.config.GetInt("foo.bar"))
config = self.get_config()
self.assertIsNone(config.GetInt("foo.bar"))
def test_GetSyncAnalysisStateData(self):
"""Test config entries with a sync state analysis data."""
superproject_logging_data = {}
superproject_logging_data["test"] = False
options = type("options", (object,), {})()
options.verbose = "true"
options.mp_update = "false"
TESTS = (
("superproject.test", "false"),
("options.verbose", "true"),
("options.mpupdate", "false"),
("main.version", "1"),
)
self.config.UpdateSyncAnalysisState(options, superproject_logging_data)
sync_data = self.config.GetSyncAnalysisStateData()
for key, value in TESTS:
self.assertEqual(
sync_data[f"{git_config.SYNC_STATE_PREFIX}{key}"], value
)
self.assertTrue(
sync_data[f"{git_config.SYNC_STATE_PREFIX}main.synctime"]
)
def test_get_boolean_false(readonly_config: git_config.GitConfig) -> None:
"""Test GetBoolean on valid false boolean."""
assert readonly_config.GetBoolean("section.boolfalse") is False
def test_get_int_undefined(readonly_config: git_config.GitConfig) -> None:
"""Test GetInt on key that doesn't exist."""
assert readonly_config.GetInt("section.missing") is None
def test_get_int_invalid(readonly_config: git_config.GitConfig) -> None:
"""Test GetInt on invalid integer value."""
assert readonly_config.GetInt("section.intinvalid") is None
@pytest.mark.parametrize(
"key, expected",
(
("inthex", 16),
("inthexk", 16384),
("int", 10),
("intk", 10240),
("intm", 10485760),
("intg", 10737418240),
),
)
def test_get_int_valid(
readonly_config: git_config.GitConfig, key: str, expected: int
) -> None:
"""Test GetInt on valid integers."""
assert readonly_config.GetInt(f"section.{key}") == expected
@pytest.fixture
def rw_config_file(tmp_path: Path) -> Path:
"""Return a path to a temporary config file."""
return tmp_path / "config"
def test_set_string(rw_config_file: Path) -> None:
"""Test SetString behavior."""
config = git_config.GitConfig(str(rw_config_file))
# Set a value.
assert config.GetString("foo.bar") is None
config.SetString("foo.bar", "val")
assert config.GetString("foo.bar") == "val"
# Make sure the value was actually written out.
config2 = git_config.GitConfig(str(rw_config_file))
assert config2.GetString("foo.bar") == "val"
# Update the value.
config.SetString("foo.bar", "valll")
assert config.GetString("foo.bar") == "valll"
config3 = git_config.GitConfig(str(rw_config_file))
assert config3.GetString("foo.bar") == "valll"
# Delete the value.
config.SetString("foo.bar", None)
assert config.GetString("foo.bar") is None
config4 = git_config.GitConfig(str(rw_config_file))
assert config4.GetString("foo.bar") is None
def test_set_boolean(rw_config_file: Path) -> None:
"""Test SetBoolean behavior."""
config = git_config.GitConfig(str(rw_config_file))
# Set a true value.
assert config.GetBoolean("foo.bar") is None
for val in (True, 1):
config.SetBoolean("foo.bar", val)
assert config.GetBoolean("foo.bar") is True
# Make sure the value was actually written out.
config2 = git_config.GitConfig(str(rw_config_file))
assert config2.GetBoolean("foo.bar") is True
assert config2.GetString("foo.bar") == "true"
# Set a false value.
for val in (False, 0):
config.SetBoolean("foo.bar", val)
assert config.GetBoolean("foo.bar") is False
# Make sure the value was actually written out.
config3 = git_config.GitConfig(str(rw_config_file))
assert config3.GetBoolean("foo.bar") is False
assert config3.GetString("foo.bar") == "false"
# Delete the value.
config.SetBoolean("foo.bar", None)
assert config.GetBoolean("foo.bar") is None
config4 = git_config.GitConfig(str(rw_config_file))
assert config4.GetBoolean("foo.bar") is None
def test_set_int(rw_config_file: Path) -> None:
"""Test SetInt behavior."""
config = git_config.GitConfig(str(rw_config_file))
# Set a value.
assert config.GetInt("foo.bar") is None
config.SetInt("foo.bar", 10)
assert config.GetInt("foo.bar") == 10
# Make sure the value was actually written out.
config2 = git_config.GitConfig(str(rw_config_file))
assert config2.GetInt("foo.bar") == 10
assert config2.GetString("foo.bar") == "10"
# Update the value.
config.SetInt("foo.bar", 20)
assert config.GetInt("foo.bar") == 20
config3 = git_config.GitConfig(str(rw_config_file))
assert config3.GetInt("foo.bar") == 20
# Delete the value.
config.SetInt("foo.bar", None)
assert config.GetInt("foo.bar") is None
config4 = git_config.GitConfig(str(rw_config_file))
assert config4.GetInt("foo.bar") is None
def test_get_sync_analysis_state_data(rw_config_file: Path) -> None:
"""Test config entries with a sync state analysis data."""
config = git_config.GitConfig(str(rw_config_file))
superproject_logging_data: dict[str, Any] = {"test": False}
class Options:
"""Container for testing."""
options = Options()
options.verbose = "true"
options.mp_update = "false"
TESTS = (
("superproject.test", "false"),
("options.verbose", "true"),
("options.mpupdate", "false"),
("main.version", "1"),
)
config.UpdateSyncAnalysisState(options, superproject_logging_data)
sync_data = config.GetSyncAnalysisStateData()
for key, value in TESTS:
assert sync_data[f"{git_config.SYNC_STATE_PREFIX}{key}"] == value
assert sync_data[f"{git_config.SYNC_STATE_PREFIX}main.synctime"]
+37 -32
View File
@@ -14,42 +14,47 @@
"""Unittests for the hooks.py module."""
import unittest
import pytest
import hooks
class RepoHookShebang(unittest.TestCase):
"""Check shebang parsing in RepoHook."""
@pytest.mark.parametrize(
"data",
(
"",
"#\n# foo\n",
"# Bad shebang in script\n#!/foo\n",
),
)
def test_no_shebang(data: str) -> None:
"""Lines w/out shebangs should be rejected."""
assert hooks.RepoHook._ExtractInterpFromShebang(data) is None
def test_no_shebang(self):
"""Lines w/out shebangs should be rejected."""
DATA = ("", "#\n# foo\n", "# Bad shebang in script\n#!/foo\n")
for data in DATA:
self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data))
def test_direct_interp(self):
"""Lines whose shebang points directly to the interpreter."""
DATA = (
("#!/foo", "/foo"),
("#! /foo", "/foo"),
("#!/bin/foo ", "/bin/foo"),
("#! /usr/foo ", "/usr/foo"),
("#! /usr/foo -args", "/usr/foo"),
)
for shebang, interp in DATA:
self.assertEqual(
hooks.RepoHook._ExtractInterpFromShebang(shebang), interp
)
@pytest.mark.parametrize(
"shebang, interp",
(
("#!/foo", "/foo"),
("#! /foo", "/foo"),
("#!/bin/foo ", "/bin/foo"),
("#! /usr/foo ", "/usr/foo"),
("#! /usr/foo -args", "/usr/foo"),
),
)
def test_direct_interp(shebang: str, interp: str) -> None:
"""Lines whose shebang points directly to the interpreter."""
assert hooks.RepoHook._ExtractInterpFromShebang(shebang) == interp
def test_env_interp(self):
"""Lines whose shebang launches through `env`."""
DATA = (
("#!/usr/bin/env foo", "foo"),
("#!/bin/env foo", "foo"),
("#! /bin/env /bin/foo ", "/bin/foo"),
)
for shebang, interp in DATA:
self.assertEqual(
hooks.RepoHook._ExtractInterpFromShebang(shebang), interp
)
@pytest.mark.parametrize(
"shebang, interp",
(
("#!/usr/bin/env foo", "foo"),
("#!/bin/env foo", "foo"),
("#! /bin/env /bin/foo ", "/bin/foo"),
),
)
def test_env_interp(shebang: str, interp: str) -> None:
"""Lines whose shebang launches through `env`."""
assert hooks.RepoHook._ExtractInterpFromShebang(shebang) == interp
+23 -27
View File
@@ -14,39 +14,35 @@
"""Unittests for the platform_utils.py module."""
import os
import tempfile
import unittest
from pathlib import Path
import pytest
import platform_utils
class RemoveTests(unittest.TestCase):
"""Check remove() helper."""
def test_remove_missing_ok(tmp_path: Path) -> None:
"""Check missing_ok handling."""
path = tmp_path / "test"
def testMissingOk(self):
"""Check missing_ok handling."""
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "test")
# Should not fail.
platform_utils.remove(path, missing_ok=True)
# Should not fail.
platform_utils.remove(path, missing_ok=True)
# Should fail.
with pytest.raises(OSError):
platform_utils.remove(path)
with pytest.raises(OSError):
platform_utils.remove(path, missing_ok=False)
# Should fail.
self.assertRaises(OSError, platform_utils.remove, path)
self.assertRaises(
OSError, platform_utils.remove, path, missing_ok=False
)
# Should not fail if it exists.
path.touch()
platform_utils.remove(path, missing_ok=True)
assert not path.exists()
# Should not fail if it exists.
open(path, "w").close()
platform_utils.remove(path, missing_ok=True)
self.assertFalse(os.path.exists(path))
path.touch()
platform_utils.remove(path)
assert not path.exists()
open(path, "w").close()
platform_utils.remove(path)
self.assertFalse(os.path.exists(path))
open(path, "w").close()
platform_utils.remove(path, missing_ok=False)
self.assertFalse(os.path.exists(path))
path.touch()
platform_utils.remove(path, missing_ok=False)
assert not path.exists()
+77 -71
View File
@@ -12,90 +12,96 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit test for repo_logging module."""
"""Unittests for the repo_logging.py module."""
import contextlib
import io
import logging
import unittest
import re
from unittest import mock
import pytest
from color import SetDefaultColoring
from error import RepoExitError
from repo_logging import RepoLogger
class TestRepoLogger(unittest.TestCase):
@mock.patch.object(RepoLogger, "error")
def test_log_aggregated_errors_logs_aggregated_errors(self, mock_error):
"""Test if log_aggregated_errors logs a list of aggregated errors."""
logger = RepoLogger(__name__)
logger.log_aggregated_errors(
RepoExitError(
aggregate_errors=[
Exception("foo"),
Exception("bar"),
Exception("baz"),
Exception("hello"),
Exception("world"),
Exception("test"),
]
)
)
mock_error.assert_has_calls(
[
mock.call("=" * 80),
mock.call(
"Repo command failed due to the following `%s` errors:",
"RepoExitError",
),
mock.call("foo\nbar\nbaz\nhello\nworld"),
mock.call("+%d additional errors...", 1),
@mock.patch.object(RepoLogger, "error")
def test_log_aggregated_errors_logs_aggregated_errors(mock_error) -> None:
"""Test if log_aggregated_errors logs a list of aggregated errors."""
logger = RepoLogger(__name__)
logger.log_aggregated_errors(
RepoExitError(
aggregate_errors=[
Exception("foo"),
Exception("bar"),
Exception("baz"),
Exception("hello"),
Exception("world"),
Exception("test"),
]
)
)
@mock.patch.object(RepoLogger, "error")
def test_log_aggregated_errors_logs_single_error(self, mock_error):
"""Test if log_aggregated_errors logs empty aggregated_errors."""
mock_error.assert_has_calls(
[
mock.call("=" * 80),
mock.call(
"Repo command failed due to the following `%s` errors:",
"RepoExitError",
),
mock.call("foo\nbar\nbaz\nhello\nworld"),
mock.call("+%d additional errors...", 1),
]
)
@mock.patch.object(RepoLogger, "error")
def test_log_aggregated_errors_logs_single_error(mock_error) -> None:
"""Test if log_aggregated_errors logs empty aggregated_errors."""
logger = RepoLogger(__name__)
logger.log_aggregated_errors(RepoExitError())
mock_error.assert_has_calls(
[
mock.call("=" * 80),
mock.call("Repo command failed: %s", "RepoExitError"),
]
)
@pytest.mark.parametrize(
"level",
(
logging.INFO,
logging.WARN,
logging.ERROR,
),
)
def test_log_with_format_string(level: int) -> None:
"""Test different log levels with format strings."""
name = logging.getLevelName(level)
# Set color output to "always" for consistent test results.
# This ensures the logger's behavior is uniform across different
# environments and git configurations.
SetDefaultColoring("always")
# Regex pattern to match optional ANSI color codes.
# \033 - Escape character
# \[ - Opening square bracket
# [0-9;]* - Zero or more digits or semicolons
# m - Ending 'm' character
# ? - Makes the entire group optional
opt_color = r"(\033\[[0-9;]*m)?"
output = io.StringIO()
with contextlib.redirect_stderr(output):
logger = RepoLogger(__name__)
logger.log_aggregated_errors(RepoExitError())
logger.log(level, "%s", "100% pass")
mock_error.assert_has_calls(
[
mock.call("=" * 80),
mock.call("Repo command failed: %s", "RepoExitError"),
]
)
def test_log_with_format_string(self):
"""Test different log levels with format strings."""
# Set color output to "always" for consistent test results.
# This ensures the logger's behavior is uniform across different
# environments and git configurations.
SetDefaultColoring("always")
# Regex pattern to match optional ANSI color codes.
# \033 - Escape character
# \[ - Opening square bracket
# [0-9;]* - Zero or more digits or semicolons
# m - Ending 'm' character
# ? - Makes the entire group optional
opt_color = r"(\033\[[0-9;]*m)?"
for level in (logging.INFO, logging.WARN, logging.ERROR):
name = logging.getLevelName(level)
with self.subTest(level=level, name=name):
output = io.StringIO()
with contextlib.redirect_stderr(output):
logger = RepoLogger(__name__)
logger.log(level, "%s", "100% pass")
self.assertRegex(
output.getvalue().strip(),
f"^{opt_color}100% pass{opt_color}$",
f"failed for level {name}",
)
assert re.search(
f"^{opt_color}100% pass{opt_color}$", output.getvalue().strip()
), f"failed for level {name}"
+24 -33
View File
@@ -15,46 +15,37 @@
"""Unittests for the repo_trace.py module."""
import os
import unittest
from unittest import mock
import pytest
import repo_trace
class TraceTests(unittest.TestCase):
def test_trace_max_size_enforced(monkeypatch: pytest.MonkeyPatch) -> None:
"""Check Trace behavior."""
content = "git chicken"
def testTrace_MaxSizeEnforced(self):
content = "git chicken"
with repo_trace.Trace(content, first_trace=True):
pass
first_trace_size = os.path.getsize(repo_trace._TRACE_FILE)
with repo_trace.Trace(content, first_trace=True):
pass
first_trace_size = os.path.getsize(repo_trace._TRACE_FILE)
with repo_trace.Trace(content):
pass
assert os.path.getsize(repo_trace._TRACE_FILE) > first_trace_size
with repo_trace.Trace(content):
pass
self.assertGreater(
os.path.getsize(repo_trace._TRACE_FILE), first_trace_size
)
# Check we clear everything if the last chunk is larger than _MAX_SIZE.
monkeypatch.setattr(repo_trace, "_MAX_SIZE", 0)
with repo_trace.Trace(content, first_trace=True):
pass
assert os.path.getsize(repo_trace._TRACE_FILE) == first_trace_size
# Check we clear everything is the last chunk is larger than _MAX_SIZE.
with mock.patch("repo_trace._MAX_SIZE", 0):
with repo_trace.Trace(content, first_trace=True):
pass
self.assertEqual(
first_trace_size, os.path.getsize(repo_trace._TRACE_FILE)
)
# Check we only clear the chunks we need to.
new_max = (first_trace_size + 1) / (1024 * 1024)
monkeypatch.setattr(repo_trace, "_MAX_SIZE", new_max)
with repo_trace.Trace(content, first_trace=True):
pass
assert os.path.getsize(repo_trace._TRACE_FILE) == first_trace_size * 2
# Check we only clear the chunks we need to.
repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024)
with repo_trace.Trace(content, first_trace=True):
pass
self.assertEqual(
first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE)
)
with repo_trace.Trace(content, first_trace=True):
pass
self.assertEqual(
first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE)
)
with repo_trace.Trace(content, first_trace=True):
pass
assert os.path.getsize(repo_trace._TRACE_FILE) == first_trace_size * 2
+68 -56
View File
@@ -16,72 +16,84 @@
import multiprocessing
import subprocess
import unittest
from typing import Tuple
from unittest import mock
import pytest
import ssh
class SshTests(unittest.TestCase):
"""Tests the ssh functions."""
@pytest.fixture(autouse=True)
def clear_ssh_version_cache() -> None:
"""Clear the ssh version cache before each test."""
ssh.version.cache_clear()
def setUp(self) -> None:
super().setUp()
ssh.version.cache_clear()
def test_parse_ssh_version(self):
"""Check _parse_ssh_version() handling."""
ver = ssh._parse_ssh_version("Unknown\n")
self.assertEqual(ver, ())
ver = ssh._parse_ssh_version("OpenSSH_1.0\n")
self.assertEqual(ver, (1, 0))
ver = ssh._parse_ssh_version(
"OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n"
)
self.assertEqual(ver, (6, 6, 1))
ver = ssh._parse_ssh_version(
"OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n"
)
self.assertEqual(ver, (7, 6))
ver = ssh._parse_ssh_version("OpenSSH_9.0p1, LibreSSL 3.3.6\n")
self.assertEqual(ver, (9, 0))
@pytest.mark.parametrize(
"input_str, expected",
(
("Unknown\n", ()),
("OpenSSH_1.0\n", (1, 0)),
(
"OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n",
(6, 6, 1),
),
(
"OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n",
(7, 6),
),
("OpenSSH_9.0p1, LibreSSL 3.3.6\n", (9, 0)),
),
)
def test_parse_ssh_version(input_str: str, expected: Tuple[int, ...]) -> None:
"""Check _parse_ssh_version() handling."""
assert ssh._parse_ssh_version(input_str) == expected
def test_version(self):
"""Check version() handling."""
with mock.patch("ssh._run_ssh_version", return_value="OpenSSH_1.2\n"):
self.assertEqual(ssh.version(), (1, 2))
def test_context_manager_empty(self):
"""Verify context manager with no clients works correctly."""
with multiprocessing.Manager() as manager:
with ssh.ProxyManager(manager):
pass
def test_version() -> None:
"""Check version() handling."""
with mock.patch("ssh._run_ssh_version", return_value="OpenSSH_1.2\n"):
assert ssh.version() == (1, 2)
def test_context_manager_child_cleanup(self):
"""Verify orphaned clients & masters get cleaned up."""
with multiprocessing.Manager() as manager:
with mock.patch("ssh.version", return_value=(1, 2)):
with ssh.ProxyManager(manager) as ssh_proxy:
client = subprocess.Popen(["sleep", "964853320"])
ssh_proxy.add_client(client)
master = subprocess.Popen(["sleep", "964853321"])
ssh_proxy.add_master(master)
# If the process still exists, these will throw timeout errors.
client.wait(0)
master.wait(0)
def test_ssh_sock(self):
"""Check sock() function."""
manager = multiprocessing.Manager()
def test_context_manager_empty() -> None:
"""Verify context manager with no clients works correctly."""
with multiprocessing.Manager() as manager:
with ssh.ProxyManager(manager):
pass
def test_context_manager_child_cleanup() -> None:
"""Verify orphaned clients & masters get cleaned up."""
with multiprocessing.Manager() as manager:
with mock.patch("ssh.version", return_value=(1, 2)):
with ssh.ProxyManager(manager) as ssh_proxy:
client = subprocess.Popen(["sleep", "964853320"])
ssh_proxy.add_client(client)
master = subprocess.Popen(["sleep", "964853321"])
ssh_proxy.add_master(master)
# If the process still exists, these will throw timeout errors.
client.wait(0)
master.wait(0)
def test_ssh_sock(monkeypatch: pytest.MonkeyPatch) -> None:
"""Check sock() function."""
with multiprocessing.Manager() as manager:
proxy = ssh.ProxyManager(manager)
with mock.patch("tempfile.mkdtemp", return_value="/tmp/foo"):
# Old ssh version uses port.
with mock.patch("ssh.version", return_value=(6, 6)):
with proxy as ssh_proxy:
self.assertTrue(ssh_proxy.sock().endswith("%p"))
monkeypatch.setattr(
"tempfile.mkdtemp", lambda *args, **kwargs: "/tmp/foo"
)
proxy._sock_path = None
# New ssh version uses hash.
with mock.patch("ssh.version", return_value=(6, 7)):
with proxy as ssh_proxy:
self.assertTrue(ssh_proxy.sock().endswith("%C"))
# Old ssh version uses port.
with mock.patch("ssh.version", return_value=(6, 6)):
with proxy as ssh_proxy:
assert ssh_proxy.sock().endswith("%p")
proxy._sock_path = None
# New ssh version uses hash.
with mock.patch("ssh.version", return_value=(6, 7)):
with proxy as ssh_proxy:
assert ssh_proxy.sock().endswith("%C")
proxy._sock_path = None
+120 -126
View File
@@ -15,170 +15,164 @@
"""Unittests for the subcmds module (mostly __init__.py than subcommands)."""
import optparse
import unittest
from typing import Type
import pytest
from command import Command
import subcmds
class AllCommands(unittest.TestCase):
"""Check registered all_commands."""
# NB: We don't test all subcommands as we want to avoid "change detection"
# tests, so we just look for the most common/important ones here that are
# unlikely to ever change.
@pytest.mark.parametrize(
"cmd", ("cherry-pick", "help", "init", "start", "sync", "upload")
)
def test_required_basic(cmd: str) -> None:
"""Basic checking of registered commands."""
assert cmd in subcmds.all_commands
def test_required_basic(self):
"""Basic checking of registered commands."""
# NB: We don't test all subcommands as we want to avoid "change
# detection" tests, so we just look for the most common/important ones
# here that are unlikely to ever change.
for cmd in {"cherry-pick", "help", "init", "start", "sync", "upload"}:
self.assertIn(cmd, subcmds.all_commands)
def test_naming(self):
"""Verify we don't add things that we shouldn't."""
for cmd in subcmds.all_commands:
# Reject filename suffixes like "help.py".
self.assertNotIn(".", cmd)
@pytest.mark.parametrize("name", subcmds.all_commands.keys())
def test_naming(name: str) -> None:
"""Verify we don't add things that we shouldn't."""
# Reject filename suffixes like "help.py".
assert "." not in name
# Make sure all '_' were converted to '-'.
self.assertNotIn("_", cmd)
# Make sure all '_' were converted to '-'.
assert "_" not in name
# Reject internal python paths like "__init__".
self.assertFalse(cmd.startswith("__"))
# Reject internal python paths like "__init__".
assert not name.startswith("__")
def test_help_desc_style(self):
"""Force some consistency in option descriptions.
Python's optparse & argparse has a few default options like --help.
Their option description text uses lowercase sentence fragments, so
enforce our options follow the same style so UI is consistent.
@pytest.mark.parametrize("name, cls", subcmds.all_commands.items())
def test_help_desc_style(name: str, cls: Type[Command]) -> None:
"""Force some consistency in option descriptions.
We enforce:
* Text starts with lowercase.
* Text doesn't end with period.
"""
for name, cls in subcmds.all_commands.items():
cmd = cls()
parser = cmd.OptionParser
for option in parser.option_list:
if option.help == optparse.SUPPRESS_HELP:
continue
Python's optparse & argparse has a few default options like --help.
Their option description text uses lowercase sentence fragments, so
enforce our options follow the same style so UI is consistent.
c = option.help[0]
self.assertEqual(
c.lower(),
c,
msg=f"subcmds/{name}.py: {option.get_opt_string()}: "
f'help text should start with lowercase: "{option.help}"',
)
We enforce:
* Text starts with lowercase.
* Text doesn't end with period.
"""
cmd = cls()
parser = cmd.OptionParser
for option in parser.option_list:
if option.help == optparse.SUPPRESS_HELP or not option.help:
continue
self.assertNotEqual(
option.help[-1],
".",
msg=f"subcmds/{name}.py: {option.get_opt_string()}: "
f'help text should not end in a period: "{option.help}"',
)
c = option.help[0]
assert c.lower() == c, (
f"subcmds/{name}.py: {option.get_opt_string()}: "
f'help text should start with lowercase: "{option.help}"'
)
def test_cli_option_style(self):
"""Force some consistency in option flags."""
for name, cls in subcmds.all_commands.items():
cmd = cls()
parser = cmd.OptionParser
for option in parser.option_list:
for opt in option._long_opts:
self.assertNotIn(
"_",
opt,
msg=f"subcmds/{name}.py: {opt}: only use dashes in "
"options, not underscores",
)
assert option.help[-1] != ".", (
f"subcmds/{name}.py: {option.get_opt_string()}: "
f'help text should not end in a period: "{option.help}"'
)
def test_cli_option_dest(self):
"""Block redundant dest= arguments."""
def _check_dest(opt):
"""Check the dest= setting."""
# If the destination is not set, nothing to check.
# If long options are not set, then there's no implicit destination.
# If callback is used, then a destination might be needed because
# optparse cannot assume a value is always stored.
if opt.dest is None or not opt._long_opts or opt.callback:
return
@pytest.mark.parametrize("name, cls", subcmds.all_commands.items())
def test_cli_option_style(name: str, cls: Type[Command]) -> None:
"""Force some consistency in option flags."""
cmd = cls()
parser = cmd.OptionParser
for option in parser.option_list:
for opt in option._long_opts:
assert "_" not in opt, (
f"subcmds/{name}.py: {opt}: only use dashes in "
"options, not underscores"
)
long = opt._long_opts[0]
assert long.startswith("--")
# This matches optparse's behavior.
implicit_dest = long[2:].replace("-", "_")
if implicit_dest == opt.dest:
bad_opts.append((str(opt), opt.dest))
# Hook the option check list.
optparse.Option.CHECK_METHODS.insert(0, _check_dest)
def test_cli_option_dest() -> None:
"""Block redundant dest= arguments."""
bad_opts: list[tuple[str, str]] = []
def _check_dest(opt: optparse.Option) -> None:
"""Check the dest= setting."""
# If the destination is not set, nothing to check.
# If long options are not set, then there's no implicit destination.
# If callback is used, then a destination might be needed because
# optparse cannot assume a value is always stored.
if opt.dest is None or not opt._long_opts or opt.callback:
return
long = opt._long_opts[0]
assert long.startswith("--")
# This matches optparse's behavior.
implicit_dest = long[2:].replace("-", "_")
if implicit_dest == opt.dest:
bad_opts.append((str(opt), opt.dest))
# Hook the option check list.
optparse.Option.CHECK_METHODS.insert(0, _check_dest)
try:
# Gather all the bad options up front so people can see all bad options
# instead of failing at the first one.
all_bad_opts = {}
all_bad_opts: dict[str, list[tuple[str, str]]] = {}
for name, cls in subcmds.all_commands.items():
bad_opts = all_bad_opts[name] = []
bad_opts = []
cmd = cls()
# Trigger construction of parser.
cmd.OptionParser
_ = cmd.OptionParser
all_bad_opts[name] = bad_opts
errmsg = None
for name, bad_opts in sorted(all_bad_opts.items()):
if bad_opts:
errmsg = ""
for name, bad_opts_list in sorted(all_bad_opts.items()):
if bad_opts_list:
if not errmsg:
errmsg = "Omit redundant dest= when defining options.\n"
errmsg += f"\nSubcommand {name} (subcmds/{name}.py):\n"
errmsg += "".join(
f" {opt}: dest='{dest}'\n" for opt, dest in bad_opts
f" {opt}: dest='{dest}'\n" for opt, dest in bad_opts_list
)
if errmsg:
self.fail(errmsg)
pytest.fail(errmsg)
finally:
# Make sure we aren't popping the wrong stuff.
assert optparse.Option.CHECK_METHODS.pop(0) is _check_dest
def test_common_validate_options(self):
"""Verify CommonValidateOptions sets up expected fields."""
for name, cls in subcmds.all_commands.items():
cmd = cls()
opts, args = cmd.OptionParser.parse_args([])
# Verify the fields don't exist yet.
self.assertFalse(
hasattr(opts, "verbose"),
msg=f"{name}: has verbose before validation",
)
self.assertFalse(
hasattr(opts, "quiet"),
msg=f"{name}: has quiet before validation",
)
@pytest.mark.parametrize("name, cls", subcmds.all_commands.items())
def test_common_validate_options(name: str, cls: Type[Command]) -> None:
"""Verify CommonValidateOptions sets up expected fields."""
cmd = cls()
opts, args = cmd.OptionParser.parse_args([])
cmd.CommonValidateOptions(opts, args)
# Verify the fields don't exist yet.
assert not hasattr(
opts, "verbose"
), f"{name}: has verbose before validation"
assert not hasattr(opts, "quiet"), f"{name}: has quiet before validation"
# Verify the fields exist now.
self.assertTrue(
hasattr(opts, "verbose"),
msg=f"{name}: missing verbose after validation",
)
self.assertTrue(
hasattr(opts, "quiet"),
msg=f"{name}: missing quiet after validation",
)
self.assertTrue(
hasattr(opts, "outer_manifest"),
msg=f"{name}: missing outer_manifest after validation",
)
cmd.CommonValidateOptions(opts, args)
def test_attribute_error_repro(self):
"""Confirm that accessing verbose before CommonValidateOptions fails."""
from subcmds.sync import Sync
# Verify the fields exist now.
assert hasattr(opts, "verbose"), f"{name}: missing verbose after validation"
assert hasattr(opts, "quiet"), f"{name}: missing quiet after validation"
assert hasattr(
opts, "outer_manifest"
), f"{name}: missing outer_manifest after validation"
cmd = Sync()
opts, args = cmd.OptionParser.parse_args([])
# This confirms that without the fix in main.py, an AttributeError
# would be raised because CommonValidateOptions hasn't been called yet.
with self.assertRaises(AttributeError):
_ = opts.verbose
def test_attribute_error_repro() -> None:
"""Confirm that accessing verbose before CommonValidateOptions fails."""
from subcmds.sync import Sync
cmd.CommonValidateOptions(opts, args)
self.assertTrue(hasattr(opts, "verbose"))
cmd = Sync()
opts, args = cmd.OptionParser.parse_args([])
# This confirms that without the fix in main.py, an AttributeError
# would be raised because CommonValidateOptions hasn't been called yet.
with pytest.raises(AttributeError):
_ = opts.verbose
cmd.CommonValidateOptions(opts, args)
assert hasattr(opts, "verbose")
+27 -24
View File
@@ -14,33 +14,36 @@
"""Unittests for the subcmds/init.py module."""
import unittest
from typing import List
import pytest
from subcmds import init
class InitCommand(unittest.TestCase):
"""Check registered all_commands."""
@pytest.mark.parametrize(
"argv",
([],),
)
def test_cli_parser_good(argv: List[str]) -> None:
"""Check valid command line options."""
cmd = init.Init()
opts, args = cmd.OptionParser.parse_args(argv)
cmd.ValidateOptions(opts, args)
def setUp(self):
self.cmd = init.Init()
def test_cli_parser_good(self):
"""Check valid command line options."""
ARGV = ([],)
for argv in ARGV:
opts, args = self.cmd.OptionParser.parse_args(argv)
self.cmd.ValidateOptions(opts, args)
def test_cli_parser_bad(self):
"""Check invalid command line options."""
ARGV = (
# Too many arguments.
["url", "asdf"],
# Conflicting options.
["--mirror", "--archive"],
)
for argv in ARGV:
opts, args = self.cmd.OptionParser.parse_args(argv)
with self.assertRaises(SystemExit):
self.cmd.ValidateOptions(opts, args)
@pytest.mark.parametrize(
"argv",
(
# Too many arguments.
["url", "asdf"],
# Conflicting options.
["--mirror", "--archive"],
),
)
def test_cli_parser_bad(argv: List[str]) -> None:
"""Check invalid command line options."""
cmd = init.Init()
opts, args = cmd.OptionParser.parse_args(argv)
with pytest.raises(SystemExit):
cmd.ValidateOptions(opts, args)
+4 -9
View File
@@ -14,15 +14,10 @@
"""Unittests for the update_manpages module."""
import unittest
from release import update_manpages
class UpdateManpagesTest(unittest.TestCase):
"""Tests the update-manpages code."""
def test_replace_regex(self):
"""Check that replace_regex works."""
data = "\n\033[1mSummary\033[m\n"
self.assertEqual(update_manpages.replace_regex(data), "\nSummary\n")
def test_replace_regex() -> None:
"""Check that replace_regex works."""
data = "\n\033[1mSummary\033[m\n"
assert update_manpages.replace_regex(data) == "\nSummary\n"