diff --git a/tests/test_git_trace2_event_log.py b/tests/test_git_trace2_event_log.py index be2d09b07..9a6ba2052 100644 --- a/tests/test_git_trace2_event_log.py +++ b/tests/test_git_trace2_event_log.py @@ -18,17 +18,24 @@ import contextlib import io import json import os +import re import socket import tempfile import threading -import unittest +from typing import Any, Dict, List, Optional from unittest import mock +import pytest + import git_trace2_event_log import platform_utils -def serverLoggingThread(socket_path, server_ready, received_traces): +def server_logging_thread( + socket_path: str, + server_ready: threading.Condition, + received_traces: List[str], +) -> None: """Helper function to receive logs over a Unix domain socket. Appends received messages on the provided socket and appends to @@ -57,405 +64,425 @@ def serverLoggingThread(socket_path, server_ready, received_traces): received_traces.extend(data.decode("utf-8").splitlines()) -class EventLogTestCase(unittest.TestCase): - """TestCase for the EventLog module.""" +PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" +PARENT_SID_VALUE = "parent_sid" +SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" +FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}" - PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" - PARENT_SID_VALUE = "parent_sid" - SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" - FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}" - def setUp(self): - """Load the event_log module every time.""" - self._event_log = None - # By default we initialize with the expected case where - # repo launches us (so GIT_TRACE2_PARENT_SID is set). - env = { - self.PARENT_SID_KEY: self.PARENT_SID_VALUE, - } - self._event_log = git_trace2_event_log.EventLog(env=env) - self._log_data = None +@pytest.fixture +def event_log() -> git_trace2_event_log.EventLog: + """Fixture for the EventLog module.""" + # By default we initialize with the expected case where + # repo launches us (so GIT_TRACE2_PARENT_SID is set). + env = {PARENT_SID_KEY: PARENT_SID_VALUE} + return git_trace2_event_log.EventLog(env=env) - def verifyCommonKeys( - self, log_entry, expected_event_name=None, full_sid=True + +def verify_common_keys( + log_entry: Dict[str, Any], + expected_event_name: Optional[str] = None, + full_sid: bool = True, +) -> None: + """Helper function to verify common event log keys.""" + assert "event" in log_entry + assert "sid" in log_entry + assert "thread" in log_entry + assert "time" in log_entry + + # Do basic data format validation. + if expected_event_name: + assert expected_event_name == log_entry["event"] + if full_sid: + assert re.match(FULL_SID_REGEX, log_entry["sid"]) + else: + assert re.match(SELF_SID_REGEX, log_entry["sid"]) + assert re.match(r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$", log_entry["time"]) + + +def read_log(log_path: str) -> List[Dict[str, Any]]: + """Helper function to read log data into a list.""" + log_data = [] + with open(log_path, mode="rb") as f: + for line in f: + log_data.append(json.loads(line)) + return log_data + + +def remove_prefix(s: str, prefix: str) -> str: + """Return a copy string after removing |prefix| from |s|, if present or + the original string.""" + if s.startswith(prefix): + return s[len(prefix) :] + else: + return s + + +def test_initial_state_with_parent_sid( + event_log: git_trace2_event_log.EventLog, +) -> None: + """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent.""" + assert re.match(FULL_SID_REGEX, event_log.full_sid) + + +def test_initial_state_no_parent_sid() -> None: + """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set.""" + # Setup an empty environment dict (no parent sid). + event_log = git_trace2_event_log.EventLog(env={}) + assert re.match(SELF_SID_REGEX, event_log.full_sid) + + +def test_version_event(event_log: git_trace2_event_log.EventLog) -> None: + """Test 'version' event data is valid. + + Verify that the 'version' event is written even when no other + events are added. + + Expected event log: + + """ + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + # A log with no added events should only have the version entry. + assert len(log_data) == 1 + version_event = log_data[0] + verify_common_keys(version_event, expected_event_name="version") + # Check for 'version' event specific fields. + assert "evt" in version_event + assert "exe" in version_event + # Verify "evt" version field is a string. + assert isinstance(version_event["evt"], str) + + +def test_start_event(event_log: git_trace2_event_log.EventLog) -> None: + """Test and validate 'start' event data is valid. + + Expected event log: + + + """ + event_log.StartEvent([]) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 2 + start_event = log_data[1] + verify_common_keys(log_data[0], expected_event_name="version") + verify_common_keys(start_event, expected_event_name="start") + # Check for 'start' event specific fields. + assert "argv" in start_event + assert isinstance(start_event["argv"], list) + + +def test_exit_event_result_none( + event_log: git_trace2_event_log.EventLog, +) -> None: + """Test 'exit' event data is valid when result is None. + + We expect None result to be converted to 0 in the exit event data. + + Expected event log: + + + """ + event_log.ExitEvent(None) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 2 + exit_event = log_data[1] + verify_common_keys(log_data[0], expected_event_name="version") + verify_common_keys(exit_event, expected_event_name="exit") + # Check for 'exit' event specific fields. + assert "code" in exit_event + # 'None' result should convert to 0 (successful) return code. + assert exit_event["code"] == 0 + + +def test_exit_event_result_integer( + event_log: git_trace2_event_log.EventLog, +) -> None: + """Test 'exit' event data is valid when result is an integer. + + Expected event log: + + + """ + event_log.ExitEvent(2) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 2 + exit_event = log_data[1] + verify_common_keys(log_data[0], expected_event_name="version") + verify_common_keys(exit_event, expected_event_name="exit") + # Check for 'exit' event specific fields. + assert "code" in exit_event + assert exit_event["code"] == 2 + + +def test_command_event(event_log: git_trace2_event_log.EventLog) -> None: + """Test and validate 'command' event data is valid. + + Expected event log: + + + """ + event_log.CommandEvent(name="repo", subcommands=["init", "this"]) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 2 + command_event = log_data[1] + verify_common_keys(log_data[0], expected_event_name="version") + verify_common_keys(command_event, expected_event_name="cmd_name") + # Check for 'command' event specific fields. + assert "name" in command_event + assert command_event["name"] == "repo-init-this" + + +def test_def_params_event_repo_config( + event_log: git_trace2_event_log.EventLog, +) -> None: + """Test 'def_params' event data outputs only repo config keys. + + Expected event log: + + + + """ + config = { + "git.foo": "bar", + "repo.partialclone": "true", + "repo.partialclonefilter": "blob:none", + } + event_log.DefParamRepoEvents(config) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 3 + def_param_events = log_data[1:] + verify_common_keys(log_data[0], expected_event_name="version") + + for event in def_param_events: + verify_common_keys(event, expected_event_name="def_param") + # Check for 'def_param' event specific fields. + assert "param" in event + assert "value" in event + assert event["param"].startswith("repo.") + + +def test_def_params_event_no_repo_config( + event_log: git_trace2_event_log.EventLog, +) -> None: + """Test 'def_params' event data won't output non-repo config keys. + + Expected event log: + + """ + config = { + "git.foo": "bar", + "git.core.foo2": "baz", + } + event_log.DefParamRepoEvents(config) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 1 + verify_common_keys(log_data[0], expected_event_name="version") + + +def test_data_event_config(event_log: git_trace2_event_log.EventLog) -> None: + """Test 'data' event data outputs all config keys. + + Expected event log: + + + + """ + config = { + "git.foo": "bar", + "repo.partialclone": "false", + "repo.syncstate.superproject.hassuperprojecttag": "true", + "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"], + } + prefix_value = "prefix" + event_log.LogDataConfigEvents(config, prefix_value) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 5 + data_events = log_data[1:] + verify_common_keys(log_data[0], expected_event_name="version") + + for event in data_events: + verify_common_keys(event) + # Check for 'data' event specific fields. + assert "key" in event + assert "value" in event + key = event["key"] + key = remove_prefix(key, f"{prefix_value}/") + value = event["value"] + assert event_log.GetDataEventName(value) == event["event"] + assert key in config + assert value == config[key] + + +def test_error_event(event_log: git_trace2_event_log.EventLog) -> None: + """Test and validate 'error' event data is valid. + + Expected event log: + + + """ + msg = "invalid option: --cahced" + fmt = "invalid option: %s" + event_log.ErrorEvent(msg, fmt) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = event_log.Write(path=tempdir) + log_data = read_log(log_path) + + assert len(log_data) == 2 + error_event = log_data[1] + verify_common_keys(log_data[0], expected_event_name="version") + verify_common_keys(error_event, expected_event_name="error") + # Check for 'error' event specific fields. + assert "msg" in error_event + assert "fmt" in error_event + assert error_event["msg"] == f"RepoErrorEvent:{msg}" + assert error_event["fmt"] == f"RepoErrorEvent:{fmt}" + + +def test_write_with_filename(event_log: git_trace2_event_log.EventLog) -> None: + """Test Write() with a path to a file exits with None.""" + assert event_log.Write(path="path/to/file") is None + + +def test_write_with_git_config( + tmp_path, + event_log: git_trace2_event_log.EventLog, +) -> None: + """Test Write() uses the git config path when 'git config' call succeeds.""" + with mock.patch.object( + event_log, + "_GetEventTargetPath", + return_value=str(tmp_path), ): - """Helper function to verify common event log keys.""" - self.assertIn("event", log_entry) - self.assertIn("sid", log_entry) - self.assertIn("thread", log_entry) - self.assertIn("time", log_entry) + assert os.path.dirname(event_log.Write()) == str(tmp_path) - # Do basic data format validation. - if expected_event_name: - self.assertEqual(expected_event_name, log_entry["event"]) - if full_sid: - self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX) - else: - self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX) - self.assertRegex( - log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$" + +def test_write_no_git_config(event_log: git_trace2_event_log.EventLog) -> None: + """Test Write() with no git config variable present exits with None.""" + with mock.patch.object(event_log, "_GetEventTargetPath", return_value=None): + assert event_log.Write() is None + + +def test_write_non_string(event_log: git_trace2_event_log.EventLog) -> None: + """Test Write() with non-string type for |path| throws TypeError.""" + with pytest.raises(TypeError): + event_log.Write(path=1234) + + +@pytest.mark.skipif( + not hasattr(socket, "AF_UNIX"), reason="Requires AF_UNIX sockets" +) +def test_write_socket(event_log: git_trace2_event_log.EventLog) -> None: + """Test Write() with Unix domain socket and validate received traces.""" + received_traces: List[str] = [] + with tempfile.TemporaryDirectory(prefix="test_server_sockets") as tempdir: + socket_path = os.path.join(tempdir, "server.sock") + server_ready = threading.Condition() + # Start "server" listening on Unix domain socket at socket_path. + server_thread = threading.Thread( + target=server_logging_thread, + args=(socket_path, server_ready, received_traces), ) + try: + server_thread.start() - def readLog(self, log_path): - """Helper function to read log data into a list.""" - log_data = [] - with open(log_path, mode="rb") as f: - for line in f: - log_data.append(json.loads(line)) - return log_data + with server_ready: + server_ready.wait(timeout=120) - def remove_prefix(self, s, prefix): - """Return a copy string after removing |prefix| from |s|, if present or - the original string.""" - if s.startswith(prefix): - return s[len(prefix) :] - else: - return s + event_log.StartEvent([]) + path = event_log.Write(path=f"af_unix:{socket_path}") + finally: + server_thread.join(timeout=5) - def test_initial_state_with_parent_sid(self): - """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent.""" - self.assertRegex(self._event_log.full_sid, self.FULL_SID_REGEX) - - def test_initial_state_no_parent_sid(self): - """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set.""" - # Setup an empty environment dict (no parent sid). - self._event_log = git_trace2_event_log.EventLog(env={}) - self.assertRegex(self._event_log.full_sid, self.SELF_SID_REGEX) - - def test_version_event(self): - """Test 'version' event data is valid. - - Verify that the 'version' event is written even when no other - events are addded. - - Expected event log: - - """ - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - # A log with no added events should only have the version entry. - self.assertEqual(len(self._log_data), 1) - version_event = self._log_data[0] - self.verifyCommonKeys(version_event, expected_event_name="version") - # Check for 'version' event specific fields. - self.assertIn("evt", version_event) - self.assertIn("exe", version_event) - # Verify "evt" version field is a string. - self.assertIsInstance(version_event["evt"], str) - - def test_start_event(self): - """Test and validate 'start' event data is valid. - - Expected event log: - - - """ - self._event_log.StartEvent([]) - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - start_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - self.verifyCommonKeys(start_event, expected_event_name="start") - # Check for 'start' event specific fields. - self.assertIn("argv", start_event) - self.assertTrue(isinstance(start_event["argv"], list)) - - def test_exit_event_result_none(self): - """Test 'exit' event data is valid when result is None. - - We expect None result to be converted to 0 in the exit event data. - - Expected event log: - - - """ - self._event_log.ExitEvent(None) - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - exit_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - self.verifyCommonKeys(exit_event, expected_event_name="exit") - # Check for 'exit' event specific fields. - self.assertIn("code", exit_event) - # 'None' result should convert to 0 (successful) return code. - self.assertEqual(exit_event["code"], 0) - - def test_exit_event_result_integer(self): - """Test 'exit' event data is valid when result is an integer. - - Expected event log: - - - """ - self._event_log.ExitEvent(2) - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - exit_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - self.verifyCommonKeys(exit_event, expected_event_name="exit") - # Check for 'exit' event specific fields. - self.assertIn("code", exit_event) - self.assertEqual(exit_event["code"], 2) - - def test_command_event(self): - """Test and validate 'command' event data is valid. - - Expected event log: - - - """ - self._event_log.CommandEvent(name="repo", subcommands=["init", "this"]) - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - command_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - self.verifyCommonKeys(command_event, expected_event_name="cmd_name") - # Check for 'command' event specific fields. - self.assertIn("name", command_event) - self.assertEqual(command_event["name"], "repo-init-this") - - def test_def_params_event_repo_config(self): - """Test 'def_params' event data outputs only repo config keys. - - Expected event log: - - - - """ - config = { - "git.foo": "bar", - "repo.partialclone": "true", - "repo.partialclonefilter": "blob:none", - } - self._event_log.DefParamRepoEvents(config) - - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 3) - def_param_events = self._log_data[1:] - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - - for event in def_param_events: - self.verifyCommonKeys(event, expected_event_name="def_param") - # Check for 'def_param' event specific fields. - self.assertIn("param", event) - self.assertIn("value", event) - self.assertTrue(event["param"].startswith("repo.")) - - def test_def_params_event_no_repo_config(self): - """Test 'def_params' event data won't output non-repo config keys. - - Expected event log: - - """ - config = { - "git.foo": "bar", - "git.core.foo2": "baz", - } - self._event_log.DefParamRepoEvents(config) - - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 1) - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - - def test_data_event_config(self): - """Test 'data' event data outputs all config keys. - - Expected event log: - - - - """ - config = { - "git.foo": "bar", - "repo.partialclone": "false", - "repo.syncstate.superproject.hassuperprojecttag": "true", - "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"], - } - prefix_value = "prefix" - self._event_log.LogDataConfigEvents(config, prefix_value) - - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 5) - data_events = self._log_data[1:] - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - - for event in data_events: - self.verifyCommonKeys(event) - # Check for 'data' event specific fields. - self.assertIn("key", event) - self.assertIn("value", event) - key = event["key"] - key = self.remove_prefix(key, f"{prefix_value}/") - value = event["value"] - self.assertEqual( - self._event_log.GetDataEventName(value), event["event"] - ) - self.assertTrue(key in config and value == config[key]) - - def test_error_event(self): - """Test and validate 'error' event data is valid. - - Expected event log: - - - """ - msg = "invalid option: --cahced" - fmt = "invalid option: %s" - self._event_log.ErrorEvent(msg, fmt) - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - log_path = self._event_log.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - error_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name="version") - self.verifyCommonKeys(error_event, expected_event_name="error") - # Check for 'error' event specific fields. - self.assertIn("msg", error_event) - self.assertIn("fmt", error_event) - self.assertEqual(error_event["msg"], f"RepoErrorEvent:{msg}") - self.assertEqual(error_event["fmt"], f"RepoErrorEvent:{fmt}") - - def test_write_with_filename(self): - """Test Write() with a path to a file exits with None.""" - self.assertIsNone(self._event_log.Write(path="path/to/file")) - - def test_write_with_git_config(self): - """Test Write() uses the git config path when 'git config' call - succeeds.""" - with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: - with mock.patch.object( - self._event_log, - "_GetEventTargetPath", - return_value=tempdir, - ): - self.assertEqual( - os.path.dirname(self._event_log.Write()), tempdir - ) - - def test_write_no_git_config(self): - """Test Write() with no git config variable present exits with None.""" - with mock.patch.object( - self._event_log, "_GetEventTargetPath", return_value=None - ): - self.assertIsNone(self._event_log.Write()) - - def test_write_non_string(self): - """Test Write() with non-string type for |path| throws TypeError.""" - with self.assertRaises(TypeError): - self._event_log.Write(path=1234) - - @unittest.skipIf(not hasattr(socket, "AF_UNIX"), "Requires AF_UNIX sockets") - def test_write_socket(self): - """Test Write() with Unix domain socket for |path| and validate received - traces.""" - received_traces = [] - with tempfile.TemporaryDirectory( - prefix="test_server_sockets" - ) as tempdir: - socket_path = os.path.join(tempdir, "server.sock") - server_ready = threading.Condition() - # Start "server" listening on Unix domain socket at socket_path. - server_thread = threading.Thread( - target=serverLoggingThread, - args=(socket_path, server_ready, received_traces), - ) - try: - server_thread.start() - - with server_ready: - server_ready.wait(timeout=120) - - self._event_log.StartEvent([]) - path = self._event_log.Write(path=f"af_unix:{socket_path}") - finally: - server_thread.join(timeout=5) - - self.assertEqual(path, f"af_unix:stream:{socket_path}") - self.assertEqual(len(received_traces), 2) - version_event = json.loads(received_traces[0]) - start_event = json.loads(received_traces[1]) - self.verifyCommonKeys(version_event, expected_event_name="version") - self.verifyCommonKeys(start_event, expected_event_name="start") - # Check for 'start' event specific fields. - self.assertIn("argv", start_event) - self.assertIsInstance(start_event["argv"], list) + assert path == f"af_unix:stream:{socket_path}" + assert len(received_traces) == 2 + version_event = json.loads(received_traces[0]) + start_event = json.loads(received_traces[1]) + verify_common_keys(version_event, expected_event_name="version") + verify_common_keys(start_event, expected_event_name="start") + # Check for 'start' event specific fields. + assert "argv" in start_event + assert isinstance(start_event["argv"], list) -class EventLogVerboseTestCase(unittest.TestCase): +class TestEventLogVerbose: """TestCase for the EventLog module verbose logging.""" - def setUp(self): - self._event_log = git_trace2_event_log.EventLog(env={}) - - def test_write_socket_error_no_verbose(self): + def test_write_socket_error_no_verbose(self) -> None: """Test Write() suppression of socket errors when not verbose.""" - self._event_log.verbose = False + event_log = git_trace2_event_log.EventLog(env={}) + event_log.verbose = False with contextlib.redirect_stderr( io.StringIO() ) as mock_stderr, mock.patch("socket.socket", side_effect=OSError): - self._event_log.Write(path="af_unix:stream:/tmp/test_sock") - self.assertEqual(mock_stderr.getvalue(), "") + event_log.Write(path="af_unix:stream:/tmp/test_sock") + assert mock_stderr.getvalue() == "" - def test_write_socket_error_verbose(self): + def test_write_socket_error_verbose(self) -> None: """Test Write() printing of socket errors when verbose.""" - self._event_log.verbose = True + event_log = git_trace2_event_log.EventLog(env={}) + event_log.verbose = True with contextlib.redirect_stderr( io.StringIO() ) as mock_stderr, mock.patch( "socket.socket", side_effect=OSError("Mock error") ): - self._event_log.Write(path="af_unix:stream:/tmp/test_sock") - self.assertIn( - "git trace2 logging failed: Mock error", - mock_stderr.getvalue(), + event_log.Write(path="af_unix:stream:/tmp/test_sock") + assert ( + "git trace2 logging failed: Mock error" + in mock_stderr.getvalue() ) - def test_write_file_error_no_verbose(self): + def test_write_file_error_no_verbose(self) -> None: """Test Write() suppression of file errors when not verbose.""" - self._event_log.verbose = False + event_log = git_trace2_event_log.EventLog(env={}) + event_log.verbose = False with contextlib.redirect_stderr( io.StringIO() ) as mock_stderr, mock.patch( "tempfile.NamedTemporaryFile", side_effect=FileExistsError ): - self._event_log.Write(path="/tmp") - self.assertEqual(mock_stderr.getvalue(), "") + event_log.Write(path="/tmp") + assert mock_stderr.getvalue() == "" - def test_write_file_error_verbose(self): + def test_write_file_error_verbose(self) -> None: """Test Write() printing of file errors when verbose.""" - self._event_log.verbose = True + event_log = git_trace2_event_log.EventLog(env={}) + event_log.verbose = True with contextlib.redirect_stderr( io.StringIO() ) as mock_stderr, mock.patch( "tempfile.NamedTemporaryFile", side_effect=FileExistsError("Mock error"), ): - self._event_log.Write(path="/tmp") - self.assertIn( - "git trace2 logging failed: FileExistsError", - mock_stderr.getvalue(), + event_log.Write(path="/tmp") + assert ( + "git trace2 logging failed: FileExistsError" + in mock_stderr.getvalue() ) diff --git a/tests/test_manifest_xml.py b/tests/test_manifest_xml.py index 5e0c78334..473f781ba 100644 --- a/tests/test_manifest_xml.py +++ b/tests/test_manifest_xml.py @@ -18,10 +18,10 @@ import os from pathlib import Path import platform import re -import tempfile -import unittest import xml.dom.minidom +import pytest + import error import manifest_xml @@ -66,7 +66,7 @@ if os.path.sep != "/": ) -def sort_attributes(manifest): +def sort_attributes(manifest: str) -> str: """Sort the attributes of all elements alphabetically. This is needed because different versions of the toxml() function from @@ -93,13 +93,12 @@ def sort_attributes(manifest): return new_manifest -class ManifestParseTestCase(unittest.TestCase): - """TestCase for parsing manifests.""" +class RepoClient: + """Basic empty repo checkout.""" - def setUp(self): - self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") - self.tempdir = Path(self.tempdirobj.name) - self.repodir = self.tempdir / ".repo" + def __init__(self, topdir: Path): + self.topdir = topdir + self.repodir = self.topdir / ".repo" self.manifest_dir = self.repodir / "manifests" self.manifest_file = self.repodir / manifest_xml.MANIFEST_FILE_NAME self.local_manifest_dir = ( @@ -107,7 +106,6 @@ class ManifestParseTestCase(unittest.TestCase): ) self.repodir.mkdir() self.manifest_dir.mkdir() - # The manifest parsing really wants a git repo currently. gitdir = self.repodir / "manifests.git" gitdir.mkdir() @@ -117,10 +115,7 @@ class ManifestParseTestCase(unittest.TestCase): """ ) - def tearDown(self): - self.tempdirobj.cleanup() - - def getXmlManifest(self, data): + def get_xml_manifest(self, data: str) -> manifest_xml.XmlManifest: """Helper to initialize a manifest for testing.""" self.manifest_file.write_text(data, encoding="utf-8") return manifest_xml.XmlManifest( @@ -128,33 +123,43 @@ class ManifestParseTestCase(unittest.TestCase): ) @staticmethod - def encodeXmlAttr(attr): + def encode_xml_attr(attr: str) -> str: """Encode |attr| using XML escape rules.""" return attr.replace("\r", " ").replace("\n", " ") -class ManifestValidateFilePaths(unittest.TestCase): +@pytest.fixture +def repo_client(tmp_path: Path) -> RepoClient: + """Generate a basic empty repo checkout. + + The manifest is not generated. + """ + return RepoClient(tmp_path) + + +class TestManifestValidateFilePaths: """Check _ValidateFilePaths helper. This doesn't access a real filesystem. """ - def check_both(self, *args): - manifest_xml.XmlManifest._ValidateFilePaths("copyfile", *args) - manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) + def check_both(self, src: str, dest: str) -> None: + """Check copyfile & linkfile.""" + manifest_xml.XmlManifest._ValidateFilePaths("copyfile", src, dest) + manifest_xml.XmlManifest._ValidateFilePaths("linkfile", src, dest) - def test_normal_path(self): + def test_normal_path(self) -> None: """Make sure good paths are accepted.""" self.check_both("foo", "bar") self.check_both("foo/bar", "bar") self.check_both("foo", "bar/bar") self.check_both("foo/bar", "bar/bar") - def test_symlink_targets(self): + def test_symlink_targets(self) -> None: """Some extra checks for symlinks.""" - def check(*args): - manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) + def check(src: str, dest: str) -> None: + manifest_xml.XmlManifest._ValidateFilePaths("linkfile", src, dest) # We allow symlinks to end in a slash since we allow them to point to # dirs in general. Technically the slash isn't necessary. @@ -162,114 +167,111 @@ class ManifestValidateFilePaths(unittest.TestCase): # We allow a single '.' to get a reference to the project itself. check(".", "bar") - def test_bad_paths(self): + def test_bad_paths(self) -> None: """Make sure bad paths (src & dest) are rejected.""" for path in INVALID_FS_PATHS: - self.assertRaises( - error.ManifestInvalidPathError, self.check_both, path, "a" - ) - self.assertRaises( - error.ManifestInvalidPathError, self.check_both, "a", path - ) + with pytest.raises(error.ManifestInvalidPathError): + self.check_both(path, "a") + with pytest.raises(error.ManifestInvalidPathError): + self.check_both("a", path) -class ValueTests(unittest.TestCase): +class TestValue: """Check utility parsing code.""" - def _get_node(self, text): + def _get_node(self, text: str) -> xml.dom.minidom.Element: return xml.dom.minidom.parseString(text).firstChild - def test_bool_default(self): + def test_bool_default(self) -> None: """Check XmlBool default handling.""" node = self._get_node("") - self.assertIsNone(manifest_xml.XmlBool(node, "a")) - self.assertIsNone(manifest_xml.XmlBool(node, "a", None)) - self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) + assert manifest_xml.XmlBool(node, "a") is None + assert manifest_xml.XmlBool(node, "a", None) is None + assert manifest_xml.XmlBool(node, "a", 123) == 123 node = self._get_node('') - self.assertIsNone(manifest_xml.XmlBool(node, "a")) + assert manifest_xml.XmlBool(node, "a") is None - def test_bool_invalid(self): + def test_bool_invalid(self) -> None: """Check XmlBool invalid handling.""" node = self._get_node('') - self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) + assert manifest_xml.XmlBool(node, "a", 123) == 123 - def test_bool_true(self): + def test_bool_true(self) -> None: """Check XmlBool true values.""" for value in ("yes", "true", "1"): node = self._get_node(f'') - self.assertTrue(manifest_xml.XmlBool(node, "a")) + assert manifest_xml.XmlBool(node, "a") is True - def test_bool_false(self): + def test_bool_false(self) -> None: """Check XmlBool false values.""" for value in ("no", "false", "0"): node = self._get_node(f'') - self.assertFalse(manifest_xml.XmlBool(node, "a")) + assert manifest_xml.XmlBool(node, "a") is False - def test_int_default(self): + def test_int_default(self) -> None: """Check XmlInt default handling.""" node = self._get_node("") - self.assertIsNone(manifest_xml.XmlInt(node, "a")) - self.assertIsNone(manifest_xml.XmlInt(node, "a", None)) - self.assertEqual(123, manifest_xml.XmlInt(node, "a", 123)) + assert manifest_xml.XmlInt(node, "a") is None + assert manifest_xml.XmlInt(node, "a", None) is None + assert manifest_xml.XmlInt(node, "a", 123) == 123 node = self._get_node('') - self.assertIsNone(manifest_xml.XmlInt(node, "a")) + assert manifest_xml.XmlInt(node, "a") is None - def test_int_good(self): + def test_int_good(self) -> None: """Check XmlInt numeric handling.""" for value in (-1, 0, 1, 50000): node = self._get_node(f'') - self.assertEqual(value, manifest_xml.XmlInt(node, "a")) + assert manifest_xml.XmlInt(node, "a") == value - def test_int_invalid(self): + def test_int_invalid(self) -> None: """Check XmlInt invalid handling.""" - with self.assertRaises(error.ManifestParseError): + with pytest.raises(error.ManifestParseError): node = self._get_node('') manifest_xml.XmlInt(node, "a") -class XmlManifestTests(ManifestParseTestCase): +class TestXmlManifest: """Check manifest processing.""" - def test_empty(self): + def test_empty(self, repo_client: RepoClient) -> None: """Parse an 'empty' manifest file.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( '' "" ) - self.assertEqual(manifest.remotes, {}) - self.assertEqual(manifest.projects, []) + assert manifest.remotes == {} + assert manifest.projects == [] - def test_link(self): + def test_link(self, repo_client: RepoClient) -> None: """Verify Link handling with new names.""" - manifest = manifest_xml.XmlManifest( - str(self.repodir), str(self.manifest_file) + manifest = repo_client.get_xml_manifest("") + (repo_client.manifest_dir / "foo.xml").write_text( + "" ) - (self.manifest_dir / "foo.xml").write_text("") manifest.Link("foo.xml") - self.assertIn( - '', self.manifest_file.read_text() + assert ( + '' + in repo_client.manifest_file.read_text() ) - def test_toxml_empty(self): + def test_toxml_empty(self, repo_client: RepoClient) -> None: """Verify the ToXml() helper.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( '' "" ) - self.assertEqual( - manifest.ToXml().toxml(), '' - ) + assert manifest.ToXml().toxml() == '' - def test_todict_empty(self): + def test_todict_empty(self, repo_client: RepoClient) -> None: """Verify the ToDict() helper.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( '' "" ) - self.assertEqual(manifest.ToDict(), {}) + assert manifest.ToDict() == {} - def test_toxml_omit_local(self): + def test_toxml_omit_local(self, repo_client: RepoClient) -> None: """Does not include local_manifests projects when omit_local=True.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( '' '' '' @@ -277,16 +279,16 @@ class XmlManifestTests(ManifestParseTestCase): '' "" ) - self.assertEqual( - sort_attributes(manifest.ToXml(omit_local=True).toxml()), - '' + assert ( + sort_attributes(manifest.ToXml(omit_local=True).toxml()) + == '' '' - '', + '' ) - def test_toxml_with_local(self): + def test_toxml_with_local(self, repo_client: RepoClient) -> None: """Does include local_manifests projects when omit_local=False.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( '' '' '' @@ -294,17 +296,17 @@ class XmlManifestTests(ManifestParseTestCase): '' "" ) - self.assertEqual( - sort_attributes(manifest.ToXml(omit_local=False).toxml()), - '' + assert ( + sort_attributes(manifest.ToXml(omit_local=False).toxml()) + == '' '' '' - '', + '' ) - def test_repo_hooks(self): + def test_repo_hooks(self, repo_client: RepoClient) -> None: """Check repo-hooks settings.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -314,14 +316,12 @@ class XmlManifestTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.repo_hooks_project.name, "repohooks") - self.assertEqual( - manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"] - ) + assert manifest.repo_hooks_project.name == "repohooks" + assert manifest.repo_hooks_project.enabled_repo_hooks == ["a", "b"] - def test_repo_hooks_unordered(self): - """Check repo-hooks settings work even if the project def comes second.""" # noqa: E501 - manifest = self.getXmlManifest( + def test_repo_hooks_unordered(self, repo_client: RepoClient) -> None: + """Check repo-hooks settings work when the project comes after.""" + manifest = repo_client.get_xml_manifest( """ @@ -331,14 +331,12 @@ class XmlManifestTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.repo_hooks_project.name, "repohooks") - self.assertEqual( - manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"] - ) + assert manifest.repo_hooks_project.name == "repohooks" + assert manifest.repo_hooks_project.enabled_repo_hooks == ["a", "b"] - def test_unknown_tags(self): + def test_unknown_tags(self, repo_client: RepoClient) -> None: """Check superproject settings.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -349,20 +347,20 @@ class XmlManifestTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.superproject.name, "superproject") - self.assertEqual(manifest.superproject.remote.name, "test-remote") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.superproject.name == "superproject" + assert manifest.superproject.remote.name == "test-remote" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' - "", + "" ) - def test_remote_annotations(self): + def test_remote_annotations(self, repo_client: RepoClient) -> None: """Check remote settings.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -371,24 +369,20 @@ class XmlManifestTests(ManifestParseTestCase): """ ) - self.assertEqual( - manifest.remotes["test-remote"].annotations[0].name, "foo" - ) - self.assertEqual( - manifest.remotes["test-remote"].annotations[0].value, "bar" - ) - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.remotes["test-remote"].annotations[0].name == "foo" + assert manifest.remotes["test-remote"].annotations[0].value == "bar" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' "" - "", + "" ) - def test_parse_with_xml_doctype(self): + def test_parse_with_xml_doctype(self, repo_client: RepoClient) -> None: """Check correct manifest parse with DOCTYPE node present.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -398,42 +392,41 @@ class XmlManifestTests(ManifestParseTestCase): """ ) - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].name, "test-project") + assert len(manifest.projects) == 1 + assert manifest.projects[0].name == "test-project" - def test_sync_j_max(self): + def test_sync_j_max(self, repo_client: RepoClient) -> None: """Check sync-j-max handling.""" # Check valid value. - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( '' ) - self.assertEqual(manifest.default.sync_j_max, 5) - self.assertEqual( - manifest.ToXml().toxml(), - '' - '', + assert manifest.default.sync_j_max == 5 + assert ( + manifest.ToXml().toxml() == '' + '' ) # Check invalid values. - with self.assertRaises(error.ManifestParseError): - manifest = self.getXmlManifest( + with pytest.raises(error.ManifestParseError): + manifest = repo_client.get_xml_manifest( '' ) manifest.ToXml() - with self.assertRaises(error.ManifestParseError): - manifest = self.getXmlManifest( + with pytest.raises(error.ManifestParseError): + manifest = repo_client.get_xml_manifest( '' ) manifest.ToXml() -class IncludeElementTests(ManifestParseTestCase): +class TestIncludeElement: """Tests for .""" - def test_revision_default(self): + def test_revision_default(self, repo_client: RepoClient) -> None: """Check handling of revision attribute.""" - root_m = self.manifest_dir / "root.xml" + root_m = repo_client.manifest_dir / "root.xml" root_m.write_text( """ @@ -445,7 +438,7 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "stable.xml").write_text( + (repo_client.manifest_dir / "stable.xml").write_text( """ @@ -455,7 +448,7 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "man1.xml").write_text( + (repo_client.manifest_dir / "man1.xml").write_text( """ @@ -463,7 +456,7 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "man2.xml").write_text( + (repo_client.manifest_dir / "man2.xml").write_text( """ @@ -471,31 +464,34 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) + include_m = manifest_xml.XmlManifest( + str(repo_client.repodir), str(root_m) + ) for proj in include_m.projects: if proj.name == "root-name1": # Check include revision not set on root level proj. - self.assertNotEqual("stable-branch", proj.revisionExpr) + assert proj.revisionExpr != "stable-branch" if proj.name == "root-name2": # Check root proj revision not removed. - self.assertEqual("refs/heads/main", proj.revisionExpr) + assert proj.revisionExpr == "refs/heads/main" if proj.name == "stable-name1": # Check stable proj has inherited revision include node. - self.assertEqual("stable-branch", proj.revisionExpr) + assert proj.revisionExpr == "stable-branch" if proj.name == "stable-name2": # Check stable proj revision can override include node. - self.assertEqual("stable-branch2", proj.revisionExpr) + assert proj.revisionExpr == "stable-branch2" if proj.name == "man1-name1": - self.assertEqual("stable-branch", proj.revisionExpr) + assert proj.revisionExpr == "stable-branch" if proj.name == "man1-name2": - self.assertEqual("stable-branch3", proj.revisionExpr) + assert proj.revisionExpr == "stable-branch3" if proj.name == "man2-name1": - self.assertEqual("stable-branch2", proj.revisionExpr) + assert proj.revisionExpr == "stable-branch2" if proj.name == "man2-name2": - self.assertEqual("stable-branch3", proj.revisionExpr) + assert proj.revisionExpr == "stable-branch3" - def test_group_levels(self): - root_m = self.manifest_dir / "root.xml" + def test_group_levels(self, repo_client: RepoClient) -> None: + """Check handling of nested include groups.""" + root_m = repo_client.manifest_dir / "root.xml" root_m.write_text( """ @@ -507,7 +503,7 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "level1.xml").write_text( + (repo_client.manifest_dir / "level1.xml").write_text( """ @@ -515,33 +511,38 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "level2.xml").write_text( + (repo_client.manifest_dir / "level2.xml").write_text( """ """ ) - include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) + include_m = manifest_xml.XmlManifest( + str(repo_client.repodir), str(root_m) + ) for proj in include_m.projects: if proj.name == "root-name1": # Check include group not set on root level proj. - self.assertNotIn("level1-group", proj.groups) + assert "level1-group" not in proj.groups if proj.name == "root-name2": # Check root proj group not removed. - self.assertIn("r2g1", proj.groups) + assert "r2g1" in proj.groups if proj.name == "level1-name1": # Check level1 proj has inherited group level 1. - self.assertIn("level1-group", proj.groups) + assert "level1-group" in proj.groups if proj.name == "level2-name1": # Check level2 proj has inherited group levels 1 and 2. - self.assertIn("level1-group", proj.groups) - self.assertIn("level2-group", proj.groups) + assert "level1-group" in proj.groups + assert "level2-group" in proj.groups # Check level2 proj group not removed. - self.assertIn("l2g1", proj.groups) + assert "l2g1" in proj.groups - def test_group_levels_with_extend_project(self): - root_m = self.manifest_dir / "root.xml" + def test_group_levels_with_extend_project( + self, repo_client: RepoClient + ) -> None: + """Check inheritance of groups via extend-project.""" + root_m = repo_client.manifest_dir / "root.xml" root_m.write_text( """ @@ -552,32 +553,36 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "man1.xml").write_text( + (repo_client.manifest_dir / "man1.xml").write_text( """ """ ) - (self.manifest_dir / "man2.xml").write_text( + (repo_client.manifest_dir / "man2.xml").write_text( """ """ ) - include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) + include_m = manifest_xml.XmlManifest( + str(repo_client.repodir), str(root_m) + ) proj = include_m.projects[0] # Check project has inherited group via project element. - self.assertIn("top-group1", proj.groups) + assert "top-group1" in proj.groups # Check project has inherited group via extend-project element. - self.assertIn("top-group2", proj.groups) + assert "top-group2" in proj.groups # Check project has set group via extend-project element. - self.assertIn("eg1", proj.groups) + assert "eg1" in proj.groups - def test_extend_project_does_not_inherit_local_groups(self): + def test_extend_project_does_not_inherit_local_groups( + self, repo_client: RepoClient + ) -> None: """Check that extend-project does not inherit local groups.""" - root_m = self.manifest_dir / "root.xml" + root_m = repo_client.manifest_dir / "root.xml" root_m.write_text( """ @@ -588,26 +593,28 @@ class IncludeElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "man1.xml").write_text( + (repo_client.manifest_dir / "man1.xml").write_text( """ """ ) - include_m = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) + include_m = manifest_xml.XmlManifest( + str(repo_client.repodir), str(root_m) + ) proj = include_m.projects[0] - self.assertIn("g1", proj.groups) - self.assertNotIn("local:g2", proj.groups) - self.assertIn("g3", proj.groups) + assert "g1" in proj.groups + assert "local:g2" not in proj.groups + assert "g3" in proj.groups - def test_allow_bad_name_from_user(self): + def test_allow_bad_name_from_user(self, repo_client: RepoClient) -> None: """Check handling of bad name attribute from the user's input.""" - def parse(name): - name = self.encodeXmlAttr(name) - manifest = self.getXmlManifest( + def parse(name: str) -> None: + name = repo_client.encode_xml_attr(name) + manifest = repo_client.get_xml_manifest( f""" @@ -620,26 +627,26 @@ class IncludeElementTests(ManifestParseTestCase): manifest.ToXml() # Setup target of the include. - target = self.tempdir / "target.xml" + target = repo_client.topdir / "target.xml" target.write_text("") # Include with absolute path. - parse(os.path.abspath(target)) + parse(str(target.absolute())) # Include with relative path. - parse(os.path.relpath(target, self.manifest_dir)) + parse(os.path.relpath(str(target), str(repo_client.manifest_dir))) - def test_bad_name_checks(self): + def test_bad_name_checks(self, repo_client: RepoClient) -> None: """Check handling of bad name attribute.""" - def parse(name): - name = self.encodeXmlAttr(name) + def parse(name: str) -> None: + name = repo_client.encode_xml_attr(name) # Setup target of the include. - (self.manifest_dir / "target.xml").write_text( + (repo_client.manifest_dir / "target.xml").write_text( f'' ) - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -652,23 +659,23 @@ class IncludeElementTests(ManifestParseTestCase): manifest.ToXml() # Handle empty name explicitly because a different codepath rejects it. - with self.assertRaises(error.ManifestParseError): + with pytest.raises(error.ManifestParseError): parse("") for path in INVALID_FS_PATHS: if not path: continue - with self.assertRaises(error.ManifestInvalidPathError): + with pytest.raises(error.ManifestInvalidPathError): parse(path) -class ProjectElementTests(ManifestParseTestCase): +class TestProjectElement: """Tests for .""" - def test_group(self): + def test_group(self, repo_client: RepoClient) -> None: """Check project group settings.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -678,28 +685,33 @@ class ProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(len(manifest.projects), 2) + assert len(manifest.projects) == 2 # Ordering isn't guaranteed. result = { manifest.projects[0].name: manifest.projects[0].groups, manifest.projects[1].name: manifest.projects[1].groups, } - self.assertEqual( - result["test-name"], {"name:test-name", "all", "path:test-path"} - ) - self.assertEqual( - result["extras"], - {"g1", "g2", "name:extras", "all", "path:path"}, - ) + assert result["test-name"] == { + "name:test-name", + "all", + "path:test-path", + } + assert result["extras"] == { + "g1", + "g2", + "name:extras", + "all", + "path:path", + } groupstr = "default,platform-" + platform.system().lower() - self.assertEqual(groupstr, manifest.GetManifestGroupsStr()) + assert manifest.GetManifestGroupsStr() == groupstr groupstr = "g1,g2,g1" manifest.manifestProject.config.SetString("manifest.groups", groupstr) - self.assertEqual(groupstr, manifest.GetManifestGroupsStr()) + assert manifest.GetManifestGroupsStr() == groupstr - def test_set_revision_id(self): + def test_set_revision_id(self, repo_client: RepoClient) -> None: """Check setting of project's revisionId.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -708,25 +720,25 @@ class ProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(len(manifest.projects), 1) + assert len(manifest.projects) == 1 project = manifest.projects[0] project.SetRevisionId("ABCDEF") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' # noqa: E501 - "", + "" ) - def test_trailing_slash(self): + def test_trailing_slash(self, repo_client: RepoClient) -> None: """Check handling of trailing slashes in attributes.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - return self.getXmlManifest( + def parse(name: str, path: str) -> manifest_xml.XmlManifest: + name = repo_client.encode_xml_attr(name) + path = repo_client.encode_xml_attr(path) + return repo_client.get_xml_manifest( f""" @@ -737,48 +749,36 @@ class ProjectElementTests(ManifestParseTestCase): ) manifest = parse("a/path/", "foo") - self.assertEqual( - os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + assert os.path.normpath(manifest.projects[0].gitdir) == os.path.join( + str(repo_client.topdir), ".repo", "projects", "foo.git" ) - self.assertEqual( - os.path.normpath(manifest.projects[0].objdir), - os.path.join( - self.tempdir, ".repo", "project-objects", "a", "path.git" - ), + assert os.path.normpath(manifest.projects[0].objdir) == os.path.join( + str(repo_client.topdir), ".repo", "project-objects", "a", "path.git" ) manifest = parse("a/path", "foo/") - self.assertEqual( - os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + assert os.path.normpath(manifest.projects[0].gitdir) == os.path.join( + str(repo_client.topdir), ".repo", "projects", "foo.git" ) - self.assertEqual( - os.path.normpath(manifest.projects[0].objdir), - os.path.join( - self.tempdir, ".repo", "project-objects", "a", "path.git" - ), + assert os.path.normpath(manifest.projects[0].objdir) == os.path.join( + str(repo_client.topdir), ".repo", "project-objects", "a", "path.git" ) manifest = parse("a/path", "foo//////") - self.assertEqual( - os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + assert os.path.normpath(manifest.projects[0].gitdir) == os.path.join( + str(repo_client.topdir), ".repo", "projects", "foo.git" ) - self.assertEqual( - os.path.normpath(manifest.projects[0].objdir), - os.path.join( - self.tempdir, ".repo", "project-objects", "a", "path.git" - ), + assert os.path.normpath(manifest.projects[0].objdir) == os.path.join( + str(repo_client.topdir), ".repo", "project-objects", "a", "path.git" ) - def test_toplevel_path(self): + def test_toplevel_path(self, repo_client: RepoClient) -> None: """Check handling of path=. specially.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - return self.getXmlManifest( + def parse(name: str, path: str) -> manifest_xml.XmlManifest: + name = repo_client.encode_xml_attr(name) + path = repo_client.encode_xml_attr(path) + return repo_client.get_xml_manifest( f""" @@ -790,18 +790,19 @@ class ProjectElementTests(ManifestParseTestCase): for path in (".", "./", ".//", ".///"): manifest = parse("server/path", path) - self.assertEqual( - os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, ".repo", "projects", "..git"), + assert os.path.normpath( + manifest.projects[0].gitdir + ) == os.path.join( + str(repo_client.topdir), ".repo", "projects", "..git" ) - def test_bad_path_name_checks(self): + def test_bad_path_name_checks(self, repo_client: RepoClient) -> None: """Check handling of bad path & name attributes.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - manifest = self.getXmlManifest( + def parse(name: str, path: str) -> None: + name = repo_client.encode_xml_attr(name) + path = repo_client.encode_xml_attr(path) + manifest = repo_client.get_xml_manifest( f""" @@ -818,28 +819,28 @@ class ProjectElementTests(ManifestParseTestCase): # Handle empty name explicitly because a different codepath rejects it. # Empty path is OK because it defaults to the name field. - with self.assertRaises(error.ManifestParseError): + with pytest.raises(error.ManifestParseError): parse("", "ok") for path in INVALID_FS_PATHS: if not path or path.endswith("/") or path.endswith(os.path.sep): continue - with self.assertRaises(error.ManifestInvalidPathError): + with pytest.raises(error.ManifestInvalidPathError): parse(path, "ok") # We have a dedicated test for path=".". if path not in {"."}: - with self.assertRaises(error.ManifestInvalidPathError): + with pytest.raises(error.ManifestInvalidPathError): parse("ok", path) -class SuperProjectElementTests(ManifestParseTestCase): +class TestSuperProjectElement: """Tests for .""" - def test_superproject(self): + def test_superproject(self, repo_client: RepoClient) -> None: """Check superproject settings.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -848,25 +849,24 @@ class SuperProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.superproject.name, "superproject") - self.assertEqual(manifest.superproject.remote.name, "test-remote") - self.assertEqual( - manifest.superproject.remote.url, "http://localhost/superproject" + assert manifest.superproject.name == "superproject" + assert manifest.superproject.remote.name == "test-remote" + assert ( + manifest.superproject.remote.url == "http://localhost/superproject" ) - self.assertEqual(manifest.superproject.revision, "refs/heads/main") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.superproject.revision == "refs/heads/main" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' - "", + "" ) - def test_superproject_revision(self): + def test_superproject_revision(self, repo_client: RepoClient) -> None: """Check superproject settings with a different revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -875,25 +875,26 @@ class SuperProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.superproject.name, "superproject") - self.assertEqual(manifest.superproject.remote.name, "test-remote") - self.assertEqual( - manifest.superproject.remote.url, "http://localhost/superproject" + assert manifest.superproject.name == "superproject" + assert manifest.superproject.remote.name == "test-remote" + assert ( + manifest.superproject.remote.url == "http://localhost/superproject" ) - self.assertEqual(manifest.superproject.revision, "refs/heads/stable") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.superproject.revision == "refs/heads/stable" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' - "", + "" ) - def test_superproject_revision_default_negative(self): + def test_superproject_revision_default_negative( + self, repo_client: RepoClient + ) -> None: """Check superproject settings with a same revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -902,51 +903,53 @@ class SuperProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.superproject.name, "superproject") - self.assertEqual(manifest.superproject.remote.name, "test-remote") - self.assertEqual( - manifest.superproject.remote.url, "http://localhost/superproject" + assert manifest.superproject.name == "superproject" + assert manifest.superproject.remote.name == "test-remote" + assert ( + manifest.superproject.remote.url == "http://localhost/superproject" ) - self.assertEqual(manifest.superproject.revision, "refs/heads/stable") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.superproject.revision == "refs/heads/stable" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' - "", + "" ) - def test_superproject_revision_remote(self): + def test_superproject_revision_remote( + self, repo_client: RepoClient + ) -> None: """Check superproject settings with a same revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ - + -""" # noqa: E501 +""" ) - self.assertEqual(manifest.superproject.name, "superproject") - self.assertEqual(manifest.superproject.remote.name, "test-remote") - self.assertEqual( - manifest.superproject.remote.url, "http://localhost/superproject" + assert manifest.superproject.name == "superproject" + assert manifest.superproject.remote.name == "test-remote" + assert ( + manifest.superproject.remote.url == "http://localhost/superproject" ) - self.assertEqual(manifest.superproject.revision, "refs/heads/stable") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.superproject.revision == "refs/heads/stable" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' # noqa: E501 '' '' - "", + "" ) - def test_remote(self): + def test_remote(self, repo_client: RepoClient) -> None: """Check superproject settings with a remote.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -956,28 +959,26 @@ class SuperProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.superproject.name, "platform/superproject") - self.assertEqual( - manifest.superproject.remote.name, "superproject-remote" + assert manifest.superproject.name == "platform/superproject" + assert manifest.superproject.remote.name == "superproject-remote" + assert ( + manifest.superproject.remote.url + == "http://localhost/platform/superproject" ) - self.assertEqual( - manifest.superproject.remote.url, - "http://localhost/platform/superproject", - ) - self.assertEqual(manifest.superproject.revision, "refs/heads/main") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.superproject.revision == "refs/heads/main" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' '' # noqa: E501 - "", + "" ) - def test_defalut_remote(self): + def test_default_remote(self, repo_client: RepoClient) -> None: """Check superproject settings with a default remote.""" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( """ @@ -986,62 +987,61 @@ class SuperProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.superproject.name, "superproject") - self.assertEqual(manifest.superproject.remote.name, "default-remote") - self.assertEqual(manifest.superproject.revision, "refs/heads/main") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.superproject.name == "superproject" + assert manifest.superproject.remote.name == "default-remote" + assert manifest.superproject.revision == "refs/heads/main" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' - "", + "" ) -class ContactinfoElementTests(ManifestParseTestCase): +class TestContactinfoElement: """Tests for .""" - def test_contactinfo(self): + def test_contactinfo(self, repo_client: RepoClient) -> None: """Check contactinfo settings.""" bugurl = "http://localhost/contactinfo" - manifest = self.getXmlManifest( + manifest = repo_client.get_xml_manifest( f""" """ ) - self.assertEqual(manifest.contactinfo.bugurl, bugurl) - self.assertEqual( - manifest.ToXml().toxml(), - '' + assert manifest.contactinfo.bugurl == bugurl + assert ( + manifest.ToXml().toxml() == '' f'' - "", + "" ) -class DefaultElementTests(ManifestParseTestCase): +class TestDefaultElement: """Tests for .""" - def test_default(self): + def test_default(self) -> None: """Check default settings.""" a = manifest_xml._Default() a.revisionExpr = "foo" a.remote = manifest_xml._XmlRemote(name="remote") b = manifest_xml._Default() b.revisionExpr = "bar" - self.assertEqual(a, a) - self.assertNotEqual(a, b) - self.assertNotEqual(b, a.remote) - self.assertNotEqual(a, 123) - self.assertNotEqual(a, None) + assert a == a + assert a != b + assert b != a.remote + assert a != 123 + assert a is not None -class RemoteElementTests(ManifestParseTestCase): +class TestRemoteElement: """Tests for .""" - def test_remote(self): + def test_remote(self) -> None: """Check remote settings.""" a = manifest_xml._XmlRemote(name="foo") a.AddAnnotation("key1", "value1", "true") @@ -1051,20 +1051,21 @@ class RemoteElementTests(ManifestParseTestCase): c.AddAnnotation("key1", "value2", "true") d = manifest_xml._XmlRemote(name="foo") d.AddAnnotation("key1", "value1", "false") - self.assertEqual(a, a) - self.assertNotEqual(a, b) - self.assertNotEqual(a, c) - self.assertNotEqual(a, d) - self.assertNotEqual(a, manifest_xml._Default()) - self.assertNotEqual(a, 123) - self.assertNotEqual(a, None) + assert a == a + assert a != b + assert a != c + assert a != d + assert a != manifest_xml._Default() + assert a != 123 + assert a is not None -class RemoveProjectElementTests(ManifestParseTestCase): +class TestRemoveProjectElement: """Tests for .""" - def test_remove_one_project(self): - manifest = self.getXmlManifest( + def test_remove_one_project(self, repo_client: RepoClient) -> None: + """Check removal of a single project.""" + manifest = repo_client.get_xml_manifest( """ @@ -1074,10 +1075,13 @@ class RemoveProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.projects, []) + assert manifest.projects == [] - def test_remove_one_project_one_remains(self): - manifest = self.getXmlManifest( + def test_remove_one_project_one_remains( + self, repo_client: RepoClient + ) -> None: + """Check removal of one project while another remains.""" + manifest = repo_client.get_xml_manifest( """ @@ -1089,24 +1093,30 @@ class RemoveProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].name, "yourproject") + assert len(manifest.projects) == 1 + assert manifest.projects[0].name == "yourproject" - def test_remove_one_project_doesnt_exist(self): - with self.assertRaises(manifest_xml.ManifestParseError): - manifest = self.getXmlManifest( - """ + def test_remove_one_project_doesnt_exist( + self, repo_client: RepoClient + ) -> None: + """Check removal of non-existent project fails.""" + manifest = repo_client.get_xml_manifest( + """ """ - ) + ) + with pytest.raises(error.ManifestParseError): manifest.projects - def test_remove_one_optional_project_doesnt_exist(self): - manifest = self.getXmlManifest( + def test_remove_one_optional_project_doesnt_exist( + self, repo_client: RepoClient + ) -> None: + """Check optional removal of non-existent project passes.""" + manifest = repo_client.get_xml_manifest( """ @@ -1115,10 +1125,11 @@ class RemoveProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.projects, []) + assert manifest.projects == [] - def test_remove_using_path_attrib(self): - manifest = self.getXmlManifest( + def test_remove_using_path_attrib(self, repo_client: RepoClient) -> None: + """Check removal using name and path attributes.""" + manifest = repo_client.get_xml_manifest( """ @@ -1145,18 +1156,21 @@ class RemoveProjectElementTests(ManifestParseTestCase): for proj in manifest.projects: if proj.name == "project1": found_proj1_path1 = True - self.assertEqual(proj.relpath, "tests/path1") + assert proj.relpath == "tests/path1" if proj.name == "project2": found_proj2 = True - self.assertNotEqual(proj.name, "project3") - self.assertNotEqual(proj.name, "project4") - self.assertNotEqual(proj.name, "project5") - self.assertNotEqual(proj.name, "project6") - self.assertTrue(found_proj1_path1) - self.assertTrue(found_proj2) + assert proj.name != "project3" + assert proj.name != "project4" + assert proj.name != "project5" + assert proj.name != "project6" + assert found_proj1_path1 + assert found_proj2 - def test_base_revision_checks_on_patching(self): - manifest_fail_wrong_tag = self.getXmlManifest( + def test_base_revision_checks_on_patching( + self, repo_client: RepoClient + ) -> None: + """Check base-rev validation during patching.""" + manifest_fail_wrong_tag = repo_client.get_xml_manifest( """ @@ -1166,10 +1180,10 @@ class RemoveProjectElementTests(ManifestParseTestCase): """ ) - with self.assertRaises(error.ManifestParseError): + with pytest.raises(error.ManifestParseError): manifest_fail_wrong_tag.ToXml() - manifest_fail_remove = self.getXmlManifest( + manifest_fail_remove = repo_client.get_xml_manifest( """ @@ -1179,10 +1193,10 @@ class RemoveProjectElementTests(ManifestParseTestCase): """ ) - with self.assertRaises(error.ManifestParseError): + with pytest.raises(error.ManifestParseError): manifest_fail_remove.ToXml() - manifest_fail_extend = self.getXmlManifest( + manifest_fail_extend = repo_client.get_xml_manifest( """ @@ -1192,10 +1206,10 @@ class RemoveProjectElementTests(ManifestParseTestCase): """ ) - with self.assertRaises(error.ManifestParseError): + with pytest.raises(error.ManifestParseError): manifest_fail_extend.ToXml() - manifest_fail_unknown = self.getXmlManifest( + manifest_fail_unknown = repo_client.get_xml_manifest( """ @@ -1205,10 +1219,10 @@ class RemoveProjectElementTests(ManifestParseTestCase): """ ) - with self.assertRaises(error.ManifestParseError): + with pytest.raises(error.ManifestParseError): manifest_fail_unknown.ToXml() - manifest_ok = self.getXmlManifest( + manifest_ok = repo_client.get_xml_manifest( """ @@ -1234,18 +1248,21 @@ class RemoveProjectElementTests(ManifestParseTestCase): found_proj2 = True if proj.name == "project3": found_proj3 = True - self.assertNotEqual(proj.name, "project1") - self.assertNotEqual(proj.name, "project4") - self.assertTrue(found_proj2) - self.assertTrue(found_proj3) - self.assertTrue(len(manifest_ok.projects) == 2) + assert proj.name != "project1" + assert proj.name != "project4" + assert found_proj2 + assert found_proj3 + assert len(manifest_ok.projects) == 2 -class ExtendProjectElementTests(ManifestParseTestCase): +class TestExtendProjectElement: """Tests for .""" - def test_extend_project_dest_path_single_match(self): - manifest = self.getXmlManifest( + def test_extend_project_dest_path_single_match( + self, repo_client: RepoClient + ) -> None: + """Check dest-path when single match exists.""" + manifest = repo_client.get_xml_manifest( """ @@ -1255,13 +1272,15 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].relpath, "bar") + assert len(manifest.projects) == 1 + assert manifest.projects[0].relpath == "bar" - def test_extend_project_dest_path_multi_match(self): - with self.assertRaises(manifest_xml.ManifestParseError): - manifest = self.getXmlManifest( - """ + def test_extend_project_dest_path_multi_match( + self, repo_client: RepoClient + ) -> None: + """Check dest-path when multiple matches exist fails.""" + manifest = repo_client.get_xml_manifest( + """ @@ -1270,11 +1289,15 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ - ) + ) + with pytest.raises(error.ManifestParseError): manifest.projects - def test_extend_project_dest_path_multi_match_path_specified(self): - manifest = self.getXmlManifest( + def test_extend_project_dest_path_multi_match_path_specified( + self, repo_client: RepoClient + ) -> None: + """Check dest-path when path is specified for multi-match.""" + manifest = repo_client.get_xml_manifest( """ @@ -1285,29 +1308,32 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(len(manifest.projects), 2) + assert len(manifest.projects) == 2 if manifest.projects[0].relpath == "y": - self.assertEqual(manifest.projects[1].relpath, "bar") + assert manifest.projects[1].relpath == "bar" else: - self.assertEqual(manifest.projects[0].relpath, "bar") - self.assertEqual(manifest.projects[1].relpath, "y") + assert manifest.projects[0].relpath == "bar" + assert manifest.projects[1].relpath == "y" - def test_extend_project_dest_branch(self): - manifest = self.getXmlManifest( + def test_extend_project_dest_branch(self, repo_client: RepoClient) -> None: + """Check dest-branch update via extend-project.""" + manifest = repo_client.get_xml_manifest( """ - + -""" # noqa: E501 +""" ) - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].dest_branch, "bar") + assert len(manifest.projects) == 1 + assert manifest.projects[0].dest_branch == "bar" - def test_extend_project_upstream(self): - manifest = self.getXmlManifest( + def test_extend_project_upstream(self, repo_client: RepoClient) -> None: + """Check upstream update via extend-project.""" + manifest = repo_client.get_xml_manifest( """ @@ -1317,11 +1343,12 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].upstream, "bar") + assert len(manifest.projects) == 1 + assert manifest.projects[0].upstream == "bar" - def test_extend_project_copyfiles(self): - manifest = self.getXmlManifest( + def test_extend_project_copyfiles(self, repo_client: RepoClient) -> None: + """Check copyfile addition via extend-project.""" + manifest = repo_client.get_xml_manifest( """ @@ -1333,21 +1360,24 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(list(manifest.projects[0].copyfiles)[0].src, "foo") - self.assertEqual(list(manifest.projects[0].copyfiles)[0].dest, "bar") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert list(manifest.projects[0].copyfiles)[0].src == "foo" + assert list(manifest.projects[0].copyfiles)[0].dest == "bar" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' '' "" - "", + "" ) - def test_extend_project_duplicate_copyfiles(self): - root_m = self.manifest_dir / "root.xml" + def test_extend_project_duplicate_copyfiles( + self, repo_client: RepoClient + ) -> None: + """Check duplicate copyfile handling in includes.""" + root_m = repo_client.manifest_dir / "root.xml" root_m.write_text( """ @@ -1359,21 +1389,21 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "man1.xml").write_text( + (repo_client.manifest_dir / "man1.xml").write_text( """ """ ) - (self.manifest_dir / "man2.xml").write_text( + (repo_client.manifest_dir / "man2.xml").write_text( """ """ ) - (self.manifest_dir / "common.xml").write_text( + (repo_client.manifest_dir / "common.xml").write_text( """ @@ -1382,13 +1412,16 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - manifest = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) - self.assertEqual(len(manifest.projects[0].copyfiles), 1) - self.assertEqual(list(manifest.projects[0].copyfiles)[0].src, "foo") - self.assertEqual(list(manifest.projects[0].copyfiles)[0].dest, "bar") + manifest = manifest_xml.XmlManifest( + str(repo_client.repodir), str(root_m) + ) + assert len(manifest.projects[0].copyfiles) == 1 + assert list(manifest.projects[0].copyfiles)[0].src == "foo" + assert list(manifest.projects[0].copyfiles)[0].dest == "bar" - def test_extend_project_linkfiles(self): - manifest = self.getXmlManifest( + def test_extend_project_linkfiles(self, repo_client: RepoClient) -> None: + """Check linkfile addition via extend-project.""" + manifest = repo_client.get_xml_manifest( """ @@ -1400,21 +1433,24 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(list(manifest.projects[0].linkfiles)[0].src, "foo") - self.assertEqual(list(manifest.projects[0].linkfiles)[0].dest, "bar") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert list(manifest.projects[0].linkfiles)[0].src == "foo" + assert list(manifest.projects[0].linkfiles)[0].dest == "bar" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' '' "" - "", + "" ) - def test_extend_project_duplicate_linkfiles(self): - root_m = self.manifest_dir / "root.xml" + def test_extend_project_duplicate_linkfiles( + self, repo_client: RepoClient + ) -> None: + """Check duplicate linkfile handling in includes.""" + root_m = repo_client.manifest_dir / "root.xml" root_m.write_text( """ @@ -1426,21 +1462,21 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - (self.manifest_dir / "man1.xml").write_text( + (repo_client.manifest_dir / "man1.xml").write_text( """ """ ) - (self.manifest_dir / "man2.xml").write_text( + (repo_client.manifest_dir / "man2.xml").write_text( """ """ ) - (self.manifest_dir / "common.xml").write_text( + (repo_client.manifest_dir / "common.xml").write_text( """ @@ -1449,13 +1485,16 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - manifest = manifest_xml.XmlManifest(str(self.repodir), str(root_m)) - self.assertEqual(len(manifest.projects[0].linkfiles), 1) - self.assertEqual(list(manifest.projects[0].linkfiles)[0].src, "foo") - self.assertEqual(list(manifest.projects[0].linkfiles)[0].dest, "bar") + manifest = manifest_xml.XmlManifest( + str(repo_client.repodir), str(root_m) + ) + assert len(manifest.projects[0].linkfiles) == 1 + assert list(manifest.projects[0].linkfiles)[0].src == "foo" + assert list(manifest.projects[0].linkfiles)[0].dest == "bar" - def test_extend_project_annotations(self): - manifest = self.getXmlManifest( + def test_extend_project_annotations(self, repo_client: RepoClient) -> None: + """Check annotation addition via extend-project.""" + manifest = repo_client.get_xml_manifest( """ @@ -1467,21 +1506,24 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual(manifest.projects[0].annotations[0].name, "foo") - self.assertEqual(manifest.projects[0].annotations[0].value, "bar") - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert manifest.projects[0].annotations[0].name == "foo" + assert manifest.projects[0].annotations[0].value == "bar" + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' '' "" - "", + "" ) - def test_extend_project_annotations_multiples(self): - manifest = self.getXmlManifest( + def test_extend_project_annotations_multiples( + self, repo_client: RepoClient + ) -> None: + """Check multiple annotation additions via extend-project.""" + manifest = repo_client.get_xml_manifest( """ @@ -1497,18 +1539,17 @@ class ExtendProjectElementTests(ManifestParseTestCase): """ ) - self.assertEqual( - [(a.name, a.value) for a in manifest.projects[0].annotations], - [ - ("foo", "bar"), - ("few", "bar"), - ("foo", "new_bar"), - ("new", "anno"), - ], - ) - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' + assert [ + (a.name, a.value) for a in manifest.projects[0].annotations + ] == [ + ("foo", "bar"), + ("few", "bar"), + ("foo", "new_bar"), + ("new", "anno"), + ] + assert ( + sort_attributes(manifest.ToXml().toxml()) + == '' '' '' '' @@ -1517,81 +1558,78 @@ class ExtendProjectElementTests(ManifestParseTestCase): '' '' "" - "", + "" ) -class NormalizeUrlTests(ManifestParseTestCase): +class TestNormalizeUrl: """Tests for normalize_url() in manifest_xml.py""" - def test_has_trailing_slash(self): + def test_has_trailing_slash(self) -> None: + """Trailing slashes should be removed.""" url = "http://foo.com/bar/baz/" - self.assertEqual( - "http://foo.com/bar/baz", manifest_xml.normalize_url(url) - ) + assert manifest_xml.normalize_url(url) == "http://foo.com/bar/baz" url = "http://foo.com/bar/" - self.assertEqual("http://foo.com/bar", manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == "http://foo.com/bar" - def test_has_leading_slash(self): + def test_has_leading_slash(self) -> None: """SCP-like syntax except a / comes before the : which git disallows.""" url = "/git@foo.com:bar/baf" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url url = "gi/t@foo.com:bar/baf" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url url = "git@fo/o.com:bar/baf" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url - def test_has_no_scheme(self): + def test_has_no_scheme(self) -> None: """Deal with cases where we have no scheme, but we also aren't dealing with the git SCP-like syntax """ url = "foo.com/baf/bat" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url url = "foo.com/baf" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url url = "git@foo.com/baf/bat" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url url = "git@foo.com/baf" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url url = "/file/path/here" - self.assertEqual(url, manifest_xml.normalize_url(url)) + assert manifest_xml.normalize_url(url) == url - def test_has_no_scheme_matches_scp_like_syntax(self): + def test_has_no_scheme_matches_scp_like_syntax(self) -> None: + """SCP-like syntax should be converted to ssh://.""" url = "git@foo.com:bar/baf" - self.assertEqual( - "ssh://git@foo.com/bar/baf", manifest_xml.normalize_url(url) - ) + assert manifest_xml.normalize_url(url) == "ssh://git@foo.com/bar/baf" url = "git@foo.com:bar/" - self.assertEqual( - "ssh://git@foo.com/bar", manifest_xml.normalize_url(url) - ) + assert manifest_xml.normalize_url(url) == "ssh://git@foo.com/bar" - def test_remote_url_resolution(self): + def test_remote_url_resolution(self) -> None: + """Check resolvedFetchUrl calculation.""" remote = manifest_xml._XmlRemote( name="foo", fetch="git@github.com:org2/", manifestUrl="git@github.com:org2/custom_manifest.git", ) - self.assertEqual("ssh://git@github.com/org2", remote.resolvedFetchUrl) + assert remote.resolvedFetchUrl == "ssh://git@github.com/org2" remote = manifest_xml._XmlRemote( name="foo", fetch="ssh://git@github.com/org2/", manifestUrl="git@github.com:org2/custom_manifest.git", ) - self.assertEqual("ssh://git@github.com/org2", remote.resolvedFetchUrl) + assert remote.resolvedFetchUrl == "ssh://git@github.com/org2" remote = manifest_xml._XmlRemote( name="foo", fetch="git@github.com:org2/", manifestUrl="ssh://git@github.com/org2/custom_manifest.git", ) - self.assertEqual("ssh://git@github.com/org2", remote.resolvedFetchUrl) + assert remote.resolvedFetchUrl == "ssh://git@github.com/org2" diff --git a/tests/test_subcmds_upload.py b/tests/test_subcmds_upload.py index cd8889778..51c0a4cb7 100644 --- a/tests/test_subcmds_upload.py +++ b/tests/test_subcmds_upload.py @@ -14,9 +14,10 @@ """Unittests for the subcmds/upload.py module.""" -import unittest from unittest import mock +import pytest + from error import GitError from error import UploadError from subcmds import upload @@ -26,45 +27,39 @@ class UnexpectedError(Exception): """An exception not expected by upload command.""" -class UploadCommand(unittest.TestCase): - """Check registered all_commands.""" +# A stub people list (reviewers, cc). +_STUB_PEOPLE = ([], []) - def setUp(self): - self.cmd = upload.Upload() - self.branch = mock.MagicMock() - self.people = mock.MagicMock() - self.opt, _ = self.cmd.OptionParser.parse_args([]) - mock.patch.object( - self.cmd, "_AppendAutoList", return_value=None - ).start() - mock.patch.object(self.cmd, "git_event_log").start() - def tearDown(self): - mock.patch.stopall() +@pytest.fixture +def cmd() -> upload.Upload: + """Fixture to provide an Upload command instance with mocked methods.""" + cmd = upload.Upload() + with mock.patch.object( + cmd, "_AppendAutoList", return_value=None + ), mock.patch.object(cmd, "git_event_log"): + yield cmd - def test_UploadAndReport_UploadError(self): - """Check UploadExitError raised when UploadError encountered.""" - side_effect = UploadError("upload error") - with mock.patch.object( - self.cmd, "_UploadBranch", side_effect=side_effect - ): - with self.assertRaises(upload.UploadExitError): - self.cmd._UploadAndReport(self.opt, [self.branch], self.people) - def test_UploadAndReport_GitError(self): - """Check UploadExitError raised when GitError encountered.""" - side_effect = GitError("some git error") - with mock.patch.object( - self.cmd, "_UploadBranch", side_effect=side_effect - ): - with self.assertRaises(upload.UploadExitError): - self.cmd._UploadAndReport(self.opt, [self.branch], self.people) +def test_UploadAndReport_UploadError(cmd: upload.Upload) -> None: + """Check UploadExitError raised when UploadError encountered.""" + opt, _ = cmd.OptionParser.parse_args([]) + with mock.patch.object(cmd, "_UploadBranch", side_effect=UploadError("")): + with pytest.raises(upload.UploadExitError): + cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE) - def test_UploadAndReport_UnhandledError(self): - """Check UnexpectedError passed through.""" - side_effect = UnexpectedError("some os error") - with mock.patch.object( - self.cmd, "_UploadBranch", side_effect=side_effect - ): - with self.assertRaises(type(side_effect)): - self.cmd._UploadAndReport(self.opt, [self.branch], self.people) + +def test_UploadAndReport_GitError(cmd: upload.Upload) -> None: + """Check UploadExitError raised when GitError encountered.""" + opt, _ = cmd.OptionParser.parse_args([]) + with mock.patch.object(cmd, "_UploadBranch", side_effect=GitError("")): + with pytest.raises(upload.UploadExitError): + cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE) + + +def test_UploadAndReport_UnhandledError(cmd: upload.Upload) -> None: + """Check UnexpectedError passed through.""" + opt, _ = cmd.OptionParser.parse_args([]) + with mock.patch.object(cmd, "_UploadBranch", side_effect=UnexpectedError): + with pytest.raises(UnexpectedError): + cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE) diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 7845ee163..7a1a87289 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -19,261 +19,298 @@ import os import re import subprocess import sys -import tempfile -import unittest from unittest import mock +import pytest import utils_for_test import main import wrapper -class RepoWrapperTestCase(unittest.TestCase): - """TestCase for the wrapper module.""" - - def setUp(self): - """Load the wrapper module every time.""" - wrapper.Wrapper.cache_clear() - self.wrapper = wrapper.Wrapper() +@pytest.fixture(autouse=True) +def reset_wrapper() -> None: + """Reset the wrapper module every time.""" + wrapper.Wrapper.cache_clear() -class RepoWrapperUnitTest(RepoWrapperTestCase): +@pytest.fixture +def repo_wrapper() -> wrapper.Wrapper: + """Fixture for the wrapper module.""" + return wrapper.Wrapper() + + +class GitCheckout: + """Class to hold git checkout info for tests.""" + + def __init__(self, git_dir, rev_list): + self.git_dir = git_dir + self.rev_list = rev_list + + +@pytest.fixture(scope="module") +def git_checkout(tmp_path_factory) -> GitCheckout: + """Fixture for tests that use a real/small git checkout. + + Create a repo to operate on, but do it once per-test-run. + """ + tempdir = tmp_path_factory.mktemp("repo-rev-tests") + run_git = wrapper.Wrapper().run_git + + remote = os.path.join(tempdir, "remote") + os.mkdir(remote) + + utils_for_test.init_git_tree(remote) + run_git("commit", "--allow-empty", "-minit", cwd=remote) + run_git("branch", "stable", cwd=remote) + run_git("tag", "v1.0", cwd=remote) + run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote) + rev_list = run_git("rev-list", "HEAD", cwd=remote).stdout.splitlines() + + run_git("init", cwd=tempdir) + run_git( + "fetch", + remote, + "+refs/heads/*:refs/remotes/origin/*", + cwd=tempdir, + ) + yield GitCheckout(tempdir, rev_list) + + +class TestRepoWrapper: """Tests helper functions in the repo wrapper""" - def test_version(self): + def test_version(self, repo_wrapper: wrapper.Wrapper) -> None: """Make sure _Version works.""" - with self.assertRaises(SystemExit) as e: + with pytest.raises(SystemExit) as e: with mock.patch("sys.stdout", new_callable=io.StringIO) as stdout: with mock.patch( "sys.stderr", new_callable=io.StringIO ) as stderr: - self.wrapper._Version() - self.assertEqual(0, e.exception.code) - self.assertEqual("", stderr.getvalue()) - self.assertIn("repo launcher version", stdout.getvalue()) + repo_wrapper._Version() + assert e.value.code == 0 + assert stderr.getvalue() == "" + assert "repo launcher version" in stdout.getvalue() - def test_python_constraints(self): + def test_python_constraints(self, repo_wrapper: wrapper.Wrapper) -> None: """The launcher should never require newer than main.py.""" - self.assertGreaterEqual( - main.MIN_PYTHON_VERSION_HARD, self.wrapper.MIN_PYTHON_VERSION_HARD + assert ( + main.MIN_PYTHON_VERSION_HARD >= repo_wrapper.MIN_PYTHON_VERSION_HARD ) - self.assertGreaterEqual( - main.MIN_PYTHON_VERSION_SOFT, self.wrapper.MIN_PYTHON_VERSION_SOFT + assert ( + main.MIN_PYTHON_VERSION_SOFT >= repo_wrapper.MIN_PYTHON_VERSION_SOFT ) # Make sure the versions are themselves in sync. - self.assertGreaterEqual( - self.wrapper.MIN_PYTHON_VERSION_SOFT, - self.wrapper.MIN_PYTHON_VERSION_HARD, + assert ( + repo_wrapper.MIN_PYTHON_VERSION_SOFT + >= repo_wrapper.MIN_PYTHON_VERSION_HARD ) - def test_init_parser(self): + def test_init_parser(self, repo_wrapper: wrapper.Wrapper) -> None: """Make sure 'init' GetParser works.""" - parser = self.wrapper.GetParser() + parser = repo_wrapper.GetParser() opts, args = parser.parse_args([]) - self.assertEqual([], args) - self.assertIsNone(opts.manifest_url) + assert args == [] + assert opts.manifest_url is None -class SetGitTrace2ParentSid(RepoWrapperTestCase): +class TestSetGitTrace2ParentSid: """Check SetGitTrace2ParentSid behavior.""" KEY = "GIT_TRACE2_PARENT_SID" VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$") - def test_first_set(self): + def test_first_set(self, repo_wrapper: wrapper.Wrapper) -> None: """Test env var not yet set.""" env = {} - self.wrapper.SetGitTrace2ParentSid(env) - self.assertIn(self.KEY, env) + repo_wrapper.SetGitTrace2ParentSid(env) + assert self.KEY in env value = env[self.KEY] - self.assertRegex(value, self.VALID_FORMAT) + assert self.VALID_FORMAT.match(value) - def test_append(self): + def test_append(self, repo_wrapper: wrapper.Wrapper) -> None: """Test env var is appended.""" env = {self.KEY: "pfx"} - self.wrapper.SetGitTrace2ParentSid(env) - self.assertIn(self.KEY, env) + repo_wrapper.SetGitTrace2ParentSid(env) + assert self.KEY in env value = env[self.KEY] - self.assertTrue(value.startswith("pfx/")) - self.assertRegex(value[4:], self.VALID_FORMAT) + assert value.startswith("pfx/") + assert self.VALID_FORMAT.match(value[4:]) - def test_global_context(self): + def test_global_context(self, repo_wrapper: wrapper.Wrapper) -> None: """Check os.environ gets updated by default.""" os.environ.pop(self.KEY, None) - self.wrapper.SetGitTrace2ParentSid() - self.assertIn(self.KEY, os.environ) + repo_wrapper.SetGitTrace2ParentSid() + assert self.KEY in os.environ value = os.environ[self.KEY] - self.assertRegex(value, self.VALID_FORMAT) + assert self.VALID_FORMAT.match(value) -class RunCommand(RepoWrapperTestCase): +class TestRunCommand: """Check run_command behavior.""" - def test_capture(self): + def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None: """Check capture_output handling.""" - ret = self.wrapper.run_command(["echo", "hi"], capture_output=True) + ret = repo_wrapper.run_command(["echo", "hi"], capture_output=True) # echo command appends OS specific linesep, but on Windows + Git Bash # we get UNIX ending, so we allow both. - self.assertIn(ret.stdout, ["hi" + os.linesep, "hi\n"]) + assert ret.stdout in ["hi" + os.linesep, "hi\n"] - def test_check(self): + def test_check(self, repo_wrapper: wrapper.Wrapper) -> None: """Check check handling.""" - self.wrapper.run_command(["true"], check=False) - self.wrapper.run_command(["true"], check=True) - self.wrapper.run_command(["false"], check=False) - with self.assertRaises(subprocess.CalledProcessError): - self.wrapper.run_command(["false"], check=True) + repo_wrapper.run_command(["true"], check=False) + repo_wrapper.run_command(["true"], check=True) + repo_wrapper.run_command(["false"], check=False) + with pytest.raises(subprocess.CalledProcessError): + repo_wrapper.run_command(["false"], check=True) -class RunGit(RepoWrapperTestCase): +class TestRunGit: """Check run_git behavior.""" - def test_capture(self): + def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None: """Check capture_output handling.""" - ret = self.wrapper.run_git("--version") - self.assertIn("git", ret.stdout) + ret = repo_wrapper.run_git("--version") + assert "git" in ret.stdout - def test_check(self): + def test_check(self, repo_wrapper: wrapper.Wrapper) -> None: """Check check handling.""" - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.run_git("--version-asdfasdf") - self.wrapper.run_git("--version-asdfasdf", check=False) + with pytest.raises(repo_wrapper.CloneFailure): + repo_wrapper.run_git("--version-asdfasdf") + repo_wrapper.run_git("--version-asdfasdf", check=False) -class ParseGitVersion(RepoWrapperTestCase): +class TestParseGitVersion: """Check ParseGitVersion behavior.""" - def test_autoload(self): + def test_autoload(self, repo_wrapper: wrapper.Wrapper) -> None: """Check we can load the version from the live git.""" - ret = self.wrapper.ParseGitVersion() - self.assertIsNotNone(ret) + assert repo_wrapper.ParseGitVersion() is not None - def test_bad_ver(self): + def test_bad_ver(self, repo_wrapper: wrapper.Wrapper) -> None: """Check handling of bad git versions.""" - ret = self.wrapper.ParseGitVersion(ver_str="asdf") - self.assertIsNone(ret) + assert repo_wrapper.ParseGitVersion(ver_str="asdf") is None - def test_normal_ver(self): + def test_normal_ver(self, repo_wrapper: wrapper.Wrapper) -> None: """Check handling of normal git versions.""" - ret = self.wrapper.ParseGitVersion(ver_str="git version 2.25.1") - self.assertEqual(2, ret.major) - self.assertEqual(25, ret.minor) - self.assertEqual(1, ret.micro) - self.assertEqual("2.25.1", ret.full) + ret = repo_wrapper.ParseGitVersion(ver_str="git version 2.25.1") + assert ret.major == 2 + assert ret.minor == 25 + assert ret.micro == 1 + assert ret.full == "2.25.1" - def test_extended_ver(self): + def test_extended_ver(self, repo_wrapper: wrapper.Wrapper) -> None: """Check handling of extended distro git versions.""" - ret = self.wrapper.ParseGitVersion( + ret = repo_wrapper.ParseGitVersion( ver_str="git version 1.30.50.696.g5e7596f4ac-goog" ) - self.assertEqual(1, ret.major) - self.assertEqual(30, ret.minor) - self.assertEqual(50, ret.micro) - self.assertEqual("1.30.50.696.g5e7596f4ac-goog", ret.full) + assert ret.major == 1 + assert ret.minor == 30 + assert ret.micro == 50 + assert ret.full == "1.30.50.696.g5e7596f4ac-goog" -class CheckGitVersion(RepoWrapperTestCase): +class TestCheckGitVersion: """Check _CheckGitVersion behavior.""" - def test_unknown(self): + def test_unknown(self, repo_wrapper: wrapper.Wrapper) -> None: """Unknown versions should abort.""" with mock.patch.object( - self.wrapper, "ParseGitVersion", return_value=None + repo_wrapper, "ParseGitVersion", return_value=None ): - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper._CheckGitVersion() + with pytest.raises(repo_wrapper.CloneFailure): + repo_wrapper._CheckGitVersion() - def test_old(self): + def test_old(self, repo_wrapper: wrapper.Wrapper) -> None: """Old versions should abort.""" with mock.patch.object( - self.wrapper, + repo_wrapper, "ParseGitVersion", - return_value=self.wrapper.GitVersion(1, 0, 0, "1.0.0"), + return_value=repo_wrapper.GitVersion(1, 0, 0, "1.0.0"), ): - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper._CheckGitVersion() + with pytest.raises(repo_wrapper.CloneFailure): + repo_wrapper._CheckGitVersion() - def test_new(self): + def test_new(self, repo_wrapper: wrapper.Wrapper) -> None: """Newer versions should run fine.""" with mock.patch.object( - self.wrapper, + repo_wrapper, "ParseGitVersion", - return_value=self.wrapper.GitVersion(100, 0, 0, "100.0.0"), + return_value=repo_wrapper.GitVersion(100, 0, 0, "100.0.0"), ): - self.wrapper._CheckGitVersion() + repo_wrapper._CheckGitVersion() -class Requirements(RepoWrapperTestCase): +class TestRequirements: """Check Requirements handling.""" - def test_missing_file(self): + def test_missing_file(self, repo_wrapper: wrapper.Wrapper) -> None: """Don't crash if the file is missing (old version).""" - self.assertIsNone( - self.wrapper.Requirements.from_dir(utils_for_test.THIS_DIR) + assert ( + repo_wrapper.Requirements.from_dir(utils_for_test.THIS_DIR) is None ) - self.assertIsNone( - self.wrapper.Requirements.from_file( + assert ( + repo_wrapper.Requirements.from_file( utils_for_test.THIS_DIR / "xxxxxxxxxxxxxxxxxxxxxxxx" ) + is None ) - def test_corrupt_data(self): + def test_corrupt_data(self, repo_wrapper: wrapper.Wrapper) -> None: """If the file can't be parsed, don't blow up.""" - self.assertIsNone(self.wrapper.Requirements.from_file(__file__)) - self.assertIsNone(self.wrapper.Requirements.from_data(b"x")) + assert repo_wrapper.Requirements.from_file(__file__) is None + assert repo_wrapper.Requirements.from_data(b"x") is None - def test_valid_data(self): + def test_valid_data(self, repo_wrapper: wrapper.Wrapper) -> None: """Make sure we can parse the file we ship.""" - self.assertIsNotNone(self.wrapper.Requirements.from_data(b"{}")) + assert repo_wrapper.Requirements.from_data(b"{}") is not None rootdir = utils_for_test.THIS_DIR.parent - self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir)) - self.assertIsNotNone( - self.wrapper.Requirements.from_file(rootdir / "requirements.json") + assert repo_wrapper.Requirements.from_dir(rootdir) is not None + assert ( + repo_wrapper.Requirements.from_file(rootdir / "requirements.json") + is not None ) - def test_format_ver(self): + def test_format_ver(self, repo_wrapper: wrapper.Wrapper) -> None: """Check format_ver can format.""" - self.assertEqual( - "1.2.3", self.wrapper.Requirements._format_ver((1, 2, 3)) - ) - self.assertEqual("1", self.wrapper.Requirements._format_ver([1])) + assert repo_wrapper.Requirements._format_ver((1, 2, 3)) == "1.2.3" + assert repo_wrapper.Requirements._format_ver([1]) == "1" - def test_assert_all_unknown(self): + def test_assert_all_unknown(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_all works with incompatible file.""" - reqs = self.wrapper.Requirements({}) + reqs = repo_wrapper.Requirements({}) reqs.assert_all() - def test_assert_all_new_repo(self): + def test_assert_all_new_repo(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_all accepts new enough repo.""" - reqs = self.wrapper.Requirements({"repo": {"hard": [1, 0]}}) + reqs = repo_wrapper.Requirements({"repo": {"hard": [1, 0]}}) reqs.assert_all() - def test_assert_all_old_repo(self): + def test_assert_all_old_repo(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_all rejects old repo.""" - reqs = self.wrapper.Requirements({"repo": {"hard": [99999, 0]}}) - with self.assertRaises(SystemExit): + reqs = repo_wrapper.Requirements({"repo": {"hard": [99999, 0]}}) + with pytest.raises(SystemExit): reqs.assert_all() - def test_assert_all_new_python(self): + def test_assert_all_new_python(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_all accepts new enough python.""" - reqs = self.wrapper.Requirements({"python": {"hard": sys.version_info}}) + reqs = repo_wrapper.Requirements({"python": {"hard": sys.version_info}}) reqs.assert_all() - def test_assert_all_old_python(self): + def test_assert_all_old_python(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_all rejects old python.""" - reqs = self.wrapper.Requirements({"python": {"hard": [99999, 0]}}) - with self.assertRaises(SystemExit): + reqs = repo_wrapper.Requirements({"python": {"hard": [99999, 0]}}) + with pytest.raises(SystemExit): reqs.assert_all() - def test_assert_ver_unknown(self): + def test_assert_ver_unknown(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_ver works with incompatible file.""" - reqs = self.wrapper.Requirements({}) + reqs = repo_wrapper.Requirements({}) reqs.assert_ver("xxx", (1, 0)) - def test_assert_ver_new(self): + def test_assert_ver_new(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_ver allows new enough versions.""" - reqs = self.wrapper.Requirements( + reqs = repo_wrapper.Requirements( {"git": {"hard": [1, 0], "soft": [2, 0]}} ) reqs.assert_ver("git", (1, 0)) @@ -281,274 +318,279 @@ class Requirements(RepoWrapperTestCase): reqs.assert_ver("git", (2, 0)) reqs.assert_ver("git", (2, 5)) - def test_assert_ver_old(self): + def test_assert_ver_old(self, repo_wrapper: wrapper.Wrapper) -> None: """Check assert_ver rejects old versions.""" - reqs = self.wrapper.Requirements( + reqs = repo_wrapper.Requirements( {"git": {"hard": [1, 0], "soft": [2, 0]}} ) - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): reqs.assert_ver("git", (0, 5)) -class NeedSetupGnuPG(RepoWrapperTestCase): +class TestNeedSetupGnuPG: """Check NeedSetupGnuPG behavior.""" - def test_missing_dir(self): + def test_missing_dir(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None: """The ~/.repoconfig tree doesn't exist yet.""" - with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: - self.wrapper.home_dot_repo = os.path.join(tempdir, "foo") - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + repo_wrapper.home_dot_repo = str(tmp_path / "foo") + assert repo_wrapper.NeedSetupGnuPG() - def test_missing_keyring(self): + def test_missing_keyring( + self, tmp_path, repo_wrapper: wrapper.Wrapper + ) -> None: """The keyring-version file doesn't exist yet.""" - with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: - self.wrapper.home_dot_repo = tempdir - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + repo_wrapper.home_dot_repo = str(tmp_path) + assert repo_wrapper.NeedSetupGnuPG() - def test_empty_keyring(self): + def test_empty_keyring( + self, tmp_path, repo_wrapper: wrapper.Wrapper + ) -> None: """The keyring-version file exists, but is empty.""" - with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, "keyring-version"), "w"): - pass - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + repo_wrapper.home_dot_repo = str(tmp_path) + (tmp_path / "keyring-version").write_text("") + assert repo_wrapper.NeedSetupGnuPG() - def test_old_keyring(self): + def test_old_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None: """The keyring-version file exists, but it's old.""" - with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, "keyring-version"), "w") as fp: - fp.write("1.0\n") - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + repo_wrapper.home_dot_repo = str(tmp_path) + (tmp_path / "keyring-version").write_text("1.0\n") + assert repo_wrapper.NeedSetupGnuPG() - def test_new_keyring(self): + def test_new_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None: """The keyring-version file exists, and is up-to-date.""" - with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, "keyring-version"), "w") as fp: - fp.write("1000.0\n") - self.assertFalse(self.wrapper.NeedSetupGnuPG()) + repo_wrapper.home_dot_repo = str(tmp_path) + (tmp_path / "keyring-version").write_text("1000.0\n") + assert not repo_wrapper.NeedSetupGnuPG() -class SetupGnuPG(RepoWrapperTestCase): +class TestSetupGnuPG: """Check SetupGnuPG behavior.""" - def test_full(self): + def test_full(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None: """Make sure it works completely.""" - with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: - self.wrapper.home_dot_repo = tempdir - self.wrapper.gpg_dir = os.path.join( - self.wrapper.home_dot_repo, "gnupg" - ) - self.assertTrue(self.wrapper.SetupGnuPG(True)) - with open(os.path.join(tempdir, "keyring-version")) as fp: - data = fp.read() - self.assertEqual( - ".".join(str(x) for x in self.wrapper.KEYRING_VERSION), - data.strip(), - ) + repo_wrapper.home_dot_repo = str(tmp_path) + repo_wrapper.gpg_dir = str(tmp_path / "gnupg") + assert repo_wrapper.SetupGnuPG(True) + data = (tmp_path / "keyring-version").read_text() + assert ( + ".".join(str(x) for x in repo_wrapper.KEYRING_VERSION) + == data.strip() + ) -class VerifyRev(RepoWrapperTestCase): +class TestVerifyRev: """Check verify_rev behavior.""" - def test_verify_passes(self): + def test_verify_passes(self, repo_wrapper: wrapper.Wrapper) -> None: """Check when we have a valid signed tag.""" desc_result = subprocess.CompletedProcess([], 0, "v1.0\n", "") gpg_result = subprocess.CompletedProcess([], 0, "", "") with mock.patch.object( - self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + repo_wrapper, "run_git", side_effect=(desc_result, gpg_result) ): - ret = self.wrapper.verify_rev( + ret = repo_wrapper.verify_rev( "/", "refs/heads/stable", "1234", True ) - self.assertEqual("v1.0^0", ret) + assert ret == "v1.0^0" - def test_unsigned_commit(self): + def test_unsigned_commit(self, repo_wrapper: wrapper.Wrapper) -> None: """Check we fall back to signed tag when we have an unsigned commit.""" desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "") gpg_result = subprocess.CompletedProcess([], 0, "", "") with mock.patch.object( - self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + repo_wrapper, "run_git", side_effect=(desc_result, gpg_result) ): - ret = self.wrapper.verify_rev( + ret = repo_wrapper.verify_rev( "/", "refs/heads/stable", "1234", True ) - self.assertEqual("v1.0^0", ret) + assert ret == "v1.0^0" - def test_verify_fails(self): + def test_verify_fails(self, repo_wrapper: wrapper.Wrapper) -> None: """Check we fall back to signed tag when we have an unsigned commit.""" desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "") gpg_result = RuntimeError with mock.patch.object( - self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + repo_wrapper, "run_git", side_effect=(desc_result, gpg_result) ): - with self.assertRaises(RuntimeError): - self.wrapper.verify_rev("/", "refs/heads/stable", "1234", True) + with pytest.raises(RuntimeError): + repo_wrapper.verify_rev("/", "refs/heads/stable", "1234", True) -class GitCheckoutTestCase(RepoWrapperTestCase): - """Tests that use a real/small git checkout.""" - - GIT_DIR = None - REV_LIST = None - - @classmethod - def setUpClass(cls): - # Create a repo to operate on, but do it once per-class. - cls.tempdirobj = tempfile.TemporaryDirectory(prefix="repo-rev-tests") - cls.GIT_DIR = cls.tempdirobj.name - run_git = wrapper.Wrapper().run_git - - remote = os.path.join(cls.GIT_DIR, "remote") - os.mkdir(remote) - - utils_for_test.init_git_tree(remote) - run_git("commit", "--allow-empty", "-minit", cwd=remote) - run_git("branch", "stable", cwd=remote) - run_git("tag", "v1.0", cwd=remote) - run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote) - cls.REV_LIST = run_git( - "rev-list", "HEAD", cwd=remote - ).stdout.splitlines() - - run_git("init", cwd=cls.GIT_DIR) - run_git( - "fetch", - remote, - "+refs/heads/*:refs/remotes/origin/*", - cwd=cls.GIT_DIR, - ) - - @classmethod - def tearDownClass(cls): - if not cls.tempdirobj: - return - - cls.tempdirobj.cleanup() - - -class ResolveRepoRev(GitCheckoutTestCase): +class TestResolveRepoRev: """Check resolve_repo_rev behavior.""" - def test_explicit_branch(self): + def test_explicit_branch( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Check refs/heads/branch argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev( - self.GIT_DIR, "refs/heads/stable" + rrev, lrev = repo_wrapper.resolve_repo_rev( + git_checkout.git_dir, "refs/heads/stable" ) - self.assertEqual("refs/heads/stable", rrev) - self.assertEqual(self.REV_LIST[1], lrev) + assert rrev == "refs/heads/stable" + assert lrev == git_checkout.rev_list[1] - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/heads/unknown") + with pytest.raises(repo_wrapper.CloneFailure): + repo_wrapper.resolve_repo_rev( + git_checkout.git_dir, "refs/heads/unknown" + ) - def test_explicit_tag(self): + def test_explicit_tag( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Check refs/tags/tag argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev( - self.GIT_DIR, "refs/tags/v1.0" + rrev, lrev = repo_wrapper.resolve_repo_rev( + git_checkout.git_dir, "refs/tags/v1.0" ) - self.assertEqual("refs/tags/v1.0", rrev) - self.assertEqual(self.REV_LIST[1], lrev) + assert rrev == "refs/tags/v1.0" + assert lrev == git_checkout.rev_list[1] - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/tags/unknown") + with pytest.raises(repo_wrapper.CloneFailure): + repo_wrapper.resolve_repo_rev( + git_checkout.git_dir, "refs/tags/unknown" + ) - def test_branch_name(self): + def test_branch_name( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Check branch argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "stable") - self.assertEqual("refs/heads/stable", rrev) - self.assertEqual(self.REV_LIST[1], lrev) + rrev, lrev = repo_wrapper.resolve_repo_rev( + git_checkout.git_dir, "stable" + ) + assert rrev == "refs/heads/stable" + assert lrev == git_checkout.rev_list[1] - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "main") - self.assertEqual("refs/heads/main", rrev) - self.assertEqual(self.REV_LIST[0], lrev) + rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "main") + assert rrev == "refs/heads/main" + assert lrev == git_checkout.rev_list[0] - def test_tag_name(self): + def test_tag_name( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Check tag argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "v1.0") - self.assertEqual("refs/tags/v1.0", rrev) - self.assertEqual(self.REV_LIST[1], lrev) + rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "v1.0") + assert rrev == "refs/tags/v1.0" + assert lrev == git_checkout.rev_list[1] - def test_full_commit(self): + def test_full_commit( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Check specific commit argument.""" - commit = self.REV_LIST[0] - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) - self.assertEqual(commit, rrev) - self.assertEqual(commit, lrev) + commit = git_checkout.rev_list[0] + rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit) + assert rrev == commit + assert lrev == commit - def test_partial_commit(self): + def test_partial_commit( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Check specific (partial) commit argument.""" - commit = self.REV_LIST[0][0:20] - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) - self.assertEqual(self.REV_LIST[0], rrev) - self.assertEqual(self.REV_LIST[0], lrev) + commit = git_checkout.rev_list[0][0:20] + rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit) + assert rrev == git_checkout.rev_list[0] + assert lrev == git_checkout.rev_list[0] - def test_unknown(self): + def test_unknown( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Check unknown ref/commit argument.""" - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, "boooooooya") + with pytest.raises(repo_wrapper.CloneFailure): + repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "boooooooya") -class CheckRepoVerify(RepoWrapperTestCase): +class TestCheckRepoVerify: """Check check_repo_verify behavior.""" - def test_no_verify(self): + def test_no_verify(self, repo_wrapper: wrapper.Wrapper) -> None: """Always fail with --no-repo-verify.""" - self.assertFalse(self.wrapper.check_repo_verify(False)) + assert not repo_wrapper.check_repo_verify(False) - def test_gpg_initialized(self): + def test_gpg_initialized( + self, + repo_wrapper: wrapper.Wrapper, + ) -> None: """Should pass if gpg is setup already.""" with mock.patch.object( - self.wrapper, "NeedSetupGnuPG", return_value=False + repo_wrapper, "NeedSetupGnuPG", return_value=False ): - self.assertTrue(self.wrapper.check_repo_verify(True)) + assert repo_wrapper.check_repo_verify(True) - def test_need_gpg_setup(self): + def test_need_gpg_setup( + self, + repo_wrapper: wrapper.Wrapper, + ) -> None: """Should pass/fail based on gpg setup.""" with mock.patch.object( - self.wrapper, "NeedSetupGnuPG", return_value=True + repo_wrapper, "NeedSetupGnuPG", return_value=True ): - with mock.patch.object(self.wrapper, "SetupGnuPG") as m: + with mock.patch.object(repo_wrapper, "SetupGnuPG") as m: m.return_value = True - self.assertTrue(self.wrapper.check_repo_verify(True)) + assert repo_wrapper.check_repo_verify(True) m.return_value = False - self.assertFalse(self.wrapper.check_repo_verify(True)) + assert not repo_wrapper.check_repo_verify(True) -class CheckRepoRev(GitCheckoutTestCase): +class TestCheckRepoRev: """Check check_repo_rev behavior.""" - def test_verify_works(self): + def test_verify_works( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Should pass when verification passes.""" with mock.patch.object( - self.wrapper, "check_repo_verify", return_value=True + repo_wrapper, "check_repo_verify", return_value=True ): with mock.patch.object( - self.wrapper, "verify_rev", return_value="12345" + repo_wrapper, "verify_rev", return_value="12345" ): - rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, "stable") - self.assertEqual("refs/heads/stable", rrev) - self.assertEqual("12345", lrev) + rrev, lrev = repo_wrapper.check_repo_rev( + git_checkout.git_dir, "stable" + ) + assert rrev == "refs/heads/stable" + assert lrev == "12345" - def test_verify_fails(self): + def test_verify_fails( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Should fail when verification fails.""" with mock.patch.object( - self.wrapper, "check_repo_verify", return_value=True + repo_wrapper, "check_repo_verify", return_value=True ): with mock.patch.object( - self.wrapper, "verify_rev", side_effect=RuntimeError + repo_wrapper, "verify_rev", side_effect=RuntimeError ): - with self.assertRaises(RuntimeError): - self.wrapper.check_repo_rev(self.GIT_DIR, "stable") + with pytest.raises(RuntimeError): + repo_wrapper.check_repo_rev(git_checkout.git_dir, "stable") - def test_verify_ignore(self): + def test_verify_ignore( + self, + repo_wrapper: wrapper.Wrapper, + git_checkout: GitCheckout, + ) -> None: """Should pass when verification is disabled.""" with mock.patch.object( - self.wrapper, "verify_rev", side_effect=RuntimeError + repo_wrapper, "verify_rev", side_effect=RuntimeError ): - rrev, lrev = self.wrapper.check_repo_rev( - self.GIT_DIR, "stable", repo_verify=False + rrev, lrev = repo_wrapper.check_repo_rev( + git_checkout.git_dir, "stable", repo_verify=False ) - self.assertEqual("refs/heads/stable", rrev) - self.assertEqual(self.REV_LIST[1], lrev) + assert rrev == "refs/heads/stable" + assert lrev == git_checkout.rev_list[1]