diff --git a/crypto/src/mls/conversation/conversation_guard/commit.rs b/crypto/src/mls/conversation/conversation_guard/commit.rs index cc7e3ce2cf..b596efae67 100644 --- a/crypto/src/mls/conversation/conversation_guard/commit.rs +++ b/crypto/src/mls/conversation/conversation_guard/commit.rs @@ -107,37 +107,30 @@ impl ConversationGuard { self.ensure_no_pending_commit().await?; let backend = self.crypto_provider().await?; let credential = self.credential().await?; - let signer = credential.signature_key(); - let database = self.database().await?; - let mut conversation = self.conversation_mut().await; - - let (commit, welcome, group_info) = conversation - .group - .add_members(&backend, signer, key_packages.clone()) - .await - .map_err(|err| { - if Self::err_is_duplicate_signature_key(&err) { - let affected_clients = Self::clients_with_duplicate_signature_keys(key_packages.as_ref()); - Error::DuplicateSignature { affected_clients } - } else { - MlsError::wrap("group add members")(err).into() - } - })?; - - // commit requires an optional welcome - let welcome = Some(welcome); - let group_info = Self::group_info(group_info)?; - conversation.persist_group_when_changed(&database, false).await?; - - let commit = MlsCommitBundle { - commit, - welcome, - group_info, - encrypted_message: None, - }; + self.conversation_mut(async move |conversation, _database| { + let signer = credential.signature_key(); + let (commit, welcome, group_info) = conversation + .group + .add_members(&backend, signer, key_packages.clone()) + .await + .map_err(|err| { + if Self::err_is_duplicate_signature_key(&err) { + let affected_clients = Self::clients_with_duplicate_signature_keys(key_packages.as_ref()); + Error::DuplicateSignature { affected_clients } + } else { + MlsError::wrap("group add members")(err).into() + } + })?; - Ok(commit) + Ok(MlsCommitBundle { + commit, + welcome: Some(welcome), + group_info: Self::group_info(group_info)?, + encrypted_message: None, + }) + }) + .await } fn err_is_duplicate_signature_key( @@ -239,42 +232,42 @@ impl ConversationGuard { /// committed. pub(crate) async fn set_credential_inner(&mut self, credential: &Credential) -> Result { self.ensure_no_pending_commit().await?; - let backend = &self.crypto_provider().await?; - let database = self.database().await?; - let mut conversation = self.conversation_mut().await; - - // If the credential remains the same and we still want to update, we explicitly need to pass `None` to openmls, - // if we just passed an unchanged leaf node, no update commit would be created. - // Also, we can avoid cloning in the case we don't need to create a new leaf node. - let updated_leaf_node = { - let leaf_node = conversation.group.own_leaf().ok_or(LeafError::InternalMlsError)?; - if leaf_node.credential() == &credential.mls_credential { - None - } else { - let mut leaf_node = leaf_node.clone(); - leaf_node.set_credential_with_key(credential.to_mls_credential_with_key()); - Some(leaf_node) - } - }; - - let (commit, welcome, group_info) = conversation - .group - .explicit_self_update(backend, &credential.signature_key_pair, updated_leaf_node) - .await - .map_err(MlsError::wrap("group self update"))?; + let backend = self.crypto_provider().await?; + let credential = credential.clone(); + + self.conversation_mut(async move |conversation, _database| { + // If the credential remains the same and we still want to update, we explicitly need to pass `None` to + // openmls, if we just passed an unchanged leaf node, no update commit would be created. + // Also, we can avoid cloning in the case we don't need to create a new leaf node. + let updated_leaf_node = { + let leaf_node = conversation.group.own_leaf().ok_or(LeafError::InternalMlsError)?; + if leaf_node.credential() == &credential.mls_credential { + None + } else { + let mut leaf_node = leaf_node.clone(); + leaf_node.set_credential_with_key(credential.to_mls_credential_with_key()); + Some(leaf_node) + } + }; - // We should always have ratchet tree extension turned on hence GroupInfo should always be present - let group_info = group_info.ok_or(LeafError::MissingGroupInfo)?; - let group_info = MlsGroupInfoBundle::try_new_full_plaintext(group_info)?; + let (commit, welcome, group_info) = conversation + .group + .explicit_self_update(&backend, &credential.signature_key_pair, updated_leaf_node) + .await + .map_err(MlsError::wrap("group self update"))?; - conversation.persist_group_when_changed(&database, false).await?; + // We should always have ratchet tree extension turned on hence GroupInfo should always be present + let group_info = group_info.ok_or(LeafError::MissingGroupInfo)?; + let group_info = MlsGroupInfoBundle::try_new_full_plaintext(group_info)?; - Ok(MlsCommitBundle { - welcome, - commit, - group_info, - encrypted_message: None, + Ok(MlsCommitBundle { + welcome, + commit, + group_info, + encrypted_message: None, + }) }) + .await } /// Commits all pending proposals of the group diff --git a/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs b/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs index bf1ec90c31..c11e01dad9 100644 --- a/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs +++ b/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs @@ -154,7 +154,6 @@ impl ConversationGuard { ) -> Result { let session = &self.session().await?; let provider = &self.crypto_provider().await?; - let database = self.database().await?; let parsed_message = self.parse_message(message.clone()).await?; let message_result = self.process_message(parsed_message).await; @@ -166,15 +165,17 @@ impl ConversationGuard { .. })) = message_result { - let mut conversation = self.conversation_mut().await; - let ct = conversation.extract_confirmation_tag_from_own_commit(&message)?; - let mut decrypted_message = conversation.handle_own_commit(session, &database, provider, ct).await?; + let mut decrypted_message = self + .conversation_mut(async |conversation, database| { + let ct = conversation.extract_confirmation_tag_from_own_commit(&message)?; + conversation.handle_own_commit(session, database, provider, ct).await + }) + .await?; debug_assert!( decrypted_message.buffered_messages.is_none(), "decrypted message should be constructed with empty buffer" ); if recursion_policy == RecursionPolicy::AsNecessary { - drop(conversation); decrypted_message.buffered_messages = self.restore_and_clear_pending_messages().await?; } @@ -228,16 +229,18 @@ impl ConversationGuard { } } ProcessedMessageContent::ProposalMessage(proposal) => { - let mut conversation = self.conversation_mut().await; - info!( - group_id = conversation.id, - sender = Obfuscated::from(proposal.sender()), - proposals = Obfuscated::from(&proposal.proposal); - "Received proposal" - ); + self.conversation_mut(async move |conversation, _database| { + info!( + group_id = conversation.id, + sender = Obfuscated::from(proposal.sender()), + proposals = Obfuscated::from(&proposal.proposal); + "Received proposal" + ); + conversation.group.store_pending_proposal(*proposal); + Ok(()) + }) + .await?; - conversation.group.store_pending_proposal(*proposal); - drop(conversation); if let Some(commit) = self.retrieve_buffered_commit() .await @@ -283,98 +286,107 @@ impl ConversationGuard { } ProcessedMessageContent::StagedCommitMessage(staged_commit) => { self.validate_commit(&staged_commit).await?; - let mut conversation = self.conversation_mut().await; - let pending_proposals = conversation.self_pending_proposals().cloned().collect::>(); + let (proposals, is_active, delay, removed_members, added_members, group_id) = self + .conversation_mut(async |conversation, database| { + let pending_proposals = conversation.self_pending_proposals().cloned().collect::>(); - // getting the pending has to be done before `merge_staged_commit` otherwise it's wiped out - let pending_commit = conversation.group.pending_commit().cloned(); + // getting the pending has to be done before `merge_staged_commit` otherwise it's wiped out + let pending_commit = conversation.group.pending_commit().cloned(); - let removed_indices = staged_commit - .remove_proposals() - .map(|p| p.remove_proposal().removed()) - .collect::>(); + let removed_indices = staged_commit + .remove_proposals() + .map(|p| p.remove_proposal().removed()) + .collect::>(); - let added_credentials = staged_commit - .add_proposals() - .map(|p| p.add_proposal().key_package.leaf_node().credential().to_owned()) - .collect::>(); + let added_credentials = staged_commit + .add_proposals() + .map(|p| p.add_proposal().key_package.leaf_node().credential().to_owned()) + .collect::>(); - let removed_members = Self::members_at_indices(removed_indices, conversation.group()); + let removed_members = Self::members_at_indices(removed_indices, conversation.group()); - conversation - .group - .merge_staged_commit(provider, *staged_commit.clone()) - .await - .map_err(MlsError::wrap("merge staged commit"))?; - - let added_members = conversation - .group - .members() - .filter_map(|member| added_credentials.contains(&member.credential).then(|| member.clone())) - .collect::>(); - - let (proposals_to_renew, needs_update) = Renew::renew( - &conversation.group.own_leaf_index(), - pending_proposals.iter(), - pending_commit.as_ref(), - staged_commit.as_ref(), - ); - let proposals = conversation - .renew_proposals_for_current_epoch( - session, - provider, - &database, - proposals_to_renew.into_iter(), - needs_update, - ) + conversation + .group + .merge_staged_commit(provider, *staged_commit.clone()) + .await + .map_err(MlsError::wrap("merge staged commit"))?; + + let added_members = conversation + .group + .members() + .filter_map(|member| added_credentials.contains(&member.credential).then(|| member.clone())) + .collect::>(); + + let (proposals_to_renew, needs_update) = Renew::renew( + &conversation.group.own_leaf_index(), + pending_proposals.iter(), + pending_commit.as_ref(), + staged_commit.as_ref(), + ); + let proposals = conversation + .renew_proposals_for_current_epoch( + session, + provider, + database, + proposals_to_renew.into_iter(), + needs_update, + ) + .await?; + + let is_active = conversation.group.is_active(); + let delay = conversation.compute_next_commit_delay(); + let group_id = conversation.id.clone(); + + Ok((proposals, is_active, delay, removed_members, added_members, group_id)) + }) .await?; // can't use `.then` because async let mut buffered_messages = None; - // drop conversation to allow borrowing `self` again - drop(conversation); if recursion_policy == RecursionPolicy::AsNecessary { buffered_messages = self.restore_and_clear_pending_messages().await?; } - let conversation = self.conversation().await; let epoch = staged_commit.staged_context().epoch().as_u64(); info!( added = Obfuscated::from(&added_members), removed = Obfuscated::from(&removed_members), - group_id = conversation.id, + group_id, epoch, proposals:? = staged_commit.queued_proposals().map(Obfuscated::from).collect::>(); "Epoch advanced" ); - session.notify_epoch_changed(conversation.id.clone(), epoch).await; + session.notify_epoch_changed(group_id, epoch).await; MlsConversationDecryptMessage { app_msg: None, proposals, - is_active: conversation.group.is_active(), - delay: conversation.compute_next_commit_delay(), + is_active, + delay, sender_client_id: None, identity, buffered_messages, } } ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => { - let mut conversation = self.conversation_mut().await; - info!( - group_id = conversation.id, - sender = Obfuscated::from(proposal.sender()); - "Received external join proposal" - ); - - conversation.group.store_pending_proposal(*proposal); + let delay = self + .conversation_mut(async move |conversation, _database| { + info!( + group_id = conversation.id, + sender = Obfuscated::from(proposal.sender()); + "Received external join proposal" + ); + conversation.group.store_pending_proposal(*proposal); + Ok(conversation.compute_next_commit_delay()) + }) + .await?; MlsConversationDecryptMessage { app_msg: None, proposals: vec![], is_active: true, - delay: conversation.compute_next_commit_delay(), + delay, sender_client_id: None, identity, buffered_messages: None, @@ -382,10 +394,6 @@ impl ConversationGuard { } }; - let mut conversation = self.conversation_mut().await; - - conversation.persist_group_when_changed(&database, false).await?; - Ok(decrypted) } @@ -433,53 +441,55 @@ impl ConversationGuard { ) -> Result { let msg_epoch = protocol_message.epoch().as_u64(); let backend = self.crypto_provider().await?; - let mut conversation = self.conversation_mut().await; - let group_epoch = conversation.group.epoch().as_u64(); - let processed_msg = conversation - .group - .process_message(&backend, protocol_message) - .await - .map_err(|e| match e { - ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( - MessageDecryptionError::GenerationOutOfBound, - )) => Error::DuplicateMessage, - ProcessMessageError::ValidationError(ValidationError::WrongEpoch) => { - if is_duplicate { - Error::DuplicateMessage - } else if msg_epoch == group_epoch + 1 { - // limit to next epoch otherwise if we were buffering a commit for epoch + 2 - // we would fail when trying to decrypt it in [MlsCentral::commit_accepted] - - // We need to buffer the message until the group has advanced to the right - // epoch. We can't do that here--we don't have the appropriate data in scope - // --but we can at least produce the proper error and return that, so our - // caller can handle it. Our caller needs to know about the epoch number, so - // we pass it back inside the error. - Error::BufferedFutureMessage { - message_epoch: msg_epoch, + self.conversation_mut(async move |conversation, _database| { + let group_epoch = conversation.group.epoch().as_u64(); + let processed_msg = conversation + .group + .process_message(&backend, protocol_message) + .await + .map_err(|e| match e { + ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( + MessageDecryptionError::GenerationOutOfBound, + )) => Error::DuplicateMessage, + ProcessMessageError::ValidationError(ValidationError::WrongEpoch) => { + if is_duplicate { + Error::DuplicateMessage + } else if msg_epoch == group_epoch + 1 { + // limit to next epoch otherwise if we were buffering a commit for epoch + 2 + // we would fail when trying to decrypt it in [MlsCentral::commit_accepted] + + // We need to buffer the message until the group has advanced to the right + // epoch. We can't do that here--we don't have the appropriate data in scope + // --but we can at least produce the proper error and return that, so our + // caller can handle it. Our caller needs to know about the epoch number, so + // we pass it back inside the error. + Error::BufferedFutureMessage { + message_epoch: msg_epoch, + } + } else if msg_epoch < group_epoch { + match content_type { + ContentType::Application => Error::StaleMessage, + ContentType::Commit => Error::StaleCommit, + ContentType::Proposal => Error::StaleProposal, + } + } else { + Error::UnbufferedFarFutureMessage } - } else if msg_epoch < group_epoch { - match content_type { - ContentType::Application => Error::StaleMessage, - ContentType::Commit => Error::StaleCommit, - ContentType::Proposal => Error::StaleProposal, - } - } else { - Error::UnbufferedFarFutureMessage } - } - ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( - MessageDecryptionError::AeadError, - )) => Error::DecryptionError, - ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( - MessageDecryptionError::SecretTreeError(SecretTreeError::TooDistantInThePast), - )) => Error::MessageEpochTooOld, - _ => MlsError::wrap("processing message")(e).into(), - })?; - if is_duplicate { - return Err(Error::DuplicateMessage); - } - Ok(processed_msg) + ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( + MessageDecryptionError::AeadError, + )) => Error::DecryptionError, + ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( + MessageDecryptionError::SecretTreeError(SecretTreeError::TooDistantInThePast), + )) => Error::MessageEpochTooOld, + _ => MlsError::wrap("processing message")(e).into(), + })?; + if is_duplicate { + return Err(Error::DuplicateMessage); + } + Ok(processed_msg) + }) + .await } async fn validate_commit(&self, commit: &StagedCommit) -> Result<()> { @@ -1053,10 +1063,12 @@ mod tests { conversation .guard() .await - .conversation_mut() + .conversation_mut(async |conv, _db| { + conv.group.clear_pending_proposals(); + Ok(()) + }) .await - .group - .clear_pending_proposals(); + .unwrap(); let commit_guard = conversation.update_unmerged().await; let old_commit = commit_guard.message().to_bytes().unwrap(); let conversation = commit_guard.finish(); diff --git a/crypto/src/mls/conversation/conversation_guard/encrypt.rs b/crypto/src/mls/conversation/conversation_guard/encrypt.rs index ab6adcb9c6..d8864983b1 100644 --- a/crypto/src/mls/conversation/conversation_guard/encrypt.rs +++ b/crypto/src/mls/conversation/conversation_guard/encrypt.rs @@ -21,23 +21,23 @@ impl ConversationGuard { pub async fn encrypt_message(&mut self, message: impl AsRef<[u8]>) -> Result> { let backend = self.crypto_provider().await?; let credential = self.credential().await?; - let signer = credential.signature_key(); - let database = self.database().await?; - let mut inner = self.conversation_mut().await; - let encrypted = inner - .group - .create_message(&backend, signer, message.as_ref()) - .map_err(MlsError::wrap("creating message"))?; - // make sure all application messages are encrypted - debug_assert!(matches!(encrypted.body, MlsMessageOutBody::PrivateMessage(_))); + self.conversation_mut(async move |conversation, _database| { + let signer = credential.signature_key(); + let encrypted = conversation + .group + .create_message(&backend, signer, message.as_ref()) + .map_err(MlsError::wrap("creating message"))?; - let encrypted = encrypted - .to_bytes() - .map_err(MlsError::wrap("constructing byte vector of encrypted message"))?; + // make sure all application messages are encrypted + debug_assert!(matches!(encrypted.body, MlsMessageOutBody::PrivateMessage(_))); - inner.persist_group_when_changed(&database, false).await?; - Ok(encrypted) + encrypted + .to_bytes() + .map_err(MlsError::wrap("constructing byte vector of encrypted message")) + .map_err(Into::into) + }) + .await } } diff --git a/crypto/src/mls/conversation/conversation_guard/history_sharing.rs b/crypto/src/mls/conversation/conversation_guard/history_sharing.rs index c67f330f72..834ae6d165 100644 --- a/crypto/src/mls/conversation/conversation_guard/history_sharing.rs +++ b/crypto/src/mls/conversation/conversation_guard/history_sharing.rs @@ -157,36 +157,35 @@ impl ConversationGuard { self.clear_pending_commit().await?; - let session = &self.session().await?; - let provider = &self.crypto_provider().await?; + let session = self.session().await?; + let provider = self.crypto_provider().await?; let history_secret = self.generate_history_secret().await?; - let database = &self.database().await?; let key_package = history_secret.key_package.clone().into(); - let mut conversation = self.conversation_mut().await; - - // Propose to remove the old history client - for history_client in existing_history_clients { - conversation - .propose_remove_member(session, provider, database, history_client) - .await?; - } + let remove_and_add = self + .conversation_mut(async move |conversation, database| { + // Propose to remove the old history client + for history_client in existing_history_clients { + conversation + .propose_remove_member(&session, &provider, database, history_client) + .await?; + } - // Propose to add a new history client - conversation - .propose_add_member(session, provider, database, key_package) + // Propose to add a new history client + conversation + .propose_add_member(&session, &provider, database, key_package) + .await?; + + // We're getting the proposals we just created from the pending proposals queue, as the previously + // called `propose_remove()` and `propose_add()` pushed them to that queue as a side effect. + Ok(conversation + .self_pending_proposals() + .map(|proposal| proposal.proposal()) + .cloned() + .collect::>()) + }) .await?; - // We're getting the proposals we just created from the pending proposals queue, as the previously - // called `propose_remove()` and `propose_add()` pushed them to that queue as a side effect. - let remove_and_add = conversation - .self_pending_proposals() - .map(|proposal| proposal.proposal()) - .cloned() - .collect(); - - drop(conversation); - let inline_proposals = [pending_proposals, remove_and_add].concat(); let commit = self diff --git a/crypto/src/mls/conversation/conversation_guard/merge.rs b/crypto/src/mls/conversation/conversation_guard/merge.rs index 1eda71e6fc..d56157c7a0 100644 --- a/crypto/src/mls/conversation/conversation_guard/merge.rs +++ b/crypto/src/mls/conversation/conversation_guard/merge.rs @@ -32,18 +32,17 @@ impl ConversationGuard { /// When the conversation is not found or the proposal reference does not identify a proposal /// in the local pending proposal store pub async fn clear_pending_proposal(&mut self, proposal_ref: MlsProposalRef) -> Result<()> { - let database = self.database().await?; - let mut conversation = self.conversation_mut().await; - conversation - .group - .remove_pending_proposal(&database, &proposal_ref) - .await - .map_err(|mls_group_state_error| match mls_group_state_error { - MlsGroupStateError::PendingProposalNotFound => Error::PendingProposalNotFound(proposal_ref), - _ => MlsError::wrap("removing pending proposal")(mls_group_state_error).into(), - })?; - conversation.persist_group_when_changed(&database, true).await?; - Ok(()) + self.conversation_mut(async move |conversation, database| { + conversation + .group + .remove_pending_proposal(database, &proposal_ref) + .await + .map_err(|mls_group_state_error| match mls_group_state_error { + MlsGroupStateError::PendingProposalNotFound => Error::PendingProposalNotFound(proposal_ref), + _ => MlsError::wrap("removing pending proposal")(mls_group_state_error).into(), + }) + }) + .await } /// Allows to remove a pending commit. Use this when backend rejects the commit @@ -58,17 +57,17 @@ impl ConversationGuard { /// # Errors /// When there is no pending commit pub(crate) async fn clear_pending_commit(&mut self) -> Result<()> { - let database = self.database().await?; - - let mut conversation = self.conversation_mut().await; - if conversation.group.pending_commit().is_some() { - conversation.group.clear_pending_commit(); - conversation.persist_group_when_changed(&database, true).await?; - log::info!(group_id = conversation.id(); "Cleared pending commit."); - Ok(()) - } else { - Err(Error::PendingCommitNotFound) - } + self.conversation_mut(async |conversation, database| { + if conversation.group.pending_commit().is_some() { + conversation.group.clear_pending_commit(); + conversation.persist_group_when_changed(database, true).await?; + log::info!(group_id = conversation.id(); "Cleared pending commit."); + Ok(()) + } else { + Err(Error::PendingCommitNotFound) + } + }) + .await } /// Clear a pending commit if it exists. Unlike [Self::clear_pending_commit], diff --git a/crypto/src/mls/conversation/conversation_guard/mod.rs b/crypto/src/mls/conversation/conversation_guard/mod.rs index fef632d477..efce3be312 100644 --- a/crypto/src/mls/conversation/conversation_guard/mod.rs +++ b/crypto/src/mls/conversation/conversation_guard/mod.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use async_lock::{RwLockReadGuard, RwLockWriteGuard}; +use async_lock::RwLockReadGuard; use core_crypto_keystore::{CryptoKeystoreMls as _, Database}; use openmls::prelude::group_info::GroupInfo; use openmls_traits::OpenMlsCryptoProvider as _; @@ -47,8 +47,30 @@ impl ConversationGuard { Self { inner, central_context } } - pub(crate) async fn conversation_mut(&mut self) -> RwLockWriteGuard<'_, MlsConversation> { - self.inner.write().await + /// Perform an operation on a mutable reference to this conversation. + /// + /// Errors will be propagated. + /// When the operation does not error, [`MlsConversation::persist_group_when_changed`] will be called automatically. + /// This ensures that persistence cannot be forgotten. + /// + /// We choose to implement this as a closure instead of a lightweight holding a reference to the coversation + /// which calls that method on `Drop` because this way we can ensure we do _not_ automatically call it when there is + /// an error. + pub(crate) async fn conversation_mut( + &mut self, + operation: impl AsyncFnOnce(&mut MlsConversation, &Database) -> Result, + ) -> Result { + // we can't get the database if the transaction context has been invalidated, + // and we want to have that error first before evaluating anything in the operation. + let database = self + .central_context + .database() + .await + .map_err(RecursiveError::transaction("getting database from context"))?; + let mut guard = self.inner.write().await; + let ok_result = operation(&mut guard, &database).await?; + guard.persist_group_when_changed(&database, false).await?; + Ok(ok_result) } async fn transport(&self) -> Result> { @@ -75,14 +97,20 @@ impl ConversationGuard { .mls_groups() .await .map_err(RecursiveError::transaction("getting mls groups"))?; - let mut conversation = self.conversation_mut().await; - conversation.wipe_associated_entities(&provider).await?; + + let id = self + .conversation_mut(async |conversation, _| { + conversation.wipe_associated_entities(&provider).await?; + Ok(conversation.id().to_owned()) + }) + .await?; provider .key_store() - .mls_group_delete(conversation.id()) + .mls_group_delete(&id) .await .map_err(KeystoreError::wrap("deleting mls group"))?; - let _ = group_store.remove(conversation.id()); + let _ = group_store.remove(&id); + Ok(()) } diff --git a/crypto/src/mls/conversation/merge.rs b/crypto/src/mls/conversation/merge.rs index 55131ebea5..88a1de2fb1 100644 --- a/crypto/src/mls/conversation/merge.rs +++ b/crypto/src/mls/conversation/merge.rs @@ -65,16 +65,16 @@ mod tests { assert!(conversation.has_pending_commit().await); + let session = alice.session().await; + conversation .guard() .await - .conversation_mut() - .await - .commit_accepted( - &alice.transaction.session().await.unwrap(), - &alice.database().await, - &alice.session().await.crypto_provider, - ) + .conversation_mut(async |conversation, database| { + conversation + .commit_accepted(&session, database, &session.crypto_provider) + .await + }) .await .unwrap(); @@ -101,16 +101,16 @@ mod tests { assert!(conversation.has_pending_proposals().await); assert!(conversation.has_pending_commit().await); + let session = alice.session().await; + conversation .guard() .await - .conversation_mut() - .await - .commit_accepted( - &alice.transaction.session().await.unwrap(), - &alice.database().await, - &alice.session().await.crypto_provider, - ) + .conversation_mut(async |conversation, database| { + conversation + .commit_accepted(&session, database, &session.crypto_provider) + .await + }) .await .unwrap(); assert!(!conversation.has_pending_proposals().await); diff --git a/crypto/src/transaction_context/conversation/external_commit.rs b/crypto/src/transaction_context/conversation/external_commit.rs index e747fdb6ab..c75ab533ed 100644 --- a/crypto/src/transaction_context/conversation/external_commit.rs +++ b/crypto/src/transaction_context/conversation/external_commit.rs @@ -374,28 +374,37 @@ mod tests { // we need an invalid GroupInfo; let's manufacture one. let group_info = { - let mut conversation = conversation.guard().await; - let mut conversation = conversation.conversation_mut().await; - let group = &mut conversation.group; - let ct = group - .credential() + let mut guard = conversation.guard().await; + let (ciphersuite, credential_type) = { + let conversation = guard.conversation().await; + let credential_type = conversation + .group + .credential() + .unwrap() + .credential_type() + .try_into() + .expect("case conversation has a known credential type"); + let ciphersuite = conversation.group.ciphersuite(); + (ciphersuite, credential_type) + }; + let credential = alice.find_any_credential(ciphersuite.into(), credential_type).await; + let mls_provider = alice.transaction.mls_provider().await.unwrap(); + guard + .conversation_mut(async move |conversation, _database| { + let gi = conversation + .group + .export_group_info( + &mls_provider, + &credential.signature_key_pair, + // joining by external commit assumes we include a ratchet tree, but this `false` + // says to leave it out + false, + ) + .unwrap(); + Ok(gi.group_info().unwrap()) + }) + .await .unwrap() - .credential_type() - .try_into() - .expect("case conversation has a known credential type"); - let cs = group.ciphersuite(); - let cb = alice.find_any_credential(cs.into(), ct).await; - - let gi = group - .export_group_info( - &alice.transaction.mls_provider().await.unwrap(), - &cb.signature_key_pair, - // joining by external commit assumes we include a ratchet tree, but this `false` - // says to leave it out - false, - ) - .unwrap(); - gi.group_info().unwrap() }; let join_ext_commit = guest diff --git a/crypto/src/transaction_context/conversation/proposal.rs b/crypto/src/transaction_context/conversation/proposal.rs index 1d6e618aff..8b4228cf65 100644 --- a/crypto/src/transaction_context/conversation/proposal.rs +++ b/crypto/src/transaction_context/conversation/proposal.rs @@ -2,7 +2,9 @@ use openmls::prelude::KeyPackage; use super::{Error, Result}; use crate::{ - ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, transaction_context::TransactionContext, + ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, + mls::conversation::{ConversationWithMls as _, Error as ConversationError}, + transaction_context::TransactionContext, }; impl TransactionContext { @@ -38,34 +40,54 @@ impl TransactionContext { /// If the conversation is not found, an error will be returned. Errors from OpenMls can be /// returned as well, when for example there's a commit pending to be merged async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result { - let mut conversation = self.conversation(id).await?; - let mut conversation = conversation.conversation_mut().await; - let client = &self.session().await?; - let provider = &self.mls_provider().await?; - let database = &self.database().await?; - let proposal = match proposal { - MlsProposal::Add(key_package) => conversation - .propose_add_member(client, provider, database, key_package.into()) - .await - .map_err(RecursiveError::mls_conversation("proposing to add member"))?, - MlsProposal::Update => conversation - .propose_self_update(client, provider, database) - .await - .map_err(RecursiveError::mls_conversation("proposing self update"))?, - MlsProposal::Remove(client_id) => { - let index = conversation + let mut guard = self.conversation(id).await?; + let client = self.session().await?; + let provider = self.mls_provider().await?; + + // For Remove proposals, look up the leaf index before taking the write lock so we can + // surface ClientNotFound as a transaction-level error. + let remove_index = if let MlsProposal::Remove(client_id) = &proposal { + Some( + guard + .conversation() + .await .group .members() .find(|kp| kp.credential.identity() == client_id.as_slice()) - .ok_or(Error::ClientNotFound(client_id)) - .map(|kp| kp.index)?; - (*conversation) - .propose_remove_member(client, provider, database, index) - .await - .map_err(RecursiveError::mls_conversation("proposing to remove member"))? - } + .ok_or_else(|| Error::ClientNotFound(client_id.clone())) + .map(|kp| kp.index)?, + ) + } else { + None }; - Ok(proposal) + + guard + .conversation_mut(async move |conversation, database| { + let proposal = match proposal { + MlsProposal::Add(key_package) => conversation + .propose_add_member(&client, &provider, database, key_package.into()) + .await + .map_err(RecursiveError::mls_conversation("proposing to add member")) + .map_err(ConversationError::from)?, + MlsProposal::Update => conversation + .propose_self_update(&client, &provider, database) + .await + .map_err(RecursiveError::mls_conversation("proposing self update")) + .map_err(ConversationError::from)?, + MlsProposal::Remove(_) => { + let index = remove_index.expect("we always have a remove index for a remove proposal"); + conversation + .propose_remove_member(&client, &provider, database, index) + .await + .map_err(RecursiveError::mls_conversation("proposing to remove member")) + .map_err(ConversationError::from)? + } + }; + Ok(proposal) + }) + .await + .map_err(RecursiveError::mls_conversation("new proposal")) + .map_err(Into::into) } }