diff --git a/pyproject.toml b/pyproject.toml index fa09b94..cfbfdcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,3 @@ -[build-system] -requires = ["pdm-backend"] -build-backend = "pdm.backend" - [project] name = "stack-pr" authors = [ @@ -20,20 +16,40 @@ classifiers = [ "Intended Audience :: Developers", "Topic :: Software Development :: Version Control :: Git", "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Environment :: Console", + "Topic :: Utilities", "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] # Version is dynamically set by pdm by the SCM version dynamic = ["version"] -dependencies = [] +dependencies = ["typing_extensions; python_version<\"3.13\""] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-mock", + "mypy", + "ruff", +] [project.urls] -Homepage = "https://github.com/modularml/stack-pr" -Repository = "https://github.com/modularml/stack-pr" -"Bug Tracker" = "https://github.com/modularml/stack-pr/issues" +Homepage = "https://github.com/modular/stack-pr" +Repository = "https://github.com/modular/stack-pr" +"Bug Tracker" = "https://github.com/modular/stack-pr/issues" [project.scripts] stack-pr = "stack_pr.cli:main" +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + [tool.pdm] distribution = true @@ -52,4 +68,134 @@ pdm = ">=2.17.1,<2.18" [tool.pixi.tasks] [tool.pixi.dependencies] -python = ">=3.8" +python = "3.9.*" + +[tool.ruff] + +# Same as Black. +line-length = 88 + +# Assume Python 3.9 +target-version = "py39" + +[tool.ruff.lint] +# Enable pycodestyle (`E`), Pyflakes (`F`), and isort (`I`) codes +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade + "N", # pep8-naming + "SIM", # flake8-simplify + "RUF", # ruff-specific rules + "W", # pycodestyle warnings + "YTT", # flake8-2020 + "ANN", # flake8-annotations + "S", # flake8-bandit + "BLE", # flake8-blind-except + "FBT", # flake8-boolean-trap + "A", # flake8-builtins + "C", # flake8-comprehensions + "DTZ", # flake8-datetimez + "T10", # flake8-debugger + "ISC", # flake8-implicit-str-concat + "G", # flake8-logging-format + "INP", # flake8-no-pep420 + "PIE", # flake8-pie + "T20", # flake8-print + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RSE", # flake8-raise + "RET", # flake8-return + "SLF", # flake8-self + "TID", # flake8-tidy-imports + "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + "ERA", # eradicate + "PD", # pandas-vet + "PGH", # pygrep-hooks + "PL", # pylint + "TRY", # tryceratops + "FA100", #future-rewritable-type-annotation + "PYI036", # bad-exit-annotation + # "COM812", # trailing-comma (this one makes ruff formater mad) +] + +# Ignore specific rules +ignore = [ + # This is just too complex to do anything about when invoking gh suprocess. + "S603", # subprocess call with untrusted input + + # We use some of the strings for output in the CLI and want specific formatting. + "TRY003", # Avoid specifying long messages outside exception class + + # We forward kwargs a lot and ruff doesn't like it. (more than 5) + "PLR0913", # Too many arguments + + # FIXME: We use print statements in the CLI instead of stderr/stdout (this may change) + "T201", # allow print statements + + # FIXME: Some of our strings are long for CLI output. We should refactor them. + "E501", # Line too long +] + +# Borrowing a rustism and allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.lint.isort] +known-first-party = ["stack_pr"] + +[tool.ruff.lint.mccabe] +# Unlike Flake8, default to a complexity level of 16. +max-complexity = 16 + +[tool.ruff.lint.per-file-ignores] +# Ignore unused imports in __init__.py files +"__init__.py" = ["F401"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 + +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" +inline-quotes = "double" +multiline-quotes = "double" + +[tool.ruff.lint.extend-per-file-ignores] +# Allow certain useful patterns in tests +"tests/**/*.py" = [ + "S101", # allow assert statements within if statements + "ARG", # allow unused arguments + "FBT", # allow boolean traps + "ANN401", # allow Any in tests + "T201", # allow print statements in tests + "PLR2004", # allow magic value comparisons in tests +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +testpaths = ["tests"] +python_files = ["test_*.py"] +addopts = "-v" + +[tool.mypy] +# Let's be strict. +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_optional = true diff --git a/src/stack_pr/__main__.py b/src/stack_pr/__main__.py index 0a5b791..f512a1e 100644 --- a/src/stack_pr/__main__.py +++ b/src/stack_pr/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from stack_pr.cli import main if __name__ == "__main__": diff --git a/src/stack_pr/cli.py b/src/stack_pr/cli.py index 45a9c3d..d068998 100755 --- a/src/stack_pr/cli.py +++ b/src/stack_pr/cli.py @@ -47,14 +47,21 @@ # # ===----------------------------------------------------------------------=== # +from __future__ import annotations + import argparse import configparser import json +import logging import os import re +import sys +from dataclasses import dataclass from functools import cache +from logging import getLogger +from pathlib import Path +from re import Pattern from subprocess import SubprocessError -from typing import List, NamedTuple, Optional, Pattern from stack_pr.git import ( branch_exists, @@ -66,9 +73,10 @@ from stack_pr.shell_commands import ( get_command_output, run_shell_command, - set_show_commands, ) +logger = getLogger(__name__) + # A bunch of regexps for parsing commit messages and PR descriptions RE_RAW_COMMIT_ID = re.compile(r"^(?P[a-f0-9]+)$", re.MULTILINE) RE_RAW_AUTHOR = re.compile( @@ -184,6 +192,7 @@ # ===----------------------------------------------------------------------=== # # Class to work with git commit contents # ===----------------------------------------------------------------------=== # +@dataclass class CommitHeader: """ Represents the information extracted from `git rev-list --header` @@ -192,12 +201,12 @@ class CommitHeader: # The unparsed output from git rev-list --header raw_header: str - def __init__(self, raw_header: str): - self.raw_header = raw_header - def _search_group(self, regex: Pattern[str], group: str) -> str: m = regex.search(self.raw_header) - assert m + if m is None: + raise ValueError( + f"Required field '{group}' not found in commit header: {self.raw_header}" + ) return m.group(group) def tree(self) -> str: @@ -209,7 +218,7 @@ def title(self) -> str: def commit_id(self) -> str: return self._search_group(RE_RAW_COMMIT_ID, "commit") - def parents(self) -> List[str]: + def parents(self) -> list[str]: return [m.group("commit") for m in RE_RAW_PARENT.finditer(self.raw_header)] def author(self) -> str: @@ -230,18 +239,18 @@ def commit_msg(self) -> str: # ===----------------------------------------------------------------------=== # # Class to work with PR stack entries # ===----------------------------------------------------------------------=== # +@dataclass class StackEntry: """ Represents an entry in a stack of PRs and contains associated info, such as linked PR, head and base branches, original git commit. """ - def __init__(self, commit: CommitHeader): - self.commit = commit - self._pr: Optional[str] = None - self._base: Optional[str] = None - self._head: Optional[str] = None - self.need_update: bool = False + commit: CommitHeader + _pr: str | None = None + _base: str | None = None + _head: str | None = None + need_update: bool = False @property def pr(self) -> str: @@ -250,7 +259,7 @@ def pr(self) -> str: return self._pr @pr.setter - def pr(self, pr: str): + def pr(self, pr: str) -> None: self._pr = pr def has_pr(self) -> bool: @@ -263,32 +272,30 @@ def head(self) -> str: return self._head @head.setter - def head(self, head: str): + def head(self, head: str) -> None: self._head = head def has_head(self) -> bool: return self._head is not None @property - def base(self) -> str: - if self._base is None: - raise ValueError("base is not set") + def base(self) -> str | None: return self._base @base.setter - def base(self, base: str): + def base(self, base: str | None) -> None: self._base = base + def has_base(self) -> bool: + return self._base is not None + def has_missing_info(self) -> bool: return None in (self._pr, self._head, self._base) - def pprint(self, links: bool): + def pprint(self, *, links: bool) -> str: s = b(self.commit.commit_id()[:8]) pr_string = None - if self.has_pr(): - pr_string = blue("#" + last(self.pr)) - else: - pr_string = red("no PR") + pr_string = blue("#" + last(self.pr)) if self.has_pr() else red("no PR") branch_string = None if self._head or self._base: head_str = green(self._head) if self._head else red(str(self._head)) @@ -309,10 +316,10 @@ def pprint(self, links: bool): return s - def __repr__(self): - return self.pprint(False) + def __repr__(self) -> str: + return self.pprint(links=False) - def read_metadata(self): + def read_metadata(self) -> None: self.commit.commit_msg() x = RE_STACK_INFO_LINE.search(self.commit.commit_msg()) if not x: @@ -326,7 +333,7 @@ def read_metadata(self): # ===----------------------------------------------------------------------=== # -class bcolors: +class ShellColors: HEADER = "\033[95m" OKBLUE = "\033[94m" OKCYAN = "\033[96m" @@ -338,28 +345,28 @@ class bcolors: UNDERLINE = "\033[4m" -def b(s: str): - return bcolors.BOLD + s + bcolors.ENDC +def b(s: str) -> str: + return ShellColors.BOLD + s + ShellColors.ENDC -def h(s: str): - return bcolors.HEADER + s + bcolors.ENDC +def h(s: str) -> str: + return ShellColors.HEADER + s + ShellColors.ENDC -def green(s: str): - return bcolors.OKGREEN + s + bcolors.ENDC +def green(s: str) -> str: + return ShellColors.OKGREEN + s + ShellColors.ENDC -def blue(s: str): - return bcolors.OKBLUE + s + bcolors.ENDC +def blue(s: str) -> str: + return ShellColors.OKBLUE + s + ShellColors.ENDC -def red(s: str): - return bcolors.FAIL + s + bcolors.ENDC +def red(s: str) -> str: + return ShellColors.FAIL + s + ShellColors.ENDC # https://gist.github.com/egmontkob/eb114294efbcd5adb1944c9f3cb5feda -def link(location: str, text: str): +def link(location: str, text: str) -> str: """ Emits a link to the terminal using the terminal hyperlink specification. @@ -368,19 +375,23 @@ def link(location: str, text: str): return f"\033]8;;{location}\033\\{text}\033]8;;\033\\" -def error(msg): +def error(msg: str) -> None: print(red("\nERROR: ") + msg) -# TODO: replace this with modular.utils.logging -def log(msg, level=0): - print(msg) +def log(msg: str, *, level: int = 1) -> None: + if level <= 1: + print(msg) + elif level == 1: + logger.info(msg) + elif level >= 2: # noqa: PLR2004 + logger.debug(msg) # ===----------------------------------------------------------------------=== # # Common utility functions # ===----------------------------------------------------------------------=== # -def split_header(s: str) -> List[CommitHeader]: +def split_header(s: str) -> list[CommitHeader]: return [CommitHeader(h) for h in s.split("\0")[:-1]] @@ -389,7 +400,7 @@ def last(ref: str, sep: str = "/") -> str: # TODO: Move to 'modular.utils.git' -def is_ancestor(commit1: str, commit2: str, verbose: bool) -> bool: +def is_ancestor(commit1: str, commit2: str, *, verbose: bool) -> bool: """ Returns true if 'commit1' is an ancestor of 'commit2'. """ @@ -412,16 +423,16 @@ def is_repo_clean() -> bool: return not bool(changes) -def get_stack(base: str, head: str, verbose: bool) -> List[StackEntry]: - if not is_ancestor(base, head, verbose): +def get_stack(base: str, head: str, *, verbose: bool) -> list[StackEntry]: + if not is_ancestor(base, head, verbose=verbose): error( f"{base} is not an ancestor of {head}.\n" "Could not find commits for the stack." ) - exit(1) + sys.exit(1) # Find list of commits since merge base. - st: List[StackEntry] = [] + st: list[StackEntry] = [] stack = ( split_header( get_command_output(["git", "rev-list", "--header", "^" + base, head]) @@ -437,14 +448,14 @@ def get_stack(base: str, head: str, verbose: bool) -> List[StackEntry]: return st -def set_base_branches(st: List[StackEntry], target: str): - prev_branch = target +def set_base_branches(st: list[StackEntry], target: str) -> None: + prev_branch: str | None = target for e in st: - e.base, prev_branch = prev_branch, e._head + e.base, prev_branch = prev_branch, e.head -def verify(st: List[StackEntry], check_base: bool = False): - log(h("Verifying stack info"), level=1) +def verify(st: list[StackEntry], *, check_base: bool = False) -> None: + log(h("Verifying stack info")) for index, e in enumerate(st): if e.has_missing_info(): error(ERROR_STACKINFO_MISSING.format(**locals())) @@ -500,13 +511,13 @@ def verify(st: List[StackEntry], check_base: bool = False): raise RuntimeError -def print_stack(st: List[StackEntry], links: bool, level=1): +def print_stack(st: list[StackEntry], *, links: bool, level: int = 1) -> None: log(b("Stack:"), level=level) for e in reversed(st): - log(" * " + e.pprint(links), level=level) + log(" * " + e.pprint(links=links), level=level) -def draft_bitmask_type(value: str) -> List[bool]: +def draft_bitmask_type(value: str) -> list[bool]: # Validate that only 0s and 1s are present if value and not set(value).issubset({"0", "1"}): raise argparse.ArgumentTypeError("Bitmask must only contain 0s and 1s.") @@ -518,19 +529,27 @@ def draft_bitmask_type(value: str) -> List[bool]: # ===----------------------------------------------------------------------=== # # SUBMIT # ===----------------------------------------------------------------------=== # -def add_or_update_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) -> bool: +def add_or_update_metadata(e: StackEntry, *, needs_rebase: bool, verbose: bool) -> bool: if needs_rebase: + if not e.has_base() or not e.has_head(): + error("Stack entry has no base or head branch") + raise RuntimeError + run_shell_command( [ "git", "rebase", - e.base, - e.head, + e.base or "", + e.head or "", "--committer-date-is-author-date", ], quiet=not verbose, ) else: + if not e.has_head(): + error("Stack entry has no head branch") + raise RuntimeError + run_shell_command(["git", "checkout", e.head], quiet=not verbose) commit_msg = e.commit.commit_msg() @@ -549,7 +568,7 @@ def add_or_update_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) -> return True -def fix_branch_name_template(branch_name_template: str): +def fix_branch_name_template(branch_name_template: str) -> str: if "$ID" not in branch_name_template: return f"{branch_name_template}/$ID" @@ -557,15 +576,14 @@ def fix_branch_name_template(branch_name_template: str): @cache -def get_branch_name_base(branch_name_template: str): +def get_branch_name_base(branch_name_template: str) -> str: username = get_gh_username() current_branch_name = get_current_branch_name() branch_name_base = branch_name_template.replace("$USERNAME", username) - branch_name_base = branch_name_base.replace("$BRANCH", current_branch_name) - return branch_name_base + return branch_name_base.replace("$BRANCH", current_branch_name) -def get_branch_id(branch_name_template: str, branch_name: str): +def get_branch_id(branch_name_template: str, branch_name: str) -> str | None: branch_name_base = get_branch_name_base(branch_name_template) pattern = branch_name_base.replace(r"$ID", r"(\d+)") match = re.search(pattern, branch_name) @@ -574,23 +592,21 @@ def get_branch_id(branch_name_template: str, branch_name: str): return None -def generate_branch_name(branch_name_template: str, branch_id: int): +def generate_branch_name(branch_name_template: str, branch_id: int) -> str: branch_name_base = get_branch_name_base(branch_name_template) - branch_name = branch_name_base.replace(r"$ID", branch_id) - return branch_name + return branch_name_base.replace(r"$ID", str(branch_id)) -def get_taken_branch_ids(refs: List[str], branch_name_template: str) -> List[int]: - branch_ids = list(get_branch_id(branch_name_template, ref) for ref in refs) - branch_ids = [int(branch_id) for branch_id in branch_ids if branch_id is not None] - return branch_ids +def get_taken_branch_ids(refs: list[str], branch_name_template: str) -> list[int]: + branch_ids = [get_branch_id(branch_name_template, ref) for ref in refs] + return [int(branch_id) for branch_id in branch_ids if branch_id is not None] -def generate_available_branch_name(refs: List[str], branch_name_template: str) -> str: +def generate_available_branch_name(refs: list[str], branch_name_template: str) -> str: branch_ids = get_taken_branch_ids(refs, branch_name_template) max_ref_num = max(branch_ids) if branch_ids else 0 new_branch_id = max_ref_num + 1 - return generate_branch_name(branch_name_template, str(new_branch_id)) + return generate_branch_name(branch_name_template, new_branch_id) def get_available_branch_name(remote: str, branch_name_template: str) -> str: @@ -606,18 +622,18 @@ def get_available_branch_name(remote: str, branch_name_template: str) -> str: ] ).split() - refs = list([ref.strip("'") for ref in refs]) + refs = [ref.strip("'") for ref in refs] return generate_available_branch_name(refs, branch_name_template) def get_next_available_branch_name(branch_name_template: str, name: str) -> str: - id = get_branch_id(branch_name_template, name) - return generate_branch_name(branch_name_template, str(int(id) + 1)) + branch_id = get_branch_id(branch_name_template, name) + return generate_branch_name(branch_name_template, int(branch_id or 0) + 1) def set_head_branches( - st: List[StackEntry], remote: str, verbose: bool, branch_name_template: str -): + st: list[StackEntry], remote: str, *, verbose: bool, branch_name_template: str +) -> None: """Set the head ref for each stack entry if it doesn't already have one.""" run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose) @@ -630,10 +646,12 @@ def set_head_branches( def init_local_branches( - st: List[StackEntry], remote: str, verbose: bool, branch_name_template: str -): - log(h("Initializing local branches"), level=1) - set_head_branches(st, remote, verbose, branch_name_template) + st: list[StackEntry], remote: str, *, verbose: bool, branch_name_template: str +) -> None: + log(h("Initializing local branches")) + set_head_branches( + st, remote, verbose=verbose, branch_name_template=branch_name_template + ) for e in st: run_shell_command( ["git", "checkout", e.commit.commit_id(), "-B", e.head], @@ -641,42 +659,50 @@ def init_local_branches( ) -def push_branches(st: List[StackEntry], remote, verbose: bool): - log(h("Updating remote branches"), level=1) +def push_branches(st: list[StackEntry], remote: str, *, verbose: bool) -> None: + log(h("Updating remote branches")) cmd = ["git", "push", "-f", remote] cmd.extend([f"{e.head}:{e.head}" for e in st]) run_shell_command(cmd, quiet=not verbose) -def print_cmd_failure_details(exc: SubprocessError): - cmd_stdout = ( - exc.stdout.decode("utf-8").replace("\\n", "\n").replace("\\t", "\t") - if exc.stdout - else None - ) - cmd_stderr = ( - exc.stderr.decode("utf-8").replace("\\n", "\n").replace("\\t", "\t") - if exc.stderr - else None - ) - print(f"Exitcode: {exc.returncode}") +def print_cmd_failure_details(exc: SubprocessError) -> None: + # Test if SubprocessError subclass has stdout and stderr attributes + if hasattr(exc, "stdout") and exc.stdout: + cmd_stdout = ( + exc.stdout.decode("utf-8").replace("\\n", "\n").replace("\\t", "\t") + ) + else: + cmd_stdout = None + + if hasattr(exc, "stderr") and exc.stderr: + cmd_stderr = ( + exc.stderr.decode("utf-8").replace("\\n", "\n").replace("\\t", "\t") + ) + else: + cmd_stderr = None + + print(f"Exitcode: {exc.returncode if hasattr(exc, 'returncode') else 'unknown'}") print(f"Stdout: {cmd_stdout}") print(f"Stderr: {cmd_stderr}") -def create_pr(e: StackEntry, is_draft: bool, reviewer: str = ""): +def create_pr(e: StackEntry, *, is_draft: bool, reviewer: str = "") -> None: # Don't do anything if the PR already exists if e.has_pr(): return + if not e.has_base() or not e.has_head(): + error("Stack entry has no base or head branch") + raise RuntimeError log(h("Creating PR " + green(f"'{e.head}' -> '{e.base}'")), level=1) cmd = [ "gh", "pr", "create", "-B", - e.base, + e.base or "", "-H", - e.head, + e.head or "", "-t", e.commit.title(), "-F", @@ -697,7 +723,7 @@ def create_pr(e: StackEntry, is_draft: bool, reviewer: str = ""): e.pr = r.split()[-1] -def generate_toc(st: List[StackEntry], current: str) -> str: +def generate_toc(st: list[StackEntry], current: str) -> str: def toc_entry(se: StackEntry) -> str: pr_id = last(se.pr) arrow = "__->__" if pr_id == current else "" @@ -707,14 +733,14 @@ def toc_entry(se: StackEntry) -> str: return f"Stacked PRs:\n{''.join(entries)}\n" -def get_current_pr_body(e: StackEntry): +def get_current_pr_body(e: StackEntry) -> str: out = get_command_output( ["gh", "pr", "view", e.pr, "--json", "body"], ) - return json.loads(out)["body"].strip() + return str(json.loads(out)["body"] or "").strip() -def add_cross_links(st: List[StackEntry], keep_body: bool, verbose: bool): +def add_cross_links(st: list[StackEntry], *, keep_body: bool, verbose: bool) -> None: for e in st: pr_id = last(e.pr) pr_toc = generate_toc(st, pr_id) @@ -745,11 +771,15 @@ def add_cross_links(st: List[StackEntry], keep_body: bool, verbose: bool): ] ) - run_shell_command( - ["gh", "pr", "edit", e.pr, "-t", title, "-F", "-", "-B", e.base], - input="\n".join(pr_body).encode(), - quiet=not verbose, - ) + if e.has_base(): + run_shell_command( + ["gh", "pr", "edit", e.pr, "-t", title, "-F", "-", "-B", e.base or ""], + input="\n".join(pr_body).encode(), + quiet=not verbose, + ) + else: + error("Stack entry has no base branch") + raise RuntimeError # Temporarily set base branches of existing PRs to the bottom of the stack. @@ -775,7 +805,9 @@ def add_cross_links(st: List[StackEntry], keep_body: bool, verbose: bool): # # To avoid this, we temporarily set all base branches to point to 'main' - once # all the branches are pushed we can set the actual base branches. -def reset_remote_base_branches(st: List[StackEntry], target: str, verbose: bool): +def reset_remote_base_branches( + st: list[StackEntry], target: str, *, verbose: bool +) -> None: log(h("Resetting remote base branches"), level=1) for e in filter(lambda e: e.has_pr(), st): @@ -794,23 +826,24 @@ def reset_remote_base_branches(st: List[StackEntry], target: str, verbose: bool) # base (e.g. explicit hash of the commit) - but most probably nobody ever would # need that. def should_update_local_base( - head: str, base: str, remote: str, target: str, verbose: bool -): + head: str, base: str, remote: str, target: str, *, verbose: bool +) -> bool: base_hash = get_command_output(["git", "rev-parse", base]) target_hash = get_command_output(["git", "rev-parse", f"{remote}/{target}"]) return ( - is_ancestor(base, f"{remote}/{target}", verbose) - and is_ancestor(f"{remote}/{target}", head, verbose) + is_ancestor(base, f"{remote}/{target}", verbose=verbose) + and is_ancestor(f"{remote}/{target}", head, verbose=verbose) and base_hash != target_hash ) -def update_local_base(base: str, remote: str, target: str, verbose: bool): +def update_local_base(base: str, remote: str, target: str, *, verbose: bool) -> None: log(h(f"Updating local branch {base} to {remote}/{target}"), level=1) run_shell_command(["git", "rebase", f"{remote}/{target}", base], quiet=not verbose) -class CommonArgs(NamedTuple): +@dataclass +class CommonArgs: """Class to help type checkers and separate implementation for CLI args.""" base: str @@ -822,7 +855,7 @@ class CommonArgs(NamedTuple): branch_name_template: str @classmethod - def from_args(cls, args: argparse.Namespace) -> "CommonArgs": + def from_args(cls, args: argparse.Namespace) -> CommonArgs: return cls( args.base, args.head, @@ -834,18 +867,21 @@ def from_args(cls, args: argparse.Namespace) -> "CommonArgs": ) -# If the base isn't explicitly specified, find the merge base between -# 'origin/main' and 'head'. -# -# E.g. in the example below we want to include commits E and F into the stack, -# and to do that we pick B as our base: -# -# --> a ----> b ----> c ----> d -# (main) \ (origin/main) -# \ -# ---> e ----> f -# (head) def deduce_base(args: CommonArgs) -> CommonArgs: + """Deduce the base branch from the head and target branches. + + If the base isn't explicitly specified, find the merge base between + 'origin/main' and 'head'. + + E.g. in the example below we want to include commits E and F into the stack, + and to do that we pick B as our base: + + --> a ----> b ----> c ----> d + (main) \\ (origin/main) + \\ + ---> e ----> f + (head) + """ if args.base: return args deduced_base = get_command_output( @@ -862,7 +898,7 @@ def deduce_base(args: CommonArgs) -> CommonArgs: ) -def print_tips_after_export(st: List[StackEntry], args: CommonArgs): +def print_tips_after_export(st: list[StackEntry], args: CommonArgs) -> None: stack_size = len(st) if stack_size == 0: return @@ -880,56 +916,76 @@ def print_tips_after_export(st: List[StackEntry], args: CommonArgs): # ===----------------------------------------------------------------------=== # def command_submit( args: CommonArgs, + *, draft: bool, reviewer: str, keep_body: bool, - draft_bitmask: List[bool] = None, -): + draft_bitmask: list[bool] | None = None, +) -> None: + """Entry point for 'submit' command. + + Args: + args: CommonArgs object containing command line arguments. + draft: Boolean flag indicating if the PRs should be created as drafts. + reviewer: String representing the reviewer of the PRs. + keep_body: Boolean flag indicating if the body of the PRs should be kept. + draft_bitmask: List of boolean values indicating if each PR should be created as + a draft. + """ log(h("SUBMIT"), level=1) - current_branch = get_current_branch_name() if should_update_local_base( - args.head, args.base, args.remote, args.target, args.verbose + head=args.head, + base=args.base, + remote=args.remote, + target=args.target, + verbose=args.verbose, ): - update_local_base(args.base, args.remote, args.target, args.verbose) + update_local_base( + base=args.base, remote=args.remote, target=args.target, verbose=args.verbose + ) run_shell_command(["git", "checkout", current_branch], quiet=not args.verbose) # Determine what commits belong to the stack - st = get_stack(args.base, args.head, args.verbose) + st = get_stack(base=args.base, head=args.head, verbose=args.verbose) if not st: - log(h("Empty stack!"), level=1) - log(h(blue("SUCCESS!")), level=1) + log(h("Empty stack!")) + log(h(blue("SUCCESS!"))) return if (draft_bitmask is not None) and (len(draft_bitmask) != len(st)): - log( - h("Draft bitmask passed to 'submit' doesn't match number of PRs!"), - level=1, - ) + log(h("Draft bitmask passed to 'submit' doesn't match number of PRs!")) return # Create local branches and initialize base and head fields in the stack # elements - init_local_branches(st, args.remote, args.verbose, args.branch_name_template) + init_local_branches( + st, + args.remote, + verbose=args.verbose, + branch_name_template=args.branch_name_template, + ) set_base_branches(st, args.target) - print_stack(st, args.hyperlinks) + print_stack(st, links=args.hyperlinks) # If the current branch contains commits from the stack, we will need to # rebase it in the end since the commits will be modified. top_branch = st[-1].head - need_to_rebase_current = is_ancestor(top_branch, current_branch, args.verbose) + need_to_rebase_current = is_ancestor( + top_branch, current_branch, verbose=args.verbose + ) - reset_remote_base_branches(st, args.target, args.verbose) + reset_remote_base_branches(st, target=args.target, verbose=args.verbose) # Push local branches to remote - push_branches(st, args.remote, args.verbose) + push_branches(st, remote=args.remote, verbose=args.verbose) # Now we have all the branches, so we can create the corresponding PRs log(h("Submitting PRs"), level=1) for e_idx, e in enumerate(st): is_pr_draft = draft or ((draft_bitmask is not None) and draft_bitmask[e_idx]) - create_pr(e, is_pr_draft, reviewer) + create_pr(e, is_draft=is_pr_draft, reviewer=reviewer) # Verify consistency in everything we have so far verify(st) @@ -939,15 +995,17 @@ def command_submit( needs_rebase = False for e in st: try: - needs_rebase = add_or_update_metadata(e, needs_rebase, args.verbose) + needs_rebase = add_or_update_metadata( + e, needs_rebase=needs_rebase, verbose=args.verbose + ) except Exception: error(ERROR_CANT_UPDATE_META.format(**locals())) raise - push_branches(st, args.remote, args.verbose) + push_branches(st, remote=args.remote, verbose=args.verbose) log(h("Adding cross-links to PRs"), level=1) - add_cross_links(st, keep_body, args.verbose) + add_cross_links(st, keep_body=keep_body, verbose=args.verbose) if need_to_rebase_current: log(h(f"Rebasing the original branch '{current_branch}'"), level=1) @@ -965,7 +1023,7 @@ def command_submit( log(h(f"Checking out the original branch '{current_branch}'"), level=1) run_shell_command(["git", "checkout", current_branch], quiet=not args.verbose) - delete_local_branches(st, args.verbose) + delete_local_branches(st, verbose=args.verbose) print_tips_after_export(st, args) log(h(blue("SUCCESS!")), level=1) @@ -973,8 +1031,8 @@ def command_submit( # ===----------------------------------------------------------------------=== # # LAND # ===----------------------------------------------------------------------=== # -def rebase_pr(e: StackEntry, remote: str, target: str, verbose: bool): - log(b("Rebasing ") + e.pprint(False), level=2) +def rebase_pr(e: StackEntry, remote: str, target: str, *, verbose: bool) -> None: + log(b("Rebasing ") + e.pprint(links=False), level=2) # Rebase the head branch to the most recent 'origin/main' run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose) cmd = ["git", "checkout", f"{remote}/{e.head}", "-B", e.head] @@ -1001,8 +1059,8 @@ def rebase_pr(e: StackEntry, remote: str, target: str, verbose: bool): ) -def land_pr(e: StackEntry, remote: str, target: str, verbose: bool): - log(b("Landing ") + e.pprint(False), level=2) +def land_pr(e: StackEntry, remote: str, target: str, *, verbose: bool) -> None: + log(b("Landing ") + e.pprint(links=False), level=2) # Rebase the head branch to the most recent 'origin/main' run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose) cmd = ["git", "checkout", f"{remote}/{e.head}", "-B", e.head] @@ -1032,7 +1090,7 @@ def land_pr(e: StackEntry, remote: str, target: str, verbose: bool): ) -def delete_local_branches(st: List[StackEntry], verbose: bool): +def delete_local_branches(st: list[StackEntry], *, verbose: bool) -> None: log(h("Deleting local branches"), level=1) # Delete local branches cmd = ["git", "branch", "-D"] @@ -1041,8 +1099,8 @@ def delete_local_branches(st: List[StackEntry], verbose: bool): def delete_remote_branches( - st: List[StackEntry], remote: str, verbose: bool, branch_name_template: str -): + st: list[StackEntry], remote: str, *, verbose: bool, branch_name_template: str +) -> None: log(h("Deleting remote branches"), level=1) run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose) @@ -1067,19 +1125,25 @@ def delete_remote_branches( # ===----------------------------------------------------------------------=== # # Entry point for 'land' command # ===----------------------------------------------------------------------=== # -def command_land(args: CommonArgs): +def command_land(args: CommonArgs) -> None: log(h("LAND"), level=1) current_branch = get_current_branch_name() if should_update_local_base( - args.head, args.base, args.remote, args.target, args.verbose + head=args.head, + base=args.base, + remote=args.remote, + target=args.target, + verbose=args.verbose, ): - update_local_base(args.base, args.remote, args.target, args.verbose) + update_local_base( + base=args.base, remote=args.remote, target=args.target, verbose=args.verbose + ) run_shell_command(["git", "checkout", current_branch], quiet=not args.verbose) # Determine what commits belong to the stack - st = get_stack(args.base, args.head, args.verbose) + st = get_stack(base=args.base, head=args.head, verbose=args.verbose) if not st: log(h("Empty stack!"), level=1) log(h(blue("SUCCESS!")), level=1) @@ -1089,21 +1153,21 @@ def command_land(args: CommonArgs): # already be there from the metadata that commits need to have by that # point. set_base_branches(st, args.target) - print_stack(st, args.hyperlinks) + print_stack(st, links=args.hyperlinks) # Verify that the stack is correct before trying to land it. verify(st, check_base=True) # All good, land the bottommost PR! - land_pr(st[0], args.remote, args.target, args.verbose) + land_pr(st[0], remote=args.remote, target=args.target, verbose=args.verbose) # The rest of the stack now needs to be rebased. if len(st) > 1: log(h("Rebasing the rest of the stack"), level=1) prs_to_rebase = st[1:] - print_stack(prs_to_rebase, args.hyperlinks) + print_stack(prs_to_rebase, links=args.hyperlinks, level=1) for e in prs_to_rebase: - rebase_pr(e, args.remote, args.target, args.verbose) + rebase_pr(e, remote=args.remote, target=args.target, verbose=args.verbose) # Change the target of the new bottom-most PR in the stack to 'target' run_shell_command( ["gh", "pr", "edit", prs_to_rebase[0].pr, "-B", args.target], @@ -1113,7 +1177,7 @@ def command_land(args: CommonArgs): # Delete local and remote stack branches run_shell_command(["git", "checkout", current_branch], quiet=not args.verbose) - delete_local_branches(st, args.verbose) + delete_local_branches(st, verbose=args.verbose) # If local branch {target} exists, rebase it on the remote/target if branch_exists(args.target): @@ -1126,23 +1190,45 @@ def command_land(args: CommonArgs): quiet=not args.verbose, ) - log(h(blue("SUCCESS!")), level=1) + log(h(blue("SUCCESS!"))) # ===----------------------------------------------------------------------=== # # ABANDON # ===----------------------------------------------------------------------=== # -def strip_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) -> str: +def strip_metadata(e: StackEntry, *, needs_rebase: bool, verbose: bool) -> str: + """Strip the stack metadata from the commit message and amend the commit. + + Args: + e: StackEntry object representing the commit to strip metadata from. + needs_rebase: Boolean flag indicating if the commit needs to be rebased. + verbose: Boolean flag indicating if verbose output should be printed. + + Returns: + The SHA of the commit after stripping the metadata. + """ m = e.commit.commit_msg() m = RE_STACK_INFO_LINE.sub("", m) if needs_rebase: + if not e.has_base() or not e.has_head(): + error("Stack entry has no base or head branch") + raise RuntimeError run_shell_command( - ["git", "rebase", e.base, e.head, "--committer-date-is-author-date"], + [ + "git", + "rebase", + e.base or "", + e.head or "", + "--committer-date-is-author-date", + ], quiet=not verbose, ) else: - run_shell_command(["git", "checkout", e.head], quiet=not verbose) + if not e.has_head(): + error("Stack entry has no head branch") + raise RuntimeError + run_shell_command(["git", "checkout", e.head or ""], quiet=not verbose) run_shell_command( ["git", "commit", "--amend", "-F", "-"], @@ -1156,20 +1242,25 @@ def strip_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) -> str: # ===----------------------------------------------------------------------=== # # Entry point for 'abandon' command # ===----------------------------------------------------------------------=== # -def command_abandon(args: CommonArgs): - log(h("ABANDON"), level=1) - st = get_stack(args.base, args.head, args.verbose) +def command_abandon(args: CommonArgs) -> None: + log(h("ABANDON")) + st = get_stack(base=args.base, head=args.head, verbose=args.verbose) if not st: - log(h("Empty stack!"), level=1) - log(h(blue("SUCCESS!")), level=1) + log(h("Empty stack!")) + log(h(blue("SUCCESS!"))) return current_branch = get_current_branch_name() - init_local_branches(st, args.remote, args.verbose, args.branch_name_template) + init_local_branches( + st, + remote=args.remote, + verbose=args.verbose, + branch_name_template=args.branch_name_template, + ) set_base_branches(st, args.target) - print_stack(st, args.hyperlinks) + print_stack(st, links=args.hyperlinks) - log(h("Stripping stack metadata from commit messages"), level=1) + log(h("Stripping stack metadata from commit messages")) last_hash = "" # The first commit doesn't need to be rebased since its will not change. @@ -1177,83 +1268,90 @@ def command_abandon(args: CommonArgs): # changed as we strip the metadata from the commit messages. need_rebase = False for e in st: - last_hash = strip_metadata(e, need_rebase, args.verbose) + last_hash = strip_metadata(e, needs_rebase=need_rebase, verbose=args.verbose) need_rebase = True - log(h("Rebasing the current branch on top of updated top branch"), level=1) + log(h("Rebasing the current branch on top of updated top branch")) run_shell_command( ["git", "rebase", last_hash, current_branch], quiet=not args.verbose ) - delete_local_branches(st, args.verbose) - delete_remote_branches(st, args.remote, args.verbose, args.branch_name_template) - log(h(blue("SUCCESS!")), level=1) + delete_local_branches(st, verbose=args.verbose) + delete_remote_branches( + st, + remote=args.remote, + verbose=args.verbose, + branch_name_template=args.branch_name_template, + ) + log(h(blue("SUCCESS!"))) # ===----------------------------------------------------------------------=== # # VIEW # ===----------------------------------------------------------------------=== # -def print_tips_after_view(st: List[StackEntry], args: CommonArgs): +def print_tips_after_view(st: list[StackEntry], args: CommonArgs) -> None: stack_size = len(st) if stack_size == 0: return - ready_to_land = all([not e.has_missing_info() for e in st]) + ready_to_land = all(not e.has_missing_info() for e in st) top_commit = args.head if top_commit == "HEAD": top_commit = get_current_branch_name() if ready_to_land: - log(b("\nThis stack is ready to land!"), level=1) + log(b("\nThis stack is ready to land!")) log(UPDATE_STACK_TIP.format(**locals())) log(LAND_STACK_TIP.format(**locals())) return # Stack is not ready to land, suggest exporting it first - log( - b("\nThis stack can't be landed yet, you need to export it first."), - level=1, - ) + log(b("\nThis stack can't be landed yet, you need to export it first.")) log(EXPORT_STACK_TIP.format(**locals())) # ===----------------------------------------------------------------------=== # # Entry point for 'view' command # ===----------------------------------------------------------------------=== # -def command_view(args: CommonArgs): - log(h("VIEW"), level=1) +def command_view(args: CommonArgs) -> None: + log(h("VIEW")) if should_update_local_base( - args.head, args.base, args.remote, args.target, args.verbose + head=args.head, + base=args.base, + remote=args.remote, + target=args.target, + verbose=args.verbose, ): log( red( f"\nWarning: Local '{args.base}' is behind" f" '{args.remote}/{args.target}'!" ), - level=1, ) log( ("Consider updating your local branch by running the following commands:"), - level=1, ) log( b(f" git rebase {args.remote}/{args.target} {args.base}"), - level=1, ) log( b(f" git checkout {get_current_branch_name()}\n"), - level=1, ) - st = get_stack(args.base, args.head, args.verbose) + st = get_stack(base=args.base, head=args.head, verbose=args.verbose) - set_head_branches(st, args.remote, args.verbose, args.branch_name_template) - set_base_branches(st, args.target) - print_stack(st, args.hyperlinks) + set_head_branches( + st, + remote=args.remote, + verbose=args.verbose, + branch_name_template=args.branch_name_template, + ) + set_base_branches(st, target=args.target) + print_stack(st, links=args.hyperlinks) print_tips_after_view(st, args) - log(h(blue("SUCCESS!")), level=1) + log(h(blue("SUCCESS!"))) # ===----------------------------------------------------------------------=== # @@ -1362,14 +1460,14 @@ def create_argparser( return parser -def load_config(config_file): +def load_config(config_file: str) -> configparser.ConfigParser: config = configparser.ConfigParser() - if os.path.isfile(config_file): + if Path(config_file).is_file(): config.read(config_file) return config -def main(): +def main() -> None: # noqa: PLR0912 config_file = os.getenv("STACKPR_CONFIG", ".stack-pr.cfg") config = load_config(config_file) @@ -1386,8 +1484,7 @@ def main(): common_args = CommonArgs.from_args(args) if common_args.verbose: - # Output shell commands that we run if verbose=True - set_show_commands(True) + logger.setLevel(logging.DEBUG) check_gh_installed() @@ -1405,9 +1502,9 @@ def main(): if args.command in ["submit", "export"]: command_submit( common_args, - args.draft, - args.reviewer, - args.keep_body, + draft=args.draft, + reviewer=args.reviewer, + keep_body=args.keep_body, draft_bitmask=args.draft_bitmask, ) elif args.command == "land": @@ -1417,7 +1514,8 @@ def main(): elif args.command == "view": command_view(common_args) else: - raise Exception(f"Unknown command {args.command}") + print(h(red("Unknown command: " + args.command))) + return except Exception as exc: # If something failed, checkout the original branch run_shell_command( diff --git a/src/stack_pr/git.py b/src/stack_pr/git.py index 9dd45af..3cb6ce3 100644 --- a/src/stack_pr/git.py +++ b/src/stack_pr/git.py @@ -1,41 +1,39 @@ +from __future__ import annotations + import re -import shutil import string import subprocess +from collections.abc import Sequence +from dataclasses import dataclass from pathlib import Path -from typing import Dict, Optional, Sequence, Set -from .shell_commands import get_command_output, run_shell_command +from stack_pr.shell_commands import get_command_output, run_shell_command class GitError(Exception): pass -username_override = None +# Git constants +GIT_NOT_A_REPO_ERROR = 128 +GIT_SHA_LENGTH = 40 -def override_username(username: str): - """Override username for testing purposes. Call with None to reset.""" - global username_override - username_override = username +@dataclass +class GitConfig: + """ + Configuration for git operations. + """ + username_override: str | None = None -def fetch_checkout_commit( - repo_dir: Path, ref: str, quiet: bool, remote: str = "origin" -): - """Helper function to quickly fetch and checkout a new ref. + def set_username_override(self, username: str | None) -> None: + """Override username for testing purposes. Call with None to reset.""" + self.username_override = username - Args: - repo_dir: path to an existing git repository. - ref: a tag, brach, or (full) commit SHA. - remote: git remote to use. Default: "origin". - """ - run_shell_command( - ["git", "fetch", "--depth=1", remote, ref], cwd=repo_dir, quiet=quiet - ) - run_shell_command(["git", "checkout", "FETCH_HEAD"], cwd=repo_dir, quiet=quiet) +# Create a singleton instance +git_config = GitConfig() def is_full_git_sha(s: str) -> bool: @@ -44,50 +42,14 @@ def is_full_git_sha(s: str) -> bool: The string needs to consist of 40 lowercase hex characters. """ - if len(s) != 40: + if len(s) != GIT_SHA_LENGTH: return False digits = set(string.hexdigits.lower()) return all(c in digits for c in s) -def shallow_clone( - clone_dir: Path, url: str, ref: str, quiet: bool, remove_git: bool = False -): - """Clone the given repo without any git history. - - This makes the cloning faster for repos with large histories. - - Args: - clone_dir: path to the new clone directory. It is created if it doesn't - already exist. - url: repository url to clone from. - ref: a tag, brach, or (full) commit SHA. - remove_git: remove the .git directory after cloning. - - Raises: - FileExistsError: if clone_dir exists and is not an empty directory. - """ - - if clone_dir.exists(): - if not clone_dir.is_dir() or any(clone_dir.iterdir()): - raise FileExistsError( - f"Clone directory already exists and is not empty: {clone_dir}" - ) - else: - clone_dir.mkdir(parents=True) - - run_shell_command(["git", "init"], cwd=clone_dir, quiet=quiet) - run_shell_command( - ["git", "remote", "add", "origin", url], cwd=clone_dir, quiet=quiet - ) - fetch_checkout_commit(clone_dir, ref, quiet) - - if remove_git: - shutil.rmtree(clone_dir / ".git") - - -def branch_exists(branch: str, repo_dir: Optional[Path] = None) -> bool: +def branch_exists(branch: str, repo_dir: Path | None = None) -> bool: """Returns whether a branch with the given name exists. Args: @@ -114,7 +76,7 @@ def branch_exists(branch: str, repo_dir: Optional[Path] = None) -> bool: raise GitError("Not inside a valid git repository.") -def get_current_branch_name(repo_dir: Optional[Path] = None) -> str: +def get_current_branch_name(repo_dir: Path | None = None) -> str: """Returns the name of the branch currently checked out. Args: @@ -134,14 +96,14 @@ def get_current_branch_name(repo_dir: Optional[Path] = None) -> str: ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=repo_dir ).strip() except subprocess.CalledProcessError as e: - if e.returncode == 128: + if e.returncode == GIT_NOT_A_REPO_ERROR: raise GitError("Not inside a valid git repository.") from e raise def get_uncommitted_changes( - repo_dir: Optional[Path] = None, -) -> Dict[str, Sequence[str]]: + repo_dir: Path | None = None, +) -> dict[str, list[str]]: """Return a dictionary of uncommitted changes. Args: @@ -159,11 +121,11 @@ def get_uncommitted_changes( try: out = get_command_output(["git", "status", "--porcelain"], cwd=repo_dir) except subprocess.CalledProcessError as e: - if e.returncode == 128: + if e.returncode == GIT_NOT_A_REPO_ERROR: raise GitError("Not inside a valid git repository.") from None raise - changes = {} + changes: dict[str, list[str]] = {} for line in out.splitlines(): # First two chars are the status, changed path starts at 4th character. changes.setdefault(line[:2], []).append(line[3:]) @@ -171,7 +133,7 @@ def get_uncommitted_changes( # TODO: enforce this as a module dependency -def check_gh_installed(): +def check_gh_installed() -> None: """Check if the gh tool is installed. Raises: @@ -187,7 +149,6 @@ def check_gh_installed(): ) from err -# TODO: figure out how to test this def get_gh_username() -> str: """Return the current github username. @@ -197,10 +158,10 @@ def get_gh_username() -> str: Current github username as a string. Raises: - GitError: if called outside a git repo, or. + GitError: if called outside a git repo. """ - if username_override is not None: - return username_override + if git_config.username_override is not None: + return git_config.username_override user_query = get_command_output( [ @@ -223,7 +184,7 @@ def get_gh_username() -> str: def get_changed_files( - base: Optional[str] = None, repo_dir: Optional[Path] = None + base: str | None = None, repo_dir: Path | None = None ) -> Sequence[Path]: """Get the list of files changed between this commit and the base commit. @@ -242,8 +203,8 @@ def get_changed_files( def get_changed_dirs( - base: Optional[str] = None, repo_dir: Optional[Path] = None -) -> Set[Path]: + base: str | None = None, repo_dir: Path | None = None +) -> set[Path]: """Get the list of top-level directories changed between this commit and the base commit. diff --git a/src/stack_pr/py.typed b/src/stack_pr/py.typed new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/src/stack_pr/py.typed @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/stack_pr/shell_commands.py b/src/stack_pr/shell_commands.py index 8481de5..ac12e30 100644 --- a/src/stack_pr/shell_commands.py +++ b/src/stack_pr/shell_commands.py @@ -1,19 +1,31 @@ +from __future__ import annotations + import subprocess +import sys +from collections.abc import Iterable +from logging import getLogger from pathlib import Path -from typing import Any, Iterable, Union -ShellCommand = Iterable[Union[str, Path]] +if sys.version_info >= (3, 13): + # Unpack moved to typing + from typing import Any, Union +else: + from typing import Union -SHOW_COMMANDS = False + from typing_extensions import Any -def set_show_commands(val: bool): - global SHOW_COMMANDS - SHOW_COMMANDS = val +logger = getLogger(__name__) + +ShellCommand = Iterable[Union[str, Path]] def run_shell_command( - cmd: ShellCommand, *, quiet: bool, check: bool = True, **kwargs: Any + cmd: ShellCommand, + *, + quiet: bool, + check: bool = True, + **kwargs: Any, # noqa: ANN401 ) -> subprocess.CompletedProcess: """Runs a shell command using the arguments provided. @@ -32,15 +44,16 @@ def run_shell_command( if "shell" in kwargs: raise ValueError("shell support has been removed") _ = subprocess.list2cmdline(cmd) - kwargs.update({"check": check}) if quiet: kwargs.update({"stdout": subprocess.DEVNULL, "stderr": subprocess.DEVNULL}) - if SHOW_COMMANDS: - print(f"Running: {cmd}") - return subprocess.run(list(map(str, cmd)), **kwargs) + logger.debug("Running: %s", cmd) + return subprocess.run(list(map(str, cmd)), **kwargs, check=check) -def get_command_output(cmd: ShellCommand, **kwargs: Any) -> str: +def get_command_output( + cmd: ShellCommand, + **kwargs: Any, # noqa: ANN401 +) -> str: """A wrapper over run_shell_command that captures stdout into a string. Args: @@ -57,4 +70,4 @@ def get_command_output(cmd: ShellCommand, **kwargs: Any) -> str: if "capture_output" in kwargs: raise ValueError("Cannot pass capture_output when using get_command_output") proc = run_shell_command(cmd, capture_output=True, quiet=False, **kwargs) - return proc.stdout.decode("utf-8").rstrip() + return str(proc.stdout.decode("utf-8").rstrip()) diff --git a/tests/test_misc.py b/tests/test_misc.py index 6fc22b8..93d8ee9 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -3,56 +3,58 @@ sys.path.append(str(Path(__file__).parent.parent / "src")) -from stack_pr.git import override_username +import pytest + from stack_pr.cli import ( - get_branch_id, + generate_available_branch_name, generate_branch_name, - get_taken_branch_ids, + get_branch_id, get_gh_username, - generate_available_branch_name, + get_taken_branch_ids, ) - -import pytest +from stack_pr.git import git_config @pytest.fixture(scope="module") -def username(): - override_username("TestBot") +def username() -> str: + git_config.set_username_override("TestBot") return get_gh_username() @pytest.mark.parametrize( - "template,branch_name,expected", + ("template", "branch_name", "expected"), [ ("feature-$ID-desc", "feature-123-desc", "123"), ("$USERNAME/stack/$ID", "{username}/stack/99", "99"), ("$USERNAME/stack/$ID", "refs/remote/origin/{username}/stack/99", "99"), ], ) -def test_get_branch_id(username, template, branch_name, expected): +def test_get_branch_id( + username: str, template: str, branch_name: str, expected: str +) -> None: branch_name = branch_name.format(username=username) assert get_branch_id(template, branch_name) == expected @pytest.mark.parametrize( - "template,branch_name", + ("template", "branch_name"), [ ("feature/$ID/desc", "feature/abc/desc"), ("feature/$ID/desc", "wrong/format"), ("$USERNAME/stack/$ID", "{username}/main/99"), ], ) -def test_get_branch_id_no_match(username, template, branch_name): +def test_get_branch_id_no_match(username: str, template: str, branch_name: str) -> None: branch_name = branch_name.format(username=username) assert get_branch_id(template, branch_name) is None -def test_generate_branch_name(): +def test_generate_branch_name() -> None: template = "feature/$ID/description" - assert generate_branch_name(template, "123") == "feature/123/description" + assert generate_branch_name(template, 123) == "feature/123/description" -def test_get_taken_branch_ids(): +def test_get_taken_branch_ids() -> None: template = "$USERNAME/stack/$ID" refs = [ "refs/remotes/origin/TestBot/stack/104", @@ -71,7 +73,7 @@ def test_get_taken_branch_ids(): assert get_taken_branch_ids(refs, template) == [104, 134] -def test_generate_available_branch_name(): +def test_generate_available_branch_name() -> None: template = "$USERNAME/stack/$ID" refs = [ "refs/remotes/origin/TestBot/stack/104",