diff --git a/src/pip/_internal/resolution/resolvelib/found_candidates.py b/src/pip/_internal/resolution/resolvelib/found_candidates.py index c9b21727a..491290466 100644 --- a/src/pip/_internal/resolution/resolvelib/found_candidates.py +++ b/src/pip/_internal/resolution/resolvelib/found_candidates.py @@ -1,89 +1,51 @@ +import functools +import itertools + from pip._vendor.six.moves import collections_abc # type: ignore from pip._internal.utils.compat import lru_cache from pip._internal.utils.typing import MYPY_CHECK_RUNNING if MYPY_CHECK_RUNNING: - from typing import Callable, Iterator, Optional, Set + from typing import Any, Callable, Iterator, Optional, Set from pip._vendor.packaging.version import _BaseVersion from .base import Candidate -class _InstalledFirstCandidatesIterator(collections_abc.Iterator): - """Iterator for ``FoundCandidates``. - - This iterator is used when the resolver prefers to keep the version of an - already-installed package. The already-installed candidate is always - returned first. Candidates from index are accessed only when the resolver - wants them, and the already-installed version is excluded from them. - """ - def __init__( - self, - get_others, # type: Callable[[], Iterator[Candidate]] - installed, # type: Optional[Candidate] - ): - self._installed = installed - self._get_others = get_others - self._others = None # type: Optional[Iterator[Candidate]] - self._returned = set() # type: Set[_BaseVersion] - - def __next__(self): - # type: () -> Candidate - if self._installed and self._installed.version not in self._returned: - self._returned.add(self._installed.version) - return self._installed - if self._others is None: - self._others = self._get_others() - cand = next(self._others) - while cand.version in self._returned: - cand = next(self._others) - self._returned.add(cand.version) - return cand - - next = __next__ # XXX: Python 2. +def _deduplicated_by_version(candidates): + # type: (Iterator[Candidate]) -> Iterator[Candidate] + returned = set() # type: Set[_BaseVersion] + for candidate in candidates: + if candidate.version in returned: + continue + returned.add(candidate.version) + yield candidate -class _InstalledReplacesCandidatesIterator(collections_abc.Iterator): +def _replaces_sort_key(installed, candidate): + # type: (Candidate, Candidate) -> Any + return (candidate.version, candidate is installed) + + +def _insert_installed(installed, others): + # type: (Candidate, Iterator[Candidate]) -> Iterator[Candidate] """Iterator for ``FoundCandidates``. This iterator is used when the resolver prefers to upgrade an already-installed package. Candidates from index are returned in their normal ordering, except replaced when the version is already installed. + + The sort key prefers the installed candidate over candidates of the same + version from the index, so it is chosen on de-duplication. """ - def __init__( - self, - get_others, # type: Callable[[], Iterator[Candidate]] - installed, # type: Optional[Candidate] - ): - self._installed = installed - self._get_others = get_others - self._others = None # type: Optional[Iterator[Candidate]] - self._returned = set() # type: Set[_BaseVersion] - - def __next__(self): - # type: () -> Candidate - if self._others is None: - self._others = self._get_others() - try: - cand = next(self._others) - while cand.version in self._returned: - cand = next(self._others) - if self._installed and cand.version == self._installed.version: - cand = self._installed - except StopIteration: - # Return the already-installed candidate as the last item if its - # version does not exist on the index. - if not self._installed: - raise - if self._installed.version in self._returned: - raise - cand = self._installed - self._returned.add(cand.version) - return cand - - next = __next__ # XXX: Python 2. + candidates = sorted( + itertools.chain(others, [installed]), + key=functools.partial(_replaces_sort_key, installed), + reverse=True, + ) + return iter(candidates) class FoundCandidates(collections_abc.Sequence): @@ -106,22 +68,27 @@ class FoundCandidates(collections_abc.Sequence): def __getitem__(self, index): # type: (int) -> Candidate - # Implemented to satisfy the ABC check, This is not needed by the + # Implemented to satisfy the ABC check. This is not needed by the # resolver, and should not be used by the provider either (for # performance reasons). raise NotImplementedError("don't do this") def __iter__(self): # type: () -> Iterator[Candidate] - if self._prefers_installed: - klass = _InstalledFirstCandidatesIterator + if not self._installed: + candidates = self._get_others() + elif self._prefers_installed: + candidates = itertools.chain([self._installed], self._get_others()) else: - klass = _InstalledReplacesCandidatesIterator - return klass(self._get_others, self._installed) + candidates = _insert_installed(self._installed, self._get_others()) + return _deduplicated_by_version(candidates) @lru_cache(maxsize=1) def __len__(self): # type: () -> int + # Implement to satisfy the ABC check and used in tests. This is not + # needed by the resolver, and should not be used by the provider either + # (for performance reasons). return sum(1 for _ in self) @lru_cache(maxsize=1)