// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved. import Foundation import GRDB import SessionSnodeKit import SessionUtilitiesKit /// Abstract base class for `VisibleMessage` and `ControlMessage`. public class Message: Codable { public var id: String? public var sentTimestamp: UInt64? public var receivedTimestamp: UInt64? public var recipient: String? public var sender: String? public var openGroupServerMessageId: UInt64? public var serverHash: String? public var ttl: UInt64 { 14 * 24 * 60 * 60 * 1000 } public var isSelfSendValid: Bool { false } public var shouldBeRetryable: Bool { false } // MARK: - Validation public var isValid: Bool { if let sentTimestamp = sentTimestamp { guard sentTimestamp > 0 else { return false } } if let receivedTimestamp = receivedTimestamp { guard receivedTimestamp > 0 else { return false } } return sender != nil && recipient != nil } // MARK: - Initialization public init( id: String? = nil, sentTimestamp: UInt64? = nil, receivedTimestamp: UInt64? = nil, recipient: String? = nil, sender: String? = nil, groupPublicKey: String? = nil, openGroupServerMessageId: UInt64? = nil, serverHash: String? = nil ) { self.id = id self.sentTimestamp = sentTimestamp self.receivedTimestamp = receivedTimestamp self.recipient = recipient self.sender = sender self.openGroupServerMessageId = openGroupServerMessageId self.serverHash = serverHash } // MARK: - Proto Conversion public class func fromProto(_ proto: SNProtoContent, sender: String) -> Self? { preconditionFailure("fromProto(_:sender:) is abstract and must be overridden.") } public func toProto(_ db: Database) -> SNProtoContent? { preconditionFailure("toProto(_:) is abstract and must be overridden.") } } // MARK: - Message Parsing/Processing public typealias ProcessedMessage = ( threadId: String, threadVariant: SessionThread.Variant, proto: SNProtoContent, messageInfo: MessageReceiveJob.Details.MessageInfo ) public extension Message { enum Variant: String, Codable { case readReceipt case typingIndicator case closedGroupControlMessage case dataExtractionNotification case expirationTimerUpdate case configurationMessage case unsendRequest case messageRequestResponse case visibleMessage case callMessage case sharedConfigMessage init?(from type: Message) { switch type { case is ReadReceipt: self = .readReceipt case is TypingIndicator: self = .typingIndicator case is ClosedGroupControlMessage: self = .closedGroupControlMessage case is DataExtractionNotification: self = .dataExtractionNotification case is ExpirationTimerUpdate: self = .expirationTimerUpdate case is ConfigurationMessage: self = .configurationMessage case is UnsendRequest: self = .unsendRequest case is MessageRequestResponse: self = .messageRequestResponse case is VisibleMessage: self = .visibleMessage case is CallMessage: self = .callMessage case is SharedConfigMessage: self = .sharedConfigMessage default: return nil } } var messageType: Message.Type { switch self { case .readReceipt: return ReadReceipt.self case .typingIndicator: return TypingIndicator.self case .closedGroupControlMessage: return ClosedGroupControlMessage.self case .dataExtractionNotification: return DataExtractionNotification.self case .expirationTimerUpdate: return ExpirationTimerUpdate.self case .configurationMessage: return ConfigurationMessage.self case .unsendRequest: return UnsendRequest.self case .messageRequestResponse: return MessageRequestResponse.self case .visibleMessage: return VisibleMessage.self case .callMessage: return CallMessage.self case .sharedConfigMessage: return SharedConfigMessage.self } } func decode(from container: KeyedDecodingContainer, forKey key: CodingKeys) throws -> Message { switch self { case .readReceipt: return try container.decode(ReadReceipt.self, forKey: key) case .typingIndicator: return try container.decode(TypingIndicator.self, forKey: key) case .closedGroupControlMessage: return try container.decode(ClosedGroupControlMessage.self, forKey: key) case .dataExtractionNotification: return try container.decode(DataExtractionNotification.self, forKey: key) case .expirationTimerUpdate: return try container.decode(ExpirationTimerUpdate.self, forKey: key) case .configurationMessage: return try container.decode(ConfigurationMessage.self, forKey: key) case .unsendRequest: return try container.decode(UnsendRequest.self, forKey: key) case .messageRequestResponse: return try container.decode(MessageRequestResponse.self, forKey: key) case .visibleMessage: return try container.decode(VisibleMessage.self, forKey: key) case .callMessage: return try container.decode(CallMessage.self, forKey: key) case .sharedConfigMessage: return try container.decode(SharedConfigMessage.self, forKey: key) } } } static func createMessageFrom(_ proto: SNProtoContent, sender: String) -> Message? { // Note: This array is ordered intentionally to ensure the correct types are processed // and aren't parsed as the wrong type let prioritisedVariants: [Variant] = [ .readReceipt, .typingIndicator, .closedGroupControlMessage, .dataExtractionNotification, .expirationTimerUpdate, .configurationMessage, .unsendRequest, .messageRequestResponse, .visibleMessage, .callMessage, .sharedConfigMessage ] return prioritisedVariants .reduce(nil) { prev, variant in guard prev == nil else { return prev } return variant.messageType.fromProto(proto, sender: sender) } } static func requiresExistingConversation(message: Message, threadVariant: SessionThread.Variant) -> Bool { switch threadVariant { case .contact, .community: return false case .legacyGroup: switch message { case let controlMessage as ClosedGroupControlMessage: switch controlMessage.kind { case .new: return false default: return true } default: return true } case .group: return false } } static func shouldSync(message: Message) -> Bool { switch message { case is VisibleMessage: return true case is ExpirationTimerUpdate: return true case is ConfigurationMessage: return true case is UnsendRequest: return true case let controlMessage as ClosedGroupControlMessage: switch controlMessage.kind { case .new: return true default: return false } case let callMessage as CallMessage: switch callMessage.kind { case .answer, .endCall: return true default: return false } default: return false } } static func threadId(forMessage message: Message, destination: Message.Destination) -> String { switch destination { case .contact(let publicKey): // Extract the 'syncTarget' value if there is one let maybeSyncTarget: String? switch message { case let message as VisibleMessage: maybeSyncTarget = message.syncTarget case let message as ExpirationTimerUpdate: maybeSyncTarget = message.syncTarget default: maybeSyncTarget = nil } return (maybeSyncTarget ?? publicKey) case .closedGroup(let groupPublicKey): return groupPublicKey case .openGroup(let roomToken, let server, _, _, _): return OpenGroup.idFor(roomToken: roomToken, server: server) case .openGroupInbox(_, _, let blindedPublicKey): return blindedPublicKey } } static func processRawReceivedMessage( _ db: Database, rawMessage: SnodeReceivedMessage ) throws -> ProcessedMessage? { guard let envelope = SNProtoEnvelope.from(rawMessage) else { throw MessageReceiverError.invalidMessage } do { let processedMessage: ProcessedMessage? = try processRawReceivedMessage( db, envelope: envelope, serverExpirationTimestamp: (TimeInterval(rawMessage.info.expirationDateMs) / 1000), serverHash: rawMessage.info.hash, handleClosedGroupKeyUpdateMessages: true ) // Ensure we actually want to de-dupe messages for this namespace, otherwise just // succeed early guard rawMessage.namespace.shouldDedupeMessages else { // If we want to track the last hash then upsert the raw message info (don't // want to fail if it already exsits because we don't want to dedupe messages // in this namespace) if rawMessage.namespace.shouldFetchSinceLastHash { _ = try rawMessage.info.saved(db) } return processedMessage } // Retrieve the number of entries we have for the hash of this message let numExistingHashes: Int = (try? SnodeReceivedMessageInfo .filter(SnodeReceivedMessageInfo.Columns.hash == rawMessage.info.hash) .fetchCount(db)) .defaulting(to: 0) // Try to insert the raw message info into the database (used for both request paging and // de-duping purposes) _ = try rawMessage.info.inserted(db) // If the above insertion worked then we hadn't processed this message for this specific // service node, but may have done so for another node - if the hash already existed in // the database before we inserted it for this node then we can ignore this message as a // duplicate guard numExistingHashes == 0 else { throw MessageReceiverError.duplicateMessageNewSnode } return processedMessage } catch { // If we get 'selfSend' or 'duplicateControlMessage' errors then we still want to insert // the SnodeReceivedMessageInfo to prevent retrieving and attempting to process the same // message again (as well as ensure the next poll doesn't retrieve the same message) switch error { case MessageReceiverError.selfSend, MessageReceiverError.duplicateControlMessage: _ = try? rawMessage.info.inserted(db) break default: break } throw error } } static func processRawReceivedMessage( _ db: Database, serializedData: Data, serverHash: String? ) throws -> ProcessedMessage? { guard let envelope = try? SNProtoEnvelope.parseData(serializedData) else { throw MessageReceiverError.invalidMessage } return try processRawReceivedMessage( db, envelope: envelope, serverExpirationTimestamp: ( (TimeInterval(SnodeAPI.currentOffsetTimestampMs()) / 1000) + ControlMessageProcessRecord.defaultExpirationSeconds ), serverHash: serverHash, handleClosedGroupKeyUpdateMessages: true ) } /// This method behaves slightly differently from the other `processRawReceivedMessage` methods as it doesn't /// insert the "message info" for deduping (we want the poller to re-process the message) and also avoids handling any /// closed group key update messages (the `NotificationServiceExtension` does this itself) static func processRawReceivedMessageAsNotification( _ db: Database, envelope: SNProtoEnvelope ) throws -> ProcessedMessage? { let processedMessage: ProcessedMessage? = try processRawReceivedMessage( db, envelope: envelope, serverExpirationTimestamp: ( (TimeInterval(SnodeAPI.currentOffsetTimestampMs()) / 1000) + ControlMessageProcessRecord.defaultExpirationSeconds ), serverHash: nil, handleClosedGroupKeyUpdateMessages: false ) return processedMessage } static func processReceivedOpenGroupMessage( _ db: Database, openGroupId: String, openGroupServerPublicKey: String, message: OpenGroupAPI.Message, data: Data, dependencies: SMKDependencies = SMKDependencies() ) throws -> ProcessedMessage? { // Need a sender in order to process the message guard let sender: String = message.sender, let timestamp = message.posted else { return nil } // Note: The `posted` value is in seconds but all messages in the database use milliseconds for timestamps let envelopeBuilder = SNProtoEnvelope.builder(type: .sessionMessage, timestamp: UInt64(floor(timestamp * 1000))) envelopeBuilder.setContent(data) envelopeBuilder.setSource(sender) guard let envelope = try? envelopeBuilder.build() else { throw MessageReceiverError.invalidMessage } return try processRawReceivedMessage( db, envelope: envelope, serverExpirationTimestamp: nil, serverHash: nil, openGroupId: openGroupId, openGroupMessageServerId: message.id, openGroupServerPublicKey: openGroupServerPublicKey, handleClosedGroupKeyUpdateMessages: false, dependencies: dependencies ) } static func processReceivedOpenGroupDirectMessage( _ db: Database, openGroupServerPublicKey: String, message: OpenGroupAPI.DirectMessage, data: Data, isOutgoing: Bool? = nil, otherBlindedPublicKey: String? = nil, dependencies: SMKDependencies = SMKDependencies() ) throws -> ProcessedMessage? { // Note: The `posted` value is in seconds but all messages in the database use milliseconds for timestamps let envelopeBuilder = SNProtoEnvelope.builder(type: .sessionMessage, timestamp: UInt64(floor(message.posted * 1000))) envelopeBuilder.setContent(data) envelopeBuilder.setSource(message.sender) guard let envelope = try? envelopeBuilder.build() else { throw MessageReceiverError.invalidMessage } return try processRawReceivedMessage( db, envelope: envelope, serverExpirationTimestamp: nil, serverHash: nil, openGroupId: nil, // Explicitly null since it shouldn't be handled as an open group message openGroupMessageServerId: message.id, openGroupServerPublicKey: openGroupServerPublicKey, isOutgoing: isOutgoing, otherBlindedPublicKey: otherBlindedPublicKey, handleClosedGroupKeyUpdateMessages: false, dependencies: dependencies ) } static func processRawReceivedReactions( _ db: Database, openGroupId: String, message: OpenGroupAPI.Message, associatedPendingChanges: [OpenGroupAPI.PendingChange], dependencies: SMKDependencies = SMKDependencies() ) -> [Reaction] { var results: [Reaction] = [] guard let reactions = message.reactions else { return results } let userPublicKey: String = getUserHexEncodedPublicKey(db) let blindedUserPublicKey: String? = SessionThread .getUserHexEncodedBlindedKey( db, threadId: openGroupId, threadVariant: .community ) for (encodedEmoji, rawReaction) in reactions { if let decodedEmoji = encodedEmoji.removingPercentEncoding, rawReaction.count > 0, let reactors = rawReaction.reactors { // Decide whether we need to ignore all reactions let pendingChangeRemoveAllReaction: Bool = associatedPendingChanges.contains { pendingChange in if case .reaction(_, let emoji, let action) = pendingChange.metadata { return emoji == decodedEmoji && action == .removeAll } return false } // Decide whether we need to add an extra reaction from current user let pendingChangeSelfReaction: Bool? = { // Find the newest 'PendingChange' entry with a matching emoji, if one exists, and // set the "self reaction" value based on it's action let maybePendingChange: OpenGroupAPI.PendingChange? = associatedPendingChanges .sorted(by: { lhs, rhs -> Bool in (lhs.seqNo ?? Int64.max) >= (rhs.seqNo ?? Int64.max) }) .first { pendingChange in if case .reaction(_, let emoji, _) = pendingChange.metadata { return emoji == decodedEmoji } return false } // If there is no pending change for this reaction then return nil guard let pendingChange: OpenGroupAPI.PendingChange = maybePendingChange, case .reaction(_, _, let action) = pendingChange.metadata else { return nil } // Otherwise add/remove accordingly return action == .add }() let shouldAddSelfReaction: Bool = ( pendingChangeSelfReaction ?? ((rawReaction.you || reactors.contains(userPublicKey)) && !pendingChangeRemoveAllReaction) ) let count: Int64 = rawReaction.you ? rawReaction.count - 1 : rawReaction.count let timestampMs: Int64 = SnodeAPI.currentOffsetTimestampMs() let maxLength: Int = shouldAddSelfReaction ? 4 : 5 let desiredReactorIds: [String] = reactors .filter { $0 != blindedUserPublicKey && $0 != userPublicKey } // Remove current user for now, will add back if needed .prefix(maxLength) .map{ $0 } results = results .appending( // Add the first reaction (with the count) pendingChangeRemoveAllReaction ? nil : desiredReactorIds.first .map { reactor in Reaction( interactionId: message.id, serverHash: nil, timestampMs: timestampMs, authorId: reactor, emoji: decodedEmoji, count: count, sortId: rawReaction.index ) } ) .appending( // Add all other reactions contentsOf: desiredReactorIds.count <= 1 || pendingChangeRemoveAllReaction ? [] : desiredReactorIds .suffix(from: 1) .map { reactor in Reaction( interactionId: message.id, serverHash: nil, timestampMs: timestampMs, authorId: reactor, emoji: decodedEmoji, count: 0, // Only want this on the first reaction sortId: rawReaction.index ) } ) .appending( // Add the current user reaction (if applicable and not already included) !shouldAddSelfReaction ? nil : Reaction( interactionId: message.id, serverHash: nil, timestampMs: timestampMs, authorId: userPublicKey, emoji: decodedEmoji, count: 1, sortId: rawReaction.index ) ) } } return results } private static func processRawReceivedMessage( _ db: Database, envelope: SNProtoEnvelope, serverExpirationTimestamp: TimeInterval?, serverHash: String?, openGroupId: String? = nil, openGroupMessageServerId: Int64? = nil, openGroupServerPublicKey: String? = nil, isOutgoing: Bool? = nil, otherBlindedPublicKey: String? = nil, handleClosedGroupKeyUpdateMessages: Bool, dependencies: SMKDependencies = SMKDependencies() ) throws -> ProcessedMessage? { let (message, proto, threadId, threadVariant) = try MessageReceiver.parse( db, envelope: envelope, serverExpirationTimestamp: serverExpirationTimestamp, openGroupId: openGroupId, openGroupMessageServerId: openGroupMessageServerId, openGroupServerPublicKey: openGroupServerPublicKey, isOutgoing: isOutgoing, otherBlindedPublicKey: otherBlindedPublicKey, dependencies: dependencies ) message.serverHash = serverHash // Ignore invalid messages and hashes for messages we have previously handled guard let variant: Message.Variant = Message.Variant(from: message) else { throw MessageReceiverError.invalidMessage } /// **Note:** We want to immediately handle any `ClosedGroupControlMessage` with the kind `encryptionKeyPair` as /// we need the keyPair in storage in order to be able to parse and messages which were signed with the new key (also no need to add /// these as jobs as they will be fully handled in here) if handleClosedGroupKeyUpdateMessages { switch message { case let closedGroupControlMessage as ClosedGroupControlMessage: switch closedGroupControlMessage.kind { case .encryptionKeyPair: try MessageReceiver.handleClosedGroupControlMessage( db, threadId: threadId, threadVariant: threadVariant, message: closedGroupControlMessage ) return nil default: break } default: break } } // Prevent ControlMessages from being handled multiple times if not supported do { try ControlMessageProcessRecord( threadId: threadId, message: message, serverExpirationTimestamp: serverExpirationTimestamp )?.insert(db) } catch { // We want to custom handle this if case DatabaseError.SQLITE_CONSTRAINT_UNIQUE = error { throw MessageReceiverError.duplicateControlMessage } throw error } return ( threadId, threadVariant, proto, try MessageReceiveJob.Details.MessageInfo( message: message, variant: variant, threadVariant: threadVariant, serverExpirationTimestamp: serverExpirationTimestamp, proto: proto ) ) } } // MARK: - Mutation internal extension Message { func with(sentTimestamp: UInt64) -> Message { self.sentTimestamp = sentTimestamp return self } }