tests: convert more tests to pytest

Change-Id: Id4d48b61dc435564c336385bbc4944eb475d1942
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/569443
Tested-by: Mike Frysinger <vapier@google.com>
Commit-Queue: Mike Frysinger <vapier@google.com>
Reviewed-by: Gavin Mak <gavinmak@google.com>
This commit is contained in:
Mike Frysinger
2026-03-26 01:54:14 -04:00
committed by LUCI
parent ac2be4c089
commit 654690e1b8
4 changed files with 1295 additions and 1193 deletions
+390 -363
View File
@@ -18,17 +18,24 @@ import contextlib
import io
import json
import os
import re
import socket
import tempfile
import threading
import unittest
from typing import Any, Dict, List, Optional
from unittest import mock
import pytest
import git_trace2_event_log
import platform_utils
def serverLoggingThread(socket_path, server_ready, received_traces):
def server_logging_thread(
socket_path: str,
server_ready: threading.Condition,
received_traces: List[str],
) -> None:
"""Helper function to receive logs over a Unix domain socket.
Appends received messages on the provided socket and appends to
@@ -57,405 +64,425 @@ def serverLoggingThread(socket_path, server_ready, received_traces):
received_traces.extend(data.decode("utf-8").splitlines())
class EventLogTestCase(unittest.TestCase):
"""TestCase for the EventLog module."""
PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID"
PARENT_SID_VALUE = "parent_sid"
SELF_SID_REGEX = r"repo-\d+T\d+Z-.*"
FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}"
PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID"
PARENT_SID_VALUE = "parent_sid"
SELF_SID_REGEX = r"repo-\d+T\d+Z-.*"
FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}"
def setUp(self):
"""Load the event_log module every time."""
self._event_log = None
# By default we initialize with the expected case where
# repo launches us (so GIT_TRACE2_PARENT_SID is set).
env = {
self.PARENT_SID_KEY: self.PARENT_SID_VALUE,
}
self._event_log = git_trace2_event_log.EventLog(env=env)
self._log_data = None
@pytest.fixture
def event_log() -> git_trace2_event_log.EventLog:
"""Fixture for the EventLog module."""
# By default we initialize with the expected case where
# repo launches us (so GIT_TRACE2_PARENT_SID is set).
env = {PARENT_SID_KEY: PARENT_SID_VALUE}
return git_trace2_event_log.EventLog(env=env)
def verifyCommonKeys(
self, log_entry, expected_event_name=None, full_sid=True
def verify_common_keys(
log_entry: Dict[str, Any],
expected_event_name: Optional[str] = None,
full_sid: bool = True,
) -> None:
"""Helper function to verify common event log keys."""
assert "event" in log_entry
assert "sid" in log_entry
assert "thread" in log_entry
assert "time" in log_entry
# Do basic data format validation.
if expected_event_name:
assert expected_event_name == log_entry["event"]
if full_sid:
assert re.match(FULL_SID_REGEX, log_entry["sid"])
else:
assert re.match(SELF_SID_REGEX, log_entry["sid"])
assert re.match(r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$", log_entry["time"])
def read_log(log_path: str) -> List[Dict[str, Any]]:
"""Helper function to read log data into a list."""
log_data = []
with open(log_path, mode="rb") as f:
for line in f:
log_data.append(json.loads(line))
return log_data
def remove_prefix(s: str, prefix: str) -> str:
"""Return a copy string after removing |prefix| from |s|, if present or
the original string."""
if s.startswith(prefix):
return s[len(prefix) :]
else:
return s
def test_initial_state_with_parent_sid(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent."""
assert re.match(FULL_SID_REGEX, event_log.full_sid)
def test_initial_state_no_parent_sid() -> None:
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is not set."""
# Setup an empty environment dict (no parent sid).
event_log = git_trace2_event_log.EventLog(env={})
assert re.match(SELF_SID_REGEX, event_log.full_sid)
def test_version_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test 'version' event data is valid.
Verify that the 'version' event is written even when no other
events are added.
Expected event log:
<version event>
"""
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
# A log with no added events should only have the version entry.
assert len(log_data) == 1
version_event = log_data[0]
verify_common_keys(version_event, expected_event_name="version")
# Check for 'version' event specific fields.
assert "evt" in version_event
assert "exe" in version_event
# Verify "evt" version field is a string.
assert isinstance(version_event["evt"], str)
def test_start_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test and validate 'start' event data is valid.
Expected event log:
<version event>
<start event>
"""
event_log.StartEvent([])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
start_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
assert "argv" in start_event
assert isinstance(start_event["argv"], list)
def test_exit_event_result_none(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'exit' event data is valid when result is None.
We expect None result to be converted to 0 in the exit event data.
Expected event log:
<version event>
<exit event>
"""
event_log.ExitEvent(None)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
exit_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
assert "code" in exit_event
# 'None' result should convert to 0 (successful) return code.
assert exit_event["code"] == 0
def test_exit_event_result_integer(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'exit' event data is valid when result is an integer.
Expected event log:
<version event>
<exit event>
"""
event_log.ExitEvent(2)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
exit_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
assert "code" in exit_event
assert exit_event["code"] == 2
def test_command_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test and validate 'command' event data is valid.
Expected event log:
<version event>
<command event>
"""
event_log.CommandEvent(name="repo", subcommands=["init", "this"])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
command_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(command_event, expected_event_name="cmd_name")
# Check for 'command' event specific fields.
assert "name" in command_event
assert command_event["name"] == "repo-init-this"
def test_def_params_event_repo_config(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'def_params' event data outputs only repo config keys.
Expected event log:
<version event>
<def_param event>
<def_param event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "true",
"repo.partialclonefilter": "blob:none",
}
event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 3
def_param_events = log_data[1:]
verify_common_keys(log_data[0], expected_event_name="version")
for event in def_param_events:
verify_common_keys(event, expected_event_name="def_param")
# Check for 'def_param' event specific fields.
assert "param" in event
assert "value" in event
assert event["param"].startswith("repo.")
def test_def_params_event_no_repo_config(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'def_params' event data won't output non-repo config keys.
Expected event log:
<version event>
"""
config = {
"git.foo": "bar",
"git.core.foo2": "baz",
}
event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 1
verify_common_keys(log_data[0], expected_event_name="version")
def test_data_event_config(event_log: git_trace2_event_log.EventLog) -> None:
"""Test 'data' event data outputs all config keys.
Expected event log:
<version event>
<data event>
<data event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "false",
"repo.syncstate.superproject.hassuperprojecttag": "true",
"repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"],
}
prefix_value = "prefix"
event_log.LogDataConfigEvents(config, prefix_value)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 5
data_events = log_data[1:]
verify_common_keys(log_data[0], expected_event_name="version")
for event in data_events:
verify_common_keys(event)
# Check for 'data' event specific fields.
assert "key" in event
assert "value" in event
key = event["key"]
key = remove_prefix(key, f"{prefix_value}/")
value = event["value"]
assert event_log.GetDataEventName(value) == event["event"]
assert key in config
assert value == config[key]
def test_error_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test and validate 'error' event data is valid.
Expected event log:
<version event>
<error event>
"""
msg = "invalid option: --cahced"
fmt = "invalid option: %s"
event_log.ErrorEvent(msg, fmt)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
error_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(error_event, expected_event_name="error")
# Check for 'error' event specific fields.
assert "msg" in error_event
assert "fmt" in error_event
assert error_event["msg"] == f"RepoErrorEvent:{msg}"
assert error_event["fmt"] == f"RepoErrorEvent:{fmt}"
def test_write_with_filename(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with a path to a file exits with None."""
assert event_log.Write(path="path/to/file") is None
def test_write_with_git_config(
tmp_path,
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test Write() uses the git config path when 'git config' call succeeds."""
with mock.patch.object(
event_log,
"_GetEventTargetPath",
return_value=str(tmp_path),
):
"""Helper function to verify common event log keys."""
self.assertIn("event", log_entry)
self.assertIn("sid", log_entry)
self.assertIn("thread", log_entry)
self.assertIn("time", log_entry)
assert os.path.dirname(event_log.Write()) == str(tmp_path)
# Do basic data format validation.
if expected_event_name:
self.assertEqual(expected_event_name, log_entry["event"])
if full_sid:
self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX)
else:
self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX)
self.assertRegex(
log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$"
def test_write_no_git_config(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with no git config variable present exits with None."""
with mock.patch.object(event_log, "_GetEventTargetPath", return_value=None):
assert event_log.Write() is None
def test_write_non_string(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with non-string type for |path| throws TypeError."""
with pytest.raises(TypeError):
event_log.Write(path=1234)
@pytest.mark.skipif(
not hasattr(socket, "AF_UNIX"), reason="Requires AF_UNIX sockets"
)
def test_write_socket(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with Unix domain socket and validate received traces."""
received_traces: List[str] = []
with tempfile.TemporaryDirectory(prefix="test_server_sockets") as tempdir:
socket_path = os.path.join(tempdir, "server.sock")
server_ready = threading.Condition()
# Start "server" listening on Unix domain socket at socket_path.
server_thread = threading.Thread(
target=server_logging_thread,
args=(socket_path, server_ready, received_traces),
)
try:
server_thread.start()
def readLog(self, log_path):
"""Helper function to read log data into a list."""
log_data = []
with open(log_path, mode="rb") as f:
for line in f:
log_data.append(json.loads(line))
return log_data
with server_ready:
server_ready.wait(timeout=120)
def remove_prefix(self, s, prefix):
"""Return a copy string after removing |prefix| from |s|, if present or
the original string."""
if s.startswith(prefix):
return s[len(prefix) :]
else:
return s
event_log.StartEvent([])
path = event_log.Write(path=f"af_unix:{socket_path}")
finally:
server_thread.join(timeout=5)
def test_initial_state_with_parent_sid(self):
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent."""
self.assertRegex(self._event_log.full_sid, self.FULL_SID_REGEX)
def test_initial_state_no_parent_sid(self):
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is not set."""
# Setup an empty environment dict (no parent sid).
self._event_log = git_trace2_event_log.EventLog(env={})
self.assertRegex(self._event_log.full_sid, self.SELF_SID_REGEX)
def test_version_event(self):
"""Test 'version' event data is valid.
Verify that the 'version' event is written even when no other
events are addded.
Expected event log:
<version event>
"""
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
# A log with no added events should only have the version entry.
self.assertEqual(len(self._log_data), 1)
version_event = self._log_data[0]
self.verifyCommonKeys(version_event, expected_event_name="version")
# Check for 'version' event specific fields.
self.assertIn("evt", version_event)
self.assertIn("exe", version_event)
# Verify "evt" version field is a string.
self.assertIsInstance(version_event["evt"], str)
def test_start_event(self):
"""Test and validate 'start' event data is valid.
Expected event log:
<version event>
<start event>
"""
self._event_log.StartEvent([])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
start_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
self.assertIn("argv", start_event)
self.assertTrue(isinstance(start_event["argv"], list))
def test_exit_event_result_none(self):
"""Test 'exit' event data is valid when result is None.
We expect None result to be converted to 0 in the exit event data.
Expected event log:
<version event>
<exit event>
"""
self._event_log.ExitEvent(None)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
exit_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
self.assertIn("code", exit_event)
# 'None' result should convert to 0 (successful) return code.
self.assertEqual(exit_event["code"], 0)
def test_exit_event_result_integer(self):
"""Test 'exit' event data is valid when result is an integer.
Expected event log:
<version event>
<exit event>
"""
self._event_log.ExitEvent(2)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
exit_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
self.assertIn("code", exit_event)
self.assertEqual(exit_event["code"], 2)
def test_command_event(self):
"""Test and validate 'command' event data is valid.
Expected event log:
<version event>
<command event>
"""
self._event_log.CommandEvent(name="repo", subcommands=["init", "this"])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
command_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(command_event, expected_event_name="cmd_name")
# Check for 'command' event specific fields.
self.assertIn("name", command_event)
self.assertEqual(command_event["name"], "repo-init-this")
def test_def_params_event_repo_config(self):
"""Test 'def_params' event data outputs only repo config keys.
Expected event log:
<version event>
<def_param event>
<def_param event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "true",
"repo.partialclonefilter": "blob:none",
}
self._event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 3)
def_param_events = self._log_data[1:]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
for event in def_param_events:
self.verifyCommonKeys(event, expected_event_name="def_param")
# Check for 'def_param' event specific fields.
self.assertIn("param", event)
self.assertIn("value", event)
self.assertTrue(event["param"].startswith("repo."))
def test_def_params_event_no_repo_config(self):
"""Test 'def_params' event data won't output non-repo config keys.
Expected event log:
<version event>
"""
config = {
"git.foo": "bar",
"git.core.foo2": "baz",
}
self._event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 1)
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
def test_data_event_config(self):
"""Test 'data' event data outputs all config keys.
Expected event log:
<version event>
<data event>
<data event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "false",
"repo.syncstate.superproject.hassuperprojecttag": "true",
"repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"],
}
prefix_value = "prefix"
self._event_log.LogDataConfigEvents(config, prefix_value)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 5)
data_events = self._log_data[1:]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
for event in data_events:
self.verifyCommonKeys(event)
# Check for 'data' event specific fields.
self.assertIn("key", event)
self.assertIn("value", event)
key = event["key"]
key = self.remove_prefix(key, f"{prefix_value}/")
value = event["value"]
self.assertEqual(
self._event_log.GetDataEventName(value), event["event"]
)
self.assertTrue(key in config and value == config[key])
def test_error_event(self):
"""Test and validate 'error' event data is valid.
Expected event log:
<version event>
<error event>
"""
msg = "invalid option: --cahced"
fmt = "invalid option: %s"
self._event_log.ErrorEvent(msg, fmt)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
error_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(error_event, expected_event_name="error")
# Check for 'error' event specific fields.
self.assertIn("msg", error_event)
self.assertIn("fmt", error_event)
self.assertEqual(error_event["msg"], f"RepoErrorEvent:{msg}")
self.assertEqual(error_event["fmt"], f"RepoErrorEvent:{fmt}")
def test_write_with_filename(self):
"""Test Write() with a path to a file exits with None."""
self.assertIsNone(self._event_log.Write(path="path/to/file"))
def test_write_with_git_config(self):
"""Test Write() uses the git config path when 'git config' call
succeeds."""
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
with mock.patch.object(
self._event_log,
"_GetEventTargetPath",
return_value=tempdir,
):
self.assertEqual(
os.path.dirname(self._event_log.Write()), tempdir
)
def test_write_no_git_config(self):
"""Test Write() with no git config variable present exits with None."""
with mock.patch.object(
self._event_log, "_GetEventTargetPath", return_value=None
):
self.assertIsNone(self._event_log.Write())
def test_write_non_string(self):
"""Test Write() with non-string type for |path| throws TypeError."""
with self.assertRaises(TypeError):
self._event_log.Write(path=1234)
@unittest.skipIf(not hasattr(socket, "AF_UNIX"), "Requires AF_UNIX sockets")
def test_write_socket(self):
"""Test Write() with Unix domain socket for |path| and validate received
traces."""
received_traces = []
with tempfile.TemporaryDirectory(
prefix="test_server_sockets"
) as tempdir:
socket_path = os.path.join(tempdir, "server.sock")
server_ready = threading.Condition()
# Start "server" listening on Unix domain socket at socket_path.
server_thread = threading.Thread(
target=serverLoggingThread,
args=(socket_path, server_ready, received_traces),
)
try:
server_thread.start()
with server_ready:
server_ready.wait(timeout=120)
self._event_log.StartEvent([])
path = self._event_log.Write(path=f"af_unix:{socket_path}")
finally:
server_thread.join(timeout=5)
self.assertEqual(path, f"af_unix:stream:{socket_path}")
self.assertEqual(len(received_traces), 2)
version_event = json.loads(received_traces[0])
start_event = json.loads(received_traces[1])
self.verifyCommonKeys(version_event, expected_event_name="version")
self.verifyCommonKeys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
self.assertIn("argv", start_event)
self.assertIsInstance(start_event["argv"], list)
assert path == f"af_unix:stream:{socket_path}"
assert len(received_traces) == 2
version_event = json.loads(received_traces[0])
start_event = json.loads(received_traces[1])
verify_common_keys(version_event, expected_event_name="version")
verify_common_keys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
assert "argv" in start_event
assert isinstance(start_event["argv"], list)
class EventLogVerboseTestCase(unittest.TestCase):
class TestEventLogVerbose:
"""TestCase for the EventLog module verbose logging."""
def setUp(self):
self._event_log = git_trace2_event_log.EventLog(env={})
def test_write_socket_error_no_verbose(self):
def test_write_socket_error_no_verbose(self) -> None:
"""Test Write() suppression of socket errors when not verbose."""
self._event_log.verbose = False
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = False
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch("socket.socket", side_effect=OSError):
self._event_log.Write(path="af_unix:stream:/tmp/test_sock")
self.assertEqual(mock_stderr.getvalue(), "")
event_log.Write(path="af_unix:stream:/tmp/test_sock")
assert mock_stderr.getvalue() == ""
def test_write_socket_error_verbose(self):
def test_write_socket_error_verbose(self) -> None:
"""Test Write() printing of socket errors when verbose."""
self._event_log.verbose = True
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = True
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch(
"socket.socket", side_effect=OSError("Mock error")
):
self._event_log.Write(path="af_unix:stream:/tmp/test_sock")
self.assertIn(
"git trace2 logging failed: Mock error",
mock_stderr.getvalue(),
event_log.Write(path="af_unix:stream:/tmp/test_sock")
assert (
"git trace2 logging failed: Mock error"
in mock_stderr.getvalue()
)
def test_write_file_error_no_verbose(self):
def test_write_file_error_no_verbose(self) -> None:
"""Test Write() suppression of file errors when not verbose."""
self._event_log.verbose = False
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = False
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch(
"tempfile.NamedTemporaryFile", side_effect=FileExistsError
):
self._event_log.Write(path="/tmp")
self.assertEqual(mock_stderr.getvalue(), "")
event_log.Write(path="/tmp")
assert mock_stderr.getvalue() == ""
def test_write_file_error_verbose(self):
def test_write_file_error_verbose(self) -> None:
"""Test Write() printing of file errors when verbose."""
self._event_log.verbose = True
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = True
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch(
"tempfile.NamedTemporaryFile",
side_effect=FileExistsError("Mock error"),
):
self._event_log.Write(path="/tmp")
self.assertIn(
"git trace2 logging failed: FileExistsError",
mock_stderr.getvalue(),
event_log.Write(path="/tmp")
assert (
"git trace2 logging failed: FileExistsError"
in mock_stderr.getvalue()
)
+536 -498
View File
File diff suppressed because it is too large Load Diff
+33 -38
View File
@@ -14,9 +14,10 @@
"""Unittests for the subcmds/upload.py module."""
import unittest
from unittest import mock
import pytest
from error import GitError
from error import UploadError
from subcmds import upload
@@ -26,45 +27,39 @@ class UnexpectedError(Exception):
"""An exception not expected by upload command."""
class UploadCommand(unittest.TestCase):
"""Check registered all_commands."""
# A stub people list (reviewers, cc).
_STUB_PEOPLE = ([], [])
def setUp(self):
self.cmd = upload.Upload()
self.branch = mock.MagicMock()
self.people = mock.MagicMock()
self.opt, _ = self.cmd.OptionParser.parse_args([])
mock.patch.object(
self.cmd, "_AppendAutoList", return_value=None
).start()
mock.patch.object(self.cmd, "git_event_log").start()
def tearDown(self):
mock.patch.stopall()
@pytest.fixture
def cmd() -> upload.Upload:
"""Fixture to provide an Upload command instance with mocked methods."""
cmd = upload.Upload()
with mock.patch.object(
cmd, "_AppendAutoList", return_value=None
), mock.patch.object(cmd, "git_event_log"):
yield cmd
def test_UploadAndReport_UploadError(self):
"""Check UploadExitError raised when UploadError encountered."""
side_effect = UploadError("upload error")
with mock.patch.object(
self.cmd, "_UploadBranch", side_effect=side_effect
):
with self.assertRaises(upload.UploadExitError):
self.cmd._UploadAndReport(self.opt, [self.branch], self.people)
def test_UploadAndReport_GitError(self):
"""Check UploadExitError raised when GitError encountered."""
side_effect = GitError("some git error")
with mock.patch.object(
self.cmd, "_UploadBranch", side_effect=side_effect
):
with self.assertRaises(upload.UploadExitError):
self.cmd._UploadAndReport(self.opt, [self.branch], self.people)
def test_UploadAndReport_UploadError(cmd: upload.Upload) -> None:
"""Check UploadExitError raised when UploadError encountered."""
opt, _ = cmd.OptionParser.parse_args([])
with mock.patch.object(cmd, "_UploadBranch", side_effect=UploadError("")):
with pytest.raises(upload.UploadExitError):
cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
def test_UploadAndReport_UnhandledError(self):
"""Check UnexpectedError passed through."""
side_effect = UnexpectedError("some os error")
with mock.patch.object(
self.cmd, "_UploadBranch", side_effect=side_effect
):
with self.assertRaises(type(side_effect)):
self.cmd._UploadAndReport(self.opt, [self.branch], self.people)
def test_UploadAndReport_GitError(cmd: upload.Upload) -> None:
"""Check UploadExitError raised when GitError encountered."""
opt, _ = cmd.OptionParser.parse_args([])
with mock.patch.object(cmd, "_UploadBranch", side_effect=GitError("")):
with pytest.raises(upload.UploadExitError):
cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
def test_UploadAndReport_UnhandledError(cmd: upload.Upload) -> None:
"""Check UnexpectedError passed through."""
opt, _ = cmd.OptionParser.parse_args([])
with mock.patch.object(cmd, "_UploadBranch", side_effect=UnexpectedError):
with pytest.raises(UnexpectedError):
cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
+336 -294
View File
@@ -19,261 +19,298 @@ import os
import re
import subprocess
import sys
import tempfile
import unittest
from unittest import mock
import pytest
import utils_for_test
import main
import wrapper
class RepoWrapperTestCase(unittest.TestCase):
"""TestCase for the wrapper module."""
def setUp(self):
"""Load the wrapper module every time."""
wrapper.Wrapper.cache_clear()
self.wrapper = wrapper.Wrapper()
@pytest.fixture(autouse=True)
def reset_wrapper() -> None:
"""Reset the wrapper module every time."""
wrapper.Wrapper.cache_clear()
class RepoWrapperUnitTest(RepoWrapperTestCase):
@pytest.fixture
def repo_wrapper() -> wrapper.Wrapper:
"""Fixture for the wrapper module."""
return wrapper.Wrapper()
class GitCheckout:
"""Class to hold git checkout info for tests."""
def __init__(self, git_dir, rev_list):
self.git_dir = git_dir
self.rev_list = rev_list
@pytest.fixture(scope="module")
def git_checkout(tmp_path_factory) -> GitCheckout:
"""Fixture for tests that use a real/small git checkout.
Create a repo to operate on, but do it once per-test-run.
"""
tempdir = tmp_path_factory.mktemp("repo-rev-tests")
run_git = wrapper.Wrapper().run_git
remote = os.path.join(tempdir, "remote")
os.mkdir(remote)
utils_for_test.init_git_tree(remote)
run_git("commit", "--allow-empty", "-minit", cwd=remote)
run_git("branch", "stable", cwd=remote)
run_git("tag", "v1.0", cwd=remote)
run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote)
rev_list = run_git("rev-list", "HEAD", cwd=remote).stdout.splitlines()
run_git("init", cwd=tempdir)
run_git(
"fetch",
remote,
"+refs/heads/*:refs/remotes/origin/*",
cwd=tempdir,
)
yield GitCheckout(tempdir, rev_list)
class TestRepoWrapper:
"""Tests helper functions in the repo wrapper"""
def test_version(self):
def test_version(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure _Version works."""
with self.assertRaises(SystemExit) as e:
with pytest.raises(SystemExit) as e:
with mock.patch("sys.stdout", new_callable=io.StringIO) as stdout:
with mock.patch(
"sys.stderr", new_callable=io.StringIO
) as stderr:
self.wrapper._Version()
self.assertEqual(0, e.exception.code)
self.assertEqual("", stderr.getvalue())
self.assertIn("repo launcher version", stdout.getvalue())
repo_wrapper._Version()
assert e.value.code == 0
assert stderr.getvalue() == ""
assert "repo launcher version" in stdout.getvalue()
def test_python_constraints(self):
def test_python_constraints(self, repo_wrapper: wrapper.Wrapper) -> None:
"""The launcher should never require newer than main.py."""
self.assertGreaterEqual(
main.MIN_PYTHON_VERSION_HARD, self.wrapper.MIN_PYTHON_VERSION_HARD
assert (
main.MIN_PYTHON_VERSION_HARD >= repo_wrapper.MIN_PYTHON_VERSION_HARD
)
self.assertGreaterEqual(
main.MIN_PYTHON_VERSION_SOFT, self.wrapper.MIN_PYTHON_VERSION_SOFT
assert (
main.MIN_PYTHON_VERSION_SOFT >= repo_wrapper.MIN_PYTHON_VERSION_SOFT
)
# Make sure the versions are themselves in sync.
self.assertGreaterEqual(
self.wrapper.MIN_PYTHON_VERSION_SOFT,
self.wrapper.MIN_PYTHON_VERSION_HARD,
assert (
repo_wrapper.MIN_PYTHON_VERSION_SOFT
>= repo_wrapper.MIN_PYTHON_VERSION_HARD
)
def test_init_parser(self):
def test_init_parser(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure 'init' GetParser works."""
parser = self.wrapper.GetParser()
parser = repo_wrapper.GetParser()
opts, args = parser.parse_args([])
self.assertEqual([], args)
self.assertIsNone(opts.manifest_url)
assert args == []
assert opts.manifest_url is None
class SetGitTrace2ParentSid(RepoWrapperTestCase):
class TestSetGitTrace2ParentSid:
"""Check SetGitTrace2ParentSid behavior."""
KEY = "GIT_TRACE2_PARENT_SID"
VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$")
def test_first_set(self):
def test_first_set(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Test env var not yet set."""
env = {}
self.wrapper.SetGitTrace2ParentSid(env)
self.assertIn(self.KEY, env)
repo_wrapper.SetGitTrace2ParentSid(env)
assert self.KEY in env
value = env[self.KEY]
self.assertRegex(value, self.VALID_FORMAT)
assert self.VALID_FORMAT.match(value)
def test_append(self):
def test_append(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Test env var is appended."""
env = {self.KEY: "pfx"}
self.wrapper.SetGitTrace2ParentSid(env)
self.assertIn(self.KEY, env)
repo_wrapper.SetGitTrace2ParentSid(env)
assert self.KEY in env
value = env[self.KEY]
self.assertTrue(value.startswith("pfx/"))
self.assertRegex(value[4:], self.VALID_FORMAT)
assert value.startswith("pfx/")
assert self.VALID_FORMAT.match(value[4:])
def test_global_context(self):
def test_global_context(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check os.environ gets updated by default."""
os.environ.pop(self.KEY, None)
self.wrapper.SetGitTrace2ParentSid()
self.assertIn(self.KEY, os.environ)
repo_wrapper.SetGitTrace2ParentSid()
assert self.KEY in os.environ
value = os.environ[self.KEY]
self.assertRegex(value, self.VALID_FORMAT)
assert self.VALID_FORMAT.match(value)
class RunCommand(RepoWrapperTestCase):
class TestRunCommand:
"""Check run_command behavior."""
def test_capture(self):
def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check capture_output handling."""
ret = self.wrapper.run_command(["echo", "hi"], capture_output=True)
ret = repo_wrapper.run_command(["echo", "hi"], capture_output=True)
# echo command appends OS specific linesep, but on Windows + Git Bash
# we get UNIX ending, so we allow both.
self.assertIn(ret.stdout, ["hi" + os.linesep, "hi\n"])
assert ret.stdout in ["hi" + os.linesep, "hi\n"]
def test_check(self):
def test_check(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check check handling."""
self.wrapper.run_command(["true"], check=False)
self.wrapper.run_command(["true"], check=True)
self.wrapper.run_command(["false"], check=False)
with self.assertRaises(subprocess.CalledProcessError):
self.wrapper.run_command(["false"], check=True)
repo_wrapper.run_command(["true"], check=False)
repo_wrapper.run_command(["true"], check=True)
repo_wrapper.run_command(["false"], check=False)
with pytest.raises(subprocess.CalledProcessError):
repo_wrapper.run_command(["false"], check=True)
class RunGit(RepoWrapperTestCase):
class TestRunGit:
"""Check run_git behavior."""
def test_capture(self):
def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check capture_output handling."""
ret = self.wrapper.run_git("--version")
self.assertIn("git", ret.stdout)
ret = repo_wrapper.run_git("--version")
assert "git" in ret.stdout
def test_check(self):
def test_check(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check check handling."""
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.run_git("--version-asdfasdf")
self.wrapper.run_git("--version-asdfasdf", check=False)
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.run_git("--version-asdfasdf")
repo_wrapper.run_git("--version-asdfasdf", check=False)
class ParseGitVersion(RepoWrapperTestCase):
class TestParseGitVersion:
"""Check ParseGitVersion behavior."""
def test_autoload(self):
def test_autoload(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check we can load the version from the live git."""
ret = self.wrapper.ParseGitVersion()
self.assertIsNotNone(ret)
assert repo_wrapper.ParseGitVersion() is not None
def test_bad_ver(self):
def test_bad_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check handling of bad git versions."""
ret = self.wrapper.ParseGitVersion(ver_str="asdf")
self.assertIsNone(ret)
assert repo_wrapper.ParseGitVersion(ver_str="asdf") is None
def test_normal_ver(self):
def test_normal_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check handling of normal git versions."""
ret = self.wrapper.ParseGitVersion(ver_str="git version 2.25.1")
self.assertEqual(2, ret.major)
self.assertEqual(25, ret.minor)
self.assertEqual(1, ret.micro)
self.assertEqual("2.25.1", ret.full)
ret = repo_wrapper.ParseGitVersion(ver_str="git version 2.25.1")
assert ret.major == 2
assert ret.minor == 25
assert ret.micro == 1
assert ret.full == "2.25.1"
def test_extended_ver(self):
def test_extended_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check handling of extended distro git versions."""
ret = self.wrapper.ParseGitVersion(
ret = repo_wrapper.ParseGitVersion(
ver_str="git version 1.30.50.696.g5e7596f4ac-goog"
)
self.assertEqual(1, ret.major)
self.assertEqual(30, ret.minor)
self.assertEqual(50, ret.micro)
self.assertEqual("1.30.50.696.g5e7596f4ac-goog", ret.full)
assert ret.major == 1
assert ret.minor == 30
assert ret.micro == 50
assert ret.full == "1.30.50.696.g5e7596f4ac-goog"
class CheckGitVersion(RepoWrapperTestCase):
class TestCheckGitVersion:
"""Check _CheckGitVersion behavior."""
def test_unknown(self):
def test_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Unknown versions should abort."""
with mock.patch.object(
self.wrapper, "ParseGitVersion", return_value=None
repo_wrapper, "ParseGitVersion", return_value=None
):
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper._CheckGitVersion()
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper._CheckGitVersion()
def test_old(self):
def test_old(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Old versions should abort."""
with mock.patch.object(
self.wrapper,
repo_wrapper,
"ParseGitVersion",
return_value=self.wrapper.GitVersion(1, 0, 0, "1.0.0"),
return_value=repo_wrapper.GitVersion(1, 0, 0, "1.0.0"),
):
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper._CheckGitVersion()
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper._CheckGitVersion()
def test_new(self):
def test_new(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Newer versions should run fine."""
with mock.patch.object(
self.wrapper,
repo_wrapper,
"ParseGitVersion",
return_value=self.wrapper.GitVersion(100, 0, 0, "100.0.0"),
return_value=repo_wrapper.GitVersion(100, 0, 0, "100.0.0"),
):
self.wrapper._CheckGitVersion()
repo_wrapper._CheckGitVersion()
class Requirements(RepoWrapperTestCase):
class TestRequirements:
"""Check Requirements handling."""
def test_missing_file(self):
def test_missing_file(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Don't crash if the file is missing (old version)."""
self.assertIsNone(
self.wrapper.Requirements.from_dir(utils_for_test.THIS_DIR)
assert (
repo_wrapper.Requirements.from_dir(utils_for_test.THIS_DIR) is None
)
self.assertIsNone(
self.wrapper.Requirements.from_file(
assert (
repo_wrapper.Requirements.from_file(
utils_for_test.THIS_DIR / "xxxxxxxxxxxxxxxxxxxxxxxx"
)
is None
)
def test_corrupt_data(self):
def test_corrupt_data(self, repo_wrapper: wrapper.Wrapper) -> None:
"""If the file can't be parsed, don't blow up."""
self.assertIsNone(self.wrapper.Requirements.from_file(__file__))
self.assertIsNone(self.wrapper.Requirements.from_data(b"x"))
assert repo_wrapper.Requirements.from_file(__file__) is None
assert repo_wrapper.Requirements.from_data(b"x") is None
def test_valid_data(self):
def test_valid_data(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure we can parse the file we ship."""
self.assertIsNotNone(self.wrapper.Requirements.from_data(b"{}"))
assert repo_wrapper.Requirements.from_data(b"{}") is not None
rootdir = utils_for_test.THIS_DIR.parent
self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir))
self.assertIsNotNone(
self.wrapper.Requirements.from_file(rootdir / "requirements.json")
assert repo_wrapper.Requirements.from_dir(rootdir) is not None
assert (
repo_wrapper.Requirements.from_file(rootdir / "requirements.json")
is not None
)
def test_format_ver(self):
def test_format_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check format_ver can format."""
self.assertEqual(
"1.2.3", self.wrapper.Requirements._format_ver((1, 2, 3))
)
self.assertEqual("1", self.wrapper.Requirements._format_ver([1]))
assert repo_wrapper.Requirements._format_ver((1, 2, 3)) == "1.2.3"
assert repo_wrapper.Requirements._format_ver([1]) == "1"
def test_assert_all_unknown(self):
def test_assert_all_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all works with incompatible file."""
reqs = self.wrapper.Requirements({})
reqs = repo_wrapper.Requirements({})
reqs.assert_all()
def test_assert_all_new_repo(self):
def test_assert_all_new_repo(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all accepts new enough repo."""
reqs = self.wrapper.Requirements({"repo": {"hard": [1, 0]}})
reqs = repo_wrapper.Requirements({"repo": {"hard": [1, 0]}})
reqs.assert_all()
def test_assert_all_old_repo(self):
def test_assert_all_old_repo(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all rejects old repo."""
reqs = self.wrapper.Requirements({"repo": {"hard": [99999, 0]}})
with self.assertRaises(SystemExit):
reqs = repo_wrapper.Requirements({"repo": {"hard": [99999, 0]}})
with pytest.raises(SystemExit):
reqs.assert_all()
def test_assert_all_new_python(self):
def test_assert_all_new_python(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all accepts new enough python."""
reqs = self.wrapper.Requirements({"python": {"hard": sys.version_info}})
reqs = repo_wrapper.Requirements({"python": {"hard": sys.version_info}})
reqs.assert_all()
def test_assert_all_old_python(self):
def test_assert_all_old_python(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all rejects old python."""
reqs = self.wrapper.Requirements({"python": {"hard": [99999, 0]}})
with self.assertRaises(SystemExit):
reqs = repo_wrapper.Requirements({"python": {"hard": [99999, 0]}})
with pytest.raises(SystemExit):
reqs.assert_all()
def test_assert_ver_unknown(self):
def test_assert_ver_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_ver works with incompatible file."""
reqs = self.wrapper.Requirements({})
reqs = repo_wrapper.Requirements({})
reqs.assert_ver("xxx", (1, 0))
def test_assert_ver_new(self):
def test_assert_ver_new(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_ver allows new enough versions."""
reqs = self.wrapper.Requirements(
reqs = repo_wrapper.Requirements(
{"git": {"hard": [1, 0], "soft": [2, 0]}}
)
reqs.assert_ver("git", (1, 0))
@@ -281,274 +318,279 @@ class Requirements(RepoWrapperTestCase):
reqs.assert_ver("git", (2, 0))
reqs.assert_ver("git", (2, 5))
def test_assert_ver_old(self):
def test_assert_ver_old(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_ver rejects old versions."""
reqs = self.wrapper.Requirements(
reqs = repo_wrapper.Requirements(
{"git": {"hard": [1, 0], "soft": [2, 0]}}
)
with self.assertRaises(SystemExit):
with pytest.raises(SystemExit):
reqs.assert_ver("git", (0, 5))
class NeedSetupGnuPG(RepoWrapperTestCase):
class TestNeedSetupGnuPG:
"""Check NeedSetupGnuPG behavior."""
def test_missing_dir(self):
def test_missing_dir(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""The ~/.repoconfig tree doesn't exist yet."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = os.path.join(tempdir, "foo")
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path / "foo")
assert repo_wrapper.NeedSetupGnuPG()
def test_missing_keyring(self):
def test_missing_keyring(
self, tmp_path, repo_wrapper: wrapper.Wrapper
) -> None:
"""The keyring-version file doesn't exist yet."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
assert repo_wrapper.NeedSetupGnuPG()
def test_empty_keyring(self):
def test_empty_keyring(
self, tmp_path, repo_wrapper: wrapper.Wrapper
) -> None:
"""The keyring-version file exists, but is empty."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, "keyring-version"), "w"):
pass
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
(tmp_path / "keyring-version").write_text("")
assert repo_wrapper.NeedSetupGnuPG()
def test_old_keyring(self):
def test_old_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""The keyring-version file exists, but it's old."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, "keyring-version"), "w") as fp:
fp.write("1.0\n")
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
(tmp_path / "keyring-version").write_text("1.0\n")
assert repo_wrapper.NeedSetupGnuPG()
def test_new_keyring(self):
def test_new_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""The keyring-version file exists, and is up-to-date."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, "keyring-version"), "w") as fp:
fp.write("1000.0\n")
self.assertFalse(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
(tmp_path / "keyring-version").write_text("1000.0\n")
assert not repo_wrapper.NeedSetupGnuPG()
class SetupGnuPG(RepoWrapperTestCase):
class TestSetupGnuPG:
"""Check SetupGnuPG behavior."""
def test_full(self):
def test_full(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure it works completely."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
self.wrapper.gpg_dir = os.path.join(
self.wrapper.home_dot_repo, "gnupg"
)
self.assertTrue(self.wrapper.SetupGnuPG(True))
with open(os.path.join(tempdir, "keyring-version")) as fp:
data = fp.read()
self.assertEqual(
".".join(str(x) for x in self.wrapper.KEYRING_VERSION),
data.strip(),
)
repo_wrapper.home_dot_repo = str(tmp_path)
repo_wrapper.gpg_dir = str(tmp_path / "gnupg")
assert repo_wrapper.SetupGnuPG(True)
data = (tmp_path / "keyring-version").read_text()
assert (
".".join(str(x) for x in repo_wrapper.KEYRING_VERSION)
== data.strip()
)
class VerifyRev(RepoWrapperTestCase):
class TestVerifyRev:
"""Check verify_rev behavior."""
def test_verify_passes(self):
def test_verify_passes(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check when we have a valid signed tag."""
desc_result = subprocess.CompletedProcess([], 0, "v1.0\n", "")
gpg_result = subprocess.CompletedProcess([], 0, "", "")
with mock.patch.object(
self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
):
ret = self.wrapper.verify_rev(
ret = repo_wrapper.verify_rev(
"/", "refs/heads/stable", "1234", True
)
self.assertEqual("v1.0^0", ret)
assert ret == "v1.0^0"
def test_unsigned_commit(self):
def test_unsigned_commit(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check we fall back to signed tag when we have an unsigned commit."""
desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "")
gpg_result = subprocess.CompletedProcess([], 0, "", "")
with mock.patch.object(
self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
):
ret = self.wrapper.verify_rev(
ret = repo_wrapper.verify_rev(
"/", "refs/heads/stable", "1234", True
)
self.assertEqual("v1.0^0", ret)
assert ret == "v1.0^0"
def test_verify_fails(self):
def test_verify_fails(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check we fall back to signed tag when we have an unsigned commit."""
desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "")
gpg_result = RuntimeError
with mock.patch.object(
self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
):
with self.assertRaises(RuntimeError):
self.wrapper.verify_rev("/", "refs/heads/stable", "1234", True)
with pytest.raises(RuntimeError):
repo_wrapper.verify_rev("/", "refs/heads/stable", "1234", True)
class GitCheckoutTestCase(RepoWrapperTestCase):
"""Tests that use a real/small git checkout."""
GIT_DIR = None
REV_LIST = None
@classmethod
def setUpClass(cls):
# Create a repo to operate on, but do it once per-class.
cls.tempdirobj = tempfile.TemporaryDirectory(prefix="repo-rev-tests")
cls.GIT_DIR = cls.tempdirobj.name
run_git = wrapper.Wrapper().run_git
remote = os.path.join(cls.GIT_DIR, "remote")
os.mkdir(remote)
utils_for_test.init_git_tree(remote)
run_git("commit", "--allow-empty", "-minit", cwd=remote)
run_git("branch", "stable", cwd=remote)
run_git("tag", "v1.0", cwd=remote)
run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote)
cls.REV_LIST = run_git(
"rev-list", "HEAD", cwd=remote
).stdout.splitlines()
run_git("init", cwd=cls.GIT_DIR)
run_git(
"fetch",
remote,
"+refs/heads/*:refs/remotes/origin/*",
cwd=cls.GIT_DIR,
)
@classmethod
def tearDownClass(cls):
if not cls.tempdirobj:
return
cls.tempdirobj.cleanup()
class ResolveRepoRev(GitCheckoutTestCase):
class TestResolveRepoRev:
"""Check resolve_repo_rev behavior."""
def test_explicit_branch(self):
def test_explicit_branch(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check refs/heads/branch argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(
self.GIT_DIR, "refs/heads/stable"
rrev, lrev = repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/heads/stable"
)
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
assert rrev == "refs/heads/stable"
assert lrev == git_checkout.rev_list[1]
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/heads/unknown")
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/heads/unknown"
)
def test_explicit_tag(self):
def test_explicit_tag(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check refs/tags/tag argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(
self.GIT_DIR, "refs/tags/v1.0"
rrev, lrev = repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/tags/v1.0"
)
self.assertEqual("refs/tags/v1.0", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
assert rrev == "refs/tags/v1.0"
assert lrev == git_checkout.rev_list[1]
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/tags/unknown")
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/tags/unknown"
)
def test_branch_name(self):
def test_branch_name(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check branch argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "stable")
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
rrev, lrev = repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "stable"
)
assert rrev == "refs/heads/stable"
assert lrev == git_checkout.rev_list[1]
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "main")
self.assertEqual("refs/heads/main", rrev)
self.assertEqual(self.REV_LIST[0], lrev)
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "main")
assert rrev == "refs/heads/main"
assert lrev == git_checkout.rev_list[0]
def test_tag_name(self):
def test_tag_name(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check tag argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "v1.0")
self.assertEqual("refs/tags/v1.0", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "v1.0")
assert rrev == "refs/tags/v1.0"
assert lrev == git_checkout.rev_list[1]
def test_full_commit(self):
def test_full_commit(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check specific commit argument."""
commit = self.REV_LIST[0]
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit)
self.assertEqual(commit, rrev)
self.assertEqual(commit, lrev)
commit = git_checkout.rev_list[0]
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit)
assert rrev == commit
assert lrev == commit
def test_partial_commit(self):
def test_partial_commit(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check specific (partial) commit argument."""
commit = self.REV_LIST[0][0:20]
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit)
self.assertEqual(self.REV_LIST[0], rrev)
self.assertEqual(self.REV_LIST[0], lrev)
commit = git_checkout.rev_list[0][0:20]
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit)
assert rrev == git_checkout.rev_list[0]
assert lrev == git_checkout.rev_list[0]
def test_unknown(self):
def test_unknown(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check unknown ref/commit argument."""
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, "boooooooya")
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "boooooooya")
class CheckRepoVerify(RepoWrapperTestCase):
class TestCheckRepoVerify:
"""Check check_repo_verify behavior."""
def test_no_verify(self):
def test_no_verify(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Always fail with --no-repo-verify."""
self.assertFalse(self.wrapper.check_repo_verify(False))
assert not repo_wrapper.check_repo_verify(False)
def test_gpg_initialized(self):
def test_gpg_initialized(
self,
repo_wrapper: wrapper.Wrapper,
) -> None:
"""Should pass if gpg is setup already."""
with mock.patch.object(
self.wrapper, "NeedSetupGnuPG", return_value=False
repo_wrapper, "NeedSetupGnuPG", return_value=False
):
self.assertTrue(self.wrapper.check_repo_verify(True))
assert repo_wrapper.check_repo_verify(True)
def test_need_gpg_setup(self):
def test_need_gpg_setup(
self,
repo_wrapper: wrapper.Wrapper,
) -> None:
"""Should pass/fail based on gpg setup."""
with mock.patch.object(
self.wrapper, "NeedSetupGnuPG", return_value=True
repo_wrapper, "NeedSetupGnuPG", return_value=True
):
with mock.patch.object(self.wrapper, "SetupGnuPG") as m:
with mock.patch.object(repo_wrapper, "SetupGnuPG") as m:
m.return_value = True
self.assertTrue(self.wrapper.check_repo_verify(True))
assert repo_wrapper.check_repo_verify(True)
m.return_value = False
self.assertFalse(self.wrapper.check_repo_verify(True))
assert not repo_wrapper.check_repo_verify(True)
class CheckRepoRev(GitCheckoutTestCase):
class TestCheckRepoRev:
"""Check check_repo_rev behavior."""
def test_verify_works(self):
def test_verify_works(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Should pass when verification passes."""
with mock.patch.object(
self.wrapper, "check_repo_verify", return_value=True
repo_wrapper, "check_repo_verify", return_value=True
):
with mock.patch.object(
self.wrapper, "verify_rev", return_value="12345"
repo_wrapper, "verify_rev", return_value="12345"
):
rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, "stable")
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual("12345", lrev)
rrev, lrev = repo_wrapper.check_repo_rev(
git_checkout.git_dir, "stable"
)
assert rrev == "refs/heads/stable"
assert lrev == "12345"
def test_verify_fails(self):
def test_verify_fails(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Should fail when verification fails."""
with mock.patch.object(
self.wrapper, "check_repo_verify", return_value=True
repo_wrapper, "check_repo_verify", return_value=True
):
with mock.patch.object(
self.wrapper, "verify_rev", side_effect=RuntimeError
repo_wrapper, "verify_rev", side_effect=RuntimeError
):
with self.assertRaises(RuntimeError):
self.wrapper.check_repo_rev(self.GIT_DIR, "stable")
with pytest.raises(RuntimeError):
repo_wrapper.check_repo_rev(git_checkout.git_dir, "stable")
def test_verify_ignore(self):
def test_verify_ignore(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Should pass when verification is disabled."""
with mock.patch.object(
self.wrapper, "verify_rev", side_effect=RuntimeError
repo_wrapper, "verify_rev", side_effect=RuntimeError
):
rrev, lrev = self.wrapper.check_repo_rev(
self.GIT_DIR, "stable", repo_verify=False
rrev, lrev = repo_wrapper.check_repo_rev(
git_checkout.git_dir, "stable", repo_verify=False
)
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
assert rrev == "refs/heads/stable"
assert lrev == git_checkout.rev_list[1]