Refactor download.get_file_content.

This commit is contained in:
Chris Hunt 2019-07-20 22:51:10 -04:00
parent 82dbcdae87
commit 82ef9d67e2
2 changed files with 41 additions and 23 deletions

View File

@ -636,29 +636,30 @@ def get_file_content(url, comes_from=None, session=None):
"get_file_content() missing 1 required keyword argument: 'session'"
)
match = _scheme_re.search(url)
if match:
scheme = match.group(1).lower()
if (scheme == 'file' and comes_from and
comes_from.startswith('http')):
scheme = _get_url_scheme(url)
if scheme in ['http', 'https']:
# FIXME: catch some errors
resp = session.get(url)
resp.raise_for_status()
return resp.url, resp.text
elif scheme == 'file':
if comes_from and comes_from.startswith('http'):
raise InstallationError(
'Requirements file %s references URL %s, which is local'
% (comes_from, url))
if scheme == 'file':
path = url.split(':', 1)[1]
path = path.replace('\\', '/')
match = _url_slash_drive_re.match(path)
if match:
path = match.group(1) + ':' + path.split('|', 1)[1]
path = urllib_parse.unquote(path)
if path.startswith('/'):
path = '/' + path.lstrip('/')
url = path
else:
# FIXME: catch some errors
resp = session.get(url)
resp.raise_for_status()
return resp.url, resp.text
path = url.split(':', 1)[1]
path = path.replace('\\', '/')
match = _url_slash_drive_re.match(path)
if match:
path = match.group(1) + ':' + path.split('|', 1)[1]
path = urllib_parse.unquote(path)
if path.startswith('/'):
path = '/' + path.lstrip('/')
url = path
try:
with open(url, 'rb') as f:
content = auto_decode(f.read())
@ -669,16 +670,22 @@ def get_file_content(url, comes_from=None, session=None):
return url, content
_scheme_re = re.compile(r'^(http|https|file):', re.I)
_url_slash_drive_re = re.compile(r'/*([a-z])\|', re.I)
def _get_url_scheme(url):
# type: (Union[str, Text]) -> Optional[Text]
if ':' not in url:
return None
return url.split(':', 1)[0].lower()
def is_url(name):
# type: (Union[str, Text]) -> bool
"""Returns true if the name looks like a URL"""
if ':' not in name:
scheme = _get_url_scheme(name)
if scheme is None:
return False
scheme = name.split(':', 1)[0].lower()
return scheme in ['http', 'https', 'file', 'ftp'] + vcs.all_schemes

View File

@ -16,6 +16,7 @@ from pip._internal.download import (
PipSession,
SafeFileCache,
_download_http_url,
_get_url_scheme,
parse_content_disposition,
sanitize_content_filename,
unpack_file_url,
@ -291,6 +292,16 @@ def test_download_http_url__no_directory_traversal(tmpdir):
assert actual == ['out_dir_file']
@pytest.mark.parametrize("url,expected", [
('http://localhost:8080/', 'http'),
('file:c:/path/to/file', 'file'),
('file:/dev/null', 'file'),
('', None),
])
def test__get_url_scheme(url, expected):
assert _get_url_scheme(url) == expected
@pytest.mark.parametrize("url,win_expected,non_win_expected", [
('file:tmp', 'tmp', 'tmp'),
('file:c:/path/to/file', r'C:\path\to\file', 'c:/path/to/file'),