Skip to content

Commit c077564

Browse files
[update-checkout] finish adding type hints
1 parent 713ee94 commit c077564

File tree

5 files changed

+61
-48
lines changed

5 files changed

+61
-48
lines changed

utils/update_checkout/tests/test_locked_repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def _update_arguments_with_fake_path(repo_name: str, path: str) -> UpdateArgumen
1616
reset_to_remote=False,
1717
clean=False,
1818
stash=False,
19-
cross_repos_pr=False,
19+
cross_repos_pr={},
2020
output_prefix="",
2121
verbose=False,
2222
)

utils/update_checkout/update_checkout/git_command.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636

3737
def __str__(self):
3838
return (
39-
f"[{self.repo_name}] {Git._quote_command(self.command)} "
39+
f"[{self.repo_name}] '{Git._quote_command(self.command)}' "
4040
f"returned ({self.returncode}) with the following {self.stderr}."
4141
)
4242

utils/update_checkout/update_checkout/parallel_runner.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import sys
22
from multiprocessing import cpu_count
33
import time
4-
from typing import Callable, List, Any, Tuple, Union
4+
from typing import Callable, List, Any, Optional, Tuple, Union
55
from threading import Lock, Thread, Event
66
from concurrent.futures import ThreadPoolExecutor
77
import shutil
88

9-
from swift.utils.update_checkout.update_checkout.git_command import GitException
9+
from .git_command import GitException
1010

11-
from .runner_arguments import RunnerArguments, AdditionalSwiftSourcesArguments
11+
from .runner_arguments import RunnerArguments, AdditionalSwiftSourcesArguments, UpdateArguments
1212

1313

1414
class TaskTracker:
@@ -52,32 +52,38 @@ def done_task_counter(self) -> int:
5252
class MonitoredFunction:
5353
def __init__(
5454
self,
55-
fn: Callable,
55+
fn: Callable[..., Union[Exception]],
5656
task_tracker: TaskTracker,
5757
):
58-
self.fn = fn
58+
self._fn = fn
5959
self._task_tracker = task_tracker
6060

6161
def __call__(self, *args: Union[RunnerArguments, AdditionalSwiftSourcesArguments]):
6262
task_name = args[0].repo_name
6363
self._task_tracker.mark_task_as_running(task_name)
6464
result = None
6565
try:
66-
result = self.fn(*args)
66+
result = self._fn(*args)
6767
except Exception as e:
6868
print(e)
6969
finally:
7070
self._task_tracker.mark_task_as_done(task_name)
7171
return result
7272

7373

74-
class ParallelRunner:
74+
class ParallelRunner():
7575
def __init__(
7676
self,
77-
fn: Callable,
78-
pool_args: List[Union[RunnerArguments, AdditionalSwiftSourcesArguments]],
77+
fn: Callable[..., None],
78+
pool_args: Union[List[UpdateArguments], List[AdditionalSwiftSourcesArguments]],
7979
n_threads: int = 0,
8080
):
81+
def run_safely(*args, **kwargs):
82+
try:
83+
fn(*args, **kwargs)
84+
except GitException as e:
85+
return e
86+
8187
if n_threads == 0:
8288
# Limit the number of threads as the performance regresses if the
8389
# number is too high.
@@ -86,7 +92,8 @@ def __init__(
8692
self._monitor_polling_period = 0.1
8793
self._terminal_width = shutil.get_terminal_size().columns
8894
self._pool_args = pool_args
89-
self._fn = fn
95+
self._fn_name = fn.__name__
96+
self._fn = run_safely
9097
self._output_prefix = pool_args[0].output_prefix
9198
self._nb_repos = len(pool_args)
9299
self._stop_event = Event()
@@ -95,8 +102,8 @@ def __init__(
95102
self._task_tracker = TaskTracker()
96103
self._monitored_fn = MonitoredFunction(self._fn, self._task_tracker)
97104

98-
def run(self) -> List[Any]:
99-
print(f"Running ``{self._fn.__name__}`` with up to {self._n_threads} processes.")
105+
def run(self) -> List[Union[None, Exception]]:
106+
print(f"Running ``{self._fn_name}`` with up to {self._n_threads} processes.")
100107
if self._verbose:
101108
with ThreadPoolExecutor(max_workers=self._n_threads) as pool:
102109
results = list(pool.map(self._fn, self._pool_args, timeout=1800))
@@ -131,7 +138,9 @@ def _monitor(self):
131138
sys.stdout.flush()
132139

133140
@staticmethod
134-
def check_results(results, operation: str) -> int:
141+
def check_results(
142+
results: Optional[List[Union[GitException, Exception, Any]]], operation: str
143+
) -> int:
135144
"""Check the results of ParallelRunner and print the failures."""
136145

137146
fail_count = 0
@@ -143,11 +152,8 @@ def check_results(results, operation: str) -> int:
143152
if fail_count == 0:
144153
print(f"======{operation} FAILURES======")
145154
fail_count += 1
146-
if isinstance(r, str):
155+
if isinstance(r, (GitException, Exception)):
147156
print(r)
148157
continue
149-
if isinstance(r, GitException):
150-
print(str(r))
151-
continue
152158
print(r)
153159
return fail_count

utils/update_checkout/update_checkout/runner_arguments.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, List
2+
from typing import Any, Dict, List, Optional
33

44
from .cli_arguments import CliArguments
55

@@ -15,12 +15,12 @@ class UpdateArguments(RunnerArguments):
1515
source_root: str
1616
config: Dict[str, Any]
1717
scheme_map: Any
18-
tag: str
18+
tag: Optional[str]
1919
timestamp: Any
2020
reset_to_remote: bool
2121
clean: bool
2222
stash: bool
23-
cross_repos_pr: bool
23+
cross_repos_pr: Dict[str, str]
2424

2525
@dataclass
2626
class AdditionalSwiftSourcesArguments(RunnerArguments):

utils/update_checkout/update_checkout/update_checkout.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import sys
1616
import traceback
1717
from multiprocessing import freeze_support
18-
from typing import Any, Dict, Optional, Set, List, Union
18+
from typing import Any, Dict, Hashable, Optional, Set, List, Union
1919

2020
from .cli_arguments import CliArguments
21-
from .git_command import Git
21+
from .git_command import Git, GitException
2222
from .runner_arguments import AdditionalSwiftSourcesArguments, UpdateArguments
2323
from .parallel_runner import ParallelRunner
2424

@@ -74,7 +74,7 @@ def get_branch_for_repo(
7474
config: Dict[str, Any],
7575
repo_name: str,
7676
scheme_name: str,
77-
scheme_map: Dict[str, str],
77+
scheme_map: Optional[Dict[str, str]],
7878
cross_repos_pr: Dict[str, str],
7979
):
8080
"""Infer, fetch, and return a branch corresponding to a given PR, otherwise
@@ -85,7 +85,7 @@ def get_branch_for_repo(
8585
config (Dict[str, Any]): deserialized `update-checkout-config.json`
8686
repo_name (str): name of the repository for checking out the branch
8787
scheme_name (str): name of the scheme to look up in the config
88-
scheme_map (Dict[str, str]): map of repo names to branches to check out
88+
scheme_map (Dict[str, str] | None): map of repo names to branches to check out
8989
cross_repos_pr (Dict[str, str]): map of repo ids to PRs to check out
9090
9191
Returns:
@@ -239,8 +239,8 @@ def run_for_repo_and_each_submodule_rec(args: List[str]):
239239
# Otherwise there was some other error, and we need to handle
240240
# it like other command errors.
241241
Git.run(repo_path, ["symbolic-ref", "-q", "HEAD"])
242-
except Exception as e:
243-
if e.ret == 1:
242+
except GitException as e:
243+
if e.returncode == 1:
244244
detached_head = True
245245
else:
246246
raise # Pass this error up the chain.
@@ -268,19 +268,17 @@ def run_for_repo_and_each_submodule_rec(args: List[str]):
268268
prefix=prefix,
269269
)
270270
except Exception:
271-
(type, value, tb) = sys.exc_info()
272271
if verbose:
273272
print('Error on repo "%s": %s' % (repo_path, traceback.format_exc()))
274-
return value
273+
raise
275274

276275

277-
def get_timestamp_to_match(match_timestamp, source_root):
278-
# type: (str | None, str) -> str | None
276+
def get_timestamp_to_match(match_timestamp: bool, source_root: str):
279277
"""Computes a timestamp of the last commit on the current branch in
280278
the `swift` repository.
281279
282280
Args:
283-
match_timestamp (str | None): value of `--match-timestamp` to check.
281+
match_timestamp (bool): value of `--match-timestamp` to check.
284282
source_root (str): directory that contains sources of the Swift project.
285283
286284
Returns:
@@ -295,7 +293,7 @@ def get_timestamp_to_match(match_timestamp, source_root):
295293
return output
296294

297295

298-
def get_scheme_map(config: Dict[str, Any], scheme_name: str):
296+
def get_scheme_map(config: Dict[str, Any], scheme_name: str) -> Optional[Dict[str, str]]:
299297
"""Find a mapping from repository IDs to branches in the config.
300298
301299
Args:
@@ -342,7 +340,7 @@ def _is_any_repository_locked(pool_args: List[UpdateArguments]) -> Set[str]:
342340
locked_repositories.add(repo_name)
343341
return locked_repositories
344342

345-
def _move_llvm_project_to_first_index(pool_args: List[Union[UpdateArguments, AdditionalSwiftSourcesArguments]]):
343+
def _move_llvm_project_to_first_index(pool_args: Union[List[UpdateArguments], List[AdditionalSwiftSourcesArguments]]):
346344
llvm_project_idx = None
347345
for i in range(len(pool_args)):
348346
if pool_args[i].repo_name == "llvm-project":
@@ -351,7 +349,13 @@ def _move_llvm_project_to_first_index(pool_args: List[Union[UpdateArguments, Add
351349
if llvm_project_idx is not None:
352350
pool_args.insert(0, pool_args.pop(llvm_project_idx))
353351

354-
def update_all_repositories(args: CliArguments, config, scheme_name, scheme_map, cross_repos_pr):
352+
def update_all_repositories(
353+
args: CliArguments,
354+
config: Dict[str, Any],
355+
scheme_name: str,
356+
scheme_map: Optional[Dict[str, Any]],
357+
cross_repos_pr: Dict[str, str],
358+
):
355359
pool_args: List[UpdateArguments] = []
356360
timestamp = get_timestamp_to_match(args.match_timestamp, args.source_root)
357361
for repo_name in config['repos'].keys():
@@ -392,7 +396,7 @@ def update_all_repositories(args: CliArguments, config, scheme_name, scheme_map,
392396
locked_repositories: set[str] = _is_any_repository_locked(pool_args)
393397
if len(locked_repositories) > 0:
394398
return [
395-
f"'{repo_name}' is locked by git. Cannot update it."
399+
Exception(f"'{repo_name}' is locked by git. Cannot update it.")
396400
for repo_name in locked_repositories
397401
]
398402
_move_llvm_project_to_first_index(pool_args)
@@ -484,7 +488,7 @@ def obtain_all_additional_swift_sources(
484488
else:
485489
remote = config['https-clone-pattern'] % remote_repo_id
486490

487-
repo_branch = None
491+
repo_branch: Optional[str] = None
488492
repo_not_in_scheme = False
489493
if scheme_name:
490494
for v in config['branch-schemes'].values():
@@ -500,6 +504,9 @@ def obtain_all_additional_swift_sources(
500504
repo_branch = scheme_name
501505
if repo_not_in_scheme:
502506
continue
507+
508+
if repo_branch is None:
509+
raise RuntimeError("repo_branch is None")
503510

504511
new_args = AdditionalSwiftSourcesArguments(
505512
args=args,
@@ -570,7 +577,7 @@ def print_repo_hashes(args: CliArguments, config: Dict[str, Any]):
570577
print("{:<35}: {:<35}".format(repo_name, repo_hash))
571578

572579

573-
def merge_no_duplicates(a: dict, b: dict) -> dict:
580+
def merge_no_duplicates(a: Dict[Hashable, Any], b: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
574581
result = {**a}
575582
for key, value in b.items():
576583
if key in a:
@@ -580,7 +587,7 @@ def merge_no_duplicates(a: dict, b: dict) -> dict:
580587
return result
581588

582589

583-
def merge_config(config: dict, new_config: dict) -> dict:
590+
def merge_config(config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
584591
"""
585592
Merge two configs, with a 'last-wins' strategy.
586593
@@ -619,7 +626,7 @@ def validate_config(config: Dict[str, Any]):
619626
'too.'.format(scheme_name))
620627

621628
# Then make sure the alias names used by our branches are unique.
622-
seen = dict()
629+
seen: Dict[str, Any] = dict()
623630
for (scheme_name, scheme) in config['branch-schemes'].items():
624631
aliases = scheme['aliases']
625632
for alias in aliases:
@@ -631,7 +638,7 @@ def validate_config(config: Dict[str, Any]):
631638
seen[alias] = scheme_name
632639

633640

634-
def full_target_name(repo_path, repository, target):
641+
def full_target_name(repo_path: str, repository: str, target: str) -> str:
635642
tag, _, _ = Git.run(repo_path, ["tag", "-l", target], fatal=True)
636643
if tag == target:
637644
return tag
@@ -645,13 +652,13 @@ def full_target_name(repo_path, repository, target):
645652
raise RuntimeError('Cannot determine if %s is a branch or a tag' % target)
646653

647654

648-
def skip_list_for_platform(config: Dict[str, Any], all_repos: List[str]) -> List[str]:
655+
def skip_list_for_platform(config: Dict[str, Any], all_repos: bool) -> List[str]:
649656
"""Computes a list of repositories to skip when updating or cloning, if not
650657
overridden by `--all-repositories` CLI argument.
651658
652659
Args:
653660
config (Dict[str, Any]): deserialized `update-checkout-config.json`
654-
all_repos (List[str]): repositories not required for current platform.
661+
all_repos (bool): include all repositories.
655662
656663
Returns:
657664
List[str]: a resulting list of repositories to skip or empty list if
@@ -677,7 +684,7 @@ def skip_list_for_platform(config: Dict[str, Any], all_repos: List[str]) -> List
677684
return skip_list
678685

679686

680-
def main():
687+
def main() -> int:
681688
freeze_support()
682689
args = CliArguments.parse_args()
683690

@@ -704,7 +711,7 @@ def main():
704711
config = merge_config(config, json.load(f))
705712
validate_config(config)
706713

707-
cross_repos_pr = {}
714+
cross_repos_pr: Dict[str, str] = {}
708715
if args.github_comment:
709716
regex_pr = r'(apple/[-a-zA-Z0-9_]+/pull/\d+'\
710717
r'|apple/[-a-zA-Z0-9_]+#\d+'\
@@ -755,11 +762,11 @@ def main():
755762

756763
if args.dump_hashes:
757764
dump_repo_hashes(args, config)
758-
return (None, None)
765+
return 0
759766

760767
if args.dump_hashes_config:
761768
dump_repo_hashes(args, config, args.dump_hashes_config)
762-
return (None, None)
769+
return 0
763770

764771
# Quick check whether somebody is calling update in an empty directory
765772
directory_contents = os.listdir(args.source_root)

0 commit comments

Comments
 (0)