blob: 7fba29d034554fd71e7afb91a7c2aff640e99542 [file] [log] [blame]
#!/usr/bin/env python3
# git_utils.py
import subprocess
import sys
import os
from typing import List, Optional, Set, Dict, Tuple, Deque
from collections import deque
def run_command(
cmd: List[str],
check: bool = True,
**kwargs,
) -> subprocess.CompletedProcess:
"""Runs an external command, handles basic errors, ensures consistent output env."""
if not cmd:
raise ValueError("Command list empty.")
executable = cmd[0]
try:
env = kwargs.pop('env', {})
current_env = {**os.environ, 'LC_ALL': 'C', **env}
return subprocess.run(
cmd,
check=check,
capture_output=True,
text=True,
encoding='utf-8',
errors='replace',
env=current_env,
**kwargs)
except subprocess.CalledProcessError as e:
if check: # Only print/exit if check=True caused the error
print(f"Error running: {' '.join(cmd)}", file=sys.stderr)
if e.stderr:
print(f"Stderr:\n{e.stderr.strip()}", file=sys.stderr)
if e.stdout:
print(f"Stdout:\n{e.stdout.strip()}", file=sys.stderr)
sys.exit(e.returncode)
else:
raise # Re-raise if check=False so caller knows it failed
except FileNotFoundError:
print(f"Error: '{executable}' not found.", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Error running {' '.join(cmd)}: {e}", file=sys.stderr)
sys.exit(1)
def run_git_command(args: List[str],
check: bool = True,
**kwargs) -> subprocess.CompletedProcess:
"""Wrapper to run Git commands using run_command."""
return run_command(['git'] + args, check=check, **kwargs)
def get_current_branch() -> Optional[str]:
"""Gets the current Git branch name, returns None if detached HEAD."""
try:
result = run_git_command(['symbolic-ref', '--short', 'HEAD'], check=False)
if result.returncode == 0:
return result.stdout.strip()
result = run_git_command(['branch', '--show-current'], check=False)
if result.returncode == 0:
branch = result.stdout.strip()
return branch if branch else None
return None
except Exception:
return None
def get_branch_parent(branch_name: str) -> Optional[str]:
"""Gets the configured parent of a branch from git config."""
if not branch_name:
return None
result = run_git_command(['config', f'branch.{branch_name}.parent'],
check=False)
if result.returncode == 0 and result.stdout.strip():
return result.stdout.strip()
return None
def get_all_branches() -> List[str]:
"""Gets a list of all local branch names."""
result = run_git_command(
['for-each-ref', '--format=%(refname:short)', 'refs/heads/'])
return result.stdout.strip().split('\n') if result.stdout.strip() else []
def get_branch_children(parent_name: str,
all_local_branches: List[str]) -> List[str]:
"""Finds all branches listing parent_name as their configured parent."""
children = []
for branch in all_local_branches:
if get_branch_parent(branch) == parent_name:
children.append(branch)
return children
def get_upstream_branch_name(branch_name: str) -> Optional[str]:
"""
Gets the simple name (e.g., 'feature-x') of the upstream branch for the given local branch
using 'git rev-parse --abbrev-ref branch@{u}'.
Returns None if no upstream is configured or output is unexpected.
"""
# Use rev-parse with @{u} (short for @{upstream})
# check=False needed because it fails if no upstream is set
result = run_git_command(
['rev-parse', '--abbrev-ref', f'{branch_name}@{{u}}'], check=False)
if result.returncode == 0 and result.stdout.strip():
upstream_full_name = result.stdout.strip(
) # e.g., origin/main or origin/feature-x
parts = upstream_full_name.split('/', 1)
if len(parts) == 2 and parts[1]:
remote_name, base_name = parts
return base_name # Return just the branch name part
else:
print(
f"Warning: Unexpected upstream format '{upstream_full_name}' for branch '{branch_name}'. Cannot determine base name.",
file=sys.stderr)
return None
def get_ancestors(
start_branch: str,
mainline_branches: Set[str],
all_local_branches: List[str],
) -> List[str]:
"""
Traces direct lineage upwards via 'parent' config. Returns [parent, ..., base].
Raises ValueError if a cycle is detected in the path.
"""
ancestors: List[str] = []
current = start_branch
visited_trace = {start_branch}
while True:
parent = get_branch_parent(current)
if not parent:
break
if parent in mainline_branches:
break
if parent not in all_local_branches:
break
if parent in visited_trace:
path = [start_branch] + ancestors + [parent]
raise ValueError(
f"Cycle detected tracing ancestors: ... -> {parent} -> {current}")
ancestors.append(parent)
visited_trace.add(parent)
current = parent
return ancestors
def get_descendants(
start_branch: str,
all_local_branches: List[str],
) -> List[str]:
"""
Finds all descendants using BFS. Returns list (BFS level order).
Raises ValueError if a cycle is detected involving descendants.
"""
descendants: List[str] = []
direct_children = [
ch for ch in get_branch_children(start_branch, all_local_branches)
if ch in all_local_branches
]
queue: Deque[Tuple[str, List[str]]] = deque([
(child, [start_branch, child]) for child in direct_children
])
visited_bfs: Set[str] = set(direct_children)
processed_bfs: Set[str] = set()
while queue:
current_branch, path = queue.popleft()
if current_branch in processed_bfs:
continue
descendants.append(current_branch)
processed_bfs.add(current_branch)
grandchildren = get_branch_children(current_branch, all_local_branches)
for child in grandchildren:
if child in path:
raise ValueError(
f"Cycle detected tracing descendants: {' -> '.join(path)} -> {child}"
)
if child not in visited_bfs and child in all_local_branches:
visited_bfs.add(child)
new_path = path + [child]
queue.append((child, new_path))
return descendants
def get_connected_branches(
start_branch: str,
mainline_branches: Set[str],
all_local_branches: List[str],
) -> Set[str]:
"""
Finds the set of branches connected to start_branch (ancestors + start + descendants).
Raises ValueError if cycles are detected during traversal.
"""
if start_branch not in all_local_branches:
return set()
ancestors = get_ancestors(start_branch, mainline_branches, all_local_branches)
descendants = get_descendants(start_branch, all_local_branches)
connected = {start_branch}
connected.update(ancestors)
connected.update(descendants)
return connected
def get_stack_base(
start_branch: str,
mainline_branches: Set[str],
all_local_branches: List[str],
) -> str:
"""Finds the highest ancestor before a mainline branch or root."""
ancestors = get_ancestors(start_branch, mainline_branches,
all_local_branches) # Raises ValueError on cycle
base = ancestors[-1] if ancestors else start_branch
return base
def get_stack_branches_ordered(
start_branch: str,
mainline_branches: Set[str],
all_local_branches: List[str],
) -> List[str]:
"""Finds all branches in stack segment, ordered parent-first via BFS from base."""
stack_base = get_stack_base(start_branch, mainline_branches,
all_local_branches) # Raises ValueError on cycle
ordered_stack: List[str] = []
queue: Deque[str] = deque()
visited_bfs: Set[str] = set()
if stack_base in all_local_branches:
queue.append(stack_base)
visited_bfs.add(stack_base)
else:
return []
processed_bfs: Set[str] = set()
while queue:
current_branch = queue.popleft()
if current_branch in processed_bfs:
continue
ordered_stack.append(current_branch)
processed_bfs.add(current_branch)
children = get_branch_children(current_branch, all_local_branches)
for child in children:
if child not in visited_bfs and child in all_local_branches:
visited_bfs.add(child)
queue.append(child)
if start_branch not in visited_bfs and start_branch in all_local_branches:
print(
f"Warning: Target '{start_branch}' not reached from base '{stack_base}'. Config inconsistent?",
file=sys.stderr)
return ordered_stack
def topological_sort_branches() -> Tuple[List[str], Dict[str, Optional[str]]]:
"""
Performs a topological sort of branches based on parent config.
Returns the sorted list (parents before children) and the graph used.
Raises ValueError if cycles detected.
"""
all_branches = get_all_branches()
graph: Dict[str, Optional[str]] = {}
nodes = set()
for branch in all_branches:
parent = get_branch_parent(branch)
if parent:
graph[branch] = parent
nodes.add(branch)
nodes.add(parent)
else:
if any(get_branch_parent(b) == branch for b in all_branches):
graph[branch] = None
nodes.add(branch)
in_degree: Dict[str, int] = {node: 0 for node in nodes}
adj: Dict[str, List[str]] = {node: [] for node in nodes}
for child, parent in graph.items():
if parent in adj:
adj[parent].append(child)
in_degree[child] += 1
queue: Deque[str] = deque([node for node in nodes if in_degree[node] == 0])
sorted_order: List[str] = []
while queue:
parent = queue.popleft()
sorted_order.append(parent)
if parent in adj:
for child in sorted(adj[parent]):
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)
if len(sorted_order) != len(nodes):
remaining_nodes = {node for node, degree in in_degree.items() if degree > 0}
raise ValueError(
f"Cycle detected during topological sort involving: {remaining_nodes}")
return sorted_order, graph