diff --git a/src/extensions.rs b/src/extensions.rs index c4862b791..bcfcc8c56 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -11,6 +11,7 @@ bitflags::bitflags! { /// /// During deserialization, this extension requires that structs' names are stated explicitly. const EXPLICIT_STRUCT_NAMES = 0x8; + const ARBITRARY_IDENTIFIERS = 0x10; } } // GRCOV_EXCL_STOP diff --git a/src/parse.rs b/src/parse.rs index c03eb3cb4..69b6cfc41 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -870,6 +870,32 @@ impl<'a> Parser<'a> { return None; } + if self.exts.contains(Extensions::ARBITRARY_IDENTIFIERS) { + // FIXME: optimize + if self.check_str("i\"") { + let cursor_backup = self.cursor; + self.advance_bytes(1); + return match self.escaped_string() { + Ok(ParsedStr::Slice(ident)) => Some(ident), + Ok(ParsedStr::Allocated(_)) | Err(_) => { + self.cursor = cursor_backup; + None + } + }; + } + if self.check_str("ri#") { + let cursor_backup = self.cursor; + self.advance_bytes(2); + return match self.raw_string() { + Ok(ParsedStr::Slice(ident)) => Some(ident), + Ok(ParsedStr::Allocated(_)) | Err(_) => { + self.cursor = cursor_backup; + None + } + }; + } + } + if self.check_str("r#") { // maybe a raw identifier let len = self.next_chars_while_from_len(2, is_ident_raw_char); @@ -896,6 +922,30 @@ impl<'a> Parser<'a> { } pub fn identifier(&mut self) -> Result<&'a str> { + if self.check_str("i\"") { + let cursor_backup = self.cursor; + self.advance_bytes(1); + return match self.escaped_string() { + Ok(ParsedStr::Slice(ident)) => Ok(ident), + Ok(ParsedStr::Allocated(_)) | Err(_) => { + self.cursor = cursor_backup; + Err(Error::ExpectedIdentifier) + } + }; + } + + if self.check_str("ri#") { + let cursor_backup = self.cursor; + self.advance_bytes(2); + return match self.raw_string() { + Ok(ParsedStr::Slice(ident)) => Ok(ident), + Ok(ParsedStr::Allocated(_)) | Err(_) => { + self.cursor = cursor_backup; + Err(Error::ExpectedIdentifier) + } + }; + } + let first = self.peek_char_or_eof()?; if !is_ident_first_char(first) { if is_ident_raw_char(first) { diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 70e62b161..a50ed0f45 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -621,7 +621,18 @@ impl Serializer { } fn write_identifier(&mut self, name: &str) -> Result<()> { - self.validate_identifier(name)?; + let arbitrary_identifiers = self + .extensions() + .contains(Extensions::ARBITRARY_IDENTIFIERS); + + match self.validate_identifier(name) { + Ok(()) => self.write_valid_identifier(name), + Err(_) if arbitrary_identifiers => self.write_arbitrary_identifier(name), + Err(err) => Err(err), + } + } + + fn write_valid_identifier(&mut self, name: &str) -> Result<()> { let mut chars = name.chars(); if !chars.next().map_or(false, is_ident_first_char) || !chars.all(is_xid_continue) @@ -637,6 +648,39 @@ impl Serializer { Ok(()) } + fn write_arbitrary_identifier(&mut self, name: &str) -> Result<()> { + if name.contains('"') || name.contains('\\') { + let (_, num_consecutive_hashes) = + name.chars().fold((0, 0), |(count, max), c| match c { + '#' => (count + 1, max.max(count + 1)), + _ => (0_usize, max), + }); + let hashes: String = "#".repeat(num_consecutive_hashes + 1); + self.output.write_str("ri")?; + self.output.write_str(&hashes)?; + self.output.write_char('"')?; + self.output.write_str(name)?; + self.output.write_char('"')?; + self.output.write_str(&hashes)?; + } else { + self.output.write_str(r#"i""#)?; + self.output.write_str(name)?; + self.output.write_char('"')?; + } + Ok(()) + } + + #[allow(clippy::unused_self)] + fn validate_identifier_with_arbitrary(&self, name: &str) -> Result<()> { + if self + .extensions() + .contains(Extensions::ARBITRARY_IDENTIFIERS) + { + return Ok(()); + } + self.validate_identifier(name) + } + #[allow(clippy::unused_self)] fn validate_identifier(&self, name: &str) -> Result<()> { if name.is_empty() || !name.chars().all(is_ident_raw_char) { @@ -865,7 +909,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { Ok(()) } else { - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; self.serialize_unit() } } @@ -876,7 +920,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { _variant_index: u32, variant: &'static str, ) -> Result<()> { - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; self.write_identifier(variant)?; Ok(()) @@ -906,7 +950,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { if self.extensions().contains(Extensions::UNWRAP_NEWTYPES) || self.newtype_variant { self.newtype_variant = false; - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; return guard_recursion! { self => value.serialize(&mut *self) }; } @@ -914,7 +958,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { if self.struct_names() { self.write_identifier(name)?; } else { - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; } self.implicit_some_depth = 0; @@ -936,7 +980,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { where T: ?Sized + Serialize, { - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; self.write_identifier(variant)?; self.output.write_char('(')?; @@ -996,7 +1040,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { if self.struct_names() && !self.newtype_variant { self.write_identifier(name)?; } else { - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; } self.serialize_tuple(len) @@ -1012,7 +1056,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { self.newtype_variant = false; self.implicit_some_depth = 0; - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; self.write_identifier(variant)?; self.output.write_char('(')?; @@ -1048,12 +1092,12 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { self.implicit_some_depth = 0; if old_newtype_variant { - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; } else { if self.struct_names() { self.write_identifier(name)?; } else { - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; } self.output.write_char('(')?; } @@ -1076,7 +1120,7 @@ impl<'a, W: fmt::Write> ser::Serializer for &'a mut Serializer { self.newtype_variant = false; self.implicit_some_depth = 0; - self.validate_identifier(name)?; + self.validate_identifier_with_arbitrary(name)?; self.write_identifier(variant)?; self.output.write_char('(')?; diff --git a/tests/532_arbitrary_identifier.rs b/tests/532_arbitrary_identifier.rs new file mode 100644 index 000000000..7c76c69bd --- /dev/null +++ b/tests/532_arbitrary_identifier.rs @@ -0,0 +1,34 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Deserialize, Serialize)] +struct Event { + #[serde(rename = "@sender")] + sender: String, +} + +#[test] +fn test_arbitary_identifier() { + let ron = ron::Options::default() + .with_default_extension(ron::extensions::Extensions::ARBITRARY_IDENTIFIERS); + let event = Event { + sender: "test".to_string(), + }; + let ser = ron.to_string(&event).unwrap(); + assert_eq!(ser, r#"(i"@sender":"test")"#); + let de: Event = ron.from_str(&ser).unwrap(); + assert_eq!(de, event); +} + +#[test] +fn test_arbitary_identifier_without_extension() { + let ron = ron::Options::default(); + let event = Event { + sender: "test".to_string(), + }; + let ser = ron.to_string(&event).unwrap_err(); + // FIXME: assert_eq!(ser, ...); + let de = ron + .from_str::(r#"Event(i"@sender": "test")"#) + .unwrap_err(); + // FIXME: assert_eq!(de, ...); +}