Refactor the logic in self_outdated_check

- Move the lookup within `selfcheck/*.json` files into a method on
  `SelfCheckState`.
- Factor out `PackageFinder` interaction into a separate function.
- Rename variables to more clearly reflect what they're for.

Co-Authored-By: Pradyun Gedam <pradyunsg@gmail.com>
This commit is contained in:
Pradyun Gedam 2022-03-11 12:37:14 +00:00
parent ea0d976d9f
commit 4f35572ae7
2 changed files with 88 additions and 74 deletions

View File

@ -5,7 +5,7 @@ import logging
import optparse
import os.path
import sys
from typing import Any, Dict
from typing import Any, Dict, Optional
from pip._vendor.packaging.version import parse as parse_version
@ -17,7 +17,7 @@ from pip._internal.network.session import PipSession
from pip._internal.utils.filesystem import adjacent_tmp_file, check_path_owner, replace
from pip._internal.utils.misc import ensure_dir
SELFCHECK_DATE_FMT = "%Y-%m-%dT%H:%M:%SZ"
_DATE_FMT = "%Y-%m-%dT%H:%M:%SZ"
logger = logging.getLogger(__name__)
@ -31,17 +31,17 @@ def _get_statefile_name(key: str) -> str:
class SelfCheckState:
def __init__(self, cache_dir: str) -> None:
self.state: Dict[str, Any] = {}
self.statefile_path = None
self._state: Dict[str, Any] = {}
self._statefile_path = None
# Try to load the existing state
if cache_dir:
self.statefile_path = os.path.join(
self._statefile_path = os.path.join(
cache_dir, "selfcheck", _get_statefile_name(self.key)
)
try:
with open(self.statefile_path, encoding="utf-8") as statefile:
self.state = json.load(statefile)
with open(self._statefile_path, encoding="utf-8") as statefile:
self._state = json.load(statefile)
except (OSError, ValueError, KeyError):
# Explicitly suppressing exceptions, since we don't want to
# error out if the cache file is invalid.
@ -51,36 +51,57 @@ class SelfCheckState:
def key(self) -> str:
return sys.prefix
def save(self, pypi_version: str, current_time: datetime.datetime) -> None:
def get(self, current_time: datetime.datetime) -> Optional[str]:
"""Check if we have a not-outdated version loaded already."""
if not self._state:
return None
if "last_check" not in self._state:
return None
if "pypi_version" not in self._state:
return None
seven_days_in_seconds = 7 * 24 * 60 * 60
# Determine if we need to refresh the state
last_check = datetime.datetime.strptime(self._state["last_check"], _DATE_FMT)
seconds_since_last_check = (current_time - last_check).total_seconds()
if seconds_since_last_check > seven_days_in_seconds:
return None
return self._state["pypi_version"]
def set(self, pypi_version: str, current_time: datetime.datetime) -> None:
# If we do not have a path to cache in, don't bother saving.
if not self.statefile_path:
if not self._statefile_path:
return
# Check to make sure that we own the directory
if not check_path_owner(os.path.dirname(self.statefile_path)):
if not check_path_owner(os.path.dirname(self._statefile_path)):
return
# Now that we've ensured the directory is owned by this user, we'll go
# ahead and make sure that all our directories are created.
ensure_dir(os.path.dirname(self.statefile_path))
ensure_dir(os.path.dirname(self._statefile_path))
state = {
# Include the key so it's easy to tell which pip wrote the
# file.
"key": self.key,
"last_check": current_time.strftime(SELFCHECK_DATE_FMT),
"last_check": current_time.strftime(_DATE_FMT),
"pypi_version": pypi_version,
}
text = json.dumps(state, sort_keys=True, separators=(",", ":"))
with adjacent_tmp_file(self.statefile_path) as f:
with adjacent_tmp_file(self._statefile_path) as f:
f.write(text.encode())
try:
# Since we have a prefix-specific state file, we can just
# overwrite whatever is there, no need to check.
replace(f.name, self.statefile_path)
replace(f.name, self._statefile_path)
except OSError:
# Best effort.
pass
@ -96,6 +117,35 @@ def was_installed_by_pip(pkg: str) -> bool:
return dist is not None and "pip" == dist.installer
def _get_current_remote_pip_version(
session: PipSession, options: optparse.Values
) -> str:
# Lets use PackageFinder to see what the latest pip version is
link_collector = LinkCollector.create(
session,
options=options,
suppress_no_index=True,
)
# Pass allow_yanked=False so we don't suggest upgrading to a
# yanked version.
selection_prefs = SelectionPreferences(
allow_yanked=False,
allow_all_prereleases=False, # Explicitly set to False
)
finder = PackageFinder.create(
link_collector=link_collector,
selection_prefs=selection_prefs,
use_deprecated_html5lib=("html5lib" in options.deprecated_features_enabled),
)
best_candidate = finder.find_best_candidate("pip").best_candidate
if best_candidate is None:
return
return str(best_candidate.version)
def pip_self_version_check(session: PipSession, options: optparse.Values) -> None:
"""Check for an update for pip.
@ -107,61 +157,25 @@ def pip_self_version_check(session: PipSession, options: optparse.Values) -> Non
if not installed_dist:
return
pip_version = installed_dist.version
pypi_version = None
local_version = installed_dist.version
try:
state = SelfCheckState(cache_dir=options.cache_dir)
current_time = datetime.datetime.utcnow()
# Determine if we need to refresh the state
if "last_check" in state.state and "pypi_version" in state.state:
last_check = datetime.datetime.strptime(
state.state["last_check"], SELFCHECK_DATE_FMT
)
if (current_time - last_check).total_seconds() < 7 * 24 * 60 * 60:
pypi_version = state.state["pypi_version"]
remote_version_str = state.get(current_time)
# Refresh the version if we need to or just see if we need to warn
if pypi_version is None:
# Lets use PackageFinder to see what the latest pip version is
link_collector = LinkCollector.create(
session,
options=options,
suppress_no_index=True,
)
if remote_version_str is None:
remote_version_str = _get_current_remote_pip_version(session, options)
state.set(remote_version_str, current_time)
# Pass allow_yanked=False so we don't suggest upgrading to a
# yanked version.
selection_prefs = SelectionPreferences(
allow_yanked=False,
allow_all_prereleases=False, # Explicitly set to False
)
finder = PackageFinder.create(
link_collector=link_collector,
selection_prefs=selection_prefs,
use_deprecated_html5lib=(
"html5lib" in options.deprecated_features_enabled
),
)
best_candidate = finder.find_best_candidate("pip").best_candidate
if best_candidate is None:
return
pypi_version = str(best_candidate.version)
# save that we've performed a check
state.save(pypi_version, current_time)
remote_version = parse_version(pypi_version)
remote_version = parse_version(remote_version_str)
local_version_is_older = (
pip_version < remote_version
and pip_version.base_version != remote_version.base_version
local_version < remote_version
and local_version.base_version != remote_version.base_version
and was_installed_by_pip("pip")
)
# Determine if our pypi_version is older
if not local_version_is_older:
return
@ -178,8 +192,8 @@ def pip_self_version_check(session: PipSession, options: optparse.Values) -> Non
"You are using pip version %s; however, version %s is "
"available.\nYou should consider upgrading via the "
"'%s install --upgrade pip' command.",
pip_version,
pypi_version,
local_version,
remote_version_str,
pip_cmd,
)
except Exception:

View File

@ -131,8 +131,8 @@ def test_pip_self_version_check(
monkeypatch.setattr(logger, "debug", mock.Mock())
fake_state = mock.Mock(
state={"last_check": stored_time, "pypi_version": installed_ver},
save=mock.Mock(),
get=mock.Mock(return_value=None),
set=mock.Mock(),
)
monkeypatch.setattr(self_outdated_check, "SelfCheckState", lambda **kw: fake_state)
@ -146,16 +146,16 @@ def test_pip_self_version_check(
):
pip_self_version_check(PipSession(), _options())
# See that we saved the correct version
# See that we set the correct version
if check_if_upgrade_required:
assert fake_state.save.call_args_list == [
assert fake_state.set.call_args_list == [
mock.call(new_ver, datetime.datetime(1970, 1, 9, 10, 00, 00)),
]
elif installed_ver:
# Make sure no Exceptions
assert not cast(mock.Mock, logger.debug).call_args_list
# See that save was not called
assert fake_state.save.call_args_list == []
assert not cast(mock.Mock, logger.warning).call_args_list
# See that set was not called
assert fake_state.set.call_args_list == []
# Ensure we warn the user or not
if check_warn_logs:
@ -188,8 +188,8 @@ def _get_statefile_path(cache_dir: str, key: str) -> str:
def test_self_check_state_no_cache_dir() -> None:
state = SelfCheckState(cache_dir="")
assert state.state == {}
assert state.statefile_path is None
assert state._state == {}
assert state._statefile_path is None
def test_self_check_state_key_uses_sys_prefix(monkeypatch: pytest.MonkeyPatch) -> None:
@ -225,8 +225,8 @@ def test_self_check_state_reads_expected_statefile(
monkeypatch.setattr(sys, "prefix", key)
state = self_outdated_check.SelfCheckState(str(cache_dir))
assert state.state["last_check"] == last_check
assert state.state["pypi_version"] == pypi_version
assert state._state["last_check"] == last_check
assert state._state["pypi_version"] == pypi_version
def test_self_check_state_writes_expected_statefile(
@ -238,20 +238,20 @@ def test_self_check_state_writes_expected_statefile(
statefile_path = _get_statefile_path(str(cache_dir), key)
last_check = datetime.datetime.strptime(
"1970-01-02T11:00:00Z", self_outdated_check.SELFCHECK_DATE_FMT
"1970-01-02T11:00:00Z", self_outdated_check._DATE_FMT
)
pypi_version = "1.0"
monkeypatch.setattr(sys, "prefix", key)
state = self_outdated_check.SelfCheckState(str(cache_dir))
state.save(pypi_version, last_check)
state.set(pypi_version, last_check)
with open(statefile_path) as f:
saved = json.load(f)
expected = {
"key": key,
"last_check": last_check.strftime(self_outdated_check.SELFCHECK_DATE_FMT),
"last_check": last_check.strftime(self_outdated_check._DATE_FMT),
"pypi_version": pypi_version,
}
assert expected == saved