Compare commits

...

12 Commits
v2.62 ... main

Author SHA1 Message Date
Miyako.Enei
8869a30283 project: Drop --no-deref from update-ref --stdin
repo calls `git update-ref --stdin` when updating multiple refs during
repo init and repo sync. Historically, `--no-deref` was also passed.

Older Git 2.17 which we still support rejects the combination of
`--stdin` and `--no-deref`, emitting a usage error even when the stdin
input is valid.

The `--no-deref` option is only meaningful when updating symbolic refs
such as HEAD. The stdin-based update-ref path only operates on explicit
refs (tags, remote refs, alternates) and never symbolic refs.

Remove the unnecessary option to restore compatibility with Git 2.17
while preserving identical behavior on newer Git versions.

Tested with:
  - Git 2.17.1
  - Git 2.34.1

Change-Id: I22001de03800f5699b26a40bc1fb1fec002ed048
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/571721
Reviewed-by: Mike Frysinger <vapier@google.com>
Reviewed-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Enei <miyako.enei@alpsalpine.com>
Tested-by: Enei <miyako.enei@alpsalpine.com>
2026-04-12 16:39:28 -07:00
Gavin Mak
3b0eebeccf project: implement stateless sync pruning logic
Implement in-situ shallow re-fetching and garbage collection logic.
Enables repositories with sync-strategy="stateless" to reclaim disk
space by running reflog expire and git gc --prune=now if the working
tree is clean and has no local commits.

Bug: 498730431
Change-Id: I940bdc9b74da29d3f7b13566667dcddea769ebd3
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/568463
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Gavin Mak <gavinmak@google.com>
2026-04-09 12:09:40 -07:00
Gavin Mak
00991bfb42 manifest: Add sync-strategy attribute to project elements
The only supported sync-strategy is "stateless". The intent is to keep
the local workspace as small as possible by not keeping history during
syncs. This prevents disk space waste for projects with large binaries
where we only care about the current version.

A follow up change will implement the logic.

Bug: 498730431
Change-Id: I84a436a9ca2492893163c6cfda6c28dc62a568f0
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/568462
Tested-by: Gavin Mak <gavinmak@google.com>
Reviewed-by: Mike Frysinger <vapier@google.com>
Commit-Queue: Gavin Mak <gavinmak@google.com>
2026-04-09 12:09:28 -07:00
Nasser Grainawi
e8338b54bd tests: Convert forall subcmd test to pytest
Rewrite tests/test_subcmds_forall.py from unittest.TestCase to pytest
function-style tests to match the surrounding test suite conventions.

Replace setUp/tearDown and class-based helpers with tmp_path-based
setup, switch stdout capture to contextlib.redirect_stdout, and keep the
existing behavior checks intact (all eight projects are invoked exactly
once).

Change-Id: I9243f3461aa6850f867bdb864f4a34c442f817f6
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/569821
Reviewed-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Nasser Grainawi <nasser.grainawi@oss.qualcomm.com>
Tested-by: Nasser Grainawi <nasser.grainawi@oss.qualcomm.com>
Reviewed-by: Mike Frysinger <vapier@google.com>
2026-04-09 11:51:47 -07:00
Gavin Mak
951666fb23 gc: Fix hang during repack in partial clones
Add `--missing=allow-promisor` to `git rev-list` calls in
`repack_projects`. This prevents Git from auto-fetching missing objects
from the promisor remote, which can cause stalls due to sequential
network requests.

Also add a Git version check to ensure Git is at least 2.17.0 before
running `--repack`, as `--missing=allow-promisor` was introduced in that
version.

Bug: 500133631
Change-Id: I2dcf9b46fac4c6a53a3c2a46f06f61d6aec40f2f
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/570361
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Gavin Mak <gavinmak@google.com>
Reviewed-by: Sam Saccone <samccone@google.com>
2026-04-08 12:54:39 -07:00
Carlos Fernandez
854b330967 test_wrapper: add test for repo script executable permission
Add a test to verify that the repo launcher script has the
executable bit set, guarding against accidental permission changes.

Change-Id: I314658b57ed174673188fbbc5962d9fdeefac97d
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/569242
Reviewed-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Carlos Fernandez <carlosfsanz@meta.com>
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Carlos Fernandez <carlosfsanz@meta.com>
2026-04-06 14:16:04 -07:00
Mike Frysinger
654690e1b8 tests: convert more tests to pytest
Change-Id: Id4d48b61dc435564c336385bbc4944eb475d1942
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/569443
Tested-by: Mike Frysinger <vapier@google.com>
Commit-Queue: Mike Frysinger <vapier@google.com>
Reviewed-by: Gavin Mak <gavinmak@google.com>
2026-04-06 11:36:39 -07:00
Mike Frysinger
ac2be4c089 tests: convert __file__ usage to pathlib
Change-Id: I2408b0ac97629f0d5fc92779b78bf1ff159a6f83
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/569442
Reviewed-by: Gavin Mak <gavinmak@google.com>
Tested-by: Mike Frysinger <vapier@google.com>
Commit-Queue: Mike Frysinger <vapier@google.com>
2026-04-06 11:15:58 -07:00
Mike Frysinger
3d819e8e3e tests: unify fixture() helper with Path constant
Change-Id: I63751042391f5cc3e06af7067bc83d67bd0716dc
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/569441
Tested-by: Mike Frysinger <vapier@google.com>
Commit-Queue: Mike Frysinger <vapier@google.com>
Reviewed-by: Gavin Mak <gavinmak@google.com>
2026-04-06 11:15:14 -07:00
Carlos Fernandez
573983948a Fix all flake8 warnings from newer flake8-bugbear and flake8-comprehensions
Address warnings introduced by flake8-bugbear 24.12.12 and
flake8-comprehensions 3.16.0:

- C408: Replace dict()/list() calls with literal {} and []
- C413: Remove unnecessary list() around sorted()
- C414: Remove unnecessary list() inside sorted()
- C419: Suppress intentional list comprehension in all() (noqa)
- B001: Replace bare except with except Exception
- B006: Replace mutable default arguments with None
- B010: Replace setattr() with direct attribute assignment
- B017: Use RuntimeError instead of Exception in tests
- B019: Suppress lru_cache on methods for long-lived objects (noqa)
- B033: Remove duplicate item in set literal

Change-Id: If4693d3e946200bbc22f689f7b94da604addcb80
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/566321
Tested-by: Carlos Fernandez <carlosfsanz@meta.com>
Commit-Queue: Carlos Fernandez <carlosfsanz@meta.com>
Reviewed-by: Mike Frysinger <vapier@google.com>
Reviewed-by: Gavin Mak <gavinmak@google.com>
2026-04-03 07:50:52 -07:00
Gavin Mak
3f3c681a02 project: Refactor GetHead to use symbolic-ref first
Simplify branch resolution and optimize unborn branch detection by
prioritizing symbolic-ref over rev-parse.

Change-Id: Ic62dcb87cd051dafb00d520b1157be2e32abc2ab
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/563222
Reviewed-by: Nasser Grainawi <nasser.grainawi@oss.qualcomm.com>
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Gavin Mak <gavinmak@google.com>
2026-04-02 13:57:05 -07:00
Sam Saccone
242e97d9dd Implement command forgiveness with autocorrect
Similar to `git`, when a user types an unknown command like `repo tart`,
we now use `difflib.get_close_matches` to suggest similar commands.

If `help.autocorrect` is set in the git config, it will optionally
prompt the user to automatically run the assumed command, or wait
for a configured delay before executing it.

Verification Steps:
1. Created a dummy repo project locally.
2. Verified `help.autocorrect=0|false|off|no|show` suggests
   command and exits.
3. Verified `help.autocorrect=1|true|on|yes|immediate`
   automatically runs suggestion.
4. Verified `help.autocorrect=<number>` runs after
   `<number>*0.1` seconds.
5. Verified `help.autocorrect=never` exits immediately without
   suggestions.
6. Verified `help.autocorrect=prompt` asks user to accept [y/n]
   and handles correctly.

BUG: b/489753302

Change-Id: I6dcd63229cbd7badf5404459b48690c68f5b4857
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/558021
Tested-by: Sam Saccone <samccone@google.com>
Commit-Queue: Sam Saccone <samccone@google.com>
Reviewed-by: Mike Frysinger <vapier@google.com>
2026-03-24 15:47:37 -07:00
24 changed files with 2033 additions and 1403 deletions

View File

@@ -73,18 +73,19 @@ following DTD:
project*,
copyfile*,
linkfile*)>
<!ATTLIST project name CDATA #REQUIRED>
<!ATTLIST project path CDATA #IMPLIED>
<!ATTLIST project remote IDREF #IMPLIED>
<!ATTLIST project revision CDATA #IMPLIED>
<!ATTLIST project dest-branch CDATA #IMPLIED>
<!ATTLIST project groups CDATA #IMPLIED>
<!ATTLIST project sync-c CDATA #IMPLIED>
<!ATTLIST project sync-s CDATA #IMPLIED>
<!ATTLIST project sync-tags CDATA #IMPLIED>
<!ATTLIST project upstream CDATA #IMPLIED>
<!ATTLIST project clone-depth CDATA #IMPLIED>
<!ATTLIST project force-path CDATA #IMPLIED>
<!ATTLIST project name CDATA #REQUIRED>
<!ATTLIST project path CDATA #IMPLIED>
<!ATTLIST project remote IDREF #IMPLIED>
<!ATTLIST project revision CDATA #IMPLIED>
<!ATTLIST project dest-branch CDATA #IMPLIED>
<!ATTLIST project groups CDATA #IMPLIED>
<!ATTLIST project sync-c CDATA #IMPLIED>
<!ATTLIST project sync-s CDATA #IMPLIED>
<!ATTLIST project sync-tags CDATA #IMPLIED>
<!ATTLIST project upstream CDATA #IMPLIED>
<!ATTLIST project clone-depth CDATA #IMPLIED>
<!ATTLIST project force-path CDATA #IMPLIED>
<!ATTLIST project sync-strategy CDATA #IMPLIED>
<!ELEMENT annotation EMPTY>
<!ATTLIST annotation name CDATA #REQUIRED>
@@ -389,6 +390,22 @@ rather than the `name` attribute. This attribute only applies to the
local mirrors syncing, it will be ignored when syncing the projects in a
client working directory.
Attribute `sync-strategy`: Set the sync strategy used when fetching this
project. Currently the only supported value is `stateless`. When set to
`stateless`, repo will run a reflog expiration and aggressive garbage collection
at the end of the sync process. This is useful for projects that contain
large binary files and use `clone-depth="1"`, where garbage can accumulate
as binaries are added, deleted, or modified across successive syncs.
During a stateless sync, repo checks the following before cleaning up:
1. The project does not share an object directory with other projects.
2. The working tree is clean (no uncommitted changes, no untracked files).
3. There are no unpushed local commits.
4. There is no Git stash.
If any of these conditions are not met, repo falls back to a standard
sync without garbage collection.
### Element extend-project
Modify the attributes of the named project.

View File

@@ -47,7 +47,7 @@ logger = RepoLogger(__file__)
class _GitCall:
@functools.lru_cache(maxsize=None)
@functools.lru_cache(maxsize=None) # noqa: B019
def version_tuple(self):
ret = Wrapper().ParseGitVersion()
if ret is None:
@@ -95,7 +95,7 @@ def RepoSourceVersion():
ver = ver[1:]
else:
ver = "unknown"
setattr(RepoSourceVersion, "version", ver)
RepoSourceVersion.version = ver
return ver
@@ -611,7 +611,7 @@ class GitCommandError(GitError):
self.git_stderr = git_stderr
@property
@functools.lru_cache(maxsize=None)
@functools.lru_cache(maxsize=None) # noqa: B019
def suggestion(self):
"""Returns helpful next steps for the given stderr."""
if not self.git_stderr:

View File

@@ -42,7 +42,7 @@ SYNC_STATE_PREFIX = "repo.syncstate."
ID_RE = re.compile(r"^[0-9a-f]{40}$")
REVIEW_CACHE = dict()
REVIEW_CACHE = {}
def IsChange(rev):
@@ -111,7 +111,7 @@ class GitConfig:
return cls(configfile=os.path.join(gitdir, "config"), defaults=defaults)
def __init__(self, configfile, defaults=None, jsonFile=None):
self.file = configfile
self.file = str(configfile)
self.defaults = defaults
self._cache_dict = None
self._section_dict = None

126
main.py
View File

@@ -19,6 +19,7 @@ People shouldn't run this directly; instead, they should use the `repo` wrapper
which takes care of execing this entry point.
"""
import difflib
import getpass
import json
import netrc
@@ -29,6 +30,7 @@ import signal
import sys
import textwrap
import time
from typing import Optional
import urllib.request
from repo_logging import RepoLogger
@@ -292,6 +294,102 @@ class _Repo:
result = run()
return result
def _autocorrect_command_name(
self, name: str, config: RepoConfig
) -> Optional[str]:
"""Autocorrect command name based on user's git config."""
close_commands = difflib.get_close_matches(
name, self.commands.keys(), n=5, cutoff=0.7
)
if not close_commands:
logger.error(
"repo: '%s' is not a repo command. See 'repo help'.", name
)
return None
assumed = close_commands[0]
autocorrect = config.GetString("help.autocorrect")
# If there are multiple close matches, git won't automatically run one.
# We'll always prompt instead of guessing.
if len(close_commands) > 1:
autocorrect = "prompt"
# Handle git configuration boolean values:
# 0, "false", "off", "no", "show": show suggestion (default)
# 1, "true", "on", "yes", "immediate": run suggestion immediately
# "never": don't run or show any suggested command
# "prompt": show the suggestion and prompt for confirmation
# positive number > 1: run suggestion after specified deciseconds
if autocorrect is None:
autocorrect = "0"
autocorrect = autocorrect.lower()
if autocorrect in ("0", "false", "off", "no", "show"):
autocorrect = 0
elif autocorrect in ("true", "on", "yes", "immediate"):
autocorrect = -1 # immediate
elif autocorrect == "never":
return None
elif autocorrect == "prompt":
logger.warning(
"You called a repo command named "
"'%s', which does not exist.",
name,
)
try:
resp = input(f"Run '{assumed}' instead [y/N]? ")
if resp.lower().startswith("y"):
return assumed
except (KeyboardInterrupt, EOFError):
pass
return None
else:
try:
autocorrect = int(autocorrect)
except ValueError:
autocorrect = 0
if autocorrect != 0:
if autocorrect < 0:
logger.warning(
"You called a repo command named "
"'%s', which does not exist.\n"
"Continuing assuming that "
"you meant '%s'.",
name,
assumed,
)
else:
delay = autocorrect * 0.1
logger.warning(
"You called a repo command named "
"'%s', which does not exist.\n"
"Continuing in %.1f seconds, assuming "
"that you meant '%s'.",
name,
delay,
assumed,
)
try:
time.sleep(delay)
except KeyboardInterrupt:
return None
return assumed
logger.error(
"repo: '%s' is not a repo command. See 'repo help'.", name
)
logger.warning(
"The most similar command%s\n\t%s",
"s are" if len(close_commands) > 1 else " is",
"\n\t".join(close_commands),
)
return None
def _RunLong(self, name, gopts, argv, git_trace2_event_log):
"""Execute the (longer running) requested subcommand."""
result = 0
@@ -306,20 +404,22 @@ class _Repo:
outer_client=outer_client,
)
try:
cmd = self.commands[name](
repodir=self.repodir,
client=repo_client,
manifest=repo_client.manifest,
outer_client=outer_client,
outer_manifest=outer_client.manifest,
git_event_log=git_trace2_event_log,
if name not in self.commands:
corrected_name = self._autocorrect_command_name(
name, outer_client.globalConfig
)
except KeyError:
logger.error(
"repo: '%s' is not a repo command. See 'repo help'.", name
)
return 1
if not corrected_name:
return 1
name = corrected_name
cmd = self.commands[name](
repodir=self.repodir,
client=repo_client,
manifest=repo_client.manifest,
outer_client=outer_client,
outer_manifest=outer_client.manifest,
git_event_log=git_trace2_event_log,
)
Editor.globalConfig = cmd.client.globalConfig

View File

@@ -1,5 +1,5 @@
.\" DO NOT MODIFY THIS FILE! It was generated by help2man.
.TH REPO "1" "March 2026" "repo manifest" "Repo Manual"
.TH REPO "1" "April 2026" "repo manifest" "Repo Manual"
.SH NAME
repo \- repo manifest - manual page for repo manifest
.SH SYNOPSIS
@@ -165,15 +165,32 @@ IDREF #IMPLIED>
.TP
<!ATTLIST project revision
CDATA #IMPLIED>
.TP
<!ATTLIST project dest\-branch
CDATA #IMPLIED>
.TP
<!ATTLIST project groups
CDATA #IMPLIED>
.TP
<!ATTLIST project sync\-c
CDATA #IMPLIED>
.TP
<!ATTLIST project sync\-s
CDATA #IMPLIED>
.TP
<!ATTLIST project sync\-tags
CDATA #IMPLIED>
.TP
<!ATTLIST project upstream
CDATA #IMPLIED>
.TP
<!ATTLIST project clone\-depth
CDATA #IMPLIED>
.TP
<!ATTLIST project force\-path
CDATA #IMPLIED>
.IP
<!ATTLIST project dest\-branch CDATA #IMPLIED>
<!ATTLIST project groups CDATA #IMPLIED>
<!ATTLIST project sync\-c CDATA #IMPLIED>
<!ATTLIST project sync\-s CDATA #IMPLIED>
<!ATTLIST project sync\-tags CDATA #IMPLIED>
<!ATTLIST project upstream CDATA #IMPLIED>
<!ATTLIST project clone\-depth CDATA #IMPLIED>
<!ATTLIST project force\-path CDATA #IMPLIED>
<!ATTLIST project sync\-strategy CDATA #IMPLIED>
.IP
<!ELEMENT annotation EMPTY>
<!ATTLIST annotation name CDATA #REQUIRED>
@@ -469,6 +486,21 @@ mirror repository according to its `path` attribute (if supplied) rather than
the `name` attribute. This attribute only applies to the local mirrors syncing,
it will be ignored when syncing the projects in a client working directory.
.PP
Attribute `sync\-strategy`: Set the sync strategy used when fetching this
project. Currently the only supported value is `stateless`. When set to
`stateless`, repo will run a reflog expiration and aggressive garbage collection
at the end of the sync process. This is useful for projects that contain large
binary files and use `clone\-depth="1"`, where garbage can accumulate as binaries
are added, deleted, or modified across successive syncs.
.PP
During a stateless sync, repo checks the following before cleaning up: 1. The
project does not share an object directory with other projects. 2. The working
tree is clean (no uncommitted changes, no untracked files). 3. There are no
unpushed local commits. 4. There is no Git stash.
.PP
If any of these conditions are not met, repo falls back to a standard sync
without garbage collection.
.PP
Element extend\-project
.PP
Modify the attributes of the named project.

View File

@@ -759,14 +759,17 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
if p.clone_depth:
e.setAttribute("clone-depth", str(p.clone_depth))
if p.sync_strategy:
e.setAttribute("sync-strategy", str(p.sync_strategy))
self._output_manifest_project_extras(p, e)
if p.subprojects:
subprojects = {subp.name for subp in p.subprojects}
output_projects(p, e, list(sorted(subprojects)))
output_projects(p, e, sorted(subprojects))
projects = {p.name for p in self._paths.values() if not p.parent}
output_projects(None, root, list(sorted(projects)))
output_projects(None, root, sorted(projects))
if self._repo_hooks_project:
root.appendChild(doc.createTextNode(""))
@@ -823,7 +826,6 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
"submanifest",
# These are children of 'project' nodes.
"annotation",
"project",
"copyfile",
"linkfile",
}
@@ -1939,6 +1941,8 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
% (self.manifestFile, clone_depth)
)
sync_strategy = node.getAttribute("sync-strategy") or None
dest_branch = (
node.getAttribute("dest-branch") or self._default.destBranchExpr
)
@@ -1985,6 +1989,7 @@ https://gerrit.googlesource.com/git-repo/+/HEAD/docs/manifest-format.md
sync_s=sync_s,
sync_tags=sync_tags,
clone_depth=clone_depth,
sync_strategy=sync_strategy,
upstream=upstream,
parent=parent,
dest_branch=dest_branch,

View File

@@ -225,7 +225,7 @@ class ReviewableBranch:
@property
def unabbrev_commits(self):
r = dict()
r = {}
for commit in self.project.bare_git.rev_list(
not_rev(self.base), R_HEADS + self.name, "--"
):
@@ -553,11 +553,12 @@ class Project:
revisionExpr,
revisionId,
rebase=True,
groups=set(),
groups=None,
sync_c=False,
sync_s=False,
sync_tags=True,
clone_depth=None,
sync_strategy=None,
upstream=None,
parent=None,
use_git_worktrees=False,
@@ -605,11 +606,12 @@ class Project:
self.SetRevision(revisionExpr, revisionId=revisionId)
self.rebase = rebase
self.groups = groups
self.groups = groups if groups is not None else set()
self.sync_c = sync_c
self.sync_s = sync_s
self.sync_tags = sync_tags
self.clone_depth = clone_depth
self.sync_strategy = sync_strategy
self.upstream = upstream
self.parent = parent
# NB: Do not use this setting in __init__ to change behavior so that the
@@ -627,6 +629,7 @@ class Project:
self.linkfiles = {}
self.annotations = []
self.dest_branch = dest_branch
self.stateless_prune_needed = False
# This will be filled in if a project is later identified to be the
# project containing repo hooks.
@@ -756,6 +759,18 @@ class Project:
return True
return False
def HasStash(self) -> bool:
"""Returns True if there is a stash in the repository."""
p = GitCommand(
self,
["rev-parse", "--verify", "refs/stash"],
bare=True,
capture_stdout=True,
capture_stderr=True,
log_as_error=False,
)
return p.Wait() == 0
_userident_name = None
_userident_email = None
@@ -943,7 +958,7 @@ class Project:
out.important("prior sync failed; rebase still in progress")
out.nl()
paths = list()
paths = []
paths.extend(di.keys())
paths.extend(df.keys())
paths.extend(do)
@@ -1239,6 +1254,67 @@ class Project:
logger.error("error: Cannot extract archive %s: %s", tarpath, e)
return False
def _ShouldStatelessPrune(
self, use_superproject: Optional[bool] = None
) -> bool:
"""Determines if a stateless prune should be performed.
Stateless pruning reclaims space by running a reflog expiration and
garbage collection instead of an incremental fetch. It is only performed
if the repository is clean and has no local-only state.
"""
if not self.Exists:
return False
if self._CheckForImmutableRevision(use_superproject=use_superproject):
return False
# Query the target hash from remote to see if we are up-to-date.
target_hash = None
if IsId(self.revisionExpr):
target_hash = self.revisionExpr
else:
output = self._LsRemote(self.upstream or self.revisionExpr)
if output:
target_hash = output.splitlines()[0].split()[0]
if not target_hash:
return False
try:
local_head = self.bare_git.rev_parse("HEAD")
except GitError:
local_head = None
if target_hash == local_head:
return False
# Skip if sharing objects with other projects.
shares_objdir = self.UseAlternates or self.use_git_worktrees
if not shares_objdir:
for p in self.manifest.GetProjectsWithName(self.name):
if p != self and p.objdir == self.objdir:
shares_objdir = True
break
if shares_objdir:
return False
# Skip if HEAD contains any unpushed local commits.
try:
local_commits = self.bare_git.rev_list(
"--count", "HEAD", "--not", "--remotes", "--tags"
)
if int(local_commits[0]) > 0:
return False
except (GitError, IndexError, ValueError):
return False
if self.IsDirty(consider_untracked=True) or self.HasStash():
return False
return True
def Sync_NetworkHalf(
self,
quiet=False,
@@ -1257,7 +1333,7 @@ class Project:
submodules=False,
ssh_proxy=None,
clone_filter=None,
partial_clone_exclude=set(),
partial_clone_exclude=None,
clone_filter_for_depth=None,
):
"""Perform only the network IO portion of the sync process.
@@ -1310,10 +1386,17 @@ class Project:
if clone_bundle and os.path.exists(self.objdir):
clone_bundle = False
if partial_clone_exclude is None:
partial_clone_exclude = set()
if self.name in partial_clone_exclude:
clone_bundle = True
clone_filter = None
if self.sync_strategy == "stateless" and self._ShouldStatelessPrune(
use_superproject
):
self.stateless_prune_needed = True
if is_new is None:
is_new = not self.Exists
if is_new:
@@ -1598,6 +1681,23 @@ class Project:
def _dosubmodules():
self._SyncSubmodules(quiet=True)
def _doprune() -> None:
"""Expire reflogs and run prune-now GC for stateless sync."""
GitCommand(
self,
["reflog", "expire", "--expire=all", "--all"],
bare=True,
).Wait()
p = GitCommand(
self,
["gc", "--prune=now"],
bare=True,
capture_stdout=True,
capture_stderr=True,
)
if p.Wait() != 0:
logger.warning("warn: %s: stateless gc failed", self.name)
head = self.work_git.GetHead()
if head.startswith(R_HEADS):
branch = head[len(R_HEADS) :]
@@ -1643,6 +1743,8 @@ class Project:
fail(e)
return
self._CopyAndLinkFiles()
if self.stateless_prune_needed:
syncbuf.later2(self, _doprune, not verbose)
return
if head == revid:
@@ -1789,6 +1891,9 @@ class Project:
if submodules:
syncbuf.later1(self, _dosubmodules, not verbose)
if self.stateless_prune_needed:
syncbuf.later2(self, _doprune, not verbose)
def AddCopyFile(self, src, dest, topdir):
"""Mark |src| for copying to |dest| (relative to |topdir|).
@@ -2568,7 +2673,7 @@ class Project:
if update_ref_cmds:
GitCommand(
self,
["update-ref", "--no-deref", "--stdin"],
["update-ref", "--stdin"],
bare=True,
input="".join(update_ref_cmds),
).Wait()
@@ -2816,7 +2921,7 @@ class Project:
)
GitCommand(
self,
["update-ref", "--no-deref", "--stdin"],
["update-ref", "--stdin"],
bare=True,
input=delete_cmds,
log_as_error=False,
@@ -3964,30 +4069,14 @@ class Project:
def GetHead(self):
"""Return the ref that HEAD points to."""
try:
symbolic_head = self.rev_parse("--symbolic-full-name", HEAD)
if symbolic_head == HEAD:
# Detached HEAD. Return the commit SHA instead.
return self.rev_parse(HEAD)
return symbolic_head
except GitError as e:
# `git rev-parse --symbolic-full-name HEAD` will fail for unborn
# branches, so try symbolic-ref before falling back to raw file
# parsing.
try:
p = GitCommand(
self._project,
["symbolic-ref", "-q", HEAD],
bare=True,
gitdir=self._gitdir,
capture_stdout=True,
capture_stderr=True,
log_as_error=False,
)
if p.Wait() == 0:
return p.stdout.rstrip("\n")
except GitError:
pass
return self.symbolic_ref("-q", HEAD, log_as_error=False)
except GitError:
pass
try:
# If symbolic-ref fails, try to treat as detached HEAD.
return self.rev_parse(HEAD)
except GitError as e:
logger.warning(
"project %s: unparseable HEAD; trying to recover.\n"
"Check that HEAD ref in .git/HEAD is valid. The error "

View File

@@ -106,7 +106,7 @@ def check_path(opts: argparse.Namespace, path: Path) -> bool:
def check_paths(opts: argparse.Namespace, paths: list[Path]) -> bool:
"""Check all the paths."""
# NB: Use list comprehension and not a generator so we check all paths.
return all([check_path(opts, x) for x in paths])
return all([check_path(opts, x) for x in paths]) # noqa: C419
def find_files(opts: argparse.Namespace) -> list[Path]:

View File

@@ -48,10 +48,10 @@ wheel: <
version: "version:3.0.7"
>
# Required by pytest==8.3.4
# Required by pytest==8.3.4 and flake8-bugbear==24.12.12
wheel: <
name: "infra/python/wheels/attrs-py2_py3"
version: "version:21.4.0"
name: "infra/python/wheels/attrs-py3"
version: "version:24.2.0"
>
# NB: Keep in sync with constraints.txt.
@@ -119,6 +119,16 @@ wheel: <
version: "version:2.10.0"
>
wheel: <
name: "infra/python/wheels/flake8-bugbear-py3"
version: "version:24.12.12"
>
wheel: <
name: "infra/python/wheels/flake8-comprehensions-py3"
version: "version:3.16.0"
>
wheel: <
name: "infra/python/wheels/isort-py3"
version: "version:5.10.1"

2
ssh.py
View File

@@ -149,7 +149,7 @@ class ProxyManager:
while True:
try:
procs.pop(0)
except: # noqa: E722
except IndexError:
break
def close(self):

View File

@@ -16,6 +16,7 @@ import os
from typing import List, Set
from command import Command
from git_command import git_require
from git_command import GitCommand
import platform_utils
from progress import Progress
@@ -204,6 +205,7 @@ class Gc(Command):
[
"rev-list",
"--objects",
"--missing=allow-promisor",
f"--remotes={project.remote.name}",
"--filter=blob:none",
"--tags",
@@ -215,7 +217,12 @@ class Gc(Command):
# Get all local objects and pack them.
local_head_objects_cmd = GitCommand(
project,
["rev-list", "--objects", "HEAD^{tree}"],
[
"rev-list",
"--objects",
"--missing=allow-promisor",
"HEAD^{tree}",
],
capture_stdout=True,
verify_command=True,
)
@@ -224,6 +231,7 @@ class Gc(Command):
[
"rev-list",
"--objects",
"--missing=allow-promisor",
"--all",
"--reflog",
"--indexed-objects",
@@ -297,7 +305,8 @@ class Gc(Command):
if ret != 0:
return ret
if not opt.repack:
return
if opt.repack:
git_require((2, 17, 0), fail=True, msg="--repack")
ret = self.repack_projects(projects, opt)
return self.repack_projects(projects, opt)
return ret

View File

@@ -93,7 +93,7 @@ contain a line that matches both expressions:
pt = getattr(parser.values, "cmd_argv", None)
if pt is None:
pt = []
setattr(parser.values, "cmd_argv", pt)
parser.values.cmd_argv = pt
if opt_str == "-(":
pt.append("(")

View File

@@ -59,7 +59,7 @@ Displays detailed usage information about a command.
def PrintAllCommandsBody(self):
print("The complete list of recognized repo commands is:")
commandNames = list(sorted(all_commands))
commandNames = sorted(all_commands)
self._PrintCommands(commandNames)
print(
"See 'repo help <command>' for more information on a "
@@ -74,10 +74,8 @@ Displays detailed usage information about a command.
def PrintCommonCommandsBody(self):
print("The most commonly used repo commands are:")
commandNames = list(
sorted(
name for name, command in all_commands.items() if command.COMMON
)
commandNames = sorted(
name for name, command in all_commands.items() if command.COMMON
)
self._PrintCommands(commandNames)

View File

@@ -947,7 +947,7 @@ later is required to fix a server side protocol bug.
"sync_dict"
] = multiprocessing.Manager().dict()
objdir_project_map = dict()
objdir_project_map = {}
for index, project in enumerate(projects):
objdir_project_map.setdefault(project.objdir, []).append(index)
projects_list = list(objdir_project_map.values())
@@ -2657,7 +2657,7 @@ later is required to fix a server side protocol bug.
if previously_pending_relpaths == pending_relpaths:
stalled_projects_str = "\n".join(
f" - {path}"
for path in sorted(list(pending_relpaths))
for path in sorted(pending_relpaths)
)
logger.error(
"The following projects failed and could "

View File

@@ -14,23 +14,17 @@
"""Unittests for the color.py module."""
import os
import pytest
import utils_for_test
import color
import git_config
def fixture(*paths: str) -> str:
"""Return a path relative to test/fixtures."""
return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
@pytest.fixture
def coloring() -> color.Coloring:
"""Create a Coloring object for testing."""
config_fixture = fixture("test.gitconfig")
config_fixture = utils_for_test.FIXTURES_DIR / "test.gitconfig"
config = git_config.GitConfig(config_fixture)
color.SetDefaultColoring("true")
return color.Coloring(config, "status")

View File

@@ -14,24 +14,19 @@
"""Unittests for the git_config.py module."""
import os
from pathlib import Path
from typing import Any
import pytest
import utils_for_test
import git_config
def fixture_path(*paths: str) -> str:
"""Return a path relative to test/fixtures."""
return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
@pytest.fixture
def readonly_config() -> git_config.GitConfig:
"""Create a GitConfig object using the test.gitconfig fixture."""
config_fixture = fixture_path("test.gitconfig")
config_fixture = utils_for_test.FIXTURES_DIR / "test.gitconfig"
return git_config.GitConfig(config_fixture)
@@ -63,7 +58,7 @@ def test_get_string_with_true_value(
def test_get_string_from_missing_file() -> None:
"""Test missing config file."""
config_fixture = fixture_path("not.present.gitconfig")
config_fixture = utils_for_test.FIXTURES_DIR / "not.present.gitconfig"
config = git_config.GitConfig(config_fixture)
val = config.GetString("empty")
assert val is None

View File

@@ -18,17 +18,24 @@ import contextlib
import io
import json
import os
import re
import socket
import tempfile
import threading
import unittest
from typing import Any, Dict, List, Optional
from unittest import mock
import pytest
import git_trace2_event_log
import platform_utils
def serverLoggingThread(socket_path, server_ready, received_traces):
def server_logging_thread(
socket_path: str,
server_ready: threading.Condition,
received_traces: List[str],
) -> None:
"""Helper function to receive logs over a Unix domain socket.
Appends received messages on the provided socket and appends to
@@ -57,405 +64,425 @@ def serverLoggingThread(socket_path, server_ready, received_traces):
received_traces.extend(data.decode("utf-8").splitlines())
class EventLogTestCase(unittest.TestCase):
"""TestCase for the EventLog module."""
PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID"
PARENT_SID_VALUE = "parent_sid"
SELF_SID_REGEX = r"repo-\d+T\d+Z-.*"
FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}"
PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID"
PARENT_SID_VALUE = "parent_sid"
SELF_SID_REGEX = r"repo-\d+T\d+Z-.*"
FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}"
def setUp(self):
"""Load the event_log module every time."""
self._event_log = None
# By default we initialize with the expected case where
# repo launches us (so GIT_TRACE2_PARENT_SID is set).
env = {
self.PARENT_SID_KEY: self.PARENT_SID_VALUE,
}
self._event_log = git_trace2_event_log.EventLog(env=env)
self._log_data = None
@pytest.fixture
def event_log() -> git_trace2_event_log.EventLog:
"""Fixture for the EventLog module."""
# By default we initialize with the expected case where
# repo launches us (so GIT_TRACE2_PARENT_SID is set).
env = {PARENT_SID_KEY: PARENT_SID_VALUE}
return git_trace2_event_log.EventLog(env=env)
def verifyCommonKeys(
self, log_entry, expected_event_name=None, full_sid=True
def verify_common_keys(
log_entry: Dict[str, Any],
expected_event_name: Optional[str] = None,
full_sid: bool = True,
) -> None:
"""Helper function to verify common event log keys."""
assert "event" in log_entry
assert "sid" in log_entry
assert "thread" in log_entry
assert "time" in log_entry
# Do basic data format validation.
if expected_event_name:
assert expected_event_name == log_entry["event"]
if full_sid:
assert re.match(FULL_SID_REGEX, log_entry["sid"])
else:
assert re.match(SELF_SID_REGEX, log_entry["sid"])
assert re.match(r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$", log_entry["time"])
def read_log(log_path: str) -> List[Dict[str, Any]]:
"""Helper function to read log data into a list."""
log_data = []
with open(log_path, mode="rb") as f:
for line in f:
log_data.append(json.loads(line))
return log_data
def remove_prefix(s: str, prefix: str) -> str:
"""Return a copy string after removing |prefix| from |s|, if present or
the original string."""
if s.startswith(prefix):
return s[len(prefix) :]
else:
return s
def test_initial_state_with_parent_sid(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent."""
assert re.match(FULL_SID_REGEX, event_log.full_sid)
def test_initial_state_no_parent_sid() -> None:
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is not set."""
# Setup an empty environment dict (no parent sid).
event_log = git_trace2_event_log.EventLog(env={})
assert re.match(SELF_SID_REGEX, event_log.full_sid)
def test_version_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test 'version' event data is valid.
Verify that the 'version' event is written even when no other
events are added.
Expected event log:
<version event>
"""
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
# A log with no added events should only have the version entry.
assert len(log_data) == 1
version_event = log_data[0]
verify_common_keys(version_event, expected_event_name="version")
# Check for 'version' event specific fields.
assert "evt" in version_event
assert "exe" in version_event
# Verify "evt" version field is a string.
assert isinstance(version_event["evt"], str)
def test_start_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test and validate 'start' event data is valid.
Expected event log:
<version event>
<start event>
"""
event_log.StartEvent([])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
start_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
assert "argv" in start_event
assert isinstance(start_event["argv"], list)
def test_exit_event_result_none(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'exit' event data is valid when result is None.
We expect None result to be converted to 0 in the exit event data.
Expected event log:
<version event>
<exit event>
"""
event_log.ExitEvent(None)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
exit_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
assert "code" in exit_event
# 'None' result should convert to 0 (successful) return code.
assert exit_event["code"] == 0
def test_exit_event_result_integer(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'exit' event data is valid when result is an integer.
Expected event log:
<version event>
<exit event>
"""
event_log.ExitEvent(2)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
exit_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
assert "code" in exit_event
assert exit_event["code"] == 2
def test_command_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test and validate 'command' event data is valid.
Expected event log:
<version event>
<command event>
"""
event_log.CommandEvent(name="repo", subcommands=["init", "this"])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
command_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(command_event, expected_event_name="cmd_name")
# Check for 'command' event specific fields.
assert "name" in command_event
assert command_event["name"] == "repo-init-this"
def test_def_params_event_repo_config(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'def_params' event data outputs only repo config keys.
Expected event log:
<version event>
<def_param event>
<def_param event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "true",
"repo.partialclonefilter": "blob:none",
}
event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 3
def_param_events = log_data[1:]
verify_common_keys(log_data[0], expected_event_name="version")
for event in def_param_events:
verify_common_keys(event, expected_event_name="def_param")
# Check for 'def_param' event specific fields.
assert "param" in event
assert "value" in event
assert event["param"].startswith("repo.")
def test_def_params_event_no_repo_config(
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test 'def_params' event data won't output non-repo config keys.
Expected event log:
<version event>
"""
config = {
"git.foo": "bar",
"git.core.foo2": "baz",
}
event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 1
verify_common_keys(log_data[0], expected_event_name="version")
def test_data_event_config(event_log: git_trace2_event_log.EventLog) -> None:
"""Test 'data' event data outputs all config keys.
Expected event log:
<version event>
<data event>
<data event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "false",
"repo.syncstate.superproject.hassuperprojecttag": "true",
"repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"],
}
prefix_value = "prefix"
event_log.LogDataConfigEvents(config, prefix_value)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 5
data_events = log_data[1:]
verify_common_keys(log_data[0], expected_event_name="version")
for event in data_events:
verify_common_keys(event)
# Check for 'data' event specific fields.
assert "key" in event
assert "value" in event
key = event["key"]
key = remove_prefix(key, f"{prefix_value}/")
value = event["value"]
assert event_log.GetDataEventName(value) == event["event"]
assert key in config
assert value == config[key]
def test_error_event(event_log: git_trace2_event_log.EventLog) -> None:
"""Test and validate 'error' event data is valid.
Expected event log:
<version event>
<error event>
"""
msg = "invalid option: --cahced"
fmt = "invalid option: %s"
event_log.ErrorEvent(msg, fmt)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = event_log.Write(path=tempdir)
log_data = read_log(log_path)
assert len(log_data) == 2
error_event = log_data[1]
verify_common_keys(log_data[0], expected_event_name="version")
verify_common_keys(error_event, expected_event_name="error")
# Check for 'error' event specific fields.
assert "msg" in error_event
assert "fmt" in error_event
assert error_event["msg"] == f"RepoErrorEvent:{msg}"
assert error_event["fmt"] == f"RepoErrorEvent:{fmt}"
def test_write_with_filename(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with a path to a file exits with None."""
assert event_log.Write(path="path/to/file") is None
def test_write_with_git_config(
tmp_path,
event_log: git_trace2_event_log.EventLog,
) -> None:
"""Test Write() uses the git config path when 'git config' call succeeds."""
with mock.patch.object(
event_log,
"_GetEventTargetPath",
return_value=str(tmp_path),
):
"""Helper function to verify common event log keys."""
self.assertIn("event", log_entry)
self.assertIn("sid", log_entry)
self.assertIn("thread", log_entry)
self.assertIn("time", log_entry)
assert os.path.dirname(event_log.Write()) == str(tmp_path)
# Do basic data format validation.
if expected_event_name:
self.assertEqual(expected_event_name, log_entry["event"])
if full_sid:
self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX)
else:
self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX)
self.assertRegex(
log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$"
def test_write_no_git_config(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with no git config variable present exits with None."""
with mock.patch.object(event_log, "_GetEventTargetPath", return_value=None):
assert event_log.Write() is None
def test_write_non_string(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with non-string type for |path| throws TypeError."""
with pytest.raises(TypeError):
event_log.Write(path=1234)
@pytest.mark.skipif(
not hasattr(socket, "AF_UNIX"), reason="Requires AF_UNIX sockets"
)
def test_write_socket(event_log: git_trace2_event_log.EventLog) -> None:
"""Test Write() with Unix domain socket and validate received traces."""
received_traces: List[str] = []
with tempfile.TemporaryDirectory(prefix="test_server_sockets") as tempdir:
socket_path = os.path.join(tempdir, "server.sock")
server_ready = threading.Condition()
# Start "server" listening on Unix domain socket at socket_path.
server_thread = threading.Thread(
target=server_logging_thread,
args=(socket_path, server_ready, received_traces),
)
try:
server_thread.start()
def readLog(self, log_path):
"""Helper function to read log data into a list."""
log_data = []
with open(log_path, mode="rb") as f:
for line in f:
log_data.append(json.loads(line))
return log_data
with server_ready:
server_ready.wait(timeout=120)
def remove_prefix(self, s, prefix):
"""Return a copy string after removing |prefix| from |s|, if present or
the original string."""
if s.startswith(prefix):
return s[len(prefix) :]
else:
return s
event_log.StartEvent([])
path = event_log.Write(path=f"af_unix:{socket_path}")
finally:
server_thread.join(timeout=5)
def test_initial_state_with_parent_sid(self):
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent."""
self.assertRegex(self._event_log.full_sid, self.FULL_SID_REGEX)
def test_initial_state_no_parent_sid(self):
"""Test initial state when 'GIT_TRACE2_PARENT_SID' is not set."""
# Setup an empty environment dict (no parent sid).
self._event_log = git_trace2_event_log.EventLog(env={})
self.assertRegex(self._event_log.full_sid, self.SELF_SID_REGEX)
def test_version_event(self):
"""Test 'version' event data is valid.
Verify that the 'version' event is written even when no other
events are addded.
Expected event log:
<version event>
"""
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
# A log with no added events should only have the version entry.
self.assertEqual(len(self._log_data), 1)
version_event = self._log_data[0]
self.verifyCommonKeys(version_event, expected_event_name="version")
# Check for 'version' event specific fields.
self.assertIn("evt", version_event)
self.assertIn("exe", version_event)
# Verify "evt" version field is a string.
self.assertIsInstance(version_event["evt"], str)
def test_start_event(self):
"""Test and validate 'start' event data is valid.
Expected event log:
<version event>
<start event>
"""
self._event_log.StartEvent([])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
start_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
self.assertIn("argv", start_event)
self.assertTrue(isinstance(start_event["argv"], list))
def test_exit_event_result_none(self):
"""Test 'exit' event data is valid when result is None.
We expect None result to be converted to 0 in the exit event data.
Expected event log:
<version event>
<exit event>
"""
self._event_log.ExitEvent(None)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
exit_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
self.assertIn("code", exit_event)
# 'None' result should convert to 0 (successful) return code.
self.assertEqual(exit_event["code"], 0)
def test_exit_event_result_integer(self):
"""Test 'exit' event data is valid when result is an integer.
Expected event log:
<version event>
<exit event>
"""
self._event_log.ExitEvent(2)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
exit_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields.
self.assertIn("code", exit_event)
self.assertEqual(exit_event["code"], 2)
def test_command_event(self):
"""Test and validate 'command' event data is valid.
Expected event log:
<version event>
<command event>
"""
self._event_log.CommandEvent(name="repo", subcommands=["init", "this"])
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
command_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(command_event, expected_event_name="cmd_name")
# Check for 'command' event specific fields.
self.assertIn("name", command_event)
self.assertEqual(command_event["name"], "repo-init-this")
def test_def_params_event_repo_config(self):
"""Test 'def_params' event data outputs only repo config keys.
Expected event log:
<version event>
<def_param event>
<def_param event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "true",
"repo.partialclonefilter": "blob:none",
}
self._event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 3)
def_param_events = self._log_data[1:]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
for event in def_param_events:
self.verifyCommonKeys(event, expected_event_name="def_param")
# Check for 'def_param' event specific fields.
self.assertIn("param", event)
self.assertIn("value", event)
self.assertTrue(event["param"].startswith("repo."))
def test_def_params_event_no_repo_config(self):
"""Test 'def_params' event data won't output non-repo config keys.
Expected event log:
<version event>
"""
config = {
"git.foo": "bar",
"git.core.foo2": "baz",
}
self._event_log.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 1)
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
def test_data_event_config(self):
"""Test 'data' event data outputs all config keys.
Expected event log:
<version event>
<data event>
<data event>
"""
config = {
"git.foo": "bar",
"repo.partialclone": "false",
"repo.syncstate.superproject.hassuperprojecttag": "true",
"repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"],
}
prefix_value = "prefix"
self._event_log.LogDataConfigEvents(config, prefix_value)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 5)
data_events = self._log_data[1:]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
for event in data_events:
self.verifyCommonKeys(event)
# Check for 'data' event specific fields.
self.assertIn("key", event)
self.assertIn("value", event)
key = event["key"]
key = self.remove_prefix(key, f"{prefix_value}/")
value = event["value"]
self.assertEqual(
self._event_log.GetDataEventName(value), event["event"]
)
self.assertTrue(key in config and value == config[key])
def test_error_event(self):
"""Test and validate 'error' event data is valid.
Expected event log:
<version event>
<error event>
"""
msg = "invalid option: --cahced"
fmt = "invalid option: %s"
self._event_log.ErrorEvent(msg, fmt)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log.Write(path=tempdir)
self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2)
error_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(error_event, expected_event_name="error")
# Check for 'error' event specific fields.
self.assertIn("msg", error_event)
self.assertIn("fmt", error_event)
self.assertEqual(error_event["msg"], f"RepoErrorEvent:{msg}")
self.assertEqual(error_event["fmt"], f"RepoErrorEvent:{fmt}")
def test_write_with_filename(self):
"""Test Write() with a path to a file exits with None."""
self.assertIsNone(self._event_log.Write(path="path/to/file"))
def test_write_with_git_config(self):
"""Test Write() uses the git config path when 'git config' call
succeeds."""
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
with mock.patch.object(
self._event_log,
"_GetEventTargetPath",
return_value=tempdir,
):
self.assertEqual(
os.path.dirname(self._event_log.Write()), tempdir
)
def test_write_no_git_config(self):
"""Test Write() with no git config variable present exits with None."""
with mock.patch.object(
self._event_log, "_GetEventTargetPath", return_value=None
):
self.assertIsNone(self._event_log.Write())
def test_write_non_string(self):
"""Test Write() with non-string type for |path| throws TypeError."""
with self.assertRaises(TypeError):
self._event_log.Write(path=1234)
@unittest.skipIf(not hasattr(socket, "AF_UNIX"), "Requires AF_UNIX sockets")
def test_write_socket(self):
"""Test Write() with Unix domain socket for |path| and validate received
traces."""
received_traces = []
with tempfile.TemporaryDirectory(
prefix="test_server_sockets"
) as tempdir:
socket_path = os.path.join(tempdir, "server.sock")
server_ready = threading.Condition()
# Start "server" listening on Unix domain socket at socket_path.
server_thread = threading.Thread(
target=serverLoggingThread,
args=(socket_path, server_ready, received_traces),
)
try:
server_thread.start()
with server_ready:
server_ready.wait(timeout=120)
self._event_log.StartEvent([])
path = self._event_log.Write(path=f"af_unix:{socket_path}")
finally:
server_thread.join(timeout=5)
self.assertEqual(path, f"af_unix:stream:{socket_path}")
self.assertEqual(len(received_traces), 2)
version_event = json.loads(received_traces[0])
start_event = json.loads(received_traces[1])
self.verifyCommonKeys(version_event, expected_event_name="version")
self.verifyCommonKeys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
self.assertIn("argv", start_event)
self.assertIsInstance(start_event["argv"], list)
assert path == f"af_unix:stream:{socket_path}"
assert len(received_traces) == 2
version_event = json.loads(received_traces[0])
start_event = json.loads(received_traces[1])
verify_common_keys(version_event, expected_event_name="version")
verify_common_keys(start_event, expected_event_name="start")
# Check for 'start' event specific fields.
assert "argv" in start_event
assert isinstance(start_event["argv"], list)
class EventLogVerboseTestCase(unittest.TestCase):
class TestEventLogVerbose:
"""TestCase for the EventLog module verbose logging."""
def setUp(self):
self._event_log = git_trace2_event_log.EventLog(env={})
def test_write_socket_error_no_verbose(self):
def test_write_socket_error_no_verbose(self) -> None:
"""Test Write() suppression of socket errors when not verbose."""
self._event_log.verbose = False
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = False
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch("socket.socket", side_effect=OSError):
self._event_log.Write(path="af_unix:stream:/tmp/test_sock")
self.assertEqual(mock_stderr.getvalue(), "")
event_log.Write(path="af_unix:stream:/tmp/test_sock")
assert mock_stderr.getvalue() == ""
def test_write_socket_error_verbose(self):
def test_write_socket_error_verbose(self) -> None:
"""Test Write() printing of socket errors when verbose."""
self._event_log.verbose = True
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = True
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch(
"socket.socket", side_effect=OSError("Mock error")
):
self._event_log.Write(path="af_unix:stream:/tmp/test_sock")
self.assertIn(
"git trace2 logging failed: Mock error",
mock_stderr.getvalue(),
event_log.Write(path="af_unix:stream:/tmp/test_sock")
assert (
"git trace2 logging failed: Mock error"
in mock_stderr.getvalue()
)
def test_write_file_error_no_verbose(self):
def test_write_file_error_no_verbose(self) -> None:
"""Test Write() suppression of file errors when not verbose."""
self._event_log.verbose = False
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = False
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch(
"tempfile.NamedTemporaryFile", side_effect=FileExistsError
):
self._event_log.Write(path="/tmp")
self.assertEqual(mock_stderr.getvalue(), "")
event_log.Write(path="/tmp")
assert mock_stderr.getvalue() == ""
def test_write_file_error_verbose(self):
def test_write_file_error_verbose(self) -> None:
"""Test Write() printing of file errors when verbose."""
self._event_log.verbose = True
event_log = git_trace2_event_log.EventLog(env={})
event_log.verbose = True
with contextlib.redirect_stderr(
io.StringIO()
) as mock_stderr, mock.patch(
"tempfile.NamedTemporaryFile",
side_effect=FileExistsError("Mock error"),
):
self._event_log.Write(path="/tmp")
self.assertIn(
"git trace2 logging failed: FileExistsError",
mock_stderr.getvalue(),
event_log.Write(path="/tmp")
assert (
"git trace2 logging failed: FileExistsError"
in mock_stderr.getvalue()
)

166
tests/test_main.py Normal file
View File

@@ -0,0 +1,166 @@
# Copyright (C) 2026 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the main repo script and subcommand routing."""
from unittest import mock
import pytest
from main import _Repo
@pytest.fixture(name="repo")
def fixture_repo():
repo = _Repo("repodir")
# Overriding the command list here ensures that we are only testing
# against a fixed set of commands, reducing fragility to new
# subcommands being added to the main repo tool.
repo.commands = {"start": None, "sync": None, "smart": None}
return repo
@pytest.fixture(name="mock_config")
def fixture_mock_config():
return mock.MagicMock()
@mock.patch("time.sleep")
def test_autocorrect_delay(mock_sleep, repo, mock_config):
"""Test autocorrect with positive delay."""
mock_config.GetString.return_value = "10"
res = repo._autocorrect_command_name("tart", mock_config)
mock_config.GetString.assert_called_with("help.autocorrect")
mock_sleep.assert_called_with(1.0)
assert res == "start"
@mock.patch("time.sleep")
def test_autocorrect_delay_one(mock_sleep, repo, mock_config):
"""Test autocorrect with '1' (0.1s delay, not immediate)."""
mock_config.GetString.return_value = "1"
res = repo._autocorrect_command_name("tart", mock_config)
mock_sleep.assert_called_with(0.1)
assert res == "start"
@mock.patch("time.sleep", side_effect=KeyboardInterrupt())
def test_autocorrect_delay_interrupt(mock_sleep, repo, mock_config):
"""Test autocorrect handles KeyboardInterrupt during delay."""
mock_config.GetString.return_value = "10"
res = repo._autocorrect_command_name("tart", mock_config)
mock_sleep.assert_called_with(1.0)
assert res is None
@mock.patch("time.sleep")
def test_autocorrect_immediate(mock_sleep, repo, mock_config):
"""Test autocorrect with immediate/negative delay."""
# Test numeric negative.
mock_config.GetString.return_value = "-1"
res = repo._autocorrect_command_name("tart", mock_config)
mock_sleep.assert_not_called()
assert res == "start"
# Test string boolean "true".
mock_config.GetString.return_value = "true"
res = repo._autocorrect_command_name("tart", mock_config)
mock_sleep.assert_not_called()
assert res == "start"
# Test string boolean "yes".
mock_config.GetString.return_value = "YES"
res = repo._autocorrect_command_name("tart", mock_config)
mock_sleep.assert_not_called()
assert res == "start"
# Test string boolean "immediate".
mock_config.GetString.return_value = "Immediate"
res = repo._autocorrect_command_name("tart", mock_config)
mock_sleep.assert_not_called()
assert res == "start"
def test_autocorrect_zero_or_show(repo, mock_config):
"""Test autocorrect with zero delay (suggestions only)."""
# Test numeric zero.
mock_config.GetString.return_value = "0"
res = repo._autocorrect_command_name("tart", mock_config)
assert res is None
# Test string boolean "false".
mock_config.GetString.return_value = "False"
res = repo._autocorrect_command_name("tart", mock_config)
assert res is None
# Test string boolean "show".
mock_config.GetString.return_value = "show"
res = repo._autocorrect_command_name("tart", mock_config)
assert res is None
def test_autocorrect_never(repo, mock_config):
"""Test autocorrect with 'never'."""
mock_config.GetString.return_value = "never"
res = repo._autocorrect_command_name("tart", mock_config)
assert res is None
@mock.patch("builtins.input", return_value="y")
def test_autocorrect_prompt_yes(mock_input, repo, mock_config):
"""Test autocorrect with prompt and user answers yes."""
mock_config.GetString.return_value = "prompt"
res = repo._autocorrect_command_name("tart", mock_config)
assert res == "start"
@mock.patch("builtins.input", return_value="n")
def test_autocorrect_prompt_no(mock_input, repo, mock_config):
"""Test autocorrect with prompt and user answers no."""
mock_config.GetString.return_value = "prompt"
res = repo._autocorrect_command_name("tart", mock_config)
assert res is None
@mock.patch("builtins.input", return_value="y")
def test_autocorrect_multiple_candidates(mock_input, repo, mock_config):
"""Test autocorrect with multiple matches forces a prompt."""
mock_config.GetString.return_value = "10" # Normally just delay
# 'snart' matches both 'start' and 'smart' with > 0.7 ratio
res = repo._autocorrect_command_name("snart", mock_config)
# Because there are multiple candidates, it should prompt
mock_input.assert_called_once()
assert res == "start"
@mock.patch("builtins.input", side_effect=KeyboardInterrupt())
def test_autocorrect_prompt_interrupt(mock_input, repo, mock_config):
"""Test autocorrect with prompt and user interrupts."""
mock_config.GetString.return_value = "prompt"
res = repo._autocorrect_command_name("tart", mock_config)
assert res is None

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,7 @@ import subprocess
import tempfile
from typing import Optional
import unittest
from unittest import mock
import utils_for_test
@@ -565,3 +566,120 @@ class ManifestPropertiesFetchedCorrectly(unittest.TestCase):
fakeproj.config.SetString("manifest.platform", "auto")
self.assertEqual(fakeproj.manifest_platform, "auto")
class StatelessSyncTests(unittest.TestCase):
"""Tests for stateless sync strategy."""
def _get_project(self, tempdir):
manifest = mock.MagicMock()
manifest.manifestProject.depth = None
manifest.manifestProject.dissociate = False
manifest.manifestProject.clone_filter = None
manifest.is_multimanifest = False
manifest.manifestProject.config.GetBoolean.return_value = False
remote = mock.MagicMock()
remote.name = "origin"
remote.url = "http://"
proj = project.Project(
manifest=manifest,
name="test-project",
remote=remote,
gitdir=os.path.join(tempdir, ".git"),
objdir=os.path.join(tempdir, ".git"),
worktree=tempdir,
relpath="test-project",
revisionExpr="1234abcd",
revisionId=None,
sync_strategy="stateless",
)
proj._CheckForImmutableRevision = mock.MagicMock(return_value=False)
proj._LsRemote = mock.MagicMock(
return_value="1234abcd\trefs/heads/main\n"
)
proj.bare_git = mock.MagicMock()
proj.bare_git.rev_parse.return_value = "5678abcd"
proj.bare_git.rev_list.return_value = ["0"]
proj.IsDirty = mock.MagicMock(return_value=False)
proj.GetBranches = mock.MagicMock(return_value=[])
proj.DeleteWorktree = mock.MagicMock()
proj._InitGitDir = mock.MagicMock()
proj._RemoteFetch = mock.MagicMock(return_value=True)
proj._InitRemote = mock.MagicMock()
proj._InitMRef = mock.MagicMock()
return proj
def test_sync_network_half_stateless_prune_needed(self):
"""Test stateless sync queues prune when needed."""
with utils_for_test.TempGitTree() as tempdir:
proj = self._get_project(tempdir)
res = proj.Sync_NetworkHalf()
self.assertTrue(res.success)
proj.DeleteWorktree.assert_not_called()
self.assertTrue(proj.stateless_prune_needed)
proj._RemoteFetch.assert_called_once()
def test_sync_local_half_stateless_prune(self):
"""Test stateless GC pruning is queued in Sync_LocalHalf."""
with utils_for_test.TempGitTree() as tempdir:
proj = self._get_project(tempdir)
proj.stateless_prune_needed = True
proj._Checkout = mock.MagicMock()
proj._InitWorkTree = mock.MagicMock()
proj.IsRebaseInProgress = mock.MagicMock(return_value=False)
proj.IsCherryPickInProgress = mock.MagicMock(return_value=False)
proj.bare_ref = mock.MagicMock()
proj.bare_ref.all = {}
proj.GetRevisionId = mock.MagicMock(return_value="1234abcd")
proj._CopyAndLinkFiles = mock.MagicMock()
proj.work_git = mock.MagicMock()
proj.work_git.GetHead.return_value = "5678abcd"
syncbuf = project.SyncBuffer(proj.config)
with mock.patch("project.GitCommand") as mock_git_cmd:
mock_cmd_instance = mock.MagicMock()
mock_cmd_instance.Wait.return_value = 0
mock_git_cmd.return_value = mock_cmd_instance
proj.Sync_LocalHalf(syncbuf)
syncbuf.Finish()
self.assertEqual(mock_git_cmd.call_count, 2)
mock_git_cmd.assert_any_call(
proj, ["reflog", "expire", "--expire=all", "--all"], bare=True
)
mock_git_cmd.assert_any_call(
proj,
["gc", "--prune=now"],
bare=True,
capture_stdout=True,
capture_stderr=True,
)
def test_sync_network_half_stateless_skips_if_stash(self):
"""Test stateless sync skips if stash exists."""
with utils_for_test.TempGitTree() as tempdir:
proj = self._get_project(tempdir)
proj.HasStash = mock.MagicMock(return_value=True)
res = proj.Sync_NetworkHalf()
self.assertTrue(res.success)
self.assertFalse(getattr(proj, "stateless_prune_needed", False))
def test_sync_network_half_stateless_skips_if_local_commits(self):
"""Test stateless sync skips if there are local-only commits."""
with utils_for_test.TempGitTree() as tempdir:
proj = self._get_project(tempdir)
proj.bare_git.rev_list.return_value = ["1"]
res = proj.Sync_NetworkHalf()
self.assertTrue(res.success)
self.assertFalse(getattr(proj, "stateless_prune_needed", False))

View File

@@ -14,11 +14,9 @@
"""Unittests for the forall subcmd."""
from io import StringIO
import os
from shutil import rmtree
import tempfile
import unittest
import contextlib
import io
from pathlib import Path
from unittest import mock
import utils_for_test
@@ -28,111 +26,81 @@ import project
import subcmds
class AllCommands(unittest.TestCase):
"""Check registered all_commands."""
def _create_manifest_with_8_projects(
topdir: Path,
) -> manifest_xml.XmlManifest:
"""Create a setup of 8 projects to execute forall."""
repodir = topdir / ".repo"
manifest_dir = repodir / "manifests"
manifest_file = repodir / manifest_xml.MANIFEST_FILE_NAME
def setUp(self):
"""Common setup."""
self.tempdirobj = tempfile.TemporaryDirectory(prefix="forall_tests")
self.tempdir = self.tempdirobj.name
self.repodir = os.path.join(self.tempdir, ".repo")
self.manifest_dir = os.path.join(self.repodir, "manifests")
self.manifest_file = os.path.join(
self.repodir, manifest_xml.MANIFEST_FILE_NAME
)
self.local_manifest_dir = os.path.join(
self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME
)
os.mkdir(self.repodir)
os.mkdir(self.manifest_dir)
repodir.mkdir()
manifest_dir.mkdir()
def tearDown(self):
"""Common teardown."""
rmtree(self.tempdir, ignore_errors=True)
# Set up a manifest git dir for parsing to work.
gitdir = repodir / "manifests.git"
gitdir.mkdir()
(gitdir / "config").write_text(
"""[remote "origin"]
url = https://localhost:0/manifest
verbose = false
"""
)
def getXmlManifestWith8Projects(self):
"""Create and return a setup of 8 projects with enough dummy
files and setup to execute forall."""
# Add the manifest data.
manifest_file.write_text(
"""
<manifest>
<remote name="origin" fetch="http://localhost" />
<default remote="origin" revision="refs/heads/main" />
<project name="project1" path="tests/path1" />
<project name="project2" path="tests/path2" />
<project name="project3" path="tests/path3" />
<project name="project4" path="tests/path4" />
<project name="project5" path="tests/path5" />
<project name="project6" path="tests/path6" />
<project name="project7" path="tests/path7" />
<project name="project8" path="tests/path8" />
</manifest>
""",
encoding="utf-8",
)
# Set up a manifest git dir for parsing to work
gitdir = os.path.join(self.repodir, "manifests.git")
os.mkdir(gitdir)
with open(os.path.join(gitdir, "config"), "w") as fp:
fp.write(
"""[remote "origin"]
url = https://localhost:0/manifest
verbose = false
"""
)
# Set up 8 empty projects to match the manifest.
for x in range(1, 9):
(repodir / "projects" / "tests" / f"path{x}.git").mkdir(parents=True)
(repodir / "project-objects" / f"project{x}.git").mkdir(parents=True)
git_path = topdir / "tests" / f"path{x}"
utils_for_test.init_git_tree(git_path)
# Add the manifest data
manifest_data = """
<manifest>
<remote name="origin" fetch="http://localhost" />
<default remote="origin" revision="refs/heads/main" />
<project name="project1" path="tests/path1" />
<project name="project2" path="tests/path2" />
<project name="project3" path="tests/path3" />
<project name="project4" path="tests/path4" />
<project name="project5" path="tests/path5" />
<project name="project6" path="tests/path6" />
<project name="project7" path="tests/path7" />
<project name="project8" path="tests/path8" />
</manifest>
"""
with open(self.manifest_file, "w", encoding="utf-8") as fp:
fp.write(manifest_data)
return manifest_xml.XmlManifest(str(repodir), str(manifest_file))
# Set up 8 empty projects to match the manifest
for x in range(1, 9):
os.makedirs(
os.path.join(
self.repodir, "projects/tests/path" + str(x) + ".git"
)
)
os.makedirs(
os.path.join(
self.repodir, "project-objects/project" + str(x) + ".git"
)
)
git_path = os.path.join(self.tempdir, "tests/path" + str(x))
utils_for_test.init_git_tree(git_path)
return manifest_xml.XmlManifest(self.repodir, self.manifest_file)
def test_forall_all_projects_called_once(tmp_path: Path) -> None:
"""Test that all projects get a command run once each."""
manifest = _create_manifest_with_8_projects(tmp_path)
# Use mock to capture stdout from the forall run
@unittest.mock.patch("sys.stdout", new_callable=StringIO)
def test_forall_all_projects_called_once(self, mock_stdout):
"""Test that all projects get a command run once each."""
cmd = subcmds.forall.Forall()
cmd.manifest = manifest
manifest_with_8_projects = self.getXmlManifestWith8Projects()
# Use echo project names as the test of forall.
opts, args = cmd.OptionParser.parse_args(["-c", "echo $REPO_PROJECT"])
opts.verbose = False
cmd = subcmds.forall.Forall()
cmd.manifest = manifest_with_8_projects
# Use echo project names as the test of forall
opts, args = cmd.OptionParser.parse_args(["-c", "echo $REPO_PROJECT"])
opts.verbose = False
# Mock to not have the Execute fail on remote check
with contextlib.redirect_stdout(io.StringIO()) as stdout:
# Mock to not have the Execute fail on remote check.
with mock.patch.object(
project.Project, "GetRevisionId", return_value="refs/heads/main"
):
# Run the forall command
# Run the forall command.
cmd.Execute(opts, args)
# Verify that we got every project name in the prints
for x in range(1, 9):
self.assertIn("project" + str(x), mock_stdout.getvalue())
output = stdout.getvalue()
# Verify that we got every project name in the output.
for x in range(1, 9):
assert f"project{x}" in output
# Split the captured output into lines to count them
line_count = 0
for line in mock_stdout.getvalue().split("\n"):
# A commented out print to stderr as a reminder
# that stdout is mocked, include sys and uncomment if needed
# print(line, file=sys.stderr)
if len(line) > 0:
line_count += 1
# Verify that we didn't get more lines than expected
assert line_count == 8
# Split the captured output into lines to count them.
line_count = sum(1 for x in output.splitlines() if x)
# Verify that we didn't get more lines than expected.
assert line_count == 8

View File

@@ -14,9 +14,10 @@
"""Unittests for the subcmds/upload.py module."""
import unittest
from unittest import mock
import pytest
from error import GitError
from error import UploadError
from subcmds import upload
@@ -26,45 +27,39 @@ class UnexpectedError(Exception):
"""An exception not expected by upload command."""
class UploadCommand(unittest.TestCase):
"""Check registered all_commands."""
# A stub people list (reviewers, cc).
_STUB_PEOPLE = ([], [])
def setUp(self):
self.cmd = upload.Upload()
self.branch = mock.MagicMock()
self.people = mock.MagicMock()
self.opt, _ = self.cmd.OptionParser.parse_args([])
mock.patch.object(
self.cmd, "_AppendAutoList", return_value=None
).start()
mock.patch.object(self.cmd, "git_event_log").start()
def tearDown(self):
mock.patch.stopall()
@pytest.fixture
def cmd() -> upload.Upload:
"""Fixture to provide an Upload command instance with mocked methods."""
cmd = upload.Upload()
with mock.patch.object(
cmd, "_AppendAutoList", return_value=None
), mock.patch.object(cmd, "git_event_log"):
yield cmd
def test_UploadAndReport_UploadError(self):
"""Check UploadExitError raised when UploadError encountered."""
side_effect = UploadError("upload error")
with mock.patch.object(
self.cmd, "_UploadBranch", side_effect=side_effect
):
with self.assertRaises(upload.UploadExitError):
self.cmd._UploadAndReport(self.opt, [self.branch], self.people)
def test_UploadAndReport_GitError(self):
"""Check UploadExitError raised when GitError encountered."""
side_effect = GitError("some git error")
with mock.patch.object(
self.cmd, "_UploadBranch", side_effect=side_effect
):
with self.assertRaises(upload.UploadExitError):
self.cmd._UploadAndReport(self.opt, [self.branch], self.people)
def test_UploadAndReport_UploadError(cmd: upload.Upload) -> None:
"""Check UploadExitError raised when UploadError encountered."""
opt, _ = cmd.OptionParser.parse_args([])
with mock.patch.object(cmd, "_UploadBranch", side_effect=UploadError("")):
with pytest.raises(upload.UploadExitError):
cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
def test_UploadAndReport_UnhandledError(self):
"""Check UnexpectedError passed through."""
side_effect = UnexpectedError("some os error")
with mock.patch.object(
self.cmd, "_UploadBranch", side_effect=side_effect
):
with self.assertRaises(type(side_effect)):
self.cmd._UploadAndReport(self.opt, [self.branch], self.people)
def test_UploadAndReport_GitError(cmd: upload.Upload) -> None:
"""Check UploadExitError raised when GitError encountered."""
opt, _ = cmd.OptionParser.parse_args([])
with mock.patch.object(cmd, "_UploadBranch", side_effect=GitError("")):
with pytest.raises(upload.UploadExitError):
cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)
def test_UploadAndReport_UnhandledError(cmd: upload.Upload) -> None:
"""Check UnexpectedError passed through."""
opt, _ = cmd.OptionParser.parse_args([])
with mock.patch.object(cmd, "_UploadBranch", side_effect=UnexpectedError):
with pytest.raises(UnexpectedError):
cmd._UploadAndReport(opt, [mock.MagicMock()], _STUB_PEOPLE)

View File

@@ -19,267 +19,303 @@ import os
import re
import subprocess
import sys
import tempfile
import unittest
from unittest import mock
import pytest
import utils_for_test
import main
import wrapper
def fixture(*paths):
"""Return a path relative to tests/fixtures."""
return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
@pytest.fixture(autouse=True)
def reset_wrapper() -> None:
"""Reset the wrapper module every time."""
wrapper.Wrapper.cache_clear()
class RepoWrapperTestCase(unittest.TestCase):
"""TestCase for the wrapper module."""
def setUp(self):
"""Load the wrapper module every time."""
wrapper.Wrapper.cache_clear()
self.wrapper = wrapper.Wrapper()
@pytest.fixture
def repo_wrapper() -> wrapper.Wrapper:
"""Fixture for the wrapper module."""
return wrapper.Wrapper()
class RepoWrapperUnitTest(RepoWrapperTestCase):
class GitCheckout:
"""Class to hold git checkout info for tests."""
def __init__(self, git_dir, rev_list):
self.git_dir = git_dir
self.rev_list = rev_list
@pytest.fixture(scope="module")
def git_checkout(tmp_path_factory) -> GitCheckout:
"""Fixture for tests that use a real/small git checkout.
Create a repo to operate on, but do it once per-test-run.
"""
tempdir = tmp_path_factory.mktemp("repo-rev-tests")
run_git = wrapper.Wrapper().run_git
remote = os.path.join(tempdir, "remote")
os.mkdir(remote)
utils_for_test.init_git_tree(remote)
run_git("commit", "--allow-empty", "-minit", cwd=remote)
run_git("branch", "stable", cwd=remote)
run_git("tag", "v1.0", cwd=remote)
run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote)
rev_list = run_git("rev-list", "HEAD", cwd=remote).stdout.splitlines()
run_git("init", cwd=tempdir)
run_git(
"fetch",
remote,
"+refs/heads/*:refs/remotes/origin/*",
cwd=tempdir,
)
yield GitCheckout(tempdir, rev_list)
class TestRepoWrapper:
"""Tests helper functions in the repo wrapper"""
def test_version(self):
def test_version(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure _Version works."""
with self.assertRaises(SystemExit) as e:
with pytest.raises(SystemExit) as e:
with mock.patch("sys.stdout", new_callable=io.StringIO) as stdout:
with mock.patch(
"sys.stderr", new_callable=io.StringIO
) as stderr:
self.wrapper._Version()
self.assertEqual(0, e.exception.code)
self.assertEqual("", stderr.getvalue())
self.assertIn("repo launcher version", stdout.getvalue())
repo_wrapper._Version()
assert e.value.code == 0
assert stderr.getvalue() == ""
assert "repo launcher version" in stdout.getvalue()
def test_python_constraints(self):
def test_python_constraints(self, repo_wrapper: wrapper.Wrapper) -> None:
"""The launcher should never require newer than main.py."""
self.assertGreaterEqual(
main.MIN_PYTHON_VERSION_HARD, self.wrapper.MIN_PYTHON_VERSION_HARD
assert (
main.MIN_PYTHON_VERSION_HARD >= repo_wrapper.MIN_PYTHON_VERSION_HARD
)
self.assertGreaterEqual(
main.MIN_PYTHON_VERSION_SOFT, self.wrapper.MIN_PYTHON_VERSION_SOFT
assert (
main.MIN_PYTHON_VERSION_SOFT >= repo_wrapper.MIN_PYTHON_VERSION_SOFT
)
# Make sure the versions are themselves in sync.
self.assertGreaterEqual(
self.wrapper.MIN_PYTHON_VERSION_SOFT,
self.wrapper.MIN_PYTHON_VERSION_HARD,
assert (
repo_wrapper.MIN_PYTHON_VERSION_SOFT
>= repo_wrapper.MIN_PYTHON_VERSION_HARD
)
def test_init_parser(self):
def test_repo_script_is_executable(self) -> None:
"""The repo launcher script should be executable."""
repo_path = utils_for_test.THIS_DIR.parent / "repo"
assert os.access(repo_path, os.X_OK), f"{repo_path} is not executable"
def test_init_parser(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure 'init' GetParser works."""
parser = self.wrapper.GetParser()
parser = repo_wrapper.GetParser()
opts, args = parser.parse_args([])
self.assertEqual([], args)
self.assertIsNone(opts.manifest_url)
assert args == []
assert opts.manifest_url is None
class SetGitTrace2ParentSid(RepoWrapperTestCase):
class TestSetGitTrace2ParentSid:
"""Check SetGitTrace2ParentSid behavior."""
KEY = "GIT_TRACE2_PARENT_SID"
VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$")
def test_first_set(self):
def test_first_set(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Test env var not yet set."""
env = {}
self.wrapper.SetGitTrace2ParentSid(env)
self.assertIn(self.KEY, env)
repo_wrapper.SetGitTrace2ParentSid(env)
assert self.KEY in env
value = env[self.KEY]
self.assertRegex(value, self.VALID_FORMAT)
assert self.VALID_FORMAT.match(value)
def test_append(self):
def test_append(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Test env var is appended."""
env = {self.KEY: "pfx"}
self.wrapper.SetGitTrace2ParentSid(env)
self.assertIn(self.KEY, env)
repo_wrapper.SetGitTrace2ParentSid(env)
assert self.KEY in env
value = env[self.KEY]
self.assertTrue(value.startswith("pfx/"))
self.assertRegex(value[4:], self.VALID_FORMAT)
assert value.startswith("pfx/")
assert self.VALID_FORMAT.match(value[4:])
def test_global_context(self):
def test_global_context(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check os.environ gets updated by default."""
os.environ.pop(self.KEY, None)
self.wrapper.SetGitTrace2ParentSid()
self.assertIn(self.KEY, os.environ)
repo_wrapper.SetGitTrace2ParentSid()
assert self.KEY in os.environ
value = os.environ[self.KEY]
self.assertRegex(value, self.VALID_FORMAT)
assert self.VALID_FORMAT.match(value)
class RunCommand(RepoWrapperTestCase):
class TestRunCommand:
"""Check run_command behavior."""
def test_capture(self):
def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check capture_output handling."""
ret = self.wrapper.run_command(["echo", "hi"], capture_output=True)
ret = repo_wrapper.run_command(["echo", "hi"], capture_output=True)
# echo command appends OS specific linesep, but on Windows + Git Bash
# we get UNIX ending, so we allow both.
self.assertIn(ret.stdout, ["hi" + os.linesep, "hi\n"])
assert ret.stdout in ["hi" + os.linesep, "hi\n"]
def test_check(self):
def test_check(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check check handling."""
self.wrapper.run_command(["true"], check=False)
self.wrapper.run_command(["true"], check=True)
self.wrapper.run_command(["false"], check=False)
with self.assertRaises(subprocess.CalledProcessError):
self.wrapper.run_command(["false"], check=True)
repo_wrapper.run_command(["true"], check=False)
repo_wrapper.run_command(["true"], check=True)
repo_wrapper.run_command(["false"], check=False)
with pytest.raises(subprocess.CalledProcessError):
repo_wrapper.run_command(["false"], check=True)
class RunGit(RepoWrapperTestCase):
class TestRunGit:
"""Check run_git behavior."""
def test_capture(self):
def test_capture(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check capture_output handling."""
ret = self.wrapper.run_git("--version")
self.assertIn("git", ret.stdout)
ret = repo_wrapper.run_git("--version")
assert "git" in ret.stdout
def test_check(self):
def test_check(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check check handling."""
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.run_git("--version-asdfasdf")
self.wrapper.run_git("--version-asdfasdf", check=False)
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.run_git("--version-asdfasdf")
repo_wrapper.run_git("--version-asdfasdf", check=False)
class ParseGitVersion(RepoWrapperTestCase):
class TestParseGitVersion:
"""Check ParseGitVersion behavior."""
def test_autoload(self):
def test_autoload(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check we can load the version from the live git."""
ret = self.wrapper.ParseGitVersion()
self.assertIsNotNone(ret)
assert repo_wrapper.ParseGitVersion() is not None
def test_bad_ver(self):
def test_bad_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check handling of bad git versions."""
ret = self.wrapper.ParseGitVersion(ver_str="asdf")
self.assertIsNone(ret)
assert repo_wrapper.ParseGitVersion(ver_str="asdf") is None
def test_normal_ver(self):
def test_normal_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check handling of normal git versions."""
ret = self.wrapper.ParseGitVersion(ver_str="git version 2.25.1")
self.assertEqual(2, ret.major)
self.assertEqual(25, ret.minor)
self.assertEqual(1, ret.micro)
self.assertEqual("2.25.1", ret.full)
ret = repo_wrapper.ParseGitVersion(ver_str="git version 2.25.1")
assert ret.major == 2
assert ret.minor == 25
assert ret.micro == 1
assert ret.full == "2.25.1"
def test_extended_ver(self):
def test_extended_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check handling of extended distro git versions."""
ret = self.wrapper.ParseGitVersion(
ret = repo_wrapper.ParseGitVersion(
ver_str="git version 1.30.50.696.g5e7596f4ac-goog"
)
self.assertEqual(1, ret.major)
self.assertEqual(30, ret.minor)
self.assertEqual(50, ret.micro)
self.assertEqual("1.30.50.696.g5e7596f4ac-goog", ret.full)
assert ret.major == 1
assert ret.minor == 30
assert ret.micro == 50
assert ret.full == "1.30.50.696.g5e7596f4ac-goog"
class CheckGitVersion(RepoWrapperTestCase):
class TestCheckGitVersion:
"""Check _CheckGitVersion behavior."""
def test_unknown(self):
def test_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Unknown versions should abort."""
with mock.patch.object(
self.wrapper, "ParseGitVersion", return_value=None
repo_wrapper, "ParseGitVersion", return_value=None
):
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper._CheckGitVersion()
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper._CheckGitVersion()
def test_old(self):
def test_old(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Old versions should abort."""
with mock.patch.object(
self.wrapper,
repo_wrapper,
"ParseGitVersion",
return_value=self.wrapper.GitVersion(1, 0, 0, "1.0.0"),
return_value=repo_wrapper.GitVersion(1, 0, 0, "1.0.0"),
):
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper._CheckGitVersion()
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper._CheckGitVersion()
def test_new(self):
def test_new(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Newer versions should run fine."""
with mock.patch.object(
self.wrapper,
repo_wrapper,
"ParseGitVersion",
return_value=self.wrapper.GitVersion(100, 0, 0, "100.0.0"),
return_value=repo_wrapper.GitVersion(100, 0, 0, "100.0.0"),
):
self.wrapper._CheckGitVersion()
repo_wrapper._CheckGitVersion()
class Requirements(RepoWrapperTestCase):
class TestRequirements:
"""Check Requirements handling."""
def test_missing_file(self):
def test_missing_file(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Don't crash if the file is missing (old version)."""
testdir = os.path.dirname(os.path.realpath(__file__))
self.assertIsNone(self.wrapper.Requirements.from_dir(testdir))
self.assertIsNone(
self.wrapper.Requirements.from_file(
os.path.join(testdir, "xxxxxxxxxxxxxxxxxxxxxxxx")
assert (
repo_wrapper.Requirements.from_dir(utils_for_test.THIS_DIR) is None
)
assert (
repo_wrapper.Requirements.from_file(
utils_for_test.THIS_DIR / "xxxxxxxxxxxxxxxxxxxxxxxx"
)
is None
)
def test_corrupt_data(self):
def test_corrupt_data(self, repo_wrapper: wrapper.Wrapper) -> None:
"""If the file can't be parsed, don't blow up."""
self.assertIsNone(self.wrapper.Requirements.from_file(__file__))
self.assertIsNone(self.wrapper.Requirements.from_data(b"x"))
assert repo_wrapper.Requirements.from_file(__file__) is None
assert repo_wrapper.Requirements.from_data(b"x") is None
def test_valid_data(self):
def test_valid_data(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure we can parse the file we ship."""
self.assertIsNotNone(self.wrapper.Requirements.from_data(b"{}"))
rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir))
self.assertIsNotNone(
self.wrapper.Requirements.from_file(
os.path.join(rootdir, "requirements.json")
)
assert repo_wrapper.Requirements.from_data(b"{}") is not None
rootdir = utils_for_test.THIS_DIR.parent
assert repo_wrapper.Requirements.from_dir(rootdir) is not None
assert (
repo_wrapper.Requirements.from_file(rootdir / "requirements.json")
is not None
)
def test_format_ver(self):
def test_format_ver(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check format_ver can format."""
self.assertEqual(
"1.2.3", self.wrapper.Requirements._format_ver((1, 2, 3))
)
self.assertEqual("1", self.wrapper.Requirements._format_ver([1]))
assert repo_wrapper.Requirements._format_ver((1, 2, 3)) == "1.2.3"
assert repo_wrapper.Requirements._format_ver([1]) == "1"
def test_assert_all_unknown(self):
def test_assert_all_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all works with incompatible file."""
reqs = self.wrapper.Requirements({})
reqs = repo_wrapper.Requirements({})
reqs.assert_all()
def test_assert_all_new_repo(self):
def test_assert_all_new_repo(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all accepts new enough repo."""
reqs = self.wrapper.Requirements({"repo": {"hard": [1, 0]}})
reqs = repo_wrapper.Requirements({"repo": {"hard": [1, 0]}})
reqs.assert_all()
def test_assert_all_old_repo(self):
def test_assert_all_old_repo(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all rejects old repo."""
reqs = self.wrapper.Requirements({"repo": {"hard": [99999, 0]}})
with self.assertRaises(SystemExit):
reqs = repo_wrapper.Requirements({"repo": {"hard": [99999, 0]}})
with pytest.raises(SystemExit):
reqs.assert_all()
def test_assert_all_new_python(self):
def test_assert_all_new_python(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all accepts new enough python."""
reqs = self.wrapper.Requirements({"python": {"hard": sys.version_info}})
reqs = repo_wrapper.Requirements({"python": {"hard": sys.version_info}})
reqs.assert_all()
def test_assert_all_old_python(self):
def test_assert_all_old_python(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_all rejects old python."""
reqs = self.wrapper.Requirements({"python": {"hard": [99999, 0]}})
with self.assertRaises(SystemExit):
reqs = repo_wrapper.Requirements({"python": {"hard": [99999, 0]}})
with pytest.raises(SystemExit):
reqs.assert_all()
def test_assert_ver_unknown(self):
def test_assert_ver_unknown(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_ver works with incompatible file."""
reqs = self.wrapper.Requirements({})
reqs = repo_wrapper.Requirements({})
reqs.assert_ver("xxx", (1, 0))
def test_assert_ver_new(self):
def test_assert_ver_new(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_ver allows new enough versions."""
reqs = self.wrapper.Requirements(
reqs = repo_wrapper.Requirements(
{"git": {"hard": [1, 0], "soft": [2, 0]}}
)
reqs.assert_ver("git", (1, 0))
@@ -287,274 +323,279 @@ class Requirements(RepoWrapperTestCase):
reqs.assert_ver("git", (2, 0))
reqs.assert_ver("git", (2, 5))
def test_assert_ver_old(self):
def test_assert_ver_old(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check assert_ver rejects old versions."""
reqs = self.wrapper.Requirements(
reqs = repo_wrapper.Requirements(
{"git": {"hard": [1, 0], "soft": [2, 0]}}
)
with self.assertRaises(SystemExit):
with pytest.raises(SystemExit):
reqs.assert_ver("git", (0, 5))
class NeedSetupGnuPG(RepoWrapperTestCase):
class TestNeedSetupGnuPG:
"""Check NeedSetupGnuPG behavior."""
def test_missing_dir(self):
def test_missing_dir(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""The ~/.repoconfig tree doesn't exist yet."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = os.path.join(tempdir, "foo")
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path / "foo")
assert repo_wrapper.NeedSetupGnuPG()
def test_missing_keyring(self):
def test_missing_keyring(
self, tmp_path, repo_wrapper: wrapper.Wrapper
) -> None:
"""The keyring-version file doesn't exist yet."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
assert repo_wrapper.NeedSetupGnuPG()
def test_empty_keyring(self):
def test_empty_keyring(
self, tmp_path, repo_wrapper: wrapper.Wrapper
) -> None:
"""The keyring-version file exists, but is empty."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, "keyring-version"), "w"):
pass
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
(tmp_path / "keyring-version").write_text("")
assert repo_wrapper.NeedSetupGnuPG()
def test_old_keyring(self):
def test_old_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""The keyring-version file exists, but it's old."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, "keyring-version"), "w") as fp:
fp.write("1.0\n")
self.assertTrue(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
(tmp_path / "keyring-version").write_text("1.0\n")
assert repo_wrapper.NeedSetupGnuPG()
def test_new_keyring(self):
def test_new_keyring(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""The keyring-version file exists, and is up-to-date."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, "keyring-version"), "w") as fp:
fp.write("1000.0\n")
self.assertFalse(self.wrapper.NeedSetupGnuPG())
repo_wrapper.home_dot_repo = str(tmp_path)
(tmp_path / "keyring-version").write_text("1000.0\n")
assert not repo_wrapper.NeedSetupGnuPG()
class SetupGnuPG(RepoWrapperTestCase):
class TestSetupGnuPG:
"""Check SetupGnuPG behavior."""
def test_full(self):
def test_full(self, tmp_path, repo_wrapper: wrapper.Wrapper) -> None:
"""Make sure it works completely."""
with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir
self.wrapper.gpg_dir = os.path.join(
self.wrapper.home_dot_repo, "gnupg"
)
self.assertTrue(self.wrapper.SetupGnuPG(True))
with open(os.path.join(tempdir, "keyring-version")) as fp:
data = fp.read()
self.assertEqual(
".".join(str(x) for x in self.wrapper.KEYRING_VERSION),
data.strip(),
)
repo_wrapper.home_dot_repo = str(tmp_path)
repo_wrapper.gpg_dir = str(tmp_path / "gnupg")
assert repo_wrapper.SetupGnuPG(True)
data = (tmp_path / "keyring-version").read_text()
assert (
".".join(str(x) for x in repo_wrapper.KEYRING_VERSION)
== data.strip()
)
class VerifyRev(RepoWrapperTestCase):
class TestVerifyRev:
"""Check verify_rev behavior."""
def test_verify_passes(self):
def test_verify_passes(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check when we have a valid signed tag."""
desc_result = subprocess.CompletedProcess([], 0, "v1.0\n", "")
gpg_result = subprocess.CompletedProcess([], 0, "", "")
with mock.patch.object(
self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
):
ret = self.wrapper.verify_rev(
ret = repo_wrapper.verify_rev(
"/", "refs/heads/stable", "1234", True
)
self.assertEqual("v1.0^0", ret)
assert ret == "v1.0^0"
def test_unsigned_commit(self):
def test_unsigned_commit(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check we fall back to signed tag when we have an unsigned commit."""
desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "")
gpg_result = subprocess.CompletedProcess([], 0, "", "")
with mock.patch.object(
self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
):
ret = self.wrapper.verify_rev(
ret = repo_wrapper.verify_rev(
"/", "refs/heads/stable", "1234", True
)
self.assertEqual("v1.0^0", ret)
assert ret == "v1.0^0"
def test_verify_fails(self):
def test_verify_fails(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Check we fall back to signed tag when we have an unsigned commit."""
desc_result = subprocess.CompletedProcess([], 0, "v1.0-10-g1234\n", "")
gpg_result = Exception
gpg_result = RuntimeError
with mock.patch.object(
self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
repo_wrapper, "run_git", side_effect=(desc_result, gpg_result)
):
with self.assertRaises(Exception):
self.wrapper.verify_rev("/", "refs/heads/stable", "1234", True)
with pytest.raises(RuntimeError):
repo_wrapper.verify_rev("/", "refs/heads/stable", "1234", True)
class GitCheckoutTestCase(RepoWrapperTestCase):
"""Tests that use a real/small git checkout."""
GIT_DIR = None
REV_LIST = None
@classmethod
def setUpClass(cls):
# Create a repo to operate on, but do it once per-class.
cls.tempdirobj = tempfile.TemporaryDirectory(prefix="repo-rev-tests")
cls.GIT_DIR = cls.tempdirobj.name
run_git = wrapper.Wrapper().run_git
remote = os.path.join(cls.GIT_DIR, "remote")
os.mkdir(remote)
utils_for_test.init_git_tree(remote)
run_git("commit", "--allow-empty", "-minit", cwd=remote)
run_git("branch", "stable", cwd=remote)
run_git("tag", "v1.0", cwd=remote)
run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote)
cls.REV_LIST = run_git(
"rev-list", "HEAD", cwd=remote
).stdout.splitlines()
run_git("init", cwd=cls.GIT_DIR)
run_git(
"fetch",
remote,
"+refs/heads/*:refs/remotes/origin/*",
cwd=cls.GIT_DIR,
)
@classmethod
def tearDownClass(cls):
if not cls.tempdirobj:
return
cls.tempdirobj.cleanup()
class ResolveRepoRev(GitCheckoutTestCase):
class TestResolveRepoRev:
"""Check resolve_repo_rev behavior."""
def test_explicit_branch(self):
def test_explicit_branch(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check refs/heads/branch argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(
self.GIT_DIR, "refs/heads/stable"
rrev, lrev = repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/heads/stable"
)
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
assert rrev == "refs/heads/stable"
assert lrev == git_checkout.rev_list[1]
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/heads/unknown")
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/heads/unknown"
)
def test_explicit_tag(self):
def test_explicit_tag(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check refs/tags/tag argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(
self.GIT_DIR, "refs/tags/v1.0"
rrev, lrev = repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/tags/v1.0"
)
self.assertEqual("refs/tags/v1.0", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
assert rrev == "refs/tags/v1.0"
assert lrev == git_checkout.rev_list[1]
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/tags/unknown")
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "refs/tags/unknown"
)
def test_branch_name(self):
def test_branch_name(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check branch argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "stable")
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
rrev, lrev = repo_wrapper.resolve_repo_rev(
git_checkout.git_dir, "stable"
)
assert rrev == "refs/heads/stable"
assert lrev == git_checkout.rev_list[1]
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "main")
self.assertEqual("refs/heads/main", rrev)
self.assertEqual(self.REV_LIST[0], lrev)
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "main")
assert rrev == "refs/heads/main"
assert lrev == git_checkout.rev_list[0]
def test_tag_name(self):
def test_tag_name(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check tag argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "v1.0")
self.assertEqual("refs/tags/v1.0", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "v1.0")
assert rrev == "refs/tags/v1.0"
assert lrev == git_checkout.rev_list[1]
def test_full_commit(self):
def test_full_commit(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check specific commit argument."""
commit = self.REV_LIST[0]
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit)
self.assertEqual(commit, rrev)
self.assertEqual(commit, lrev)
commit = git_checkout.rev_list[0]
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit)
assert rrev == commit
assert lrev == commit
def test_partial_commit(self):
def test_partial_commit(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check specific (partial) commit argument."""
commit = self.REV_LIST[0][0:20]
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit)
self.assertEqual(self.REV_LIST[0], rrev)
self.assertEqual(self.REV_LIST[0], lrev)
commit = git_checkout.rev_list[0][0:20]
rrev, lrev = repo_wrapper.resolve_repo_rev(git_checkout.git_dir, commit)
assert rrev == git_checkout.rev_list[0]
assert lrev == git_checkout.rev_list[0]
def test_unknown(self):
def test_unknown(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Check unknown ref/commit argument."""
with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, "boooooooya")
with pytest.raises(repo_wrapper.CloneFailure):
repo_wrapper.resolve_repo_rev(git_checkout.git_dir, "boooooooya")
class CheckRepoVerify(RepoWrapperTestCase):
class TestCheckRepoVerify:
"""Check check_repo_verify behavior."""
def test_no_verify(self):
def test_no_verify(self, repo_wrapper: wrapper.Wrapper) -> None:
"""Always fail with --no-repo-verify."""
self.assertFalse(self.wrapper.check_repo_verify(False))
assert not repo_wrapper.check_repo_verify(False)
def test_gpg_initialized(self):
def test_gpg_initialized(
self,
repo_wrapper: wrapper.Wrapper,
) -> None:
"""Should pass if gpg is setup already."""
with mock.patch.object(
self.wrapper, "NeedSetupGnuPG", return_value=False
repo_wrapper, "NeedSetupGnuPG", return_value=False
):
self.assertTrue(self.wrapper.check_repo_verify(True))
assert repo_wrapper.check_repo_verify(True)
def test_need_gpg_setup(self):
def test_need_gpg_setup(
self,
repo_wrapper: wrapper.Wrapper,
) -> None:
"""Should pass/fail based on gpg setup."""
with mock.patch.object(
self.wrapper, "NeedSetupGnuPG", return_value=True
repo_wrapper, "NeedSetupGnuPG", return_value=True
):
with mock.patch.object(self.wrapper, "SetupGnuPG") as m:
with mock.patch.object(repo_wrapper, "SetupGnuPG") as m:
m.return_value = True
self.assertTrue(self.wrapper.check_repo_verify(True))
assert repo_wrapper.check_repo_verify(True)
m.return_value = False
self.assertFalse(self.wrapper.check_repo_verify(True))
assert not repo_wrapper.check_repo_verify(True)
class CheckRepoRev(GitCheckoutTestCase):
class TestCheckRepoRev:
"""Check check_repo_rev behavior."""
def test_verify_works(self):
def test_verify_works(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Should pass when verification passes."""
with mock.patch.object(
self.wrapper, "check_repo_verify", return_value=True
repo_wrapper, "check_repo_verify", return_value=True
):
with mock.patch.object(
self.wrapper, "verify_rev", return_value="12345"
repo_wrapper, "verify_rev", return_value="12345"
):
rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, "stable")
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual("12345", lrev)
rrev, lrev = repo_wrapper.check_repo_rev(
git_checkout.git_dir, "stable"
)
assert rrev == "refs/heads/stable"
assert lrev == "12345"
def test_verify_fails(self):
def test_verify_fails(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Should fail when verification fails."""
with mock.patch.object(
self.wrapper, "check_repo_verify", return_value=True
repo_wrapper, "check_repo_verify", return_value=True
):
with mock.patch.object(
self.wrapper, "verify_rev", side_effect=Exception
repo_wrapper, "verify_rev", side_effect=RuntimeError
):
with self.assertRaises(Exception):
self.wrapper.check_repo_rev(self.GIT_DIR, "stable")
with pytest.raises(RuntimeError):
repo_wrapper.check_repo_rev(git_checkout.git_dir, "stable")
def test_verify_ignore(self):
def test_verify_ignore(
self,
repo_wrapper: wrapper.Wrapper,
git_checkout: GitCheckout,
) -> None:
"""Should pass when verification is disabled."""
with mock.patch.object(
self.wrapper, "verify_rev", side_effect=Exception
repo_wrapper, "verify_rev", side_effect=RuntimeError
):
rrev, lrev = self.wrapper.check_repo_rev(
self.GIT_DIR, "stable", repo_verify=False
rrev, lrev = repo_wrapper.check_repo_rev(
git_checkout.git_dir, "stable", repo_verify=False
)
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev)
assert rrev == "refs/heads/stable"
assert lrev == git_checkout.rev_list[1]

View File

@@ -27,6 +27,11 @@ from typing import Optional, Union
import git_command
THIS_FILE = Path(__file__).resolve()
THIS_DIR = THIS_FILE.parent
FIXTURES_DIR = THIS_DIR / "fixtures"
def init_git_tree(
path: Union[str, Path],
ref_format: Optional[str] = None,