diff --git a/project.py b/project.py index b26e2a6ac..9e8d8605a 100644 --- a/project.py +++ b/project.py @@ -28,7 +28,7 @@ import sys import tarfile import tempfile import time -from typing import List, NamedTuple, Optional +from typing import Dict, List, NamedTuple, Optional import urllib.parse from color import Coloring @@ -651,6 +651,50 @@ class Project: return self.relpath return os.path.join(self.manifest.path_prefix, self.relpath) + def GetEnvVars(self, local: bool = True) -> Dict[str, str]: + """Get project-context environment variables. + + Args: + local: If True, REPO_PATH is relative to the local (sub)manifest. + If False, it is relative to the outermost manifest. + + Returns: + A dictionary mapping environment variable names to their values. + + Environment Variables: + See the Environment section in `repo help forall` or + `subcmds/forall.py` for details on the available variables. + Note that `forall.py` also documents some extra variables that are + specific to how the `repo forall` command iterates over projects + (e.g., `REPO_COUNT` and `REPO_I`). + """ + env = {} + + def setenv(name, val): + if val is None: + val = "" + env[name] = val + + setenv("REPO_PROJECT", self.name) + setenv("REPO_OUTERPATH", self.manifest.path_prefix) + setenv("REPO_INNERPATH", self.relpath) + setenv("REPO_PATH", self.RelPath(local=local)) + setenv("REPO_REMOTE", self.remote.name) + + try: + lrev = "" if self.manifest.IsMirror else self.GetRevisionId() + except ManifestInvalidRevisionError: + lrev = "" + setenv("REPO_LREV", lrev) + setenv("REPO_RREV", self.revisionExpr) + setenv("REPO_UPSTREAM", self.upstream) + setenv("REPO_DEST_BRANCH", self.dest_branch) + + for annotation in self.annotations: + setenv(f"REPO__{annotation.name}", annotation.value) + + return env + def SetRevision(self, revisionExpr, revisionId=None): """Set revisionId based on revision expression and id""" self.revisionExpr = revisionExpr diff --git a/subcmds/forall.py b/subcmds/forall.py index 2304e4382..4560f0d27 100644 --- a/subcmds/forall.py +++ b/subcmds/forall.py @@ -25,7 +25,6 @@ from color import Coloring from command import Command from command import DEFAULT_LOCAL_JOBS from command import MirrorSafeCommand -from error import ManifestInvalidRevisionError from repo_logging import RepoLogger @@ -339,25 +338,8 @@ def DoWork(project, mirror, opt, cmd, shell, cnt, config): val = "" env[name] = val - setenv("REPO_PROJECT", project.name) - setenv("REPO_OUTERPATH", project.manifest.path_prefix) - setenv("REPO_INNERPATH", project.relpath) - setenv("REPO_PATH", project.RelPath(local=opt.this_manifest_only)) - setenv("REPO_REMOTE", project.remote.name) - try: - # If we aren't in a fully synced state and we don't have the ref the - # manifest wants, then this will fail. Ignore it for the purposes of - # this code. - lrev = "" if mirror else project.GetRevisionId() - except ManifestInvalidRevisionError: - lrev = "" - setenv("REPO_LREV", lrev) - setenv("REPO_RREV", project.revisionExpr) - setenv("REPO_UPSTREAM", project.upstream) - setenv("REPO_DEST_BRANCH", project.dest_branch) - setenv("REPO_I", str(cnt + 1)) - for annotation in project.annotations: - setenv("REPO__%s" % (annotation.name), annotation.value) + env.update(project.GetEnvVars(local=opt.this_manifest_only)) + env["REPO_I"] = str(cnt + 1) if mirror: setenv("GIT_DIR", project.gitdir) diff --git a/tests/test_project.py b/tests/test_project.py index e3235f9c8..af8bea4f0 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1007,3 +1007,89 @@ class SyncOptimizationTests(unittest.TestCase): self.assertTrue(res) mock_git_cmd.assert_not_called() + + +class GetEnvVarsTests(unittest.TestCase): + """Tests for GetEnvVars project environment variable generation.""" + + def _get_project(self, tempdir, revisionExpr="main"): + proj = _create_mock_project(tempdir, revisionExpr=revisionExpr) + proj.GetRevisionId = mock.MagicMock(return_value="1234abcd") + return proj + + def test_get_env_vars_basic(self): + """Test that all basic environment variables are set correctly.""" + with utils_for_test.TempGitTree() as tempdir: + proj = self._get_project(tempdir) + proj.manifest.path_prefix = "sub-manifest" + proj.upstream = "upstream-branch" + proj.dest_branch = "dest-branch" + + env = proj.GetEnvVars(local=True) + + self.assertEqual(env["REPO_PROJECT"], "test-project") + self.assertEqual(env["REPO_OUTERPATH"], "sub-manifest") + self.assertEqual(env["REPO_INNERPATH"], "test-project") + self.assertEqual(env["REPO_PATH"], "test-project") + self.assertEqual(env["REPO_REMOTE"], "origin") + self.assertEqual(env["REPO_LREV"], "1234abcd") + self.assertEqual(env["REPO_RREV"], "main") + self.assertEqual(env["REPO_UPSTREAM"], "upstream-branch") + self.assertEqual(env["REPO_DEST_BRANCH"], "dest-branch") + + def test_get_env_vars_non_local(self): + """Test environment variables generation with local=False.""" + with utils_for_test.TempGitTree() as tempdir: + proj = self._get_project(tempdir) + proj.manifest.path_prefix = "sub-manifest" + + env = proj.GetEnvVars(local=False) + + # REPO_PATH should be relative to outermost manifest + # (sub-manifest/test-project) + self.assertEqual(env["REPO_PATH"], "sub-manifest/test-project") + + def test_get_env_vars_mirror(self): + """Test environment variables generation in mirror mode.""" + with utils_for_test.TempGitTree() as tempdir: + proj = self._get_project(tempdir) + proj.manifest.IsMirror = True + + env = proj.GetEnvVars() + + # In mirror mode, REPO_LREV should be empty, and GetRevisionId must + # not be called + self.assertEqual(env["REPO_LREV"], "") + proj.GetRevisionId.assert_not_called() + + def test_get_env_vars_annotations(self): + """Test that project annotations are added correctly.""" + with utils_for_test.TempGitTree() as tempdir: + proj = self._get_project(tempdir) + + annotation1 = mock.MagicMock() + annotation1.name = "key1" + annotation1.value = "value1" + + annotation2 = mock.MagicMock() + annotation2.name = "key2" + annotation2.value = "value2" + + proj.annotations = [annotation1, annotation2] + + env = proj.GetEnvVars() + + self.assertEqual(env["REPO__key1"], "value1") + self.assertEqual(env["REPO__key2"], "value2") + + def test_get_env_vars_invalid_revision_graceful(self): + """Test that invalid revision error is handled gracefully.""" + with utils_for_test.TempGitTree() as tempdir: + proj = self._get_project(tempdir) + proj.GetRevisionId.side_effect = error.ManifestInvalidRevisionError( + "revision not found" + ) + + env = proj.GetEnvVars() + + self.assertEqual(env["REPO_LREV"], "")