Merge branch 'main' into 2984-new-cache-lower-memory

This commit is contained in:
Itamar Turner-Trauring 2023-09-26 10:38:42 -04:00 committed by GitHub
commit cc14055336
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1512 additions and 40 deletions

View File

@ -28,19 +28,9 @@ It is possible to use the system trust store, instead of the bundled certifi
certificates for verifying HTTPS certificates. This approach will typically
support corporate proxy certificates without additional configuration.
In order to use system trust stores, you need to:
- Use Python 3.10 or newer.
- Install the {pypi}`truststore` package, in the Python environment you're
running pip in.
This is typically done by installing this package using a system package
manager or by using pip in {ref}`Hash-checking mode` for this package and
trusting the network using the `--trusted-host` flag.
In order to use system trust stores, you need to use Python 3.10 or newer.
```{pip-cli}
$ python -m pip install truststore
[...]
$ python -m pip install SomePackage --use-feature=truststore
[...]
Successfully installed SomePackage

View File

@ -0,0 +1 @@
Add truststore 0.8.0

View File

@ -184,6 +184,12 @@ def lint(session: nox.Session) -> None:
# git reset --hard origin/main
@nox.session
def vendoring(session: nox.Session) -> None:
# Ensure that the session Python is running 3.10+
# so that truststore can be installed correctly.
session.run(
"python", "-c", "import sys; sys.exit(1 if sys.version_info < (3, 10) else 0)"
)
session.install("vendoring~=1.2.0")
parser = argparse.ArgumentParser(prog="nox -s vendoring")

View File

@ -54,12 +54,8 @@ ignore_errors = True
# These vendored libraries use runtime magic to populate things and don't sit
# well with static typing out of the box. Eventually we should provide correct
# typing information for their public interface and remove these configs.
[mypy-pip._vendor.colorama]
follow_imports = skip
[mypy-pip._vendor.pkg_resources]
follow_imports = skip
[mypy-pip._vendor.progress.*]
follow_imports = skip
[mypy-pip._vendor.requests.*]
follow_imports = skip

View File

@ -58,12 +58,9 @@ def _create_truststore_ssl_context() -> Optional["SSLContext"]:
return None
try:
import truststore
except ImportError:
raise CommandError(
"To use the truststore feature, 'truststore' must be installed into "
"pip's current environment."
)
from pip._vendor import truststore
except ImportError as e:
raise CommandError(f"The truststore feature is unavailable: {e}")
return truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

View File

@ -46,22 +46,29 @@ def create_vendor_txt_map() -> Dict[str, str]:
return dict(line.split("==", 1) for line in lines)
def get_module_from_module_name(module_name: str) -> ModuleType:
def get_module_from_module_name(module_name: str) -> Optional[ModuleType]:
# Module name can be uppercase in vendor.txt for some reason...
module_name = module_name.lower().replace("-", "_")
# PATCH: setuptools is actually only pkg_resources.
if module_name == "setuptools":
module_name = "pkg_resources"
__import__(f"pip._vendor.{module_name}", globals(), locals(), level=0)
return getattr(pip._vendor, module_name)
try:
__import__(f"pip._vendor.{module_name}", globals(), locals(), level=0)
return getattr(pip._vendor, module_name)
except ImportError:
# We allow 'truststore' to fail to import due
# to being unavailable on Python 3.9 and earlier.
if module_name == "truststore" and sys.version_info < (3, 10):
return None
raise
def get_vendor_version_from_module(module_name: str) -> Optional[str]:
module = get_module_from_module_name(module_name)
version = getattr(module, "__version__", None)
if not version:
if module and not version:
# Try to find version in debundled module info.
assert module.__file__ is not None
env = get_environment([os.path.dirname(module.__file__)])

View File

@ -117,4 +117,5 @@ if DEBUNDLED:
vendored("rich.traceback")
vendored("tenacity")
vendored("tomli")
vendored("truststore")
vendored("urllib3")

View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2022 Seth Michael Larson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -0,0 +1,13 @@
"""Verify certificates using native system trust stores"""
import sys as _sys
if _sys.version_info < (3, 10):
raise ImportError("truststore requires Python 3.10 or later")
from ._api import SSLContext, extract_from_ssl, inject_into_ssl # noqa: E402
del _api, _sys # type: ignore[name-defined] # noqa: F821
__all__ = ["SSLContext", "inject_into_ssl", "extract_from_ssl"]
__version__ = "0.8.0"

View File

@ -0,0 +1,302 @@
import os
import platform
import socket
import ssl
import typing
import _ssl # type: ignore[import]
from ._ssl_constants import (
_original_SSLContext,
_original_super_SSLContext,
_truststore_SSLContext_dunder_class,
_truststore_SSLContext_super_class,
)
if platform.system() == "Windows":
from ._windows import _configure_context, _verify_peercerts_impl
elif platform.system() == "Darwin":
from ._macos import _configure_context, _verify_peercerts_impl
else:
from ._openssl import _configure_context, _verify_peercerts_impl
if typing.TYPE_CHECKING:
from pip._vendor.typing_extensions import Buffer
# From typeshed/stdlib/ssl.pyi
_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes]
_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes]
def inject_into_ssl() -> None:
"""Injects the :class:`truststore.SSLContext` into the ``ssl``
module by replacing :class:`ssl.SSLContext`.
"""
setattr(ssl, "SSLContext", SSLContext)
# urllib3 holds on to its own reference of ssl.SSLContext
# so we need to replace that reference too.
try:
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
setattr(urllib3_ssl, "SSLContext", SSLContext)
except ImportError:
pass
def extract_from_ssl() -> None:
"""Restores the :class:`ssl.SSLContext` class to its original state"""
setattr(ssl, "SSLContext", _original_SSLContext)
try:
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
urllib3_ssl.SSLContext = _original_SSLContext
except ImportError:
pass
class SSLContext(_truststore_SSLContext_super_class): # type: ignore[misc]
"""SSLContext API that uses system certificates on all platforms"""
@property # type: ignore[misc]
def __class__(self) -> type:
# Dirty hack to get around isinstance() checks
# for ssl.SSLContext instances in aiohttp/trustme
# when using non-CPython implementations.
return _truststore_SSLContext_dunder_class or SSLContext
def __init__(self, protocol: int = None) -> None: # type: ignore[assignment]
self._ctx = _original_SSLContext(protocol)
class TruststoreSSLObject(ssl.SSLObject):
# This object exists because wrap_bio() doesn't
# immediately do the handshake so we need to do
# certificate verifications after SSLObject.do_handshake()
def do_handshake(self) -> None:
ret = super().do_handshake()
_verify_peercerts(self, server_hostname=self.server_hostname)
return ret
self._ctx.sslobject_class = TruststoreSSLObject
def wrap_socket(
self,
sock: socket.socket,
server_side: bool = False,
do_handshake_on_connect: bool = True,
suppress_ragged_eofs: bool = True,
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLSocket:
# Use a context manager here because the
# inner SSLContext holds on to our state
# but also does the actual handshake.
with _configure_context(self._ctx):
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
try:
_verify_peercerts(ssl_sock, server_hostname=server_hostname)
except Exception:
ssl_sock.close()
raise
return ssl_sock
def wrap_bio(
self,
incoming: ssl.MemoryBIO,
outgoing: ssl.MemoryBIO,
server_side: bool = False,
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLObject:
with _configure_context(self._ctx):
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
return ssl_obj
def load_verify_locations(
self,
cafile: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
capath: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
cadata: typing.Union[str, "Buffer", None] = None,
) -> None:
return self._ctx.load_verify_locations(
cafile=cafile, capath=capath, cadata=cadata
)
def load_cert_chain(
self,
certfile: _StrOrBytesPath,
keyfile: _StrOrBytesPath | None = None,
password: _PasswordType | None = None,
) -> None:
return self._ctx.load_cert_chain(
certfile=certfile, keyfile=keyfile, password=password
)
def load_default_certs(
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
) -> None:
return self._ctx.load_default_certs(purpose)
def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_alpn_protocols(alpn_protocols)
def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_npn_protocols(npn_protocols)
def set_ciphers(self, __cipherlist: str) -> None:
return self._ctx.set_ciphers(__cipherlist)
def get_ciphers(self) -> typing.Any:
return self._ctx.get_ciphers()
def session_stats(self) -> dict[str, int]:
return self._ctx.session_stats()
def cert_store_stats(self) -> dict[str, int]:
raise NotImplementedError()
@typing.overload
def get_ca_certs(
self, binary_form: typing.Literal[False] = ...
) -> list[typing.Any]:
...
@typing.overload
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]:
...
@typing.overload
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any:
...
def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
raise NotImplementedError()
@property
def check_hostname(self) -> bool:
return self._ctx.check_hostname
@check_hostname.setter
def check_hostname(self, value: bool) -> None:
self._ctx.check_hostname = value
@property
def hostname_checks_common_name(self) -> bool:
return self._ctx.hostname_checks_common_name
@hostname_checks_common_name.setter
def hostname_checks_common_name(self, value: bool) -> None:
self._ctx.hostname_checks_common_name = value
@property
def keylog_filename(self) -> str:
return self._ctx.keylog_filename
@keylog_filename.setter
def keylog_filename(self, value: str) -> None:
self._ctx.keylog_filename = value
@property
def maximum_version(self) -> ssl.TLSVersion:
return self._ctx.maximum_version
@maximum_version.setter
def maximum_version(self, value: ssl.TLSVersion) -> None:
_original_super_SSLContext.maximum_version.__set__( # type: ignore[attr-defined]
self._ctx, value
)
@property
def minimum_version(self) -> ssl.TLSVersion:
return self._ctx.minimum_version
@minimum_version.setter
def minimum_version(self, value: ssl.TLSVersion) -> None:
_original_super_SSLContext.minimum_version.__set__( # type: ignore[attr-defined]
self._ctx, value
)
@property
def options(self) -> ssl.Options:
return self._ctx.options
@options.setter
def options(self, value: ssl.Options) -> None:
_original_super_SSLContext.options.__set__( # type: ignore[attr-defined]
self._ctx, value
)
@property
def post_handshake_auth(self) -> bool:
return self._ctx.post_handshake_auth
@post_handshake_auth.setter
def post_handshake_auth(self, value: bool) -> None:
self._ctx.post_handshake_auth = value
@property
def protocol(self) -> ssl._SSLMethod:
return self._ctx.protocol
@property
def security_level(self) -> int:
return self._ctx.security_level
@property
def verify_flags(self) -> ssl.VerifyFlags:
return self._ctx.verify_flags
@verify_flags.setter
def verify_flags(self, value: ssl.VerifyFlags) -> None:
_original_super_SSLContext.verify_flags.__set__( # type: ignore[attr-defined]
self._ctx, value
)
@property
def verify_mode(self) -> ssl.VerifyMode:
return self._ctx.verify_mode
@verify_mode.setter
def verify_mode(self, value: ssl.VerifyMode) -> None:
_original_super_SSLContext.verify_mode.__set__( # type: ignore[attr-defined]
self._ctx, value
)
def _verify_peercerts(
sock_or_sslobj: ssl.SSLSocket | ssl.SSLObject, server_hostname: str | None
) -> None:
"""
Verifies the peer certificates from an SSLSocket or SSLObject
against the certificates in the OS trust store.
"""
sslobj: ssl.SSLObject = sock_or_sslobj # type: ignore[assignment]
try:
while not hasattr(sslobj, "get_unverified_chain"):
sslobj = sslobj._sslobj # type: ignore[attr-defined]
except AttributeError:
pass
# SSLObject.get_unverified_chain() returns 'None'
# if the peer sends no certificates. This is common
# for the server-side scenario.
unverified_chain: typing.Sequence[_ssl.Certificate] = (
sslobj.get_unverified_chain() or () # type: ignore[attr-defined]
)
cert_bytes = [cert.public_bytes(_ssl.ENCODING_DER) for cert in unverified_chain]
_verify_peercerts_impl(
sock_or_sslobj.context, cert_bytes, server_hostname=server_hostname
)

View File

@ -0,0 +1,501 @@
import contextlib
import ctypes
import platform
import ssl
import typing
from ctypes import (
CDLL,
POINTER,
c_bool,
c_char_p,
c_int32,
c_long,
c_uint32,
c_ulong,
c_void_p,
)
from ctypes.util import find_library
from ._ssl_constants import _set_ssl_context_verify_mode
_mac_version = platform.mac_ver()[0]
_mac_version_info = tuple(map(int, _mac_version.split(".")))
if _mac_version_info < (10, 8):
raise ImportError(
f"Only OS X 10.8 and newer are supported, not {_mac_version_info[0]}.{_mac_version_info[1]}"
)
def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
"""Loads a CDLL by name, falling back to known path on 10.16+"""
try:
# Big Sur is technically 11 but we use 10.16 due to the Big Sur
# beta being labeled as 10.16.
path: str | None
if _mac_version_info >= (10, 16):
path = macos10_16_path
else:
path = find_library(name)
if not path:
raise OSError # Caught and reraised as 'ImportError'
return CDLL(path, use_errno=True)
except OSError:
raise ImportError(f"The library {name} failed to load") from None
Security = _load_cdll(
"Security", "/System/Library/Frameworks/Security.framework/Security"
)
CoreFoundation = _load_cdll(
"CoreFoundation",
"/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation",
)
Boolean = c_bool
CFIndex = c_long
CFStringEncoding = c_uint32
CFData = c_void_p
CFString = c_void_p
CFArray = c_void_p
CFMutableArray = c_void_p
CFError = c_void_p
CFType = c_void_p
CFTypeID = c_ulong
CFTypeRef = POINTER(CFType)
CFAllocatorRef = c_void_p
OSStatus = c_int32
CFErrorRef = POINTER(CFError)
CFDataRef = POINTER(CFData)
CFStringRef = POINTER(CFString)
CFArrayRef = POINTER(CFArray)
CFMutableArrayRef = POINTER(CFMutableArray)
CFArrayCallBacks = c_void_p
CFOptionFlags = c_uint32
SecCertificateRef = POINTER(c_void_p)
SecPolicyRef = POINTER(c_void_p)
SecTrustRef = POINTER(c_void_p)
SecTrustResultType = c_uint32
SecTrustOptionFlags = c_uint32
try:
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
Security.SecCertificateCreateWithData.restype = SecCertificateRef
Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
Security.SecCertificateCopyData.restype = CFDataRef
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
Security.SecTrustSetAnchorCertificates.restype = OSStatus
Security.SecTrustSetAnchorCertificatesOnly.argtypes = [SecTrustRef, Boolean]
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
Security.SecTrustEvaluate.restype = OSStatus
Security.SecPolicyCreateRevocation.argtypes = [CFOptionFlags]
Security.SecPolicyCreateRevocation.restype = SecPolicyRef
Security.SecPolicyCreateSSL.argtypes = [Boolean, CFStringRef]
Security.SecPolicyCreateSSL.restype = SecPolicyRef
Security.SecTrustCreateWithCertificates.argtypes = [
CFTypeRef,
CFTypeRef,
POINTER(SecTrustRef),
]
Security.SecTrustCreateWithCertificates.restype = OSStatus
Security.SecTrustGetTrustResult.argtypes = [
SecTrustRef,
POINTER(SecTrustResultType),
]
Security.SecTrustGetTrustResult.restype = OSStatus
Security.SecTrustRef = SecTrustRef # type: ignore[attr-defined]
Security.SecTrustResultType = SecTrustResultType # type: ignore[attr-defined]
Security.OSStatus = OSStatus # type: ignore[attr-defined]
kSecRevocationUseAnyAvailableMethod = 3
kSecRevocationRequirePositiveResponse = 8
CoreFoundation.CFRelease.argtypes = [CFTypeRef]
CoreFoundation.CFRelease.restype = None
CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
CoreFoundation.CFGetTypeID.restype = CFTypeID
CoreFoundation.CFStringCreateWithCString.argtypes = [
CFAllocatorRef,
c_char_p,
CFStringEncoding,
]
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
CoreFoundation.CFStringGetCString.argtypes = [
CFStringRef,
c_char_p,
CFIndex,
CFStringEncoding,
]
CoreFoundation.CFStringGetCString.restype = c_bool
CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
CoreFoundation.CFDataCreate.restype = CFDataRef
CoreFoundation.CFDataGetLength.argtypes = [CFDataRef]
CoreFoundation.CFDataGetLength.restype = CFIndex
CoreFoundation.CFDataGetBytePtr.argtypes = [CFDataRef]
CoreFoundation.CFDataGetBytePtr.restype = c_void_p
CoreFoundation.CFArrayCreate.argtypes = [
CFAllocatorRef,
POINTER(CFTypeRef),
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreate.restype = CFArrayRef
CoreFoundation.CFArrayCreateMutable.argtypes = [
CFAllocatorRef,
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
CoreFoundation.CFArrayAppendValue.restype = None
CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
CoreFoundation.CFArrayGetCount.restype = CFIndex
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
CoreFoundation.CFErrorGetCode.argtypes = [CFErrorRef]
CoreFoundation.CFErrorGetCode.restype = CFIndex
CoreFoundation.CFErrorCopyDescription.argtypes = [CFErrorRef]
CoreFoundation.CFErrorCopyDescription.restype = CFStringRef
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll( # type: ignore[attr-defined]
CoreFoundation, "kCFAllocatorDefault"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll( # type: ignore[attr-defined]
CoreFoundation, "kCFTypeArrayCallBacks"
)
CoreFoundation.CFTypeRef = CFTypeRef # type: ignore[attr-defined]
CoreFoundation.CFArrayRef = CFArrayRef # type: ignore[attr-defined]
CoreFoundation.CFStringRef = CFStringRef # type: ignore[attr-defined]
CoreFoundation.CFErrorRef = CFErrorRef # type: ignore[attr-defined]
except AttributeError:
raise ImportError("Error initializing ctypes") from None
def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typing.Any:
"""
Raises an error if the OSStatus value is non-zero.
"""
if int(result) == 0:
return args
# Returns a CFString which we need to transform
# into a UTF-8 Python string.
error_message_cfstring = None
try:
error_message_cfstring = Security.SecCopyErrorMessageString(result, None)
# First step is convert the CFString into a C string pointer.
# We try the fast no-copy way first.
error_message_cfstring_c_void_p = ctypes.cast(
error_message_cfstring, ctypes.POINTER(ctypes.c_void_p)
)
message = CoreFoundation.CFStringGetCStringPtr(
error_message_cfstring_c_void_p, CFConst.kCFStringEncodingUTF8
)
# Quoting the Apple dev docs:
#
# "A pointer to a C string or NULL if the internal
# storage of theString does not allow this to be
# returned efficiently."
#
# So we need to get our hands dirty.
if message is None:
buffer = ctypes.create_string_buffer(1024)
result = CoreFoundation.CFStringGetCString(
error_message_cfstring_c_void_p,
buffer,
1024,
CFConst.kCFStringEncodingUTF8,
)
if not result:
raise OSError("Error copying C string from CFStringRef")
message = buffer.value
finally:
if error_message_cfstring is not None:
CoreFoundation.CFRelease(error_message_cfstring)
# If no message can be found for this status we come
# up with a generic one that forwards the status code.
if message is None or message == "":
message = f"SecureTransport operation returned a non-zero OSStatus: {result}"
raise ssl.SSLError(message)
Security.SecTrustCreateWithCertificates.errcheck = _handle_osstatus # type: ignore[assignment]
Security.SecTrustSetAnchorCertificates.errcheck = _handle_osstatus # type: ignore[assignment]
Security.SecTrustGetTrustResult.errcheck = _handle_osstatus # type: ignore[assignment]
class CFConst:
"""CoreFoundation constants"""
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
errSecIncompleteCertRevocationCheck = -67635
errSecHostNameMismatch = -67602
errSecCertificateExpired = -67818
errSecNotTrusted = -67843
def _bytes_to_cf_data_ref(value: bytes) -> CFDataRef: # type: ignore[valid-type]
return CoreFoundation.CFDataCreate( # type: ignore[no-any-return]
CoreFoundation.kCFAllocatorDefault, value, len(value)
)
def _bytes_to_cf_string(value: bytes) -> CFString:
"""
Given a Python binary data, create a CFString.
The string must be CFReleased by the caller.
"""
c_str = ctypes.c_char_p(value)
cf_str = CoreFoundation.CFStringCreateWithCString(
CoreFoundation.kCFAllocatorDefault,
c_str,
CFConst.kCFStringEncodingUTF8,
)
return cf_str # type: ignore[no-any-return]
def _cf_string_ref_to_str(cf_string_ref: CFStringRef) -> str | None: # type: ignore[valid-type]
"""
Creates a Unicode string from a CFString object. Used entirely for error
reporting.
Yes, it annoys me quite a lot that this function is this complex.
"""
string = CoreFoundation.CFStringGetCStringPtr(
cf_string_ref, CFConst.kCFStringEncodingUTF8
)
if string is None:
buffer = ctypes.create_string_buffer(1024)
result = CoreFoundation.CFStringGetCString(
cf_string_ref, buffer, 1024, CFConst.kCFStringEncodingUTF8
)
if not result:
raise OSError("Error copying C string from CFStringRef")
string = buffer.value
if string is not None:
string = string.decode("utf-8")
return string # type: ignore[no-any-return]
def _der_certs_to_cf_cert_array(certs: list[bytes]) -> CFMutableArrayRef: # type: ignore[valid-type]
"""Builds a CFArray of SecCertificateRefs from a list of DER-encoded certificates.
Responsibility of the caller to call CoreFoundation.CFRelease on the CFArray.
"""
cf_array = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cf_array:
raise MemoryError("Unable to allocate memory!")
for cert_data in certs:
cf_data = None
sec_cert_ref = None
try:
cf_data = _bytes_to_cf_data_ref(cert_data)
sec_cert_ref = Security.SecCertificateCreateWithData(
CoreFoundation.kCFAllocatorDefault, cf_data
)
CoreFoundation.CFArrayAppendValue(cf_array, sec_cert_ref)
finally:
if cf_data:
CoreFoundation.CFRelease(cf_data)
if sec_cert_ref:
CoreFoundation.CFRelease(sec_cert_ref)
return cf_array # type: ignore[no-any-return]
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
check_hostname = ctx.check_hostname
verify_mode = ctx.verify_mode
ctx.check_hostname = False
_set_ssl_context_verify_mode(ctx, ssl.CERT_NONE)
try:
yield
finally:
ctx.check_hostname = check_hostname
_set_ssl_context_verify_mode(ctx, verify_mode)
def _verify_peercerts_impl(
ssl_context: ssl.SSLContext,
cert_chain: list[bytes],
server_hostname: str | None = None,
) -> None:
certs = None
policies = None
trust = None
cf_error = None
try:
if server_hostname is not None:
cf_str_hostname = None
try:
cf_str_hostname = _bytes_to_cf_string(server_hostname.encode("ascii"))
ssl_policy = Security.SecPolicyCreateSSL(True, cf_str_hostname)
finally:
if cf_str_hostname:
CoreFoundation.CFRelease(cf_str_hostname)
else:
ssl_policy = Security.SecPolicyCreateSSL(True, None)
policies = ssl_policy
if ssl_context.verify_flags & ssl.VERIFY_CRL_CHECK_CHAIN:
# Add explicit policy requiring positive revocation checks
policies = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
CoreFoundation.CFArrayAppendValue(policies, ssl_policy)
CoreFoundation.CFRelease(ssl_policy)
revocation_policy = Security.SecPolicyCreateRevocation(
kSecRevocationUseAnyAvailableMethod
| kSecRevocationRequirePositiveResponse
)
CoreFoundation.CFArrayAppendValue(policies, revocation_policy)
CoreFoundation.CFRelease(revocation_policy)
elif ssl_context.verify_flags & ssl.VERIFY_CRL_CHECK_LEAF:
raise NotImplementedError("VERIFY_CRL_CHECK_LEAF not implemented for macOS")
certs = None
try:
certs = _der_certs_to_cf_cert_array(cert_chain)
# Now that we have certificates loaded and a SecPolicy
# we can finally create a SecTrust object!
trust = Security.SecTrustRef()
Security.SecTrustCreateWithCertificates(
certs, policies, ctypes.byref(trust)
)
finally:
# The certs are now being held by SecTrust so we can
# release our handles for the array.
if certs:
CoreFoundation.CFRelease(certs)
# If there are additional trust anchors to load we need to transform
# the list of DER-encoded certificates into a CFArray. Otherwise
# pass 'None' to signal that we only want system / fetched certificates.
ctx_ca_certs_der: list[bytes] | None = ssl_context.get_ca_certs(
binary_form=True
)
if ctx_ca_certs_der:
ctx_ca_certs = None
try:
ctx_ca_certs = _der_certs_to_cf_cert_array(cert_chain)
Security.SecTrustSetAnchorCertificates(trust, ctx_ca_certs)
finally:
if ctx_ca_certs:
CoreFoundation.CFRelease(ctx_ca_certs)
else:
Security.SecTrustSetAnchorCertificates(trust, None)
cf_error = CoreFoundation.CFErrorRef()
sec_trust_eval_result = Security.SecTrustEvaluateWithError(
trust, ctypes.byref(cf_error)
)
# sec_trust_eval_result is a bool (0 or 1)
# where 1 means that the certs are trusted.
if sec_trust_eval_result == 1:
is_trusted = True
elif sec_trust_eval_result == 0:
is_trusted = False
else:
raise ssl.SSLError(
f"Unknown result from Security.SecTrustEvaluateWithError: {sec_trust_eval_result!r}"
)
cf_error_code = 0
if not is_trusted:
cf_error_code = CoreFoundation.CFErrorGetCode(cf_error)
# If the error is a known failure that we're
# explicitly okay with from SSLContext configuration
# we can set is_trusted accordingly.
if ssl_context.verify_mode != ssl.CERT_REQUIRED and (
cf_error_code == CFConst.errSecNotTrusted
or cf_error_code == CFConst.errSecCertificateExpired
):
is_trusted = True
elif (
not ssl_context.check_hostname
and cf_error_code == CFConst.errSecHostNameMismatch
):
is_trusted = True
# If we're still not trusted then we start to
# construct and raise the SSLCertVerificationError.
if not is_trusted:
cf_error_string_ref = None
try:
cf_error_string_ref = CoreFoundation.CFErrorCopyDescription(cf_error)
# Can this ever return 'None' if there's a CFError?
cf_error_message = (
_cf_string_ref_to_str(cf_error_string_ref)
or "Certificate verification failed"
)
# TODO: Not sure if we need the SecTrustResultType for anything?
# We only care whether or not it's a success or failure for now.
sec_trust_result_type = Security.SecTrustResultType()
Security.SecTrustGetTrustResult(
trust, ctypes.byref(sec_trust_result_type)
)
err = ssl.SSLCertVerificationError(cf_error_message)
err.verify_message = cf_error_message
err.verify_code = cf_error_code
raise err
finally:
if cf_error_string_ref:
CoreFoundation.CFRelease(cf_error_string_ref)
finally:
if policies:
CoreFoundation.CFRelease(policies)
if trust:
CoreFoundation.CFRelease(trust)

View File

@ -0,0 +1,66 @@
import contextlib
import os
import re
import ssl
import typing
# candidates based on https://github.com/tiran/certifi-system-store by Christian Heimes
_CA_FILE_CANDIDATES = [
# Alpine, Arch, Fedora 34+, OpenWRT, RHEL 9+, BSD
"/etc/ssl/cert.pem",
# Fedora <= 34, RHEL <= 9, CentOS <= 9
"/etc/pki/tls/cert.pem",
# Debian, Ubuntu (requires ca-certificates)
"/etc/ssl/certs/ca-certificates.crt",
# SUSE
"/etc/ssl/ca-bundle.pem",
]
_HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$")
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
# First, check whether the default locations from OpenSSL
# seem like they will give us a usable set of CA certs.
# ssl.get_default_verify_paths already takes care of:
# - getting cafile from either the SSL_CERT_FILE env var
# or the path configured when OpenSSL was compiled,
# and verifying that that path exists
# - getting capath from either the SSL_CERT_DIR env var
# or the path configured when OpenSSL was compiled,
# and verifying that that path exists
# In addition we'll check whether capath appears to contain certs.
defaults = ssl.get_default_verify_paths()
if defaults.cafile or (defaults.capath and _capath_contains_certs(defaults.capath)):
ctx.set_default_verify_paths()
else:
# cafile from OpenSSL doesn't exist
# and capath from OpenSSL doesn't contain certs.
# Let's search other common locations instead.
for cafile in _CA_FILE_CANDIDATES:
if os.path.isfile(cafile):
ctx.load_verify_locations(cafile=cafile)
break
yield
def _capath_contains_certs(capath: str) -> bool:
"""Check whether capath exists and contains certs in the expected format."""
if not os.path.isdir(capath):
return False
for name in os.listdir(capath):
if _HASHED_CERT_FILENAME_RE.match(name):
return True
return False
def _verify_peercerts_impl(
ssl_context: ssl.SSLContext,
cert_chain: list[bytes],
server_hostname: str | None = None,
) -> None:
# This is a no-op because we've enabled SSLContext's built-in
# verification via verify_mode=CERT_REQUIRED, and don't need to repeat it.
pass

View File

@ -0,0 +1,31 @@
import ssl
import sys
import typing
# Hold on to the original class so we can create it consistently
# even if we inject our own SSLContext into the ssl module.
_original_SSLContext = ssl.SSLContext
_original_super_SSLContext = super(_original_SSLContext, _original_SSLContext)
# CPython is known to be good, but non-CPython implementations
# may implement SSLContext differently so to be safe we don't
# subclass the SSLContext.
# This is returned by truststore.SSLContext.__class__()
_truststore_SSLContext_dunder_class: typing.Optional[type]
# This value is the superclass of truststore.SSLContext.
_truststore_SSLContext_super_class: type
if sys.implementation.name == "cpython":
_truststore_SSLContext_super_class = _original_SSLContext
_truststore_SSLContext_dunder_class = None
else:
_truststore_SSLContext_super_class = object
_truststore_SSLContext_dunder_class = _original_SSLContext
def _set_ssl_context_verify_mode(
ssl_context: ssl.SSLContext, verify_mode: ssl.VerifyMode
) -> None:
_original_super_SSLContext.verify_mode.__set__(ssl_context, verify_mode) # type: ignore[attr-defined]

View File

@ -0,0 +1,554 @@
import contextlib
import ssl
import typing
from ctypes import WinDLL # type: ignore
from ctypes import WinError # type: ignore
from ctypes import (
POINTER,
Structure,
c_char_p,
c_ulong,
c_void_p,
c_wchar_p,
cast,
create_unicode_buffer,
pointer,
sizeof,
)
from ctypes.wintypes import (
BOOL,
DWORD,
HANDLE,
LONG,
LPCSTR,
LPCVOID,
LPCWSTR,
LPFILETIME,
LPSTR,
LPWSTR,
)
from typing import TYPE_CHECKING, Any
from ._ssl_constants import _set_ssl_context_verify_mode
HCERTCHAINENGINE = HANDLE
HCERTSTORE = HANDLE
HCRYPTPROV_LEGACY = HANDLE
class CERT_CONTEXT(Structure):
_fields_ = (
("dwCertEncodingType", DWORD),
("pbCertEncoded", c_void_p),
("cbCertEncoded", DWORD),
("pCertInfo", c_void_p),
("hCertStore", HCERTSTORE),
)
PCERT_CONTEXT = POINTER(CERT_CONTEXT)
PCCERT_CONTEXT = POINTER(PCERT_CONTEXT)
class CERT_ENHKEY_USAGE(Structure):
_fields_ = (
("cUsageIdentifier", DWORD),
("rgpszUsageIdentifier", POINTER(LPSTR)),
)
PCERT_ENHKEY_USAGE = POINTER(CERT_ENHKEY_USAGE)
class CERT_USAGE_MATCH(Structure):
_fields_ = (
("dwType", DWORD),
("Usage", CERT_ENHKEY_USAGE),
)
class CERT_CHAIN_PARA(Structure):
_fields_ = (
("cbSize", DWORD),
("RequestedUsage", CERT_USAGE_MATCH),
("RequestedIssuancePolicy", CERT_USAGE_MATCH),
("dwUrlRetrievalTimeout", DWORD),
("fCheckRevocationFreshnessTime", BOOL),
("dwRevocationFreshnessTime", DWORD),
("pftCacheResync", LPFILETIME),
("pStrongSignPara", c_void_p),
("dwStrongSignFlags", DWORD),
)
if TYPE_CHECKING:
PCERT_CHAIN_PARA = pointer[CERT_CHAIN_PARA] # type: ignore[misc]
else:
PCERT_CHAIN_PARA = POINTER(CERT_CHAIN_PARA)
class CERT_TRUST_STATUS(Structure):
_fields_ = (
("dwErrorStatus", DWORD),
("dwInfoStatus", DWORD),
)
class CERT_CHAIN_ELEMENT(Structure):
_fields_ = (
("cbSize", DWORD),
("pCertContext", PCERT_CONTEXT),
("TrustStatus", CERT_TRUST_STATUS),
("pRevocationInfo", c_void_p),
("pIssuanceUsage", PCERT_ENHKEY_USAGE),
("pApplicationUsage", PCERT_ENHKEY_USAGE),
("pwszExtendedErrorInfo", LPCWSTR),
)
PCERT_CHAIN_ELEMENT = POINTER(CERT_CHAIN_ELEMENT)
class CERT_SIMPLE_CHAIN(Structure):
_fields_ = (
("cbSize", DWORD),
("TrustStatus", CERT_TRUST_STATUS),
("cElement", DWORD),
("rgpElement", POINTER(PCERT_CHAIN_ELEMENT)),
("pTrustListInfo", c_void_p),
("fHasRevocationFreshnessTime", BOOL),
("dwRevocationFreshnessTime", DWORD),
)
PCERT_SIMPLE_CHAIN = POINTER(CERT_SIMPLE_CHAIN)
class CERT_CHAIN_CONTEXT(Structure):
_fields_ = (
("cbSize", DWORD),
("TrustStatus", CERT_TRUST_STATUS),
("cChain", DWORD),
("rgpChain", POINTER(PCERT_SIMPLE_CHAIN)),
("cLowerQualityChainContext", DWORD),
("rgpLowerQualityChainContext", c_void_p),
("fHasRevocationFreshnessTime", BOOL),
("dwRevocationFreshnessTime", DWORD),
)
PCERT_CHAIN_CONTEXT = POINTER(CERT_CHAIN_CONTEXT)
PCCERT_CHAIN_CONTEXT = POINTER(PCERT_CHAIN_CONTEXT)
class SSL_EXTRA_CERT_CHAIN_POLICY_PARA(Structure):
_fields_ = (
("cbSize", DWORD),
("dwAuthType", DWORD),
("fdwChecks", DWORD),
("pwszServerName", LPCWSTR),
)
class CERT_CHAIN_POLICY_PARA(Structure):
_fields_ = (
("cbSize", DWORD),
("dwFlags", DWORD),
("pvExtraPolicyPara", c_void_p),
)
PCERT_CHAIN_POLICY_PARA = POINTER(CERT_CHAIN_POLICY_PARA)
class CERT_CHAIN_POLICY_STATUS(Structure):
_fields_ = (
("cbSize", DWORD),
("dwError", DWORD),
("lChainIndex", LONG),
("lElementIndex", LONG),
("pvExtraPolicyStatus", c_void_p),
)
PCERT_CHAIN_POLICY_STATUS = POINTER(CERT_CHAIN_POLICY_STATUS)
class CERT_CHAIN_ENGINE_CONFIG(Structure):
_fields_ = (
("cbSize", DWORD),
("hRestrictedRoot", HCERTSTORE),
("hRestrictedTrust", HCERTSTORE),
("hRestrictedOther", HCERTSTORE),
("cAdditionalStore", DWORD),
("rghAdditionalStore", c_void_p),
("dwFlags", DWORD),
("dwUrlRetrievalTimeout", DWORD),
("MaximumCachedCertificates", DWORD),
("CycleDetectionModulus", DWORD),
("hExclusiveRoot", HCERTSTORE),
("hExclusiveTrustedPeople", HCERTSTORE),
("dwExclusiveFlags", DWORD),
)
PCERT_CHAIN_ENGINE_CONFIG = POINTER(CERT_CHAIN_ENGINE_CONFIG)
PHCERTCHAINENGINE = POINTER(HCERTCHAINENGINE)
X509_ASN_ENCODING = 0x00000001
PKCS_7_ASN_ENCODING = 0x00010000
CERT_STORE_PROV_MEMORY = b"Memory"
CERT_STORE_ADD_USE_EXISTING = 2
USAGE_MATCH_TYPE_OR = 1
OID_PKIX_KP_SERVER_AUTH = c_char_p(b"1.3.6.1.5.5.7.3.1")
CERT_CHAIN_REVOCATION_CHECK_END_CERT = 0x10000000
CERT_CHAIN_REVOCATION_CHECK_CHAIN = 0x20000000
CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS = 0x00000007
CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG = 0x00000008
CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG = 0x00000010
CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG = 0x00000040
CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG = 0x00000020
CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG = 0x00000080
CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS = 0x00000F00
CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG = 0x00008000
CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG = 0x00004000
AUTHTYPE_SERVER = 2
CERT_CHAIN_POLICY_SSL = 4
FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000
FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200
# Flags to set for SSLContext.verify_mode=CERT_NONE
CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS = (
CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS
| CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG
| CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG
| CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG
| CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG
| CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG
| CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS
| CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG
| CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG
)
wincrypt = WinDLL("crypt32.dll")
kernel32 = WinDLL("kernel32.dll")
def _handle_win_error(result: bool, _: Any, args: Any) -> Any:
if not result:
# Note, actually raises OSError after calling GetLastError and FormatMessage
raise WinError()
return args
CertCreateCertificateChainEngine = wincrypt.CertCreateCertificateChainEngine
CertCreateCertificateChainEngine.argtypes = (
PCERT_CHAIN_ENGINE_CONFIG,
PHCERTCHAINENGINE,
)
CertCreateCertificateChainEngine.errcheck = _handle_win_error
CertOpenStore = wincrypt.CertOpenStore
CertOpenStore.argtypes = (LPCSTR, DWORD, HCRYPTPROV_LEGACY, DWORD, c_void_p)
CertOpenStore.restype = HCERTSTORE
CertOpenStore.errcheck = _handle_win_error
CertAddEncodedCertificateToStore = wincrypt.CertAddEncodedCertificateToStore
CertAddEncodedCertificateToStore.argtypes = (
HCERTSTORE,
DWORD,
c_char_p,
DWORD,
DWORD,
PCCERT_CONTEXT,
)
CertAddEncodedCertificateToStore.restype = BOOL
CertCreateCertificateContext = wincrypt.CertCreateCertificateContext
CertCreateCertificateContext.argtypes = (DWORD, c_char_p, DWORD)
CertCreateCertificateContext.restype = PCERT_CONTEXT
CertCreateCertificateContext.errcheck = _handle_win_error
CertGetCertificateChain = wincrypt.CertGetCertificateChain
CertGetCertificateChain.argtypes = (
HCERTCHAINENGINE,
PCERT_CONTEXT,
LPFILETIME,
HCERTSTORE,
PCERT_CHAIN_PARA,
DWORD,
c_void_p,
PCCERT_CHAIN_CONTEXT,
)
CertGetCertificateChain.restype = BOOL
CertGetCertificateChain.errcheck = _handle_win_error
CertVerifyCertificateChainPolicy = wincrypt.CertVerifyCertificateChainPolicy
CertVerifyCertificateChainPolicy.argtypes = (
c_ulong,
PCERT_CHAIN_CONTEXT,
PCERT_CHAIN_POLICY_PARA,
PCERT_CHAIN_POLICY_STATUS,
)
CertVerifyCertificateChainPolicy.restype = BOOL
CertCloseStore = wincrypt.CertCloseStore
CertCloseStore.argtypes = (HCERTSTORE, DWORD)
CertCloseStore.restype = BOOL
CertCloseStore.errcheck = _handle_win_error
CertFreeCertificateChain = wincrypt.CertFreeCertificateChain
CertFreeCertificateChain.argtypes = (PCERT_CHAIN_CONTEXT,)
CertFreeCertificateContext = wincrypt.CertFreeCertificateContext
CertFreeCertificateContext.argtypes = (PCERT_CONTEXT,)
CertFreeCertificateChainEngine = wincrypt.CertFreeCertificateChainEngine
CertFreeCertificateChainEngine.argtypes = (HCERTCHAINENGINE,)
FormatMessageW = kernel32.FormatMessageW
FormatMessageW.argtypes = (
DWORD,
LPCVOID,
DWORD,
DWORD,
LPWSTR,
DWORD,
c_void_p,
)
FormatMessageW.restype = DWORD
def _verify_peercerts_impl(
ssl_context: ssl.SSLContext,
cert_chain: list[bytes],
server_hostname: str | None = None,
) -> None:
"""Verify the cert_chain from the server using Windows APIs."""
pCertContext = None
hIntermediateCertStore = CertOpenStore(CERT_STORE_PROV_MEMORY, 0, None, 0, None)
try:
# Add intermediate certs to an in-memory cert store
for cert_bytes in cert_chain[1:]:
CertAddEncodedCertificateToStore(
hIntermediateCertStore,
X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
cert_bytes,
len(cert_bytes),
CERT_STORE_ADD_USE_EXISTING,
None,
)
# Cert context for leaf cert
leaf_cert = cert_chain[0]
pCertContext = CertCreateCertificateContext(
X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, leaf_cert, len(leaf_cert)
)
# Chain params to match certs for serverAuth extended usage
cert_enhkey_usage = CERT_ENHKEY_USAGE()
cert_enhkey_usage.cUsageIdentifier = 1
cert_enhkey_usage.rgpszUsageIdentifier = (c_char_p * 1)(OID_PKIX_KP_SERVER_AUTH)
cert_usage_match = CERT_USAGE_MATCH()
cert_usage_match.Usage = cert_enhkey_usage
chain_params = CERT_CHAIN_PARA()
chain_params.RequestedUsage = cert_usage_match
chain_params.cbSize = sizeof(chain_params)
pChainPara = pointer(chain_params)
if ssl_context.verify_flags & ssl.VERIFY_CRL_CHECK_CHAIN:
chain_flags = CERT_CHAIN_REVOCATION_CHECK_CHAIN
elif ssl_context.verify_flags & ssl.VERIFY_CRL_CHECK_LEAF:
chain_flags = CERT_CHAIN_REVOCATION_CHECK_END_CERT
else:
chain_flags = 0
try:
# First attempt to verify using the default Windows system trust roots
# (default chain engine).
_get_and_verify_cert_chain(
ssl_context,
None,
hIntermediateCertStore,
pCertContext,
pChainPara,
server_hostname,
chain_flags=chain_flags,
)
except ssl.SSLCertVerificationError:
# If that fails but custom CA certs have been added
# to the SSLContext using load_verify_locations,
# try verifying using a custom chain engine
# that trusts the custom CA certs.
custom_ca_certs: list[bytes] | None = ssl_context.get_ca_certs(
binary_form=True
)
if custom_ca_certs:
_verify_using_custom_ca_certs(
ssl_context,
custom_ca_certs,
hIntermediateCertStore,
pCertContext,
pChainPara,
server_hostname,
chain_flags=chain_flags,
)
else:
raise
finally:
CertCloseStore(hIntermediateCertStore, 0)
if pCertContext:
CertFreeCertificateContext(pCertContext)
def _get_and_verify_cert_chain(
ssl_context: ssl.SSLContext,
hChainEngine: HCERTCHAINENGINE | None,
hIntermediateCertStore: HCERTSTORE,
pPeerCertContext: c_void_p,
pChainPara: PCERT_CHAIN_PARA, # type: ignore[valid-type]
server_hostname: str | None,
chain_flags: int,
) -> None:
ppChainContext = None
try:
# Get cert chain
ppChainContext = pointer(PCERT_CHAIN_CONTEXT())
CertGetCertificateChain(
hChainEngine, # chain engine
pPeerCertContext, # leaf cert context
None, # current system time
hIntermediateCertStore, # additional in-memory cert store
pChainPara, # chain-building parameters
chain_flags,
None, # reserved
ppChainContext, # the resulting chain context
)
pChainContext = ppChainContext.contents
# Verify cert chain
ssl_extra_cert_chain_policy_para = SSL_EXTRA_CERT_CHAIN_POLICY_PARA()
ssl_extra_cert_chain_policy_para.cbSize = sizeof(
ssl_extra_cert_chain_policy_para
)
ssl_extra_cert_chain_policy_para.dwAuthType = AUTHTYPE_SERVER
ssl_extra_cert_chain_policy_para.fdwChecks = 0
if server_hostname:
ssl_extra_cert_chain_policy_para.pwszServerName = c_wchar_p(server_hostname)
chain_policy = CERT_CHAIN_POLICY_PARA()
chain_policy.pvExtraPolicyPara = cast(
pointer(ssl_extra_cert_chain_policy_para), c_void_p
)
if ssl_context.verify_mode == ssl.CERT_NONE:
chain_policy.dwFlags |= CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS
if not ssl_context.check_hostname:
chain_policy.dwFlags |= CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG
chain_policy.cbSize = sizeof(chain_policy)
pPolicyPara = pointer(chain_policy)
policy_status = CERT_CHAIN_POLICY_STATUS()
policy_status.cbSize = sizeof(policy_status)
pPolicyStatus = pointer(policy_status)
CertVerifyCertificateChainPolicy(
CERT_CHAIN_POLICY_SSL,
pChainContext,
pPolicyPara,
pPolicyStatus,
)
# Check status
error_code = policy_status.dwError
if error_code:
# Try getting a human readable message for an error code.
error_message_buf = create_unicode_buffer(1024)
error_message_chars = FormatMessageW(
FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
None,
error_code,
0,
error_message_buf,
sizeof(error_message_buf),
None,
)
# See if we received a message for the error,
# otherwise we use a generic error with the
# error code and hope that it's search-able.
if error_message_chars <= 0:
error_message = f"Certificate chain policy error {error_code:#x} [{policy_status.lElementIndex}]"
else:
error_message = error_message_buf.value.strip()
err = ssl.SSLCertVerificationError(error_message)
err.verify_message = error_message
err.verify_code = error_code
raise err from None
finally:
if ppChainContext:
CertFreeCertificateChain(ppChainContext.contents)
def _verify_using_custom_ca_certs(
ssl_context: ssl.SSLContext,
custom_ca_certs: list[bytes],
hIntermediateCertStore: HCERTSTORE,
pPeerCertContext: c_void_p,
pChainPara: PCERT_CHAIN_PARA, # type: ignore[valid-type]
server_hostname: str | None,
chain_flags: int,
) -> None:
hChainEngine = None
hRootCertStore = CertOpenStore(CERT_STORE_PROV_MEMORY, 0, None, 0, None)
try:
# Add custom CA certs to an in-memory cert store
for cert_bytes in custom_ca_certs:
CertAddEncodedCertificateToStore(
hRootCertStore,
X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
cert_bytes,
len(cert_bytes),
CERT_STORE_ADD_USE_EXISTING,
None,
)
# Create a custom cert chain engine which exclusively trusts
# certs from our hRootCertStore
cert_chain_engine_config = CERT_CHAIN_ENGINE_CONFIG()
cert_chain_engine_config.cbSize = sizeof(cert_chain_engine_config)
cert_chain_engine_config.hExclusiveRoot = hRootCertStore
pConfig = pointer(cert_chain_engine_config)
phChainEngine = pointer(HCERTCHAINENGINE())
CertCreateCertificateChainEngine(
pConfig,
phChainEngine,
)
hChainEngine = phChainEngine.contents
# Get and verify a cert chain using the custom chain engine
_get_and_verify_cert_chain(
ssl_context,
hChainEngine,
hIntermediateCertStore,
pPeerCertContext,
pChainPara,
server_hostname,
chain_flags,
)
finally:
if hChainEngine:
CertFreeCertificateChainEngine(hChainEngine)
CertCloseStore(hRootCertStore, 0)
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
check_hostname = ctx.check_hostname
verify_mode = ctx.verify_mode
ctx.check_hostname = False
_set_ssl_context_verify_mode(ctx, ssl.CERT_NONE)
try:
yield
finally:
ctx.check_hostname = check_hostname
_set_ssl_context_verify_mode(ctx, verify_mode)

View File

View File

@ -20,4 +20,5 @@ setuptools==68.0.0
six==1.16.0
tenacity==8.2.2
tomli==2.0.1
truststore==0.8.0
webencodings==0.5.1

View File

@ -27,20 +27,6 @@ def test_truststore_error_on_old_python(pip: PipRunner) -> None:
assert "The truststore feature is only available for Python 3.10+" in result.stderr
@pytest.mark.skipif(sys.version_info < (3, 10), reason="3.10+ required for truststore")
def test_truststore_error_without_preinstalled(pip: PipRunner) -> None:
result = pip(
"install",
"--no-index",
"does-not-matter",
expect_error=True,
)
assert (
"To use the truststore feature, 'truststore' must be installed into "
"pip's current environment."
) in result.stderr
@pytest.mark.skipif(sys.version_info < (3, 10), reason="3.10+ required for truststore")
@pytest.mark.network
@pytest.mark.parametrize(
@ -56,6 +42,5 @@ def test_trustore_can_install(
pip: PipRunner,
package: str,
) -> None:
script.pip("install", "truststore")
result = pip("install", package)
assert "Successfully installed" in result.stdout