gpg-lacre/lacre/repositories.py

158 lines
5.2 KiB
Python

"""Lacre identity and key repositories."""
from sqlalchemy import create_engine, select, delete, and_, func
from sqlalchemy.exc import OperationalError
import logging
from lacre._keyringcommon import KeyRing, KeyCache
import lacre.dbschema as db
LOG = logging.getLogger(__name__)
# Internal state
_engine = None
def connect(url):
global _engine
if not _engine:
_engine = create_engine(url)
return _engine.connect()
class IdentityRepository(KeyRing):
def __init__(self, /, connection=None, db_url=None):
self._identities = db.LACRE_IDENTITIES
self._conn = connection
self._url = db_url
self._initialised = connection is not None
def register_or_update(self, email, fprint):
assert email, "email is mandatory"
assert fprint, "fprint is mandatory"
if self._exists(email):
self._update(email, fprint)
else:
self._insert(email, fprint)
def _exists(self, email: str) -> bool:
self._ensure_connected()
selq = select(self._identities.c.email).where(self._identities.c.email == email)
return [e for e in self._conn.execute(selq)]
def _insert(self, email, fprint):
self._ensure_connected()
insq = self._identities.insert().values(email=email, fingerprint=fprint)
LOG.debug('Registering identity %s: %s', email, insq)
self._conn.execute(insq)
def _update(self, email, fprint):
self._ensure_connected()
upq = self._identities.update() \
.values(fingerprint=fprint) \
.where(self._identities.c.email == email)
LOG.debug('Updating identity %s: %s', email, upq)
self._conn.execute(upq)
def _ensure_connected(self):
if not self._initialised:
LOG.debug('Connecting with %s', self._url)
self._conn = connect(self._url)
def delete(self, email):
self._ensure_connected()
delq = delete(self._identities).where(self._identities.c.email == email)
LOG.debug('Deleting keys assigned to %s', email)
self._conn.execute(delq)
def delete_all(self):
LOG.warn('Deleting all identities from the database')
delq = delete(self._identities)
self._conn.execute(delq)
def freeze_identities(self) -> KeyCache:
"""Return a static, async-safe copy of the identity map.
Depending on the value of [daemon]bounce_on_keys_missing value,
if we get a database exception, this method will either return
empty collection or let the exception be propagated.
"""
self._ensure_connected()
try:
return self._load_identities()
except OperationalError:
if conf.flag_enabled('daemon', 'bounce_on_keys_missing'):
raise
else:
LOG.exception('Failed to load keys, returning empty collection')
return KeyCache({})
def _load_identities(self) -> KeyCache:
all_identities = select(self._identities.c.fingerprint, self._identities.c.email)
result = self._conn.execute(all_identities)
LOG.debug('Retrieving all keys')
return KeyCache({key_id: email for key_id, email in result})
class KeyConfirmationQueue:
"""Encapsulates access to lacre_keys table."""
# Default number of items retrieved from the database.
keys_read_max = 100
def __init__(self, connection):
self._keys = db.LACRE_KEYS
self._conn = connection
def fetch_keys(self, /, max_keys=None):
"""Runs a query to retrieve at most `keys_read_max` keys and returns db result."""
max_keys = max_keys or self.keys_read_max
selq = select(self._keys.c.publickey, self._keys.c.id, self._keys.c.email) \
.where(and_(self._keys.c.status == db.ST_DEFAULT, self._keys.c.confirm == "")) \
.limit(max_keys)
LOG.debug('Retrieving keys to be processed: %s', selq)
return self._conn.execute(selq)
def count_keys(self):
selq = select(func.count(self._keys.c.id))
LOG.debug('Counting all keys: %s', selq)
try:
c = [cnt for cnt in self._conn.execute(selq)]
# Result is an iterable of tuples:
return c[0][0]
except OperationalError:
LOG.exception('Cannot count keys')
return None
def fetch_keys_to_delete(self):
seldel = select(self._keys.c.email, self._keys.c.id).where(self._keys.c.status == db.ST_TO_BE_DELETED).limit(self.keys_read_max)
return self._conn.execute(seldel)
def delete_keys(self, row_id, /, email=None):
"""Remove key from the database."""
if email is not None:
delq = delete(self._keys).where(and_(self._keys.c.email == email, self._keys.c.id != row_id))
else:
delq = delete(self._keys).where(self._keys.c.id != row_id)
LOG.debug('Deleting public keys associated with confirmed email: %s', delq)
self._conn.execute(delq)
def mark_accepted(self, row_id):
modq = self._keys.update().where(self._keys.c.id == row_id).values(status=db.ST_IMPORTED)
LOG.debug("Key imported, updating key: %s", modq)
self._conn.execute(modq)