Complete type annotations in `pip/_internal/network` (#10184)

This commit is contained in:
Harutaka Kawamura 2021-07-23 19:27:28 +09:00 committed by GitHub
parent 0fb0e3b547
commit f8a7439528
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 168 deletions

View File

@ -36,8 +36,7 @@ except Exception as exc:
keyring = None
def get_keyring_auth(url, username):
# type: (Optional[str], Optional[str]) -> Optional[AuthInfo]
def get_keyring_auth(url: Optional[str], username: Optional[str]) -> Optional[AuthInfo]:
"""Return the tuple auth for a given url from keyring."""
global keyring
if not url or not keyring:
@ -71,20 +70,20 @@ def get_keyring_auth(url, username):
class MultiDomainBasicAuth(AuthBase):
def __init__(self, prompting=True, index_urls=None):
# type: (bool, Optional[List[str]]) -> None
def __init__(
self, prompting: bool = True, index_urls: Optional[List[str]] = None
) -> None:
self.prompting = prompting
self.index_urls = index_urls
self.passwords = {} # type: Dict[str, AuthInfo]
self.passwords: Dict[str, AuthInfo] = {}
# When the user is prompted to enter credentials and keyring is
# available, we will offer to save them. If the user accepts,
# this value is set to the credentials they entered. After the
# request authenticates, the caller should call
# ``save_credentials`` to save these.
self._credentials_to_save = None # type: Optional[Credentials]
self._credentials_to_save: Optional[Credentials] = None
def _get_index_url(self, url):
# type: (str) -> Optional[str]
def _get_index_url(self, url: str) -> Optional[str]:
"""Return the original index URL matching the requested URL.
Cached or dynamically generated credentials may work against
@ -106,9 +105,8 @@ class MultiDomainBasicAuth(AuthBase):
return u
return None
def _get_new_credentials(self, original_url, allow_netrc=True,
allow_keyring=False):
# type: (str, bool, bool) -> AuthInfo
def _get_new_credentials(self, original_url: str, allow_netrc: bool = True,
allow_keyring: bool = False) -> AuthInfo:
"""Find and return credentials for the specified URL."""
# Split the credentials and netloc from the url.
url, netloc, url_user_password = split_auth_netloc_from_url(
@ -157,8 +155,9 @@ class MultiDomainBasicAuth(AuthBase):
return username, password
def _get_url_and_credentials(self, original_url):
# type: (str) -> Tuple[str, Optional[str], Optional[str]]
def _get_url_and_credentials(
self, original_url: str
) -> Tuple[str, Optional[str], Optional[str]]:
"""Return the credentials to use for the provided URL.
If allowed, netrc and keyring may be used to obtain the
@ -197,8 +196,7 @@ class MultiDomainBasicAuth(AuthBase):
return url, username, password
def __call__(self, req):
# type: (Request) -> Request
def __call__(self, req: Request) -> Request:
# Get credentials for this request
url, username, password = self._get_url_and_credentials(req.url)
@ -215,8 +213,9 @@ class MultiDomainBasicAuth(AuthBase):
return req
# Factored out to allow for easy patching in tests
def _prompt_for_password(self, netloc):
# type: (str) -> Tuple[Optional[str], Optional[str], bool]
def _prompt_for_password(
self, netloc: str
) -> Tuple[Optional[str], Optional[str], bool]:
username = ask_input(f"User for {netloc}: ")
if not username:
return None, None, False
@ -227,14 +226,12 @@ class MultiDomainBasicAuth(AuthBase):
return username, password, True
# Factored out to allow for easy patching in tests
def _should_save_password_to_keyring(self):
# type: () -> bool
def _should_save_password_to_keyring(self) -> bool:
if not keyring:
return False
return ask("Save credentials to keyring [y/N]: ", ["y", "n"]) == "y"
def handle_401(self, resp, **kwargs):
# type: (Response, **Any) -> Response
def handle_401(self, resp: Response, **kwargs: Any) -> Response:
# We only care about 401 responses, anything else we want to just
# pass through the actual response
if resp.status_code != 401:
@ -286,16 +283,14 @@ class MultiDomainBasicAuth(AuthBase):
return new_resp
def warn_on_401(self, resp, **kwargs):
# type: (Response, **Any) -> None
def warn_on_401(self, resp: Response, **kwargs: Any) -> None:
"""Response callback to warn about incorrect credentials."""
if resp.status_code == 401:
logger.warning(
'401 Error, Credentials not correct for %s', resp.request.url,
)
def save_credentials(self, resp, **kwargs):
# type: (Response, **Any) -> None
def save_credentials(self, resp: Response, **kwargs: Any) -> None:
"""Response callback to save credentials on success."""
assert keyring is not None, "should never reach here without keyring"
if not keyring:

View File

@ -13,14 +13,12 @@ from pip._internal.utils.filesystem import adjacent_tmp_file, replace
from pip._internal.utils.misc import ensure_dir
def is_from_cache(response):
# type: (Response) -> bool
def is_from_cache(response: Response) -> bool:
return getattr(response, "from_cache", False)
@contextmanager
def suppressed_cache_errors():
# type: () -> Iterator[None]
def suppressed_cache_errors() -> Iterator[None]:
"""If we can't access the cache then we can just skip caching and process
requests as if caching wasn't enabled.
"""
@ -36,14 +34,12 @@ class SafeFileCache(BaseCache):
not be accessible or writable.
"""
def __init__(self, directory):
# type: (str) -> None
def __init__(self, directory: str) -> None:
assert directory is not None, "Cache directory must not be None."
super().__init__()
self.directory = directory
def _get_cache_path(self, name):
# type: (str) -> str
def _get_cache_path(self, name: str) -> str:
# From cachecontrol.caches.file_cache.FileCache._fn, brought into our
# class for backwards-compatibility and to avoid using a non-public
# method.
@ -51,15 +47,13 @@ class SafeFileCache(BaseCache):
parts = list(hashed[:5]) + [hashed]
return os.path.join(self.directory, *parts)
def get(self, key):
# type: (str) -> Optional[bytes]
def get(self, key: str) -> Optional[bytes]:
path = self._get_cache_path(key)
with suppressed_cache_errors():
with open(path, 'rb') as f:
return f.read()
def set(self, key, value):
# type: (str, bytes) -> None
def set(self, key: str, value: bytes) -> None:
path = self._get_cache_path(key)
with suppressed_cache_errors():
ensure_dir(os.path.dirname(path))
@ -69,8 +63,7 @@ class SafeFileCache(BaseCache):
replace(f.name, path)
def delete(self, key):
# type: (str) -> None
def delete(self, key: str) -> None:
path = self._get_cache_path(key)
with suppressed_cache_errors():
os.remove(path)

View File

@ -20,8 +20,7 @@ from pip._internal.utils.misc import format_size, redact_auth_from_url, splitext
logger = logging.getLogger(__name__)
def _get_http_response_size(resp):
# type: (Response) -> Optional[int]
def _get_http_response_size(resp: Response) -> Optional[int]:
try:
return int(resp.headers['content-length'])
except (ValueError, KeyError, TypeError):
@ -29,11 +28,10 @@ def _get_http_response_size(resp):
def _prepare_download(
resp, # type: Response
link, # type: Link
progress_bar # type: str
):
# type: (...) -> Iterable[bytes]
resp: Response,
link: Link,
progress_bar: str
) -> Iterable[bytes]:
total_length = _get_http_response_size(resp)
if link.netloc == PyPI.file_storage_domain:
@ -72,16 +70,14 @@ def _prepare_download(
)(chunks)
def sanitize_content_filename(filename):
# type: (str) -> str
def sanitize_content_filename(filename: str) -> str:
"""
Sanitize the "filename" value from a Content-Disposition header.
"""
return os.path.basename(filename)
def parse_content_disposition(content_disposition, default_filename):
# type: (str, str) -> str
def parse_content_disposition(content_disposition: str, default_filename: str) -> str:
"""
Parse the "filename" value from a Content-Disposition header, and
return the default filename if the result is empty.
@ -95,8 +91,7 @@ def parse_content_disposition(content_disposition, default_filename):
return filename or default_filename
def _get_http_response_filename(resp, link):
# type: (Response, Link) -> str
def _get_http_response_filename(resp: Response, link: Link) -> str:
"""Get an ideal filename from the given HTTP response, falling back to
the link filename if not provided.
"""
@ -105,7 +100,7 @@ def _get_http_response_filename(resp, link):
content_disposition = resp.headers.get('content-disposition')
if content_disposition:
filename = parse_content_disposition(content_disposition, filename)
ext = splitext(filename)[1] # type: Optional[str]
ext: Optional[str] = splitext(filename)[1]
if not ext:
ext = mimetypes.guess_extension(
resp.headers.get('content-type', '')
@ -119,8 +114,7 @@ def _get_http_response_filename(resp, link):
return filename
def _http_get_download(session, link):
# type: (PipSession, Link) -> Response
def _http_get_download(session: PipSession, link: Link) -> Response:
target_url = link.url.split('#', 1)[0]
resp = session.get(target_url, headers=HEADERS, stream=True)
raise_for_status(resp)
@ -130,15 +124,13 @@ def _http_get_download(session, link):
class Downloader:
def __init__(
self,
session, # type: PipSession
progress_bar, # type: str
):
# type: (...) -> None
session: PipSession,
progress_bar: str,
) -> None:
self._session = session
self._progress_bar = progress_bar
def __call__(self, link, location):
# type: (Link, str) -> Tuple[str, str]
def __call__(self, link: Link, location: str) -> Tuple[str, str]:
"""Download the file given by link into location."""
try:
resp = _http_get_download(self._session, link)
@ -164,15 +156,15 @@ class BatchDownloader:
def __init__(
self,
session, # type: PipSession
progress_bar, # type: str
):
# type: (...) -> None
session: PipSession,
progress_bar: str,
) -> None:
self._session = session
self._progress_bar = progress_bar
def __call__(self, links, location):
# type: (Iterable[Link], str) -> Iterable[Tuple[Link, Tuple[str, str]]]
def __call__(
self, links: Iterable[Link], location: str
) -> Iterable[Tuple[Link, Tuple[str, str]]]:
"""Download the files given by links into location."""
for link in links:
try:

View File

@ -20,8 +20,7 @@ class HTTPRangeRequestUnsupported(Exception):
pass
def dist_from_wheel_url(name, url, session):
# type: (str, str, PipSession) -> Distribution
def dist_from_wheel_url(name: str, url: str, session: PipSession) -> Distribution:
"""Return a pkg_resources.Distribution from the given wheel URL.
This uses HTTP range requests to only fetch the potion of the wheel
@ -47,8 +46,9 @@ class LazyZipOverHTTP:
during initialization.
"""
def __init__(self, url, session, chunk_size=CONTENT_CHUNK_SIZE):
# type: (str, PipSession, int) -> None
def __init__(
self, url: str, session: PipSession, chunk_size: int = CONTENT_CHUNK_SIZE
) -> None:
head = session.head(url, headers=HEADERS)
raise_for_status(head)
assert head.status_code == 200
@ -56,42 +56,36 @@ class LazyZipOverHTTP:
self._length = int(head.headers['Content-Length'])
self._file = NamedTemporaryFile()
self.truncate(self._length)
self._left = [] # type: List[int]
self._right = [] # type: List[int]
self._left: List[int] = []
self._right: List[int] = []
if 'bytes' not in head.headers.get('Accept-Ranges', 'none'):
raise HTTPRangeRequestUnsupported('range request is not supported')
self._check_zip()
@property
def mode(self):
# type: () -> str
def mode(self) -> str:
"""Opening mode, which is always rb."""
return 'rb'
@property
def name(self):
# type: () -> str
def name(self) -> str:
"""Path to the underlying file."""
return self._file.name
def seekable(self):
# type: () -> bool
def seekable(self) -> bool:
"""Return whether random access is supported, which is True."""
return True
def close(self):
# type: () -> None
def close(self) -> None:
"""Close the file."""
self._file.close()
@property
def closed(self):
# type: () -> bool
def closed(self) -> bool:
"""Whether the file is closed."""
return self._file.closed
def read(self, size=-1):
# type: (int) -> bytes
def read(self, size: int = -1) -> bytes:
"""Read up to size bytes from the object and return them.
As a convenience, if size is unspecified or -1,
@ -105,13 +99,11 @@ class LazyZipOverHTTP:
self._download(start, stop-1)
return self._file.read(size)
def readable(self):
# type: () -> bool
def readable(self) -> bool:
"""Return whether the file is readable, which is True."""
return True
def seek(self, offset, whence=0):
# type: (int, int) -> int
def seek(self, offset: int, whence: int = 0) -> int:
"""Change stream position and return the new absolute position.
Seek to offset relative position indicated by whence:
@ -121,13 +113,11 @@ class LazyZipOverHTTP:
"""
return self._file.seek(offset, whence)
def tell(self):
# type: () -> int
def tell(self) -> int:
"""Return the current position."""
return self._file.tell()
def truncate(self, size=None):
# type: (Optional[int]) -> int
def truncate(self, size: Optional[int] = None) -> int:
"""Resize the stream to the given size in bytes.
If size is unspecified resize to the current position.
@ -137,23 +127,19 @@ class LazyZipOverHTTP:
"""
return self._file.truncate(size)
def writable(self):
# type: () -> bool
def writable(self) -> bool:
"""Return False."""
return False
def __enter__(self):
# type: () -> LazyZipOverHTTP
def __enter__(self) -> "LazyZipOverHTTP":
self._file.__enter__()
return self
def __exit__(self, *exc):
# type: (*Any) -> Optional[bool]
def __exit__(self, *exc: Any) -> Optional[bool]:
return self._file.__exit__(*exc)
@contextmanager
def _stay(self):
# type: ()-> Iterator[None]
def _stay(self) -> Iterator[None]:
"""Return a context manager keeping the position.
At the end of the block, seek back to original position.
@ -164,8 +150,7 @@ class LazyZipOverHTTP:
finally:
self.seek(pos)
def _check_zip(self):
# type: () -> None
def _check_zip(self) -> None:
"""Check and download until the file is a valid ZIP."""
end = self._length - 1
for start in reversed(range(0, end, self._chunk_size)):
@ -180,8 +165,9 @@ class LazyZipOverHTTP:
else:
break
def _stream_response(self, start, end, base_headers=HEADERS):
# type: (int, int, Dict[str, str]) -> Response
def _stream_response(
self, start: int, end: int, base_headers: Dict[str, str] = HEADERS
) -> Response:
"""Return HTTP response to a range request from start to end."""
headers = base_headers.copy()
headers['Range'] = f'bytes={start}-{end}'
@ -189,8 +175,9 @@ class LazyZipOverHTTP:
headers['Cache-Control'] = 'no-cache'
return self._session.get(self._url, headers=headers, stream=True)
def _merge(self, start, end, left, right):
# type: (int, int, int, int) -> Iterator[Tuple[int, int]]
def _merge(
self, start: int, end: int, left: int, right: int
) -> Iterator[Tuple[int, int]]:
"""Return an iterator of intervals to be fetched.
Args:
@ -210,8 +197,7 @@ class LazyZipOverHTTP:
yield i, end
self._left[left:right], self._right[left:right] = [start], [end]
def _download(self, start, end):
# type: (int, int) -> None
def _download(self, start: int, end: int) -> None:
"""Download bytes from start to end inclusively."""
with self._stay():
left = bisect_left(self._right, start)

View File

@ -55,7 +55,7 @@ SecureOrigin = Tuple[str, str, Optional[Union[int, str]]]
warnings.filterwarnings("ignore", category=InsecureRequestWarning)
SECURE_ORIGINS = [
SECURE_ORIGINS: List[SecureOrigin] = [
# protocol, hostname, port
# Taken from Chrome's list of secure origins (See: http://bit.ly/1qrySKC)
("https", "*", "*"),
@ -65,7 +65,7 @@ SECURE_ORIGINS = [
("file", "*", None),
# ssh is always secure.
("ssh", "*", "*"),
] # type: List[SecureOrigin]
]
# These are environment variables present when running under various
@ -87,8 +87,7 @@ CI_ENVIRONMENT_VARIABLES = (
)
def looks_like_ci():
# type: () -> bool
def looks_like_ci() -> bool:
"""
Return whether it looks like pip is running under CI.
"""
@ -98,18 +97,17 @@ def looks_like_ci():
return any(name in os.environ for name in CI_ENVIRONMENT_VARIABLES)
def user_agent():
# type: () -> str
def user_agent() -> str:
"""
Return a string representing the user agent.
"""
data = {
data: Dict[str, Any] = {
"installer": {"name": "pip", "version": __version__},
"python": platform.python_version(),
"implementation": {
"name": platform.python_implementation(),
},
} # type: Dict[str, Any]
}
if data["implementation"]["name"] == 'CPython':
data["implementation"]["version"] = platform.python_version()
@ -200,14 +198,13 @@ class LocalFSAdapter(BaseAdapter):
def send(
self,
request, # type: PreparedRequest
stream=False, # type: bool
timeout=None, # type: Optional[Union[float, Tuple[float, float]]]
verify=True, # type: Union[bool, str]
cert=None, # type: Optional[Union[str, Tuple[str, str]]]
proxies=None, # type:Optional[Mapping[str, str]]
):
# type: (...) -> Response
request: PreparedRequest,
stream: bool = False,
timeout: Optional[Union[float, Tuple[float, float]]] = None,
verify: Union[bool, str] = True,
cert: Optional[Union[str, Tuple[str, str]]] = None,
proxies: Optional[Mapping[str, str]] = None,
) -> Response:
pathname = url_to_path(request.url)
resp = Response()
@ -233,8 +230,7 @@ class LocalFSAdapter(BaseAdapter):
return resp
def close(self):
# type: () -> None
def close(self) -> None:
pass
@ -242,12 +238,11 @@ class InsecureHTTPAdapter(HTTPAdapter):
def cert_verify(
self,
conn, # type: ConnectionPool
url, # type: str
verify, # type: Union[bool, str]
cert, # type: Optional[Union[str, Tuple[str, str]]]
):
# type: (...) -> None
conn: ConnectionPool,
url: str,
verify: Union[bool, str],
cert: Optional[Union[str, Tuple[str, str]]],
) -> None:
super().cert_verify(conn=conn, url=url, verify=False, cert=cert)
@ -255,29 +250,27 @@ class InsecureCacheControlAdapter(CacheControlAdapter):
def cert_verify(
self,
conn, # type: ConnectionPool
url, # type: str
verify, # type: Union[bool, str]
cert, # type: Optional[Union[str, Tuple[str, str]]]
):
# type: (...) -> None
conn: ConnectionPool,
url: str,
verify: Union[bool, str],
cert: Optional[Union[str, Tuple[str, str]]],
) -> None:
super().cert_verify(conn=conn, url=url, verify=False, cert=cert)
class PipSession(requests.Session):
timeout = None # type: Optional[int]
timeout: Optional[int] = None
def __init__(
self,
*args, # type: Any
retries=0, # type: int
cache=None, # type: Optional[str]
trusted_hosts=(), # type: Sequence[str]
index_urls=None, # type: Optional[List[str]]
**kwargs, # type: Any
):
# type: (...) -> None
*args: Any,
retries: int = 0,
cache: Optional[str] = None,
trusted_hosts: Sequence[str] = (),
index_urls: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""
:param trusted_hosts: Domains not to emit warnings for when not using
HTTPS.
@ -286,7 +279,7 @@ class PipSession(requests.Session):
# Namespace the attribute with "pip_" just in case to prevent
# possible conflicts with the base class.
self.pip_trusted_origins = [] # type: List[Tuple[str, Optional[int]]]
self.pip_trusted_origins: List[Tuple[str, Optional[int]]] = []
# Attach our User Agent to the request
self.headers["User-Agent"] = user_agent()
@ -348,16 +341,16 @@ class PipSession(requests.Session):
for host in trusted_hosts:
self.add_trusted_host(host, suppress_logging=True)
def update_index_urls(self, new_index_urls):
# type: (List[str]) -> None
def update_index_urls(self, new_index_urls: List[str]) -> None:
"""
:param new_index_urls: New index urls to update the authentication
handler with.
"""
self.auth.index_urls = new_index_urls
def add_trusted_host(self, host, source=None, suppress_logging=False):
# type: (str, Optional[str], bool) -> None
def add_trusted_host(
self, host: str, source: Optional[str] = None, suppress_logging: bool = False
) -> None:
"""
:param host: It is okay to provide a host that has previously been
added.
@ -385,14 +378,12 @@ class PipSession(requests.Session):
self._trusted_host_adapter
)
def iter_secure_origins(self):
# type: () -> Iterator[SecureOrigin]
def iter_secure_origins(self) -> Iterator[SecureOrigin]:
yield from SECURE_ORIGINS
for host, port in self.pip_trusted_origins:
yield ('*', host, '*' if port is None else port)
def is_secure_origin(self, location):
# type: (Link) -> bool
def is_secure_origin(self, location: Link) -> bool:
# Determine if this url used a secure transport mechanism
parsed = urllib.parse.urlparse(str(location))
origin_protocol, origin_host, origin_port = (
@ -457,8 +448,7 @@ class PipSession(requests.Session):
return False
def request(self, method, url, *args, **kwargs):
# type: (str, str, *Any, **Any) -> Response
def request(self, method: str, url: str, *args: Any, **kwargs: Any) -> Response:
# Allow setting a default timeout on a session
kwargs.setdefault("timeout", self.timeout)

View File

@ -23,11 +23,10 @@ from pip._internal.exceptions import NetworkConnectionError
# you're not asking for a compressed file and will then decompress it
# before sending because if that's the case I don't think it'll ever be
# possible to make this work.
HEADERS = {'Accept-Encoding': 'identity'} # type: Dict[str, str]
HEADERS: Dict[str, str] = {'Accept-Encoding': 'identity'}
def raise_for_status(resp):
# type: (Response) -> None
def raise_for_status(resp: Response) -> None:
http_error_msg = ''
if isinstance(resp.reason, bytes):
# We attempt to decode utf-8 first because some servers
@ -53,8 +52,9 @@ def raise_for_status(resp):
raise NetworkConnectionError(http_error_msg, response=resp)
def response_chunks(response, chunk_size=CONTENT_CHUNK_SIZE):
# type: (Response, int) -> Iterator[bytes]
def response_chunks(
response: Response, chunk_size: int = CONTENT_CHUNK_SIZE
) -> Iterator[bytes]:
"""Given a requests Response, provide the data chunks.
"""
try:

View File

@ -21,15 +21,21 @@ class PipXmlrpcTransport(xmlrpc.client.Transport):
object.
"""
def __init__(self, index_url, session, use_datetime=False):
# type: (str, PipSession, bool) -> None
def __init__(
self, index_url: str, session: PipSession, use_datetime: bool = False
) -> None:
super().__init__(use_datetime)
index_parts = urllib.parse.urlparse(index_url)
self._scheme = index_parts.scheme
self._session = session
def request(self, host, handler, request_body, verbose=False):
# type: (_HostType, str, bytes, bool) -> Tuple[_Marshallable, ...]
def request(
self,
host: "_HostType",
handler: str,
request_body: bytes,
verbose: bool = False,
) -> Tuple["_Marshallable", ...]:
assert isinstance(host, str)
parts = (self._scheme, host, handler, None, None, None)
url = urllib.parse.urlunparse(parts)