This commit is contained in:
Neil Botelho 2023-11-28 09:35:02 +01:00 committed by GitHub
commit 5a48418847
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 101 additions and 41 deletions

1
news/12388.feature.rst Normal file
View File

@ -0,0 +1 @@
Add parallel download support to BatchDownloader

View File

@ -761,6 +761,19 @@ check_build_deps: Callable[..., Option] = partial(
help="Check the build dependencies when PEP517 is used.",
)
parallel_downloads: Callable[..., Option] = partial(
Option,
"--parallel-downloads",
dest="parallel_downloads",
type="int",
metavar="n",
default=None,
help=(
"Use upto <n> threads to download packages in parallel."
"<n> must be greater than 0"
),
)
def _handle_no_use_pep517(
option: Option, opt: str, value: str, parser: OptionParser

View File

@ -118,6 +118,10 @@ class SessionCommandMixin(CommandContextMixIn):
ssl_context = None
else:
ssl_context = None
if "parallel_downloads" in options.__dict__:
parallel_downloads = options.parallel_downloads
else:
parallel_downloads = None
session = PipSession(
cache=os.path.join(cache_dir, "http-v2") if cache_dir else None,
@ -125,6 +129,7 @@ class SessionCommandMixin(CommandContextMixIn):
trusted_hosts=options.trusted_hosts,
index_urls=self._get_index_urls(options),
ssl_context=ssl_context,
parallel_downloads=parallel_downloads,
)
# Handle custom ca-bundles from the user

View File

@ -7,6 +7,7 @@ from pip._internal.cli import cmdoptions
from pip._internal.cli.cmdoptions import make_target_python
from pip._internal.cli.req_command import RequirementCommand, with_cleanup
from pip._internal.cli.status_codes import SUCCESS
from pip._internal.exceptions import CommandError
from pip._internal.operations.build.build_tracker import get_build_tracker
from pip._internal.req.req_install import check_legacy_setup_py_options
from pip._internal.utils.misc import ensure_dir, normalize_path, write_output
@ -52,6 +53,7 @@ class DownloadCommand(RequirementCommand):
self.cmd_opts.add_option(cmdoptions.no_use_pep517())
self.cmd_opts.add_option(cmdoptions.check_build_deps())
self.cmd_opts.add_option(cmdoptions.ignore_requires_python())
self.cmd_opts.add_option(cmdoptions.parallel_downloads())
self.cmd_opts.add_option(
"-d",
@ -76,6 +78,11 @@ class DownloadCommand(RequirementCommand):
@with_cleanup
def run(self, options: Values, args: List[str]) -> int:
if (options.parallel_downloads is not None) and (
options.parallel_downloads < 1
):
raise CommandError("Value of '--parallel-downloads' must be greater than 0")
options.ignore_installed = True
# editable doesn't really make sense for `pip download`, but the bowels
# of the RequirementSet code require that property.

View File

@ -74,6 +74,7 @@ class InstallCommand(RequirementCommand):
self.cmd_opts.add_option(cmdoptions.constraints())
self.cmd_opts.add_option(cmdoptions.no_deps())
self.cmd_opts.add_option(cmdoptions.pre())
self.cmd_opts.add_option(cmdoptions.parallel_downloads())
self.cmd_opts.add_option(cmdoptions.editable())
self.cmd_opts.add_option(
@ -267,6 +268,10 @@ class InstallCommand(RequirementCommand):
if options.use_user_site and options.target_dir is not None:
raise CommandError("Can not combine '--user' and '--target'")
if (options.parallel_downloads is not None) and (
options.parallel_downloads < 1
):
raise CommandError("Value of '--parallel-downloads' must be greater than 0")
# Check whether the environment we're installing into is externally
# managed, as specified in PEP 668. Specifying --root, --target, or
# --prefix disables the check, since there's no reliable way to locate

View File

@ -4,6 +4,8 @@ import email.message
import logging
import mimetypes
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Iterable, Optional, Tuple
from pip._vendor.requests.models import CONTENT_CHUNK_SIZE, Response
@ -119,6 +121,27 @@ def _http_get_download(session: PipSession, link: Link) -> Response:
return resp
def _download(
link: Link, location: str, session: PipSession, progress_bar: str
) -> Tuple[str, str]:
try:
resp = _http_get_download(session, link)
except NetworkConnectionError as e:
assert e.response is not None
logger.critical("HTTP error %s while getting %s", e.response.status_code, link)
raise
filename = _get_http_response_filename(resp, link)
filepath = os.path.join(location, filename)
chunks = _prepare_download(resp, link, progress_bar)
with open(filepath, "wb") as content_file:
for chunk in chunks:
content_file.write(chunk)
content_type = resp.headers.get("Content-Type", "")
return filepath, content_type
class Downloader:
def __init__(
self,
@ -130,24 +153,7 @@ class Downloader:
def __call__(self, link: Link, location: str) -> Tuple[str, str]:
"""Download the file given by link into location."""
try:
resp = _http_get_download(self._session, link)
except NetworkConnectionError as e:
assert e.response is not None
logger.critical(
"HTTP error %s while getting %s", e.response.status_code, link
)
raise
filename = _get_http_response_filename(resp, link)
filepath = os.path.join(location, filename)
chunks = _prepare_download(resp, link, self._progress_bar)
with open(filepath, "wb") as content_file:
for chunk in chunks:
content_file.write(chunk)
content_type = resp.headers.get("Content-Type", "")
return filepath, content_type
return _download(link, location, self._session, self._progress_bar)
class BatchDownloader:
@ -159,28 +165,33 @@ class BatchDownloader:
self._session = session
self._progress_bar = progress_bar
def _sequential_download(
self, link: Link, location: str, progress_bar: str
) -> Tuple[Link, Tuple[str, str]]:
filepath, content_type = _download(link, location, self._session, progress_bar)
return link, (filepath, content_type)
def _download_parallel(
self, links: Iterable[Link], location: str, max_workers: int
) -> Iterable[Tuple[Link, Tuple[str, str]]]:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
_download_parallel = partial(
self._sequential_download, location=location, progress_bar="off"
)
results = list(pool.map(_download_parallel, links))
return results
def __call__(
self, links: Iterable[Link], location: str
) -> Iterable[Tuple[Link, Tuple[str, str]]]:
"""Download the files given by links into location."""
for link in links:
try:
resp = _http_get_download(self._session, link)
except NetworkConnectionError as e:
assert e.response is not None
logger.critical(
"HTTP error %s while getting %s",
e.response.status_code,
link,
)
raise
filename = _get_http_response_filename(resp, link)
filepath = os.path.join(location, filename)
chunks = _prepare_download(resp, link, self._progress_bar)
with open(filepath, "wb") as content_file:
for chunk in chunks:
content_file.write(chunk)
content_type = resp.headers.get("Content-Type", "")
yield link, (filepath, content_type)
links = list(links)
max_workers = self._session.parallel_downloads
if max_workers == 1 or len(links) == 1:
# TODO: set minimum number of links to perform parallel download
for link in links:
yield self._sequential_download(link, location, self._progress_bar)
else:
results = self._download_parallel(links, location, max_workers)
for result in results:
yield result

View File

@ -326,6 +326,7 @@ class PipSession(requests.Session):
trusted_hosts: Sequence[str] = (),
index_urls: Optional[List[str]] = None,
ssl_context: Optional["SSLContext"] = None,
parallel_downloads: Optional[int] = None,
**kwargs: Any,
) -> None:
"""
@ -362,12 +363,24 @@ class PipSession(requests.Session):
backoff_factor=0.25,
) # type: ignore
# Used to set numbers of parallel downloads in
# pip._internal.network.BatchDownloader and to set pool_connection in
# the HTTPAdapter to prevent connection pool from hitting the default(10)
# limit and throwing 'Connection pool is full' warnings
self.parallel_downloads = (
parallel_downloads if (parallel_downloads is not None) else 1
)
pool_maxsize = max(self.parallel_downloads, 10)
# Our Insecure HTTPAdapter disables HTTPS validation. It does not
# support caching so we'll use it for all http:// URLs.
# If caching is disabled, we will also use it for
# https:// hosts that we've marked as ignoring
# TLS errors for (trusted-hosts).
insecure_adapter = InsecureHTTPAdapter(max_retries=retries)
insecure_adapter = InsecureHTTPAdapter(
max_retries=retries,
pool_connections=pool_maxsize,
pool_maxsize=pool_maxsize,
)
# We want to _only_ cache responses on securely fetched origins or when
# the host is specified as trusted. We do this because
@ -385,7 +398,12 @@ class PipSession(requests.Session):
max_retries=retries,
)
else:
secure_adapter = HTTPAdapter(max_retries=retries, ssl_context=ssl_context)
secure_adapter = HTTPAdapter(
max_retries=retries,
ssl_context=ssl_context,
pool_connections=pool_maxsize,
pool_maxsize=pool_maxsize,
)
self._trusted_host_adapter = insecure_adapter
self.mount("https://", secure_adapter)