package org.thoughtcrime.securesms.service; import android.content.Context; import org.jetbrains.annotations.NotNull; import org.session.libsession.database.StorageProtocol; import org.session.libsession.messaging.MessagingModuleConfiguration; import org.session.libsession.messaging.messages.ExpirationConfiguration; import org.session.libsession.messaging.messages.control.ExpirationTimerUpdate; import org.session.libsession.messaging.messages.signal.IncomingMediaMessage; import org.session.libsession.messaging.messages.signal.OutgoingExpirationUpdateMessage; import org.session.libsession.snode.SnodeAPI; import org.session.libsession.utilities.Address; import org.session.libsession.utilities.GroupUtil; import org.session.libsession.utilities.SSKEnvironment; import org.session.libsession.utilities.TextSecurePreferences; import org.session.libsession.utilities.recipients.Recipient; import org.session.libsignal.messages.SignalServiceGroup; import org.session.libsignal.utilities.Log; import org.session.libsignal.utilities.guava.Optional; import org.thoughtcrime.securesms.database.MmsDatabase; import org.thoughtcrime.securesms.database.MmsSmsDatabase; import org.thoughtcrime.securesms.database.SmsDatabase; import org.thoughtcrime.securesms.database.model.MessageRecord; import org.thoughtcrime.securesms.dependencies.DatabaseComponent; import org.thoughtcrime.securesms.mms.MmsException; import java.io.IOException; import java.util.Comparator; import java.util.TreeSet; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import network.loki.messenger.libsession_util.util.ExpiryMode; public class ExpiringMessageManager implements SSKEnvironment.MessageExpirationManagerProtocol { private static final String TAG = ExpiringMessageManager.class.getSimpleName(); private final TreeSet expiringMessageReferences = new TreeSet<>(new ExpiringMessageComparator()); private final Executor executor = Executors.newSingleThreadExecutor(); private final SmsDatabase smsDatabase; private final MmsDatabase mmsDatabase; private final MmsSmsDatabase mmsSmsDatabase; private final Context context; public ExpiringMessageManager(Context context) { this.context = context.getApplicationContext(); this.smsDatabase = DatabaseComponent.get(context).smsDatabase(); this.mmsDatabase = DatabaseComponent.get(context).mmsDatabase(); this.mmsSmsDatabase = DatabaseComponent.get(context).mmsSmsDatabase(); executor.execute(new LoadTask()); executor.execute(new ProcessTask()); } public void scheduleDeletion(long id, boolean mms, long startedAtTimestamp, long expiresInMillis) { long expiresAtMillis = startedAtTimestamp + expiresInMillis; synchronized (expiringMessageReferences) { expiringMessageReferences.add(new ExpiringMessageReference(id, mms, expiresAtMillis)); expiringMessageReferences.notifyAll(); } } public void checkSchedule() { synchronized (expiringMessageReferences) { expiringMessageReferences.notifyAll(); } } @Override public void setExpirationTimer(@NotNull ExpirationTimerUpdate message, ExpiryMode expiryMode) { String userPublicKey = TextSecurePreferences.getLocalNumber(context); String senderPublicKey = message.getSender(); long sentTimestamp = message.getSentTimestamp() == null ? 0 : message.getSentTimestamp(); long expireStartedAt = (expiryMode instanceof ExpiryMode.AfterSend || message.isSenderSelf()) ? sentTimestamp : 0; // Notify the user if (senderPublicKey == null || userPublicKey.equals(senderPublicKey)) { // sender is self or a linked device insertOutgoingExpirationTimerMessage(message, expireStartedAt); } else { insertIncomingExpirationTimerMessage(message, expireStartedAt); } if (expiryMode.getExpirySeconds() > 0 && message.getSentTimestamp() != null && senderPublicKey != null) { startAnyExpiration(message.getSentTimestamp(), senderPublicKey, expireStartedAt); } } private void insertIncomingExpirationTimerMessage(ExpirationTimerUpdate message, long expireStartedAt) { String senderPublicKey = message.getSender(); Long sentTimestamp = message.getSentTimestamp(); String groupId = message.getGroupPublicKey(); long expiresInMillis = message.getExpiryMode().getExpiryMillis(); Optional groupInfo = Optional.absent(); Address address = Address.fromSerialized(senderPublicKey); Recipient recipient = Recipient.from(context, address, false); // if the sender is blocked, we don't display the update, except if it's in a closed group if (recipient.isBlocked() && groupId == null) return; try { if (groupId != null) { String groupID = GroupUtil.doubleEncodeGroupID(groupId); groupInfo = Optional.of(new SignalServiceGroup(GroupUtil.getDecodedGroupIDAsData(groupID), SignalServiceGroup.GroupType.SIGNAL)); Address groupAddress = Address.fromSerialized(groupID); recipient = Recipient.from(context, groupAddress, false); } Long threadId = MessagingModuleConfiguration.getShared().getStorage().getThreadId(recipient); if (threadId == null) { return; } IncomingMediaMessage mediaMessage = new IncomingMediaMessage(address, sentTimestamp, -1, expiresInMillis, expireStartedAt, true, false, false, false, Optional.absent(), groupInfo, Optional.absent(), Optional.absent(), Optional.absent(), Optional.absent(), Optional.absent()); //insert the timer update message mmsDatabase.insertSecureDecryptedMessageInbox(mediaMessage, threadId, true); } catch (IOException | MmsException ioe) { Log.e("Loki", "Failed to insert expiration update message."); } } private void insertOutgoingExpirationTimerMessage(ExpirationTimerUpdate message, long expireStartedAt) { Long sentTimestamp = message.getSentTimestamp(); String groupId = message.getGroupPublicKey(); long duration = message.getExpiryMode().getExpiryMillis(); Address address; try { if (groupId != null) { address = Address.fromSerialized(GroupUtil.doubleEncodeGroupID(groupId)); } else { address = Address.fromSerialized((message.getSyncTarget() != null && !message.getSyncTarget().isEmpty()) ? message.getSyncTarget() : message.getRecipient()); } Recipient recipient = Recipient.from(context, address, false); StorageProtocol storage = MessagingModuleConfiguration.getShared().getStorage(); message.setThreadID(storage.getOrCreateThreadIdFor(address)); OutgoingExpirationUpdateMessage timerUpdateMessage = new OutgoingExpirationUpdateMessage(recipient, sentTimestamp, duration, expireStartedAt, groupId); mmsDatabase.insertSecureDecryptedMessageOutbox(timerUpdateMessage, message.getThreadID(), sentTimestamp, true); } catch (MmsException | IOException ioe) { Log.e("Loki", "Failed to insert expiration update message.", ioe); } } @Override public void startAnyExpiration(long timestamp, @NotNull String author, long expireStartedAt) { MessageRecord messageRecord = mmsSmsDatabase.getMessageFor(timestamp, author); if (messageRecord == null) return; boolean mms = messageRecord.isMms(); ExpirationConfiguration config = DatabaseComponent.get(context).storage().getExpirationConfiguration(messageRecord.getThreadId()); if (config == null || !config.isEnabled()) return; ExpiryMode mode = config.getExpiryMode(); if (mms) { mmsDatabase.markExpireStarted(messageRecord.getId(), expireStartedAt); } else { smsDatabase.markExpireStarted(messageRecord.getId(), expireStartedAt); } scheduleDeletion(messageRecord.getId(), mms, expireStartedAt, (mode != null ? mode.getExpiryMillis() : 0)); } private class LoadTask implements Runnable { public void run() { SmsDatabase.Reader smsReader = smsDatabase.readerFor(smsDatabase.getExpirationStartedMessages()); MmsDatabase.Reader mmsReader = mmsDatabase.getExpireStartedMessages(); MessageRecord messageRecord; while ((messageRecord = smsReader.getNext()) != null) { expiringMessageReferences.add(new ExpiringMessageReference(messageRecord.getId(), messageRecord.isMms(), messageRecord.getExpireStarted() + messageRecord.getExpiresIn())); } while ((messageRecord = mmsReader.getNext()) != null) { expiringMessageReferences.add(new ExpiringMessageReference(messageRecord.getId(), messageRecord.isMms(), messageRecord.getExpireStarted() + messageRecord.getExpiresIn())); } smsReader.close(); mmsReader.close(); } } @SuppressWarnings("InfiniteLoopStatement") private class ProcessTask implements Runnable { public void run() { while (true) { ExpiringMessageReference expiredMessage = null; synchronized (expiringMessageReferences) { try { while (expiringMessageReferences.isEmpty()) expiringMessageReferences.wait(); ExpiringMessageReference nextReference = expiringMessageReferences.first(); long waitTime = nextReference.expiresAtMillis - SnodeAPI.getNowWithOffset(); if (waitTime > 0) { ExpirationListener.setAlarm(context, waitTime); expiringMessageReferences.wait(waitTime); } else { expiredMessage = nextReference; expiringMessageReferences.remove(nextReference); } } catch (InterruptedException e) { Log.w(TAG, e); } } if (expiredMessage != null) { if (expiredMessage.mms) mmsDatabase.deleteMessage(expiredMessage.id); else smsDatabase.deleteMessage(expiredMessage.id); } } } } private static class ExpiringMessageReference { private final long id; private final boolean mms; private final long expiresAtMillis; private ExpiringMessageReference(long id, boolean mms, long expiresAtMillis) { this.id = id; this.mms = mms; this.expiresAtMillis = expiresAtMillis; } @Override public boolean equals(Object other) { if (other == null) return false; if (!(other instanceof ExpiringMessageReference)) return false; ExpiringMessageReference that = (ExpiringMessageReference)other; return this.id == that.id && this.mms == that.mms && this.expiresAtMillis == that.expiresAtMillis; } @Override public int hashCode() { return (int)this.id ^ (mms ? 1 : 0) ^ (int)expiresAtMillis; } } private static class ExpiringMessageComparator implements Comparator { @Override public int compare(ExpiringMessageReference lhs, ExpiringMessageReference rhs) { if (lhs.expiresAtMillis < rhs.expiresAtMillis) return -1; else if (lhs.expiresAtMillis > rhs.expiresAtMillis) return 1; else if (lhs.id < rhs.id) return -1; else if (lhs.id > rhs.id) return 1; else if (!lhs.mms && rhs.mms) return -1; else if (lhs.mms && !rhs.mms) return 1; else return 0; } } }