1
1
Fork 0
mirror of https://github.com/pypa/pip synced 2023-12-13 21:30:23 +01:00

Move all remaining type comments to annotations

Use the com2ann tool to convert remaining comments to annotations. Now,
no type comments remain.

https://github.com/ilevkivskyi/com2ann

Some types are not available at runtime (e.g. Literal) or require a
forward reference and so were quoted.
This commit is contained in:
Jon Dufresne 2021-08-27 18:29:39 -07:00
parent 7aaea4e218
commit 44a034a131
8 changed files with 237 additions and 393 deletions

View file

@ -28,8 +28,7 @@ class Hashes:
""" """
def __init__(self, hashes=None): def __init__(self, hashes: Dict[str, List[str]] = None) -> None:
# type: (Dict[str, List[str]]) -> None
""" """
:param hashes: A dict of algorithm names pointing to lists of allowed :param hashes: A dict of algorithm names pointing to lists of allowed
hex digests hex digests
@ -41,8 +40,7 @@ class Hashes:
allowed[alg] = sorted(keys) allowed[alg] = sorted(keys)
self._allowed = allowed self._allowed = allowed
def __and__(self, other): def __and__(self, other: "Hashes") -> "Hashes":
# type: (Hashes) -> Hashes
if not isinstance(other, Hashes): if not isinstance(other, Hashes):
return NotImplemented return NotImplemented
@ -62,21 +60,14 @@ class Hashes:
return Hashes(new) return Hashes(new)
@property @property
def digest_count(self): def digest_count(self) -> int:
# type: () -> int
return sum(len(digests) for digests in self._allowed.values()) return sum(len(digests) for digests in self._allowed.values())
def is_hash_allowed( def is_hash_allowed(self, hash_name: str, hex_digest: str) -> bool:
self,
hash_name, # type: str
hex_digest, # type: str
):
# type: (...) -> bool
"""Return whether the given hex digest is allowed.""" """Return whether the given hex digest is allowed."""
return hex_digest in self._allowed.get(hash_name, []) return hex_digest in self._allowed.get(hash_name, [])
def check_against_chunks(self, chunks): def check_against_chunks(self, chunks: Iterator[bytes]) -> None:
# type: (Iterator[bytes]) -> None
"""Check good hashes against ones built from iterable of chunks of """Check good hashes against ones built from iterable of chunks of
data. data.
@ -99,12 +90,10 @@ class Hashes:
return return
self._raise(gots) self._raise(gots)
def _raise(self, gots): def _raise(self, gots: Dict[str, "_Hash"]) -> "NoReturn":
# type: (Dict[str, _Hash]) -> NoReturn
raise HashMismatch(self._allowed, gots) raise HashMismatch(self._allowed, gots)
def check_against_file(self, file): def check_against_file(self, file: BinaryIO) -> None:
# type: (BinaryIO) -> None
"""Check good hashes against a file-like object """Check good hashes against a file-like object
Raise HashMismatch if none match. Raise HashMismatch if none match.
@ -112,24 +101,20 @@ class Hashes:
""" """
return self.check_against_chunks(read_chunks(file)) return self.check_against_chunks(read_chunks(file))
def check_against_path(self, path): def check_against_path(self, path: str) -> None:
# type: (str) -> None
with open(path, "rb") as file: with open(path, "rb") as file:
return self.check_against_file(file) return self.check_against_file(file)
def __bool__(self): def __bool__(self) -> bool:
# type: () -> bool
"""Return whether I know any known-good hashes.""" """Return whether I know any known-good hashes."""
return bool(self._allowed) return bool(self._allowed)
def __eq__(self, other): def __eq__(self, other: object) -> bool:
# type: (object) -> bool
if not isinstance(other, Hashes): if not isinstance(other, Hashes):
return NotImplemented return NotImplemented
return self._allowed == other._allowed return self._allowed == other._allowed
def __hash__(self): def __hash__(self) -> int:
# type: () -> int
return hash( return hash(
",".join( ",".join(
sorted( sorted(
@ -149,13 +134,11 @@ class MissingHashes(Hashes):
""" """
def __init__(self): def __init__(self) -> None:
# type: () -> None
"""Don't offer the ``hashes`` kwarg.""" """Don't offer the ``hashes`` kwarg."""
# Pass our favorite hash in to generate a "gotten hash". With the # Pass our favorite hash in to generate a "gotten hash". With the
# empty list, it will never match, so an error will always raise. # empty list, it will never match, so an error will always raise.
super().__init__(hashes={FAVORITE_HASH: []}) super().__init__(hashes={FAVORITE_HASH: []})
def _raise(self, gots): def _raise(self, gots: Dict[str, "_Hash"]) -> "NoReturn":
# type: (Dict[str, _Hash]) -> NoReturn
raise HashMissing(gots[FAVORITE_HASH].hexdigest()) raise HashMissing(gots[FAVORITE_HASH].hexdigest())

View file

@ -70,8 +70,7 @@ VersionInfo = Tuple[int, int, int]
NetlocTuple = Tuple[str, Tuple[Optional[str], Optional[str]]] NetlocTuple = Tuple[str, Tuple[Optional[str], Optional[str]]]
def get_pip_version(): def get_pip_version() -> str:
# type: () -> str
pip_pkg_dir = os.path.join(os.path.dirname(__file__), "..", "..") pip_pkg_dir = os.path.join(os.path.dirname(__file__), "..", "..")
pip_pkg_dir = os.path.abspath(pip_pkg_dir) pip_pkg_dir = os.path.abspath(pip_pkg_dir)
@ -82,8 +81,7 @@ def get_pip_version():
) )
def normalize_version_info(py_version_info): def normalize_version_info(py_version_info: Tuple[int, ...]) -> Tuple[int, int, int]:
# type: (Tuple[int, ...]) -> Tuple[int, int, int]
""" """
Convert a tuple of ints representing a Python version to one of length Convert a tuple of ints representing a Python version to one of length
three. three.
@ -102,8 +100,7 @@ def normalize_version_info(py_version_info):
return cast("VersionInfo", py_version_info) return cast("VersionInfo", py_version_info)
def ensure_dir(path): def ensure_dir(path: str) -> None:
# type: (str) -> None
"""os.path.makedirs without EEXIST.""" """os.path.makedirs without EEXIST."""
try: try:
os.makedirs(path) os.makedirs(path)
@ -113,8 +110,7 @@ def ensure_dir(path):
raise raise
def get_prog(): def get_prog() -> str:
# type: () -> str
try: try:
prog = os.path.basename(sys.argv[0]) prog = os.path.basename(sys.argv[0])
if prog in ("__main__.py", "-c"): if prog in ("__main__.py", "-c"):
@ -129,13 +125,11 @@ def get_prog():
# Retry every half second for up to 3 seconds # Retry every half second for up to 3 seconds
# Tenacity raises RetryError by default, explicitly raise the original exception # Tenacity raises RetryError by default, explicitly raise the original exception
@retry(reraise=True, stop=stop_after_delay(3), wait=wait_fixed(0.5)) @retry(reraise=True, stop=stop_after_delay(3), wait=wait_fixed(0.5))
def rmtree(dir, ignore_errors=False): def rmtree(dir: str, ignore_errors: bool = False) -> None:
# type: (str, bool) -> None
shutil.rmtree(dir, ignore_errors=ignore_errors, onerror=rmtree_errorhandler) shutil.rmtree(dir, ignore_errors=ignore_errors, onerror=rmtree_errorhandler)
def rmtree_errorhandler(func, path, exc_info): def rmtree_errorhandler(func: Callable[..., Any], path: str, exc_info: ExcInfo) -> None:
# type: (Callable[..., Any], str, ExcInfo) -> None
"""On Windows, the files in .svn are read-only, so when rmtree() tries to """On Windows, the files in .svn are read-only, so when rmtree() tries to
remove them, an exception is thrown. We catch that here, remove the remove them, an exception is thrown. We catch that here, remove the
read-only attribute, and hopefully continue without problems.""" read-only attribute, and hopefully continue without problems."""
@ -155,8 +149,7 @@ def rmtree_errorhandler(func, path, exc_info):
raise raise
def display_path(path): def display_path(path: str) -> str:
# type: (str) -> str
"""Gives the display value for a given path, making it relative to cwd """Gives the display value for a given path, making it relative to cwd
if possible.""" if possible."""
path = os.path.normcase(os.path.abspath(path)) path = os.path.normcase(os.path.abspath(path))
@ -165,8 +158,7 @@ def display_path(path):
return path return path
def backup_dir(dir, ext=".bak"): def backup_dir(dir: str, ext: str = ".bak") -> str:
# type: (str, str) -> str
"""Figure out the name of a directory to back up the given dir to """Figure out the name of a directory to back up the given dir to
(adding .bak, .bak2, etc)""" (adding .bak, .bak2, etc)"""
n = 1 n = 1
@ -177,16 +169,14 @@ def backup_dir(dir, ext=".bak"):
return dir + extension return dir + extension
def ask_path_exists(message, options): def ask_path_exists(message: str, options: Iterable[str]) -> str:
# type: (str, Iterable[str]) -> str
for action in os.environ.get("PIP_EXISTS_ACTION", "").split(): for action in os.environ.get("PIP_EXISTS_ACTION", "").split():
if action in options: if action in options:
return action return action
return ask(message, options) return ask(message, options)
def _check_no_input(message): def _check_no_input(message: str) -> None:
# type: (str) -> None
"""Raise an error if no input is allowed.""" """Raise an error if no input is allowed."""
if os.environ.get("PIP_NO_INPUT"): if os.environ.get("PIP_NO_INPUT"):
raise Exception( raise Exception(
@ -194,8 +184,7 @@ def _check_no_input(message):
) )
def ask(message, options): def ask(message: str, options: Iterable[str]) -> str:
# type: (str, Iterable[str]) -> str
"""Ask the message interactively, with the given possible responses""" """Ask the message interactively, with the given possible responses"""
while 1: while 1:
_check_no_input(message) _check_no_input(message)
@ -210,22 +199,19 @@ def ask(message, options):
return response return response
def ask_input(message): def ask_input(message: str) -> str:
# type: (str) -> str
"""Ask for input interactively.""" """Ask for input interactively."""
_check_no_input(message) _check_no_input(message)
return input(message) return input(message)
def ask_password(message): def ask_password(message: str) -> str:
# type: (str) -> str
"""Ask for a password interactively.""" """Ask for a password interactively."""
_check_no_input(message) _check_no_input(message)
return getpass.getpass(message) return getpass.getpass(message)
def strtobool(val): def strtobool(val: str) -> int:
# type: (str) -> int
"""Convert a string representation of truth to true (1) or false (0). """Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
@ -241,8 +227,7 @@ def strtobool(val):
raise ValueError(f"invalid truth value {val!r}") raise ValueError(f"invalid truth value {val!r}")
def format_size(bytes): def format_size(bytes: float) -> str:
# type: (float) -> str
if bytes > 1000 * 1000: if bytes > 1000 * 1000:
return "{:.1f} MB".format(bytes / 1000.0 / 1000) return "{:.1f} MB".format(bytes / 1000.0 / 1000)
elif bytes > 10 * 1000: elif bytes > 10 * 1000:
@ -253,8 +238,7 @@ def format_size(bytes):
return "{} bytes".format(int(bytes)) return "{} bytes".format(int(bytes))
def tabulate(rows): def tabulate(rows: Iterable[Iterable[Any]]) -> Tuple[List[str], List[int]]:
# type: (Iterable[Iterable[Any]]) -> Tuple[List[str], List[int]]
"""Return a list of formatted rows and a list of column sizes. """Return a list of formatted rows and a list of column sizes.
For example:: For example::
@ -285,8 +269,7 @@ def is_installable_dir(path: str) -> bool:
return False return False
def read_chunks(file, size=io.DEFAULT_BUFFER_SIZE): def read_chunks(file: BinaryIO, size: int = io.DEFAULT_BUFFER_SIZE) -> Iterator[bytes]:
# type: (BinaryIO, int) -> Iterator[bytes]
"""Yield pieces of data from a file-like object until EOF.""" """Yield pieces of data from a file-like object until EOF."""
while True: while True:
chunk = file.read(size) chunk = file.read(size)
@ -295,8 +278,7 @@ def read_chunks(file, size=io.DEFAULT_BUFFER_SIZE):
yield chunk yield chunk
def normalize_path(path, resolve_symlinks=True): def normalize_path(path: str, resolve_symlinks: bool = True) -> str:
# type: (str, bool) -> str
""" """
Convert a path to its canonical, case-normalized, absolute version. Convert a path to its canonical, case-normalized, absolute version.
@ -309,8 +291,7 @@ def normalize_path(path, resolve_symlinks=True):
return os.path.normcase(path) return os.path.normcase(path)
def splitext(path): def splitext(path: str) -> Tuple[str, str]:
# type: (str) -> Tuple[str, str]
"""Like os.path.splitext, but take off .tar too""" """Like os.path.splitext, but take off .tar too"""
base, ext = posixpath.splitext(path) base, ext = posixpath.splitext(path)
if base.lower().endswith(".tar"): if base.lower().endswith(".tar"):
@ -319,8 +300,7 @@ def splitext(path):
return base, ext return base, ext
def renames(old, new): def renames(old: str, new: str) -> None:
# type: (str, str) -> None
"""Like os.renames(), but handles renaming across devices.""" """Like os.renames(), but handles renaming across devices."""
# Implementation borrowed from os.renames(). # Implementation borrowed from os.renames().
head, tail = os.path.split(new) head, tail = os.path.split(new)
@ -337,8 +317,7 @@ def renames(old, new):
pass pass
def is_local(path): def is_local(path: str) -> bool:
# type: (str) -> bool
""" """
Return True if path is within sys.prefix, if we're running in a virtualenv. Return True if path is within sys.prefix, if we're running in a virtualenv.
@ -352,8 +331,7 @@ def is_local(path):
return path.startswith(normalize_path(sys.prefix)) return path.startswith(normalize_path(sys.prefix))
def dist_is_local(dist): def dist_is_local(dist: Distribution) -> bool:
# type: (Distribution) -> bool
""" """
Return True if given Distribution object is installed locally Return True if given Distribution object is installed locally
(i.e. within current virtualenv). (i.e. within current virtualenv).
@ -364,16 +342,14 @@ def dist_is_local(dist):
return is_local(dist_location(dist)) return is_local(dist_location(dist))
def dist_in_usersite(dist): def dist_in_usersite(dist: Distribution) -> bool:
# type: (Distribution) -> bool
""" """
Return True if given Distribution is installed in user site. Return True if given Distribution is installed in user site.
""" """
return dist_location(dist).startswith(normalize_path(user_site)) return dist_location(dist).startswith(normalize_path(user_site))
def dist_in_site_packages(dist): def dist_in_site_packages(dist: Distribution) -> bool:
# type: (Distribution) -> bool
""" """
Return True if given Distribution is installed in Return True if given Distribution is installed in
sysconfig.get_python_lib(). sysconfig.get_python_lib().
@ -381,8 +357,7 @@ def dist_in_site_packages(dist):
return dist_location(dist).startswith(normalize_path(site_packages)) return dist_location(dist).startswith(normalize_path(site_packages))
def dist_is_editable(dist): def dist_is_editable(dist: Distribution) -> bool:
# type: (Distribution) -> bool
""" """
Return True if given Distribution is an editable install. Return True if given Distribution is an editable install.
""" """
@ -394,14 +369,13 @@ def dist_is_editable(dist):
def get_installed_distributions( def get_installed_distributions(
local_only=True, # type: bool local_only: bool = True,
skip=stdlib_pkgs, # type: Container[str] skip: Container[str] = stdlib_pkgs,
include_editables=True, # type: bool include_editables: bool = True,
editables_only=False, # type: bool editables_only: bool = False,
user_only=False, # type: bool user_only: bool = False,
paths=None, # type: Optional[List[str]] paths: Optional[List[str]] = None,
): ) -> List[Distribution]:
# type: (...) -> List[Distribution]
"""Return a list of installed Distribution objects. """Return a list of installed Distribution objects.
Left for compatibility until direct pkg_resources uses are refactored out. Left for compatibility until direct pkg_resources uses are refactored out.
@ -423,8 +397,7 @@ def get_installed_distributions(
return [cast(_Dist, dist)._dist for dist in dists] return [cast(_Dist, dist)._dist for dist in dists]
def get_distribution(req_name): def get_distribution(req_name: str) -> Optional[Distribution]:
# type: (str) -> Optional[Distribution]
"""Given a requirement name, return the installed Distribution object. """Given a requirement name, return the installed Distribution object.
This searches from *all* distributions available in the environment, to This searches from *all* distributions available in the environment, to
@ -441,8 +414,7 @@ def get_distribution(req_name):
return cast(_Dist, dist)._dist return cast(_Dist, dist)._dist
def egg_link_path(dist): def egg_link_path(dist: Distribution) -> Optional[str]:
# type: (Distribution) -> Optional[str]
""" """
Return the path for the .egg-link file if it exists, otherwise, None. Return the path for the .egg-link file if it exists, otherwise, None.
@ -477,8 +449,7 @@ def egg_link_path(dist):
return None return None
def dist_location(dist): def dist_location(dist: Distribution) -> str:
# type: (Distribution) -> str
""" """
Get the site-packages location of this distribution. Generally Get the site-packages location of this distribution. Generally
this is dist.location, except in the case of develop-installed this is dist.location, except in the case of develop-installed
@ -493,17 +464,15 @@ def dist_location(dist):
return normalize_path(dist.location) return normalize_path(dist.location)
def write_output(msg, *args): def write_output(msg: Any, *args: Any) -> None:
# type: (Any, Any) -> None
logger.info(msg, *args) logger.info(msg, *args)
class StreamWrapper(StringIO): class StreamWrapper(StringIO):
orig_stream = None # type: TextIO orig_stream: TextIO = None
@classmethod @classmethod
def from_stream(cls, orig_stream): def from_stream(cls, orig_stream: TextIO) -> "StreamWrapper":
# type: (TextIO) -> StreamWrapper
cls.orig_stream = orig_stream cls.orig_stream = orig_stream
return cls() return cls()
@ -515,8 +484,7 @@ class StreamWrapper(StringIO):
@contextlib.contextmanager @contextlib.contextmanager
def captured_output(stream_name): def captured_output(stream_name: str) -> Iterator[StreamWrapper]:
# type: (str) -> Iterator[StreamWrapper]
"""Return a context manager used by captured_stdout/stdin/stderr """Return a context manager used by captured_stdout/stdin/stderr
that temporarily replaces the sys stream *stream_name* with a StringIO. that temporarily replaces the sys stream *stream_name* with a StringIO.
@ -530,8 +498,7 @@ def captured_output(stream_name):
setattr(sys, stream_name, orig_stdout) setattr(sys, stream_name, orig_stdout)
def captured_stdout(): def captured_stdout() -> ContextManager[StreamWrapper]:
# type: () -> ContextManager[StreamWrapper]
"""Capture the output of sys.stdout: """Capture the output of sys.stdout:
with captured_stdout() as stdout: with captured_stdout() as stdout:
@ -543,8 +510,7 @@ def captured_stdout():
return captured_output("stdout") return captured_output("stdout")
def captured_stderr(): def captured_stderr() -> ContextManager[StreamWrapper]:
# type: () -> ContextManager[StreamWrapper]
""" """
See captured_stdout(). See captured_stdout().
""" """
@ -552,16 +518,14 @@ def captured_stderr():
# Simulates an enum # Simulates an enum
def enum(*sequential, **named): def enum(*sequential: Any, **named: Any) -> Type[Any]:
# type: (*Any, **Any) -> Type[Any]
enums = dict(zip(sequential, range(len(sequential))), **named) enums = dict(zip(sequential, range(len(sequential))), **named)
reverse = {value: key for key, value in enums.items()} reverse = {value: key for key, value in enums.items()}
enums["reverse_mapping"] = reverse enums["reverse_mapping"] = reverse
return type("Enum", (), enums) return type("Enum", (), enums)
def build_netloc(host, port): def build_netloc(host: str, port: Optional[int]) -> str:
# type: (str, Optional[int]) -> str
""" """
Build a netloc from a host-port pair Build a netloc from a host-port pair
""" """
@ -573,8 +537,7 @@ def build_netloc(host, port):
return f"{host}:{port}" return f"{host}:{port}"
def build_url_from_netloc(netloc, scheme="https"): def build_url_from_netloc(netloc: str, scheme: str = "https") -> str:
# type: (str, str) -> str
""" """
Build a full URL from a netloc. Build a full URL from a netloc.
""" """
@ -584,8 +547,7 @@ def build_url_from_netloc(netloc, scheme="https"):
return f"{scheme}://{netloc}" return f"{scheme}://{netloc}"
def parse_netloc(netloc): def parse_netloc(netloc: str) -> Tuple[str, Optional[int]]:
# type: (str) -> Tuple[str, Optional[int]]
""" """
Return the host-port pair from a netloc. Return the host-port pair from a netloc.
""" """
@ -594,8 +556,7 @@ def parse_netloc(netloc):
return parsed.hostname, parsed.port return parsed.hostname, parsed.port
def split_auth_from_netloc(netloc): def split_auth_from_netloc(netloc: str) -> NetlocTuple:
# type: (str) -> NetlocTuple
""" """
Parse out and remove the auth information from a netloc. Parse out and remove the auth information from a netloc.
@ -608,7 +569,7 @@ def split_auth_from_netloc(netloc):
# behaves if more than one @ is present (which can be checked using # behaves if more than one @ is present (which can be checked using
# the password attribute of urlsplit()'s return value). # the password attribute of urlsplit()'s return value).
auth, netloc = netloc.rsplit("@", 1) auth, netloc = netloc.rsplit("@", 1)
pw = None # type: Optional[str] pw: Optional[str] = None
if ":" in auth: if ":" in auth:
# Split from the left because that's how urllib.parse.urlsplit() # Split from the left because that's how urllib.parse.urlsplit()
# behaves if more than one : is present (which again can be checked # behaves if more than one : is present (which again can be checked
@ -624,8 +585,7 @@ def split_auth_from_netloc(netloc):
return netloc, (user, pw) return netloc, (user, pw)
def redact_netloc(netloc): def redact_netloc(netloc: str) -> str:
# type: (str) -> str
""" """
Replace the sensitive data in a netloc with "****", if it exists. Replace the sensitive data in a netloc with "****", if it exists.
@ -647,8 +607,9 @@ def redact_netloc(netloc):
) )
def _transform_url(url, transform_netloc): def _transform_url(
# type: (str, Callable[[str], Tuple[Any, ...]]) -> Tuple[str, NetlocTuple] url: str, transform_netloc: Callable[[str], Tuple[Any, ...]]
) -> Tuple[str, NetlocTuple]:
"""Transform and replace netloc in a url. """Transform and replace netloc in a url.
transform_netloc is a function taking the netloc and returning a transform_netloc is a function taking the netloc and returning a
@ -666,18 +627,15 @@ def _transform_url(url, transform_netloc):
return surl, cast("NetlocTuple", netloc_tuple) return surl, cast("NetlocTuple", netloc_tuple)
def _get_netloc(netloc): def _get_netloc(netloc: str) -> NetlocTuple:
# type: (str) -> NetlocTuple
return split_auth_from_netloc(netloc) return split_auth_from_netloc(netloc)
def _redact_netloc(netloc): def _redact_netloc(netloc: str) -> Tuple[str]:
# type: (str) -> Tuple[str,]
return (redact_netloc(netloc),) return (redact_netloc(netloc),)
def split_auth_netloc_from_url(url): def split_auth_netloc_from_url(url: str) -> Tuple[str, str, Tuple[str, str]]:
# type: (str) -> Tuple[str, str, Tuple[str, str]]
""" """
Parse a url into separate netloc, auth, and url with no auth. Parse a url into separate netloc, auth, and url with no auth.
@ -687,41 +645,31 @@ def split_auth_netloc_from_url(url):
return url_without_auth, netloc, auth return url_without_auth, netloc, auth
def remove_auth_from_url(url): def remove_auth_from_url(url: str) -> str:
# type: (str) -> str
"""Return a copy of url with 'username:password@' removed.""" """Return a copy of url with 'username:password@' removed."""
# username/pass params are passed to subversion through flags # username/pass params are passed to subversion through flags
# and are not recognized in the url. # and are not recognized in the url.
return _transform_url(url, _get_netloc)[0] return _transform_url(url, _get_netloc)[0]
def redact_auth_from_url(url): def redact_auth_from_url(url: str) -> str:
# type: (str) -> str
"""Replace the password in a given url with ****.""" """Replace the password in a given url with ****."""
return _transform_url(url, _redact_netloc)[0] return _transform_url(url, _redact_netloc)[0]
class HiddenText: class HiddenText:
def __init__( def __init__(self, secret: str, redacted: str) -> None:
self,
secret, # type: str
redacted, # type: str
):
# type: (...) -> None
self.secret = secret self.secret = secret
self.redacted = redacted self.redacted = redacted
def __repr__(self): def __repr__(self) -> str:
# type: (...) -> str
return "<HiddenText {!r}>".format(str(self)) return "<HiddenText {!r}>".format(str(self))
def __str__(self): def __str__(self) -> str:
# type: (...) -> str
return self.redacted return self.redacted
# This is useful for testing. # This is useful for testing.
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
# type: (Any) -> bool
if type(self) != type(other): if type(self) != type(other):
return False return False
@ -730,19 +678,16 @@ class HiddenText:
return self.secret == other.secret return self.secret == other.secret
def hide_value(value): def hide_value(value: str) -> HiddenText:
# type: (str) -> HiddenText
return HiddenText(value, redacted="****") return HiddenText(value, redacted="****")
def hide_url(url): def hide_url(url: str) -> HiddenText:
# type: (str) -> HiddenText
redacted = redact_auth_from_url(url) redacted = redact_auth_from_url(url)
return HiddenText(url, redacted=redacted) return HiddenText(url, redacted=redacted)
def protect_pip_from_modification_on_windows(modifying_pip): def protect_pip_from_modification_on_windows(modifying_pip: bool) -> None:
# type: (bool) -> None
"""Protection of pip.exe from modification on Windows """Protection of pip.exe from modification on Windows
On Windows, any operation modifying pip should be run as: On Windows, any operation modifying pip should be run as:
@ -768,14 +713,12 @@ def protect_pip_from_modification_on_windows(modifying_pip):
) )
def is_console_interactive(): def is_console_interactive() -> bool:
# type: () -> bool
"""Is this console interactive?""" """Is this console interactive?"""
return sys.stdin is not None and sys.stdin.isatty() return sys.stdin is not None and sys.stdin.isatty()
def hash_file(path, blocksize=1 << 20): def hash_file(path: str, blocksize: int = 1 << 20) -> Tuple[Any, int]:
# type: (str, int) -> Tuple[Any, int]
"""Return (hash, length) for path using hashlib.sha256()""" """Return (hash, length) for path using hashlib.sha256()"""
h = hashlib.sha256() h = hashlib.sha256()
@ -787,8 +730,7 @@ def hash_file(path, blocksize=1 << 20):
return h, length return h, length
def is_wheel_installed(): def is_wheel_installed() -> bool:
# type: () -> bool
""" """
Return whether the wheel package is installed. Return whether the wheel package is installed.
""" """
@ -800,8 +742,7 @@ def is_wheel_installed():
return True return True
def pairwise(iterable): def pairwise(iterable: Iterable[Any]) -> Iterator[Tuple[Any, Any]]:
# type: (Iterable[Any]) -> Iterator[Tuple[Any, Any]]
""" """
Return paired elements. Return paired elements.
@ -813,10 +754,9 @@ def pairwise(iterable):
def partition( def partition(
pred, # type: Callable[[T], bool] pred: Callable[[T], bool],
iterable, # type: Iterable[T] iterable: Iterable[T],
): ) -> Tuple[Iterable[T], Iterable[T]]:
# type: (...) -> Tuple[Iterable[T], Iterable[T]]
""" """
Use a predicate to partition entries into false entries and true entries, Use a predicate to partition entries into false entries and true entries,
like like

View file

@ -10,37 +10,29 @@ class KeyBasedCompareMixin:
__slots__ = ["_compare_key", "_defining_class"] __slots__ = ["_compare_key", "_defining_class"]
def __init__(self, key, defining_class): def __init__(self, key: Any, defining_class: Type["KeyBasedCompareMixin"]) -> None:
# type: (Any, Type[KeyBasedCompareMixin]) -> None
self._compare_key = key self._compare_key = key
self._defining_class = defining_class self._defining_class = defining_class
def __hash__(self): def __hash__(self) -> int:
# type: () -> int
return hash(self._compare_key) return hash(self._compare_key)
def __lt__(self, other): def __lt__(self, other: Any) -> bool:
# type: (Any) -> bool
return self._compare(other, operator.__lt__) return self._compare(other, operator.__lt__)
def __le__(self, other): def __le__(self, other: Any) -> bool:
# type: (Any) -> bool
return self._compare(other, operator.__le__) return self._compare(other, operator.__le__)
def __gt__(self, other): def __gt__(self, other: Any) -> bool:
# type: (Any) -> bool
return self._compare(other, operator.__gt__) return self._compare(other, operator.__gt__)
def __ge__(self, other): def __ge__(self, other: Any) -> bool:
# type: (Any) -> bool
return self._compare(other, operator.__ge__) return self._compare(other, operator.__ge__)
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
# type: (Any) -> bool
return self._compare(other, operator.__eq__) return self._compare(other, operator.__eq__)
def _compare(self, other, method): def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool:
# type: (Any, Callable[[Any, Any], bool]) -> bool
if not isinstance(other, self._defining_class): if not isinstance(other, self._defining_class):
return NotImplemented return NotImplemented

View file

@ -30,12 +30,11 @@ CommandArgs = List[Union[str, HiddenText]]
LOG_DIVIDER = "----------------------------------------" LOG_DIVIDER = "----------------------------------------"
def make_command(*args): def make_command(*args: Union[str, HiddenText, CommandArgs]) -> CommandArgs:
# type: (Union[str, HiddenText, CommandArgs]) -> CommandArgs
""" """
Create a CommandArgs object. Create a CommandArgs object.
""" """
command_args = [] # type: CommandArgs command_args: CommandArgs = []
for arg in args: for arg in args:
# Check for list instead of CommandArgs since CommandArgs is # Check for list instead of CommandArgs since CommandArgs is
# only known during type-checking. # only known during type-checking.
@ -48,8 +47,7 @@ def make_command(*args):
return command_args return command_args
def format_command_args(args): def format_command_args(args: Union[List[str], CommandArgs]) -> str:
# type: (Union[List[str], CommandArgs]) -> str
""" """
Format command arguments for display. Format command arguments for display.
""" """
@ -64,8 +62,7 @@ def format_command_args(args):
) )
def reveal_command_args(args): def reveal_command_args(args: Union[List[str], CommandArgs]) -> List[str]:
# type: (Union[List[str], CommandArgs]) -> List[str]
""" """
Return the arguments in their raw, unredacted form. Return the arguments in their raw, unredacted form.
""" """
@ -73,12 +70,11 @@ def reveal_command_args(args):
def make_subprocess_output_error( def make_subprocess_output_error(
cmd_args, # type: Union[List[str], CommandArgs] cmd_args: Union[List[str], CommandArgs],
cwd, # type: Optional[str] cwd: Optional[str],
lines, # type: List[str] lines: List[str],
exit_status, # type: int exit_status: int,
): ) -> str:
# type: (...) -> str
""" """
Create and return the error message to use to log a subprocess error Create and return the error message to use to log a subprocess error
with command output. with command output.
@ -109,19 +105,18 @@ def make_subprocess_output_error(
def call_subprocess( def call_subprocess(
cmd, # type: Union[List[str], CommandArgs] cmd: Union[List[str], CommandArgs],
show_stdout=False, # type: bool show_stdout: bool = False,
cwd=None, # type: Optional[str] cwd: Optional[str] = None,
on_returncode="raise", # type: Literal["raise", "warn", "ignore"] on_returncode: 'Literal["raise", "warn", "ignore"]' = "raise",
extra_ok_returncodes=None, # type: Optional[Iterable[int]] extra_ok_returncodes: Optional[Iterable[int]] = None,
command_desc=None, # type: Optional[str] command_desc: Optional[str] = None,
extra_environ=None, # type: Optional[Mapping[str, Any]] extra_environ: Optional[Mapping[str, Any]] = None,
unset_environ=None, # type: Optional[Iterable[str]] unset_environ: Optional[Iterable[str]] = None,
spinner=None, # type: Optional[SpinnerInterface] spinner: Optional[SpinnerInterface] = None,
log_failed_cmd=True, # type: Optional[bool] log_failed_cmd: Optional[bool] = True,
stdout_only=False, # type: Optional[bool] stdout_only: Optional[bool] = False,
): ) -> str:
# type: (...) -> str
""" """
Args: Args:
show_stdout: if true, use INFO to log the subprocess's stderr and show_stdout: if true, use INFO to log the subprocess's stderr and
@ -206,7 +201,7 @@ def call_subprocess(
proc.stdin.close() proc.stdin.close()
# In this mode, stdout and stderr are in the same pipe. # In this mode, stdout and stderr are in the same pipe.
while True: while True:
line = proc.stdout.readline() # type: str line: str = proc.stdout.readline()
if not line: if not line:
break break
line = line.rstrip() line = line.rstrip()
@ -271,8 +266,7 @@ def call_subprocess(
return output return output
def runner_with_spinner_message(message): def runner_with_spinner_message(message: str) -> Callable[..., None]:
# type: (str) -> Callable[..., None]
"""Provide a subprocess_runner that shows a spinner message. """Provide a subprocess_runner that shows a spinner message.
Intended for use with for pep517's Pep517HookCaller. Thus, the runner has Intended for use with for pep517's Pep517HookCaller. Thus, the runner has
@ -280,11 +274,10 @@ def runner_with_spinner_message(message):
""" """
def runner( def runner(
cmd, # type: List[str] cmd: List[str],
cwd=None, # type: Optional[str] cwd: Optional[str] = None,
extra_environ=None, # type: Optional[Mapping[str, Any]] extra_environ: Optional[Mapping[str, Any]] = None,
): ) -> None:
# type: (...) -> None
with open_spinner(message) as spinner: with open_spinner(message) as spinner:
call_subprocess( call_subprocess(
cmd, cmd,

View file

@ -52,8 +52,7 @@ SCP_REGEX = re.compile(
) )
def looks_like_hash(sha): def looks_like_hash(sha: str) -> bool:
# type: (str) -> bool
return bool(HASH_REGEX.match(sha)) return bool(HASH_REGEX.match(sha))
@ -74,12 +73,10 @@ class Git(VersionControl):
default_arg_rev = "HEAD" default_arg_rev = "HEAD"
@staticmethod @staticmethod
def get_base_rev_args(rev): def get_base_rev_args(rev: str) -> List[str]:
# type: (str) -> List[str]
return [rev] return [rev]
def is_immutable_rev_checkout(self, url, dest): def is_immutable_rev_checkout(self, url: str, dest: str) -> bool:
# type: (str, str) -> bool
_, rev_options = self.get_url_rev_options(hide_url(url)) _, rev_options = self.get_url_rev_options(hide_url(url))
if not rev_options.rev: if not rev_options.rev:
return False return False
@ -101,8 +98,7 @@ class Git(VersionControl):
return tuple(int(c) for c in match.groups()) return tuple(int(c) for c in match.groups())
@classmethod @classmethod
def get_current_branch(cls, location): def get_current_branch(cls, location: str) -> Optional[str]:
# type: (str) -> Optional[str]
""" """
Return the current branch, or None if HEAD isn't at a branch Return the current branch, or None if HEAD isn't at a branch
(e.g. detached HEAD). (e.g. detached HEAD).
@ -127,8 +123,7 @@ class Git(VersionControl):
return None return None
@classmethod @classmethod
def get_revision_sha(cls, dest, rev): def get_revision_sha(cls, dest: str, rev: str) -> Tuple[Optional[str], bool]:
# type: (str, str) -> Tuple[Optional[str], bool]
""" """
Return (sha_or_none, is_branch), where sha_or_none is a commit hash Return (sha_or_none, is_branch), where sha_or_none is a commit hash
if the revision names a remote branch or tag, otherwise None. if the revision names a remote branch or tag, otherwise None.
@ -174,8 +169,7 @@ class Git(VersionControl):
return (sha, False) return (sha, False)
@classmethod @classmethod
def _should_fetch(cls, dest, rev): def _should_fetch(cls, dest: str, rev: str) -> bool:
# type: (str, str) -> bool
""" """
Return true if rev is a ref or is a commit that we don't have locally. Return true if rev is a ref or is a commit that we don't have locally.
@ -198,8 +192,9 @@ class Git(VersionControl):
return True return True
@classmethod @classmethod
def resolve_revision(cls, dest, url, rev_options): def resolve_revision(
# type: (str, HiddenText, RevOptions) -> RevOptions cls, dest: str, url: HiddenText, rev_options: RevOptions
) -> RevOptions:
""" """
Resolve a revision to a new RevOptions object with the SHA1 of the Resolve a revision to a new RevOptions object with the SHA1 of the
branch, tag, or ref if found. branch, tag, or ref if found.
@ -243,8 +238,7 @@ class Git(VersionControl):
return rev_options return rev_options
@classmethod @classmethod
def is_commit_id_equal(cls, dest, name): def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool:
# type: (str, Optional[str]) -> bool
""" """
Return whether the current commit hash equals the given name. Return whether the current commit hash equals the given name.
@ -258,8 +252,7 @@ class Git(VersionControl):
return cls.get_revision(dest) == name return cls.get_revision(dest) == name
def fetch_new(self, dest, url, rev_options): def fetch_new(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
rev_display = rev_options.to_display() rev_display = rev_options.to_display()
logger.info("Cloning %s%s to %s", url, rev_display, display_path(dest)) logger.info("Cloning %s%s to %s", url, rev_display, display_path(dest))
if self.get_git_version() >= (2, 17): if self.get_git_version() >= (2, 17):
@ -314,8 +307,7 @@ class Git(VersionControl):
#: repo may contain submodules #: repo may contain submodules
self.update_submodules(dest) self.update_submodules(dest)
def switch(self, dest, url, rev_options): def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
self.run_command( self.run_command(
make_command("config", "remote.origin.url", url), make_command("config", "remote.origin.url", url),
cwd=dest, cwd=dest,
@ -325,8 +317,7 @@ class Git(VersionControl):
self.update_submodules(dest) self.update_submodules(dest)
def update(self, dest, url, rev_options): def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
# First fetch changes from the default remote # First fetch changes from the default remote
if self.get_git_version() >= (1, 9): if self.get_git_version() >= (1, 9):
# fetch tags in addition to everything else # fetch tags in addition to everything else
@ -341,8 +332,7 @@ class Git(VersionControl):
self.update_submodules(dest) self.update_submodules(dest)
@classmethod @classmethod
def get_remote_url(cls, location): def get_remote_url(cls, location: str) -> str:
# type: (str) -> str
""" """
Return URL of the first remote encountered. Return URL of the first remote encountered.
@ -372,8 +362,7 @@ class Git(VersionControl):
return cls._git_remote_to_pip_url(url.strip()) return cls._git_remote_to_pip_url(url.strip())
@staticmethod @staticmethod
def _git_remote_to_pip_url(url): def _git_remote_to_pip_url(url: str) -> str:
# type: (str) -> str
""" """
Convert a remote url from what git uses to what pip accepts. Convert a remote url from what git uses to what pip accepts.
@ -404,8 +393,7 @@ class Git(VersionControl):
raise RemoteNotValidError(url) raise RemoteNotValidError(url)
@classmethod @classmethod
def has_commit(cls, location, rev): def has_commit(cls, location: str, rev: str) -> bool:
# type: (str, str) -> bool
""" """
Check if rev is a commit that is available in the local repository. Check if rev is a commit that is available in the local repository.
""" """
@ -421,8 +409,7 @@ class Git(VersionControl):
return True return True
@classmethod @classmethod
def get_revision(cls, location, rev=None): def get_revision(cls, location: str, rev: Optional[str] = None) -> str:
# type: (str, Optional[str]) -> str
if rev is None: if rev is None:
rev = "HEAD" rev = "HEAD"
current_rev = cls.run_command( current_rev = cls.run_command(
@ -434,8 +421,7 @@ class Git(VersionControl):
return current_rev.strip() return current_rev.strip()
@classmethod @classmethod
def get_subdirectory(cls, location): def get_subdirectory(cls, location: str) -> Optional[str]:
# type: (str) -> Optional[str]
""" """
Return the path to Python project root, relative to the repo root. Return the path to Python project root, relative to the repo root.
Return None if the project root is in the repo root. Return None if the project root is in the repo root.
@ -453,8 +439,7 @@ class Git(VersionControl):
return find_path_to_project_root_from_repo_root(location, repo_root) return find_path_to_project_root_from_repo_root(location, repo_root)
@classmethod @classmethod
def get_url_rev_and_auth(cls, url): def get_url_rev_and_auth(cls, url: str) -> Tuple[str, Optional[str], AuthInfo]:
# type: (str) -> Tuple[str, Optional[str], AuthInfo]
""" """
Prefixes stub URLs like 'user@hostname:user/repo.git' with 'ssh://'. Prefixes stub URLs like 'user@hostname:user/repo.git' with 'ssh://'.
That's required because although they use SSH they sometimes don't That's required because although they use SSH they sometimes don't
@ -485,8 +470,7 @@ class Git(VersionControl):
return url, rev, user_pass return url, rev, user_pass
@classmethod @classmethod
def update_submodules(cls, location): def update_submodules(cls, location: str) -> None:
# type: (str) -> None
if not os.path.exists(os.path.join(location, ".gitmodules")): if not os.path.exists(os.path.join(location, ".gitmodules")):
return return
cls.run_command( cls.run_command(
@ -495,8 +479,7 @@ class Git(VersionControl):
) )
@classmethod @classmethod
def get_repository_root(cls, location): def get_repository_root(cls, location: str) -> Optional[str]:
# type: (str) -> Optional[str]
loc = super().get_repository_root(location) loc = super().get_repository_root(location)
if loc: if loc:
return loc return loc
@ -521,8 +504,7 @@ class Git(VersionControl):
return os.path.normpath(r.rstrip("\r\n")) return os.path.normpath(r.rstrip("\r\n"))
@staticmethod @staticmethod
def should_add_vcs_url_prefix(repo_url): def should_add_vcs_url_prefix(repo_url: str) -> bool:
# type: (str) -> bool
"""In either https or ssh form, requirements must be prefixed with git+.""" """In either https or ssh form, requirements must be prefixed with git+."""
return True return True

View file

@ -34,18 +34,15 @@ class Subversion(VersionControl):
schemes = ("svn+ssh", "svn+http", "svn+https", "svn+svn", "svn+file") schemes = ("svn+ssh", "svn+http", "svn+https", "svn+svn", "svn+file")
@classmethod @classmethod
def should_add_vcs_url_prefix(cls, remote_url): def should_add_vcs_url_prefix(cls, remote_url: str) -> bool:
# type: (str) -> bool
return True return True
@staticmethod @staticmethod
def get_base_rev_args(rev): def get_base_rev_args(rev: str) -> List[str]:
# type: (str) -> List[str]
return ["-r", rev] return ["-r", rev]
@classmethod @classmethod
def get_revision(cls, location): def get_revision(cls, location: str) -> str:
# type: (str) -> str
""" """
Return the maximum revision for all files under a given location Return the maximum revision for all files under a given location
""" """
@ -74,8 +71,9 @@ class Subversion(VersionControl):
return str(revision) return str(revision)
@classmethod @classmethod
def get_netloc_and_auth(cls, netloc, scheme): def get_netloc_and_auth(
# type: (str, str) -> Tuple[str, Tuple[Optional[str], Optional[str]]] cls, netloc: str, scheme: str
) -> Tuple[str, Tuple[Optional[str], Optional[str]]]:
""" """
This override allows the auth information to be passed to svn via the This override allows the auth information to be passed to svn via the
--username and --password options instead of via the URL. --username and --password options instead of via the URL.
@ -88,8 +86,7 @@ class Subversion(VersionControl):
return split_auth_from_netloc(netloc) return split_auth_from_netloc(netloc)
@classmethod @classmethod
def get_url_rev_and_auth(cls, url): def get_url_rev_and_auth(cls, url: str) -> Tuple[str, Optional[str], AuthInfo]:
# type: (str) -> Tuple[str, Optional[str], AuthInfo]
# hotfix the URL scheme after removing svn+ from svn+ssh:// readd it # hotfix the URL scheme after removing svn+ from svn+ssh:// readd it
url, rev, user_pass = super().get_url_rev_and_auth(url) url, rev, user_pass = super().get_url_rev_and_auth(url)
if url.startswith("ssh://"): if url.startswith("ssh://"):
@ -97,9 +94,10 @@ class Subversion(VersionControl):
return url, rev, user_pass return url, rev, user_pass
@staticmethod @staticmethod
def make_rev_args(username, password): def make_rev_args(
# type: (Optional[str], Optional[HiddenText]) -> CommandArgs username: Optional[str], password: Optional[HiddenText]
extra_args = [] # type: CommandArgs ) -> CommandArgs:
extra_args: CommandArgs = []
if username: if username:
extra_args += ["--username", username] extra_args += ["--username", username]
if password: if password:
@ -108,8 +106,7 @@ class Subversion(VersionControl):
return extra_args return extra_args
@classmethod @classmethod
def get_remote_url(cls, location): def get_remote_url(cls, location: str) -> str:
# type: (str) -> str
# In cases where the source is in a subdirectory, we have to look up in # In cases where the source is in a subdirectory, we have to look up in
# the location until we find a valid project root. # the location until we find a valid project root.
orig_location = location orig_location = location
@ -133,8 +130,7 @@ class Subversion(VersionControl):
return url return url
@classmethod @classmethod
def _get_svn_url_rev(cls, location): def _get_svn_url_rev(cls, location: str) -> Tuple[Optional[str], int]:
# type: (str) -> Tuple[Optional[str], int]
from pip._internal.exceptions import InstallationError from pip._internal.exceptions import InstallationError
entries_path = os.path.join(location, cls.dirname, "entries") entries_path = os.path.join(location, cls.dirname, "entries")
@ -184,13 +180,11 @@ class Subversion(VersionControl):
return url, rev return url, rev
@classmethod @classmethod
def is_commit_id_equal(cls, dest, name): def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool:
# type: (str, Optional[str]) -> bool
"""Always assume the versions don't match""" """Always assume the versions don't match"""
return False return False
def __init__(self, use_interactive=None): def __init__(self, use_interactive: bool = None) -> None:
# type: (bool) -> None
if use_interactive is None: if use_interactive is None:
use_interactive = is_console_interactive() use_interactive = is_console_interactive()
self.use_interactive = use_interactive self.use_interactive = use_interactive
@ -200,12 +194,11 @@ class Subversion(VersionControl):
# Special value definitions: # Special value definitions:
# None: Not evaluated yet. # None: Not evaluated yet.
# Empty tuple: Could not parse version. # Empty tuple: Could not parse version.
self._vcs_version = None # type: Optional[Tuple[int, ...]] self._vcs_version: Optional[Tuple[int, ...]] = None
super().__init__() super().__init__()
def call_vcs_version(self): def call_vcs_version(self) -> Tuple[int, ...]:
# type: () -> Tuple[int, ...]
"""Query the version of the currently installed Subversion client. """Query the version of the currently installed Subversion client.
:return: A tuple containing the parts of the version information or :return: A tuple containing the parts of the version information or
@ -233,8 +226,7 @@ class Subversion(VersionControl):
return parsed_version return parsed_version
def get_vcs_version(self): def get_vcs_version(self) -> Tuple[int, ...]:
# type: () -> Tuple[int, ...]
"""Return the version of the currently installed Subversion client. """Return the version of the currently installed Subversion client.
If the version of the Subversion client has already been queried, If the version of the Subversion client has already been queried,
@ -254,8 +246,7 @@ class Subversion(VersionControl):
self._vcs_version = vcs_version self._vcs_version = vcs_version
return vcs_version return vcs_version
def get_remote_call_options(self): def get_remote_call_options(self) -> CommandArgs:
# type: () -> CommandArgs
"""Return options to be used on calls to Subversion that contact the server. """Return options to be used on calls to Subversion that contact the server.
These options are applicable for the following ``svn`` subcommands used These options are applicable for the following ``svn`` subcommands used
@ -286,8 +277,7 @@ class Subversion(VersionControl):
return [] return []
def fetch_new(self, dest, url, rev_options): def fetch_new(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
rev_display = rev_options.to_display() rev_display = rev_options.to_display()
logger.info( logger.info(
"Checking out %s%s to %s", "Checking out %s%s to %s",
@ -305,8 +295,7 @@ class Subversion(VersionControl):
) )
self.run_command(cmd_args) self.run_command(cmd_args)
def switch(self, dest, url, rev_options): def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
cmd_args = make_command( cmd_args = make_command(
"switch", "switch",
self.get_remote_call_options(), self.get_remote_call_options(),
@ -316,8 +305,7 @@ class Subversion(VersionControl):
) )
self.run_command(cmd_args) self.run_command(cmd_args)
def update(self, dest, url, rev_options): def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
cmd_args = make_command( cmd_args = make_command(
"update", "update",
self.get_remote_call_options(), self.get_remote_call_options(),

View file

@ -49,8 +49,7 @@ logger = logging.getLogger(__name__)
AuthInfo = Tuple[Optional[str], Optional[str]] AuthInfo = Tuple[Optional[str], Optional[str]]
def is_url(name): def is_url(name: str) -> bool:
# type: (str) -> bool
""" """
Return true if the name looks like a URL. Return true if the name looks like a URL.
""" """
@ -60,8 +59,9 @@ def is_url(name):
return scheme in ["http", "https", "file", "ftp"] + vcs.all_schemes return scheme in ["http", "https", "file", "ftp"] + vcs.all_schemes
def make_vcs_requirement_url(repo_url, rev, project_name, subdir=None): def make_vcs_requirement_url(
# type: (str, str, str, Optional[str]) -> str repo_url: str, rev: str, project_name: str, subdir: Optional[str] = None
) -> str:
""" """
Return the URL for a VCS requirement. Return the URL for a VCS requirement.
@ -77,8 +77,9 @@ def make_vcs_requirement_url(repo_url, rev, project_name, subdir=None):
return req return req
def find_path_to_project_root_from_repo_root(location, repo_root): def find_path_to_project_root_from_repo_root(
# type: (str, str) -> Optional[str] location: str, repo_root: str
) -> Optional[str]:
""" """
Find the the Python project's root by searching up the filesystem from Find the the Python project's root by searching up the filesystem from
`location`. Return the path to project root relative to `repo_root`. `location`. Return the path to project root relative to `repo_root`.
@ -126,11 +127,10 @@ class RevOptions:
def __init__( def __init__(
self, self,
vc_class, # type: Type[VersionControl] vc_class: Type["VersionControl"],
rev=None, # type: Optional[str] rev: Optional[str] = None,
extra_args=None, # type: Optional[CommandArgs] extra_args: Optional[CommandArgs] = None,
): ) -> None:
# type: (...) -> None
""" """
Args: Args:
vc_class: a VersionControl subclass. vc_class: a VersionControl subclass.
@ -143,26 +143,23 @@ class RevOptions:
self.extra_args = extra_args self.extra_args = extra_args
self.rev = rev self.rev = rev
self.vc_class = vc_class self.vc_class = vc_class
self.branch_name = None # type: Optional[str] self.branch_name: Optional[str] = None
def __repr__(self): def __repr__(self) -> str:
# type: () -> str
return f"<RevOptions {self.vc_class.name}: rev={self.rev!r}>" return f"<RevOptions {self.vc_class.name}: rev={self.rev!r}>"
@property @property
def arg_rev(self): def arg_rev(self) -> Optional[str]:
# type: () -> Optional[str]
if self.rev is None: if self.rev is None:
return self.vc_class.default_arg_rev return self.vc_class.default_arg_rev
return self.rev return self.rev
def to_args(self): def to_args(self) -> CommandArgs:
# type: () -> CommandArgs
""" """
Return the VCS-specific command arguments. Return the VCS-specific command arguments.
""" """
args = [] # type: CommandArgs args: CommandArgs = []
rev = self.arg_rev rev = self.arg_rev
if rev is not None: if rev is not None:
args += self.vc_class.get_base_rev_args(rev) args += self.vc_class.get_base_rev_args(rev)
@ -170,15 +167,13 @@ class RevOptions:
return args return args
def to_display(self): def to_display(self) -> str:
# type: () -> str
if not self.rev: if not self.rev:
return "" return ""
return f" (to revision {self.rev})" return f" (to revision {self.rev})"
def make_new(self, rev): def make_new(self, rev: str) -> "RevOptions":
# type: (str) -> RevOptions
""" """
Make a copy of the current instance, but with a new rev. Make a copy of the current instance, but with a new rev.
@ -189,40 +184,34 @@ class RevOptions:
class VcsSupport: class VcsSupport:
_registry = {} # type: Dict[str, VersionControl] _registry: Dict[str, "VersionControl"] = {}
schemes = ["ssh", "git", "hg", "bzr", "sftp", "svn"] schemes = ["ssh", "git", "hg", "bzr", "sftp", "svn"]
def __init__(self): def __init__(self) -> None:
# type: () -> None
# Register more schemes with urlparse for various version control # Register more schemes with urlparse for various version control
# systems # systems
urllib.parse.uses_netloc.extend(self.schemes) urllib.parse.uses_netloc.extend(self.schemes)
super().__init__() super().__init__()
def __iter__(self): def __iter__(self) -> Iterator[str]:
# type: () -> Iterator[str]
return self._registry.__iter__() return self._registry.__iter__()
@property @property
def backends(self): def backends(self) -> List["VersionControl"]:
# type: () -> List[VersionControl]
return list(self._registry.values()) return list(self._registry.values())
@property @property
def dirnames(self): def dirnames(self) -> List[str]:
# type: () -> List[str]
return [backend.dirname for backend in self.backends] return [backend.dirname for backend in self.backends]
@property @property
def all_schemes(self): def all_schemes(self) -> List[str]:
# type: () -> List[str] schemes: List[str] = []
schemes = [] # type: List[str]
for backend in self.backends: for backend in self.backends:
schemes.extend(backend.schemes) schemes.extend(backend.schemes)
return schemes return schemes
def register(self, cls): def register(self, cls: Type["VersionControl"]) -> None:
# type: (Type[VersionControl]) -> None
if not hasattr(cls, "name"): if not hasattr(cls, "name"):
logger.warning("Cannot register VCS %s", cls.__name__) logger.warning("Cannot register VCS %s", cls.__name__)
return return
@ -230,13 +219,11 @@ class VcsSupport:
self._registry[cls.name] = cls() self._registry[cls.name] = cls()
logger.debug("Registered VCS backend: %s", cls.name) logger.debug("Registered VCS backend: %s", cls.name)
def unregister(self, name): def unregister(self, name: str) -> None:
# type: (str) -> None
if name in self._registry: if name in self._registry:
del self._registry[name] del self._registry[name]
def get_backend_for_dir(self, location): def get_backend_for_dir(self, location: str) -> Optional["VersionControl"]:
# type: (str) -> Optional[VersionControl]
""" """
Return a VersionControl object if a repository of that type is found Return a VersionControl object if a repository of that type is found
at the given directory. at the given directory.
@ -259,8 +246,7 @@ class VcsSupport:
inner_most_repo_path = max(vcs_backends, key=len) inner_most_repo_path = max(vcs_backends, key=len)
return vcs_backends[inner_most_repo_path] return vcs_backends[inner_most_repo_path]
def get_backend_for_scheme(self, scheme): def get_backend_for_scheme(self, scheme: str) -> Optional["VersionControl"]:
# type: (str) -> Optional[VersionControl]
""" """
Return a VersionControl object or None. Return a VersionControl object or None.
""" """
@ -269,8 +255,7 @@ class VcsSupport:
return vcs_backend return vcs_backend
return None return None
def get_backend(self, name): def get_backend(self, name: str) -> Optional["VersionControl"]:
# type: (str) -> Optional[VersionControl]
""" """
Return a VersionControl object or None. Return a VersionControl object or None.
""" """
@ -286,14 +271,13 @@ class VersionControl:
dirname = "" dirname = ""
repo_name = "" repo_name = ""
# List of supported schemes for this Version Control # List of supported schemes for this Version Control
schemes = () # type: Tuple[str, ...] schemes: Tuple[str, ...] = ()
# Iterable of environment variable names to pass to call_subprocess(). # Iterable of environment variable names to pass to call_subprocess().
unset_environ = () # type: Tuple[str, ...] unset_environ: Tuple[str, ...] = ()
default_arg_rev = None # type: Optional[str] default_arg_rev: Optional[str] = None
@classmethod @classmethod
def should_add_vcs_url_prefix(cls, remote_url): def should_add_vcs_url_prefix(cls, remote_url: str) -> bool:
# type: (str) -> bool
""" """
Return whether the vcs prefix (e.g. "git+") should be added to a Return whether the vcs prefix (e.g. "git+") should be added to a
repository's remote url when used in a requirement. repository's remote url when used in a requirement.
@ -301,8 +285,7 @@ class VersionControl:
return not remote_url.lower().startswith(f"{cls.name}:") return not remote_url.lower().startswith(f"{cls.name}:")
@classmethod @classmethod
def get_subdirectory(cls, location): def get_subdirectory(cls, location: str) -> Optional[str]:
# type: (str) -> Optional[str]
""" """
Return the path to Python project root, relative to the repo root. Return the path to Python project root, relative to the repo root.
Return None if the project root is in the repo root. Return None if the project root is in the repo root.
@ -310,16 +293,14 @@ class VersionControl:
return None return None
@classmethod @classmethod
def get_requirement_revision(cls, repo_dir): def get_requirement_revision(cls, repo_dir: str) -> str:
# type: (str) -> str
""" """
Return the revision string that should be used in a requirement. Return the revision string that should be used in a requirement.
""" """
return cls.get_revision(repo_dir) return cls.get_revision(repo_dir)
@classmethod @classmethod
def get_src_requirement(cls, repo_dir, project_name): def get_src_requirement(cls, repo_dir: str, project_name: str) -> str:
# type: (str, str) -> str
""" """
Return the requirement string to use to redownload the files Return the requirement string to use to redownload the files
currently at the given repository directory. currently at the given repository directory.
@ -343,8 +324,7 @@ class VersionControl:
return req return req
@staticmethod @staticmethod
def get_base_rev_args(rev): def get_base_rev_args(rev: str) -> List[str]:
# type: (str) -> List[str]
""" """
Return the base revision arguments for a vcs command. Return the base revision arguments for a vcs command.
@ -353,8 +333,7 @@ class VersionControl:
""" """
raise NotImplementedError raise NotImplementedError
def is_immutable_rev_checkout(self, url, dest): def is_immutable_rev_checkout(self, url: str, dest: str) -> bool:
# type: (str, str) -> bool
""" """
Return true if the commit hash checked out at dest matches Return true if the commit hash checked out at dest matches
the revision in url. the revision in url.
@ -368,8 +347,9 @@ class VersionControl:
return False return False
@classmethod @classmethod
def make_rev_options(cls, rev=None, extra_args=None): def make_rev_options(
# type: (Optional[str], Optional[CommandArgs]) -> RevOptions cls, rev: Optional[str] = None, extra_args: Optional[CommandArgs] = None
) -> RevOptions:
""" """
Return a RevOptions object. Return a RevOptions object.
@ -380,8 +360,7 @@ class VersionControl:
return RevOptions(cls, rev, extra_args=extra_args) return RevOptions(cls, rev, extra_args=extra_args)
@classmethod @classmethod
def _is_local_repository(cls, repo): def _is_local_repository(cls, repo: str) -> bool:
# type: (str) -> bool
""" """
posix absolute paths start with os.path.sep, posix absolute paths start with os.path.sep,
win32 ones start with drive (like c:\\folder) win32 ones start with drive (like c:\\folder)
@ -390,8 +369,9 @@ class VersionControl:
return repo.startswith(os.path.sep) or bool(drive) return repo.startswith(os.path.sep) or bool(drive)
@classmethod @classmethod
def get_netloc_and_auth(cls, netloc, scheme): def get_netloc_and_auth(
# type: (str, str) -> Tuple[str, Tuple[Optional[str], Optional[str]]] cls, netloc: str, scheme: str
) -> Tuple[str, Tuple[Optional[str], Optional[str]]]:
""" """
Parse the repository URL's netloc, and return the new netloc to use Parse the repository URL's netloc, and return the new netloc to use
along with auth information. along with auth information.
@ -410,8 +390,7 @@ class VersionControl:
return netloc, (None, None) return netloc, (None, None)
@classmethod @classmethod
def get_url_rev_and_auth(cls, url): def get_url_rev_and_auth(cls, url: str) -> Tuple[str, Optional[str], AuthInfo]:
# type: (str) -> Tuple[str, Optional[str], AuthInfo]
""" """
Parse the repository URL to use, and return the URL, revision, Parse the repository URL to use, and return the URL, revision,
and auth info to use. and auth info to use.
@ -441,22 +420,22 @@ class VersionControl:
return url, rev, user_pass return url, rev, user_pass
@staticmethod @staticmethod
def make_rev_args(username, password): def make_rev_args(
# type: (Optional[str], Optional[HiddenText]) -> CommandArgs username: Optional[str], password: Optional[HiddenText]
) -> CommandArgs:
""" """
Return the RevOptions "extra arguments" to use in obtain(). Return the RevOptions "extra arguments" to use in obtain().
""" """
return [] return []
def get_url_rev_options(self, url): def get_url_rev_options(self, url: HiddenText) -> Tuple[HiddenText, RevOptions]:
# type: (HiddenText) -> Tuple[HiddenText, RevOptions]
""" """
Return the URL and RevOptions object to use in obtain(), Return the URL and RevOptions object to use in obtain(),
as a tuple (url, rev_options). as a tuple (url, rev_options).
""" """
secret_url, rev, user_pass = self.get_url_rev_and_auth(url.secret) secret_url, rev, user_pass = self.get_url_rev_and_auth(url.secret)
username, secret_password = user_pass username, secret_password = user_pass
password = None # type: Optional[HiddenText] password: Optional[HiddenText] = None
if secret_password is not None: if secret_password is not None:
password = hide_value(secret_password) password = hide_value(secret_password)
extra_args = self.make_rev_args(username, password) extra_args = self.make_rev_args(username, password)
@ -465,8 +444,7 @@ class VersionControl:
return hide_url(secret_url), rev_options return hide_url(secret_url), rev_options
@staticmethod @staticmethod
def normalize_url(url): def normalize_url(url: str) -> str:
# type: (str) -> str
""" """
Normalize a URL for comparison by unquoting it and removing any Normalize a URL for comparison by unquoting it and removing any
trailing slash. trailing slash.
@ -474,15 +452,13 @@ class VersionControl:
return urllib.parse.unquote(url).rstrip("/") return urllib.parse.unquote(url).rstrip("/")
@classmethod @classmethod
def compare_urls(cls, url1, url2): def compare_urls(cls, url1: str, url2: str) -> bool:
# type: (str, str) -> bool
""" """
Compare two repo URLs for identity, ignoring incidental differences. Compare two repo URLs for identity, ignoring incidental differences.
""" """
return cls.normalize_url(url1) == cls.normalize_url(url2) return cls.normalize_url(url1) == cls.normalize_url(url2)
def fetch_new(self, dest, url, rev_options): def fetch_new(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
""" """
Fetch a revision from a repository, in the case that this is the Fetch a revision from a repository, in the case that this is the
first fetch from the repository. first fetch from the repository.
@ -493,8 +469,7 @@ class VersionControl:
""" """
raise NotImplementedError raise NotImplementedError
def switch(self, dest, url, rev_options): def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
""" """
Switch the repo at ``dest`` to point to ``URL``. Switch the repo at ``dest`` to point to ``URL``.
@ -503,8 +478,7 @@ class VersionControl:
""" """
raise NotImplementedError raise NotImplementedError
def update(self, dest, url, rev_options): def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None:
# type: (str, HiddenText, RevOptions) -> None
""" """
Update an already-existing repo to the given ``rev_options``. Update an already-existing repo to the given ``rev_options``.
@ -514,8 +488,7 @@ class VersionControl:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def is_commit_id_equal(cls, dest, name): def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool:
# type: (str, Optional[str]) -> bool
""" """
Return whether the id of the current commit equals the given name. Return whether the id of the current commit equals the given name.
@ -525,8 +498,7 @@ class VersionControl:
""" """
raise NotImplementedError raise NotImplementedError
def obtain(self, dest, url): def obtain(self, dest: str, url: HiddenText) -> None:
# type: (str, HiddenText) -> None
""" """
Install or update in editable mode the package represented by this Install or update in editable mode the package represented by this
VersionControl object. VersionControl object.
@ -614,8 +586,7 @@ class VersionControl:
) )
self.switch(dest, url, rev_options) self.switch(dest, url, rev_options)
def unpack(self, location, url): def unpack(self, location: str, url: HiddenText) -> None:
# type: (str, HiddenText) -> None
""" """
Clean up current location and download the url repository Clean up current location and download the url repository
(and vcs infos) into location (and vcs infos) into location
@ -627,8 +598,7 @@ class VersionControl:
self.obtain(location, url=url) self.obtain(location, url=url)
@classmethod @classmethod
def get_remote_url(cls, location): def get_remote_url(cls, location: str) -> str:
# type: (str) -> str
""" """
Return the url used at location Return the url used at location
@ -638,8 +608,7 @@ class VersionControl:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def get_revision(cls, location): def get_revision(cls, location: str) -> str:
# type: (str) -> str
""" """
Return the current commit id of the files at the given location. Return the current commit id of the files at the given location.
""" """
@ -648,18 +617,17 @@ class VersionControl:
@classmethod @classmethod
def run_command( def run_command(
cls, cls,
cmd, # type: Union[List[str], CommandArgs] cmd: Union[List[str], CommandArgs],
show_stdout=True, # type: bool show_stdout: bool = True,
cwd=None, # type: Optional[str] cwd: Optional[str] = None,
on_returncode="raise", # type: Literal["raise", "warn", "ignore"] on_returncode: 'Literal["raise", "warn", "ignore"]' = "raise",
extra_ok_returncodes=None, # type: Optional[Iterable[int]] extra_ok_returncodes: Optional[Iterable[int]] = None,
command_desc=None, # type: Optional[str] command_desc: Optional[str] = None,
extra_environ=None, # type: Optional[Mapping[str, Any]] extra_environ: Optional[Mapping[str, Any]] = None,
spinner=None, # type: Optional[SpinnerInterface] spinner: Optional[SpinnerInterface] = None,
log_failed_cmd=True, # type: bool log_failed_cmd: bool = True,
stdout_only=False, # type: bool stdout_only: bool = False,
): ) -> str:
# type: (...) -> str
""" """
Run a VCS subcommand Run a VCS subcommand
This is simply a wrapper around call_subprocess that adds the VCS This is simply a wrapper around call_subprocess that adds the VCS
@ -701,8 +669,7 @@ class VersionControl:
) )
@classmethod @classmethod
def is_repository_directory(cls, path): def is_repository_directory(cls, path: str) -> bool:
# type: (str) -> bool
""" """
Return whether a directory path is a repository directory. Return whether a directory path is a repository directory.
""" """
@ -710,8 +677,7 @@ class VersionControl:
return os.path.exists(os.path.join(path, cls.dirname)) return os.path.exists(os.path.join(path, cls.dirname))
@classmethod @classmethod
def get_repository_root(cls, location): def get_repository_root(cls, location: str) -> Optional[str]:
# type: (str) -> Optional[str]
""" """
Return the "root" (top-level) directory controlled by the vcs, Return the "root" (top-level) directory controlled by the vcs,
or `None` if the directory is not in any. or `None` if the directory is not in any.