First cut of dbsqlite simplification (w/ weakrefs)

This commit is contained in:
Thomas Perl 2011-07-16 14:30:08 +02:00
parent 34b54e94b8
commit 52779f611b
9 changed files with 130 additions and 275 deletions

View File

@ -160,13 +160,12 @@ class Database(object):
return self.db.cursor()
def commit(self):
self.lock.acquire()
try:
logger.debug('Commit.')
self.db.commit()
except Exception, e:
logger.error('Cannot commit: %s', e, exc_info=True)
self.lock.release()
with self.lock:
try:
logger.debug('Commit.')
self.db.commit()
except Exception, e:
logger.error('Cannot commit: %s', e, exc_info=True)
def get_content_types(self, id):
"""Given a podcast ID, returns the content types"""
@ -177,34 +176,11 @@ class Database(object):
yield mime_type
cur.close()
def get_podcast_statistics(self, id):
def get_podcast_statistics(self, podcast_id=None):
"""Given a podcast ID, returns the statistics for it
Returns a tuple (total, deleted, new, downloaded, unplayed)
"""
total, deleted, new, downloaded, unplayed = 0, 0, 0, 0, 0
with self.lock:
cur = self.cursor()
cur.execute('SELECT COUNT(*), state, is_new FROM %s WHERE podcast_id = ? GROUP BY state, is_new' % self.TABLE_EPISODE, (id,))
for count, state, is_new in cur:
total += count
if state == gpodder.STATE_DELETED:
deleted += count
elif state == gpodder.STATE_NORMAL and is_new:
new += count
elif state == gpodder.STATE_DOWNLOADED and is_new:
downloaded += count
unplayed += count
elif state == gpodder.STATE_DOWNLOADED:
downloaded += count
cur.close()
return (total, deleted, new, downloaded, unplayed)
def get_total_count(self):
"""Get statistics for episodes in all podcasts
If the podcast_id is omitted (using the default value), the
statistics will be calculated over all podcasts.
Returns a tuple (total, deleted, new, downloaded, unplayed)
"""
@ -212,28 +188,33 @@ class Database(object):
with self.lock:
cur = self.cursor()
cur.execute('SELECT COUNT(*), state, is_new FROM %s GROUP BY state, is_new' % self.TABLE_EPISODE)
if podcast_id is not None:
cur.execute('SELECT COUNT(*), state, is_new FROM %s WHERE podcast_id = ? GROUP BY state, is_new' % self.TABLE_EPISODE, (podcast_id,))
else:
cur.execute('SELECT COUNT(*), state, is_new FROM %s GROUP BY state, is_new' % self.TABLE_EPISODE)
for count, state, is_new in cur:
total += count
if state == gpodder.STATE_DELETED:
deleted += count
elif state == gpodder.STATE_NORMAL and is_new:
new += count
elif state == gpodder.STATE_DOWNLOADED and is_new:
downloaded += count
unplayed += count
elif state == gpodder.STATE_DOWNLOADED:
downloaded += count
if is_new:
unplayed += count
cur.close()
return (total, deleted, new, downloaded, unplayed)
def load_podcasts(self, factory=None, url=None):
def load_podcasts(self, factory, cache_lookup):
"""
Returns podcast descriptions as a list of dictionaries or objects,
returned by the factory() function, which receives the dictionary
as the only argument.
The cache_lookup function takes a podcast ID and should return the
podcast object in case it is cached already in memory.
"""
logger.debug('load_podcasts')
@ -243,23 +224,22 @@ class Database(object):
cur.execute('SELECT * FROM %s ORDER BY title COLLATE UNICODE' % self.TABLE_PODCAST)
result = []
keys = list(desc[0] for desc in cur.description)
for row in cur:
podcast = dict(zip(keys, row))
if url is None or url == podcast['url']:
if factory is None:
result.append(podcast)
else:
result.append(factory(podcast, self))
keys = [desc[0] for desc in cur.description]
id_index = keys.index('id')
def make_podcast(row):
o = cache_lookup(row[id_index])
if o is None:
o = factory(dict(zip(keys, row)), self)
# TODO: save in cache!
else:
logger.debug('Cache hit: podcast %d', o.id)
return o
result = map(make_podcast, cur)
cur.close()
return result
def save_podcast(self, podcast):
self._save_object(podcast, self.TABLE_PODCAST, self.COLUMNS_PODCAST)
def delete_podcast(self, podcast):
assert podcast.id
@ -273,133 +253,69 @@ class Database(object):
cur.close()
# Commit changes
self.db.commit()
# TODO: podcast.id -> remove from cache!
def load_all_episodes(self, podcast_mapping, limit=10000):
logger.info('Loading all episodes from the database')
sql = 'SELECT * FROM %s ORDER BY published DESC LIMIT ?' % (self.TABLE_EPISODE,)
args = (limit,)
with self.lock:
cur = self.cursor()
cur.execute(sql, args)
keys = [desc[0] for desc in cur.description]
id_index = keys.index('podcast_id')
result = map(lambda row: podcast_mapping[row[id_index]].episode_factory(dict(zip(keys, row))), cur)
cur.close()
return result
def load_episodes(self, podcast, factory=lambda x: x, limit=1000, state=None):
def load_episodes(self, podcast, factory, cache_lookup):
assert podcast.id
limit = 1000
logger.info('Loading episodes for podcast %d', podcast.id)
if state is None:
sql = 'SELECT * FROM %s WHERE podcast_id = ? ORDER BY published DESC LIMIT ?' % (self.TABLE_EPISODE,)
args = (podcast.id, limit)
else:
sql = 'SELECT * FROM %s WHERE podcast_id = ? AND state = ? ORDER BY published DESC LIMIT ?' % (self.TABLE_EPISODE,)
args = (podcast.id, state, limit)
sql = 'SELECT * FROM %s WHERE podcast_id = ? ORDER BY published DESC LIMIT ?' % (self.TABLE_EPISODE,)
args = (podcast.id, limit)
with self.lock:
cur = self.cursor()
cur.execute(sql, args)
keys = [desc[0] for desc in cur.description]
result = map(lambda row: factory(dict(zip(keys, row)), self), cur)
id_index = keys.index('id')
def make_episode(row):
o = cache_lookup(row[id_index])
if o is None:
o = factory(dict(zip(keys, row)))
# TODO: save in cache!
else:
logger.debug('Cache hit: episode %d', o.id)
return o
result = map(make_episode, cur)
cur.close()
return result
def load_single_episode(self, podcast, factory=lambda x: x, **kwargs):
"""Load one episode with keywords
Return an episode object (created by "factory") for a
given podcast. You can use keyword arguments to specify
the attributes that the episode object should have.
Returns None if the episode cannot be found.
"""
assert podcast.id
# Inject podcast_id into query to reduce search space
kwargs['podcast_id'] = podcast.id
# We need to have the keys in the same order as the values, so
# we use items() and unzip the resulting list into two ordered lists
keys, args = zip(*kwargs.items())
sql = 'SELECT * FROM %s WHERE %s LIMIT 1' % (self.TABLE_EPISODE, \
' AND '.join('%s=?' % k for k in keys))
with self.lock:
cur = self.cursor()
cur.execute(sql, args)
keys = [desc[0] for desc in cur.description]
row = cur.fetchone()
if row:
result = factory(dict(zip(keys, row)), self)
else:
result = None
cur.close()
return result
def load_episode(self, id):
"""Load episode as dictionary by its id
This will return the data for an episode as
dictionary or None if it does not exist.
"""
assert id
with self.lock:
cur = self.cursor()
cur.execute('SELECT * from %s WHERE id = ? LIMIT 1' % (self.TABLE_EPISODE,), (id,))
try:
d = dict(zip((desc[0] for desc in cur.description), cur.fetchone()))
cur.close()
logger.info('Loaded episode %d', id)
return d
except:
cur.close()
return None
def get_podcast_id_from_episode_url(self, url):
"""Return the (first) associated podcast ID given an episode URL"""
assert url
return self.get('SELECT podcast_id FROM %s WHERE url = ? LIMIT 1' % (self.TABLE_EPISODE,), (url,))
def save_podcast(self, podcast):
self._save_object(podcast, self.TABLE_PODCAST, self.COLUMNS_PODCAST)
def save_episode(self, episode):
assert episode.podcast_id
assert episode.guid
self._save_object(episode, self.TABLE_EPISODE, self.COLUMNS_EPISODE)
def _save_object(self, o, table, columns):
self.lock.acquire()
try:
cur = self.cursor()
values = [getattr(o, name) for name in columns]
if o.id is None:
qmarks = ', '.join('?'*len(columns))
sql = 'INSERT INTO %s (%s) VALUES (%s)' % (table, ', '.join(columns), qmarks)
cur.execute(sql, values)
o.id = cur.lastrowid
else:
qmarks = ', '.join('%s = ?' % name for name in columns)
values.append(o.id)
sql = 'UPDATE %s SET %s WHERE id = ?' % (table, qmarks)
cur.execute(sql, values)
except Exception, e:
logger.error('Cannot save %s: %s', o, e, exc_info=True)
cur.close()
self.lock.release()
def update_episode_state(self, episode):
assert episode.id is not None
with self.lock:
cur = self.cursor()
cur.execute('UPDATE %s SET state = ?, is_new = ?, archive = ? WHERE id = ?' % (self.TABLE_EPISODE,), (episode.state, episode.is_new, episode.archive, episode.id))
try:
cur = self.cursor()
values = [getattr(o, name) for name in columns]
if o.id is None:
qmarks = ', '.join('?'*len(columns))
sql = 'INSERT INTO %s (%s) VALUES (%s)' % (table, ', '.join(columns), qmarks)
cur.execute(sql, values)
o.id = cur.lastrowid
else:
qmarks = ', '.join('%s = ?' % name for name in columns)
values.append(o.id)
sql = 'UPDATE %s SET %s WHERE id = ?' % (table, qmarks)
cur.execute(sql, values)
except Exception, e:
logger.error('Cannot save %s: %s', o, e, exc_info=True)
cur.close()
# TODO: o -> into cache!
def get(self, sql, params=None):
"""
@ -451,4 +367,5 @@ class Database(object):
cur = self.cursor()
cur.execute('DELETE FROM %s WHERE podcast_id = ? AND guid = ?' % self.TABLE_EPISODE, \
(podcast_id, guid))
# TODO: Delete episode from cache

View File

@ -434,9 +434,6 @@ class DownloadQueueManager(object):
and forcefully start the download right away.
"""
if task.status != DownloadTask.INIT:
# This task is old so update episode from db
task.episode.reload_from_db()
# Remove the task from its current position in the
# download queue (if any) to avoid race conditions
# where two worker threads download the same file

View File

@ -428,7 +428,6 @@ class EpisodeListModel(gtk.GenericTreeModel):
for index, episode in enumerate(self._episodes):
if episode.url in urls:
episode.reload_from_db()
self.emit('row-changed', (index,), self.create_tree_iter(index))
def update_by_filter_iter(self, iter, downloading=None, \
@ -439,7 +438,7 @@ class EpisodeListModel(gtk.GenericTreeModel):
downloading, include_description, generate_thumbnails)
def update_by_iter(self, iter, downloading=None, include_description=False, \
generate_thumbnails=False, reload_from_db=True):
generate_thumbnails=False):
self._downloading = downloading
self._include_description = include_description
@ -447,8 +446,6 @@ class EpisodeListModel(gtk.GenericTreeModel):
index = self.get_user_data(iter)
episode = self._episodes[index]
if reload_from_db:
episode.reload_from_db()
self.emit('row-changed', (index,), self.create_tree_iter(index))

View File

@ -143,8 +143,6 @@ class gPodderShownotesBase(BuilderWidget):
"""Called from main window for download status changes"""
if self.main_window.get_property('visible'):
self.task = task
if self.episode is not None:
self.episode.reload_from_db()
self.on_episode_status_changed()
def _download_status_progress(self):
@ -154,15 +152,6 @@ class gPodderShownotesBase(BuilderWidget):
#############################################################
def episode_is_new(self):
if self.episode is None:
return False
else:
return self.episode.check_is_new(downloading=\
self._episode_is_downloading)
#############################################################
def show(self, episode):
if self.main_window.get_property('visible'):
self.episode = None

View File

@ -546,20 +546,17 @@ class gPodder(BuilderWidget, dbus.service.Object):
file_parts = [x for x in filename.split(os.sep) if x]
if len(file_parts) == 2:
dir_name, filename = file_parts
channels = [c for c in self.channels if c.download_folder == dir_name]
if len(channels) == 1:
channel = channels[0]
return channel.get_episode_by_filename(filename)
foldername, filename = file_parts
for channel in filter(lambda c: c.download_folder == foldername, self.channels):
for episode in filter(lambda e: e.download_filename == filename, channel.get_all_episodes()):
return episode
else:
# Possibly remote file - search the database for a podcast
channel_id = self.db.get_podcast_id_from_episode_url(uri)
if channel_id is not None:
channels = [c for c in self.channels if c.id == channel_id]
if len(channels) == 1:
channel = channels[0]
return channel.get_episode_by_url(uri)
for channel in filter(lambda c: c.id == channel_id, self.channels):
for episode in filter(lambda e: e.url == url, channel.get_all_episodes()):
return episode
return None
@ -3252,15 +3249,16 @@ class gPodder(BuilderWidget, dbus.service.Object):
self.show_message(_('Please check for new episodes later.'), \
_('No new episodes available'), widget=self.btnUpdateFeeds)
def episode_is_new(self, episode):
return (episode.state == gpodder.STATE_NORMAL and
episode.is_new and
not self.episode_is_downloading(episode))
def get_new_episodes(self, channels=None):
if channels is None:
channels = self.channels
episodes = []
for channel in channels:
for episode in channel.get_new_episodes(downloading=self.episode_is_downloading):
episodes.append(episode)
return episodes
return [e for c in channels for e in filter(self.episode_is_new, c.get_all_episodes())]
def commit_changes_to_database(self):
"""This will be called after the sync process is finished"""

View File

@ -275,8 +275,7 @@ class EpisodeListModel(gtk.ListStore):
episode.total_time, \
episode.archive))
self.update_by_iter(iter, downloading, include_description, \
reload_from_db=False)
self.update_by_iter(iter, downloading, include_description)
self._on_filter_changed(self.has_episodes())
@ -297,11 +296,8 @@ class EpisodeListModel(gtk.ListStore):
self.update_by_iter(self._filter.convert_iter_to_child_iter(iter), \
downloading, include_description)
def update_by_iter(self, iter, downloading=None, include_description=False, \
reload_from_db=True):
def update_by_iter(self, iter, downloading=None, include_description=False):
episode = self.get_value(iter, self.C_EPISODE)
if reload_from_db:
episode.reload_from_db()
show_bullet = False
show_padlock = False
@ -432,12 +428,12 @@ class PodcastChannelProxy(object):
def get_statistics(self):
# Get the total statistics for all channels from the database
return self._db.get_total_count()
return self._db.get_podcast_statistics()
def get_all_episodes(self):
"""Returns a generator that yields every episode"""
channel_lookup_map = dict((c.id, c) for c in self.channels)
return self._db.load_all_episodes(channel_lookup_map)
return Model.sort_episodes_by_pubdate((e for c in self.channels
for e in c.get_all_episodes()), True)
class PodcastListModel(gtk.ListStore):

View File

@ -40,6 +40,8 @@ import datetime
import rfc822
import hashlib
import feedparser
import collections
import weakref
_ = gpodder.gettext
@ -84,6 +86,11 @@ class PodcastModelObject(object):
A generic base class for our podcast model providing common helper
and utility functions.
"""
_cache = collections.defaultdict(weakref.WeakValueDictionary)
@classmethod
def _get_cached_object(cls, id):
return cls._cache[cls].get(id, None)
@classmethod
def create_from_dict(cls, d, *args):
@ -91,18 +98,18 @@ class PodcastModelObject(object):
Create a new object, passing "args" to the constructor
and then updating the object with the values from "d".
"""
o = cls(*args)
o.update_from_dict(d)
return o
o = cls._get_cached_object(d['id'])
def update_from_dict(self, d):
"""
Updates the attributes of this object with values from the
dictionary "d" by using the keys found in "d".
"""
for k in d:
if hasattr(self, k):
setattr(self, k, d[k])
if o is None:
o = cls(*args)
# XXX: all(map(lambda k: hasattr(o, k), d))?
for k, v in d.iteritems():
setattr(o, k, v)
cls._cache[cls][o.id] = o
else:
logger.debug('Reusing reference to %s %d', cls.__name__, o.id)
return o
class PodcastEpisode(PodcastModelObject):
@ -124,18 +131,6 @@ class PodcastEpisode(PodcastModelObject):
# Accessor for the "podcast_id" DB column
podcast_id = property(fget=_get_podcast_id, fset=_set_podcast_id)
def reload_from_db(self):
"""
Re-reads all episode details for this object from the
database and updates this object accordingly. Can be
used to refresh existing objects when the database has
been updated (e.g. the filename has been set after a
download where it was not set before the download)
"""
d = self.db.load_episode(self.id)
self.update_from_dict(d or {})
return self
def has_website_link(self):
return bool(self.link) and (self.link != self.url or \
youtube.is_video_link(self.link))
@ -313,7 +308,7 @@ class PodcastEpisode(PodcastModelObject):
def set_state(self, state):
self.state = state
self.db.update_episode_state(self)
self.save()
def playback_mark(self):
self.is_new = False
@ -327,7 +322,7 @@ class PodcastEpisode(PodcastModelObject):
self.is_new = not is_played
if is_locked is not None:
self.archive = is_locked
self.db.update_episode_state(self)
self.save()
def age_in_days(self):
return util.file_age_in_days(self.local_filename(create=False, \
@ -503,24 +498,13 @@ class PodcastEpisode(PodcastModelObject):
ext = util.extension_from_mimetype(self.mime_type)
return ext
def check_is_new(self, downloading=lambda e: False):
"""
Returns True if this episode is to be considered new.
"Downloading" should be a callback that gets an episode
as its parameter and returns True if the episode is
being downloaded at the moment.
"""
return self.state == gpodder.STATE_NORMAL and \
self.is_new and \
not downloading(self)
def mark_new(self):
self.is_new = True
self.db.update_episode_state(self)
self.save()
def mark_old(self):
self.is_new = False
self.db.update_episode_state(self)
self.save()
def file_exists(self):
filename = self.local_filename(create=False, check_only=True)
@ -669,8 +653,9 @@ class PodcastChannel(PodcastModelObject):
found = False
basename = os.path.basename(filename)
existing = self.get_episode_by_filename(basename)
existing = [e for e in all_episodes if e.download_filename == basename]
if existing:
existing = existing[0]
logger.info('Importing external download: %s', filename)
existing.on_downloaded(filename)
count += 1
@ -721,7 +706,7 @@ class PodcastChannel(PodcastModelObject):
@classmethod
def load_from_db(cls, db):
return db.load_podcasts(factory=cls.create_from_dict)
return db.load_podcasts(cls.create_from_dict, cls._get_cached_object)
@classmethod
def load(cls, db, url, create=True, authentication_tokens=None,\
@ -730,10 +715,14 @@ class PodcastChannel(PodcastModelObject):
if isinstance(url, unicode):
url = url.encode('utf-8')
tmp = db.load_podcasts(factory=cls.create_from_dict, url=url)
if len(tmp):
return tmp[0]
elif create:
existing = [podcast for podcast in
db.load_podcasts(cls.create_from_dict, cls._get_cached_object)
if podcast.url == url]
if existing:
return existing[0]
if create:
tmp = cls(db)
tmp.url = url
if authentication_tokens is not None:
@ -748,7 +737,7 @@ class PodcastChannel(PodcastModelObject):
tmp.save()
return tmp
def episode_factory(self, d, db__parameter_is_unused=None):
def episode_factory(self, d):
"""
This function takes a dictionary containing key-value pairs for
episodes and returns a new PodcastEpisode object that is connected
@ -1044,33 +1033,10 @@ class PodcastChannel(PodcastModelObject):
self.title = custom_title
def get_downloaded_episodes(self):
return self.db.load_episodes(self, factory=self.episode_factory, state=gpodder.STATE_DOWNLOADED)
def get_new_episodes(self, downloading=lambda e: False):
"""
Get a list of new episodes. You can optionally specify
"downloading" as a callback that takes an episode as
a parameter and returns True if the episode is currently
being downloaded or False if not.
By default, "downloading" is implemented so that it
reports all episodes as not downloading.
"""
return [episode for episode in self.db.load_episodes(self, \
factory=self.episode_factory, state=gpodder.STATE_NORMAL) if \
episode.check_is_new(downloading=downloading)]
def get_episode_by_url(self, url):
return self.db.load_single_episode(self, \
factory=self.episode_factory, url=url)
def get_episode_by_filename(self, filename):
return self.db.load_single_episode(self, \
factory=self.episode_factory, \
download_filename=filename)
return filter(lambda e: e.was_downloaded(), self.get_all_episodes())
def get_all_episodes(self):
return self.db.load_episodes(self, factory=self.episode_factory)
return self.db.load_episodes(self, self.episode_factory, self.EpisodeClass._get_cached_object)
def find_unique_folder_name(self, download_folder):
# Remove trailing dots to avoid errors on Windows (bug 600)

View File

@ -354,14 +354,9 @@ class qtPodder(QObject):
self.podcast_model.set_podcasts(self.db, podcasts)
def select_podcast(self, podcast):
# If the currently-playing episode exists in the podcast,
# use it instead of the object from the database
current_ep = self.main.currentEpisode
episodes = [x if current_ep is None or x.id != current_ep.id \
else current_ep for x in podcast.get_all_episodes()]
self.episode_model.set_objects(episodes)
self.episode_model.set_objects(podcast.get_all_episodes())
import gc
gc.collect()
self.main.state = 'episodes'
def save_pending_data(self):

View File

@ -59,6 +59,7 @@ class QEpisode(QObject, model.PodcastEpisode):
self._qt_playing = False
changed = Signal()
never_changed = Signal()
def _title(self):
return convert(self.title)
@ -80,7 +81,7 @@ class QEpisode(QObject, model.PodcastEpisode):
def _filetype(self):
return self.file_type() or 'download' # FIXME
qfiletype = Property(unicode, _filetype, notify=changed)
qfiletype = Property(unicode, _filetype, notify=never_changed)
def _downloaded(self):
return self.state == gpodder.STATE_DOWNLOADED
@ -120,7 +121,6 @@ class QEpisode(QObject, model.PodcastEpisode):
self.changed.emit()
task.add_progress_callback(cb)
task.run()
self.reload_from_db()
self._qt_downloading = False
self.changed.emit()
@ -308,7 +308,7 @@ class EpisodeSubsetView(QObject):
if self.eql is not None:
return 0
total, deleted, new, downloaded, unplayed = self.db.get_total_count()
total, deleted, new, downloaded, unplayed = self.db.get_podcast_statistics()
return downloaded
qdownloaded = Property(int, _downloaded, notify=changed)
@ -317,7 +317,7 @@ class EpisodeSubsetView(QObject):
if self.eql is not None:
return 0
total, deleted, new, downloaded, unplayed = self.db.get_total_count()
total, deleted, new, downloaded, unplayed = self.db.get_podcast_statistics()
return new
qnew = Property(int, _new, notify=changed)