hstr/
wtf8_atom.rs

1use std::{
2    fmt::Debug,
3    hash::Hash,
4    mem::{forget, transmute, ManuallyDrop},
5    ops::Deref,
6};
7
8use debug_unreachable::debug_unreachable;
9
10use crate::{
11    macros::{get_hash, impl_from_alias, partial_eq},
12    tagged_value::TaggedValue,
13    wtf8::Wtf8,
14    Atom, DYNAMIC_TAG, INLINE_TAG, LEN_MASK, LEN_OFFSET, TAG_MASK,
15};
16
17/// A WTF-8 encoded atom. This is like [Atom], but can contain unpaired
18/// surrogates.
19///
20/// [Atom]: crate::Atom
21#[repr(transparent)]
22pub struct Wtf8Atom {
23    pub(crate) unsafe_data: TaggedValue,
24}
25
26impl Wtf8Atom {
27    #[inline(always)]
28    pub fn new<S>(s: S) -> Self
29    where
30        Self: From<S>,
31    {
32        Self::from(s)
33    }
34
35    /// Try to convert this to a UTF-8 [Atom].
36    ///
37    /// Returns [Atom] if the string is valid UTF-8, otherwise returns
38    /// the original [Wtf8Atom].
39    pub fn try_into_atom(self) -> Result<Atom, Wtf8Atom> {
40        if self.as_str().is_some() {
41            let atom = ManuallyDrop::new(self);
42            Ok(Atom {
43                unsafe_data: atom.unsafe_data,
44            })
45        } else {
46            Err(self)
47        }
48    }
49
50    #[inline(always)]
51    fn tag(&self) -> u8 {
52        self.unsafe_data.tag() & TAG_MASK
53    }
54
55    /// Return true if this is a dynamic Atom.
56    #[inline(always)]
57    fn is_dynamic(&self) -> bool {
58        self.tag() == DYNAMIC_TAG
59    }
60}
61
62impl Default for Wtf8Atom {
63    #[inline(never)]
64    fn default() -> Self {
65        Wtf8Atom::new("")
66    }
67}
68
69/// Immutable, so it's safe to be shared between threads
70unsafe impl Send for Wtf8Atom {}
71
72/// Immutable, so it's safe to be shared between threads
73unsafe impl Sync for Wtf8Atom {}
74
75impl Debug for Wtf8Atom {
76    #[inline]
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        Debug::fmt(&**self, f)
79    }
80}
81
82#[cfg(feature = "serde")]
83impl serde::ser::Serialize for Wtf8Atom {
84    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85    where
86        S: serde::ser::Serializer,
87    {
88        use crate::wtf8::Wtf8;
89        fn convert_wtf8_to_raw(s: &Wtf8) -> String {
90            let mut result = String::new();
91            let mut iter = s.code_points().peekable();
92
93            while let Some(code_point) = iter.next() {
94                if let Some(c) = code_point.to_char() {
95                    // Escape literal '\u' sequences to avoid ambiguity with surrogate encoding.
96                    // Without this escaping, we couldn't distinguish between:
97                    // - JavaScript's "\uD800" (actual unpaired surrogate)
98                    // - JavaScript's "\\uD800" (literal text '\uD800')
99                    //
100                    // By escaping literal '\u' to '\\u', we ensure:
101                    // - Unpaired surrogates serialize as '\uXXXX'
102                    // - Literal '\u' text serializes as '\\uXXXX'
103                    if c == '\\' && iter.peek().map(|cp| cp.to_u32()) == Some('u' as u32) {
104                        iter.next(); // skip 'u'
105                        result.push_str("\\\\u");
106                    } else {
107                        result.push(c)
108                    }
109                } else {
110                    // Unpaired surrogates can't be represented in valid UTF-8,
111                    // so encode them as '\uXXXX' for JavaScript compatibility
112                    result.push_str(format!("\\u{:04X}", code_point.to_u32()).as_str());
113                }
114            }
115
116            result
117        }
118
119        serializer.serialize_str(&convert_wtf8_to_raw(self))
120    }
121}
122
123#[cfg(feature = "serde")]
124impl<'de> serde::de::Deserialize<'de> for Wtf8Atom {
125    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126    where
127        D: serde::Deserializer<'de>,
128    {
129        use crate::wtf8::{CodePoint, Wtf8Buf};
130        fn convert_wtf8_string_to_wtf8(s: String) -> Wtf8Buf {
131            let mut iter = s.chars().peekable();
132            let mut result = Wtf8Buf::with_capacity(s.len());
133
134            // This function reverses the encoding done in serialize.
135            // It handles two cases:
136            // 1. '\uXXXX' - Decode as an unpaired surrogate code point
137            // 2. '\\uXXXX' - Treat as literal text '\uXXXX'
138            while let Some(c) = iter.next() {
139                if c == '\\' {
140                    if iter.peek() == Some(&'u') {
141                        // Found '\u' - might be a surrogate encoding
142                        let _ = iter.next(); // skip 'u'
143
144                        // Try to read 4 hex digits
145                        let d1 = iter.next();
146                        let d2 = iter.next();
147                        let d3 = iter.next();
148                        let d4 = iter.next();
149
150                        if d1.is_some() && d2.is_some() && d3.is_some() && d4.is_some() {
151                            let hex = format!(
152                                "{}{}{}{}",
153                                d1.unwrap(),
154                                d2.unwrap(),
155                                d3.unwrap(),
156                                d4.unwrap()
157                            );
158                            if let Ok(code_point) = u16::from_str_radix(&hex, 16) {
159                                result.push(unsafe {
160                                    CodePoint::from_u32_unchecked(code_point as u32)
161                                });
162                                continue;
163                            }
164                        }
165
166                        result.push_char('\\');
167                        result.push_char('u');
168
169                        macro_rules! push_if_some {
170                            ($expr:expr) => {
171                                if let Some(c) = $expr {
172                                    result.push_char(c);
173                                }
174                            };
175                        }
176
177                        push_if_some!(d1);
178                        push_if_some!(d2);
179                        push_if_some!(d3);
180                        push_if_some!(d4);
181                    } else if iter.peek() == Some(&'\\') {
182                        // Found '\\' - this is an escaped backslash
183                        // '\\u' should become literal '\u' text
184                        let _ = iter.next(); // skip the second '\'
185                        if iter.peek() == Some(&'u') {
186                            let _ = iter.next(); // skip 'u'
187                            result.push_char('\\');
188                            result.push_char('u');
189                        } else {
190                            result.push_str("\\\\");
191                        }
192                    } else {
193                        result.push_char(c);
194                    }
195                } else {
196                    result.push_char(c);
197                }
198            }
199            result
200        }
201
202        String::deserialize(deserializer).map(|v| convert_wtf8_string_to_wtf8(v).into())
203    }
204}
205
206impl PartialEq for Wtf8Atom {
207    #[inline(never)]
208    fn eq(&self, other: &Self) -> bool {
209        partial_eq!(self, other);
210
211        // If the store is different, the string may be the same, even though the
212        // `unsafe_data` is different
213        self.as_wtf8() == other.as_wtf8()
214    }
215}
216
217impl Eq for Wtf8Atom {}
218
219impl Hash for Wtf8Atom {
220    #[inline(always)]
221    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
222        state.write_u64(self.get_hash());
223    }
224}
225
226impl Drop for Wtf8Atom {
227    #[inline(always)]
228    fn drop(&mut self) {
229        if self.is_dynamic() {
230            unsafe { drop(crate::dynamic::restore_arc(self.unsafe_data)) }
231        }
232    }
233}
234
235impl Clone for Wtf8Atom {
236    #[inline(always)]
237    fn clone(&self) -> Self {
238        Self::from_alias(self.unsafe_data)
239    }
240}
241
242impl Deref for Wtf8Atom {
243    type Target = Wtf8;
244
245    #[inline(always)]
246    fn deref(&self) -> &Self::Target {
247        self.as_wtf8()
248    }
249}
250
251impl AsRef<Wtf8> for Wtf8Atom {
252    #[inline(always)]
253    fn as_ref(&self) -> &Wtf8 {
254        self.as_wtf8()
255    }
256}
257
258impl PartialEq<Wtf8> for Wtf8Atom {
259    #[inline]
260    fn eq(&self, other: &Wtf8) -> bool {
261        self.as_wtf8() == other
262    }
263}
264
265impl PartialEq<crate::Atom> for Wtf8Atom {
266    #[inline]
267    fn eq(&self, other: &crate::Atom) -> bool {
268        self.as_str() == Some(other.as_str())
269    }
270}
271
272impl PartialEq<&'_ Wtf8> for Wtf8Atom {
273    #[inline]
274    fn eq(&self, other: &&Wtf8) -> bool {
275        self.as_wtf8() == *other
276    }
277}
278
279impl PartialEq<Wtf8Atom> for Wtf8 {
280    #[inline]
281    fn eq(&self, other: &Wtf8Atom) -> bool {
282        self == other.as_wtf8()
283    }
284}
285
286impl PartialEq<str> for Wtf8Atom {
287    #[inline]
288    fn eq(&self, other: &str) -> bool {
289        matches!(self.as_str(), Some(s) if s == other)
290    }
291}
292
293impl PartialEq<&str> for Wtf8Atom {
294    #[inline]
295    fn eq(&self, other: &&str) -> bool {
296        matches!(self.as_str(), Some(s) if s == *other)
297    }
298}
299
300impl Wtf8Atom {
301    pub(super) fn get_hash(&self) -> u64 {
302        get_hash!(self)
303    }
304
305    fn as_wtf8(&self) -> &Wtf8 {
306        match self.tag() {
307            DYNAMIC_TAG => unsafe {
308                let item = crate::dynamic::deref_from(self.unsafe_data);
309                Wtf8::from_bytes_unchecked(transmute::<&[u8], &'static [u8]>(&item.slice))
310            },
311            INLINE_TAG => {
312                let len = (self.unsafe_data.tag() & LEN_MASK) >> LEN_OFFSET;
313                let src = self.unsafe_data.data();
314                unsafe { Wtf8::from_bytes_unchecked(&src[..(len as usize)]) }
315            }
316            _ => unsafe { debug_unreachable!() },
317        }
318    }
319}
320
321impl_from_alias!(Wtf8Atom);
322
323#[cfg(test)]
324impl Wtf8Atom {
325    pub(crate) fn ref_count(&self) -> usize {
326        match self.tag() {
327            DYNAMIC_TAG => {
328                let ptr = unsafe { crate::dynamic::deref_from(self.unsafe_data) };
329
330                triomphe::ThinArc::strong_count(&ptr.0)
331            }
332            _ => 1,
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::wtf8::{CodePoint, Wtf8Buf};
341
342    #[test]
343    fn test_serialize_normal_utf8() {
344        let atom = Wtf8Atom::new("Hello, world!");
345        let serialized = serde_json::to_string(&atom).unwrap();
346        assert_eq!(serialized, "\"Hello, world!\"");
347    }
348
349    #[test]
350    fn test_deserialize_normal_utf8() {
351        let json = "\"Hello, world!\"";
352        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
353        assert_eq!(atom.as_str(), Some("Hello, world!"));
354    }
355
356    #[test]
357    fn test_serialize_unpaired_high_surrogate() {
358        // Create a WTF-8 string with an unpaired high surrogate (U+D800)
359        let mut wtf8 = Wtf8Buf::new();
360        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
361        let atom = Wtf8Atom::from(wtf8);
362
363        let serialized = serde_json::to_string(&atom).unwrap();
364        // The serialized output will have double escaping due to serde_json
365        assert_eq!(serialized, "\"\\\\uD800\"");
366    }
367
368    #[test]
369    fn test_serialize_unpaired_low_surrogate() {
370        // Create a WTF-8 string with an unpaired low surrogate (U+DC00)
371        let mut wtf8 = Wtf8Buf::new();
372        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
373        let atom = Wtf8Atom::from(wtf8);
374
375        let serialized = serde_json::to_string(&atom).unwrap();
376        // The serialized output will have double escaping due to serde_json
377        assert_eq!(serialized, "\"\\\\uDC00\"");
378    }
379
380    #[test]
381    fn test_serialize_multiple_surrogates() {
382        // Create a WTF-8 string with multiple unpaired surrogates
383        let mut wtf8 = Wtf8Buf::new();
384        wtf8.push_str("Hello ");
385        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
386        wtf8.push_str(" World ");
387        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
388        let atom = Wtf8Atom::from(wtf8);
389
390        let serialized = serde_json::to_string(&atom).unwrap();
391        // The serialized output will have double escaping due to serde_json
392        assert_eq!(serialized, "\"Hello \\\\uD800 World \\\\uDC00\"");
393    }
394
395    #[test]
396    fn test_serialize_literal_backslash_u() {
397        // Test that literal "\u" in the string gets escaped properly
398        let atom = Wtf8Atom::new("\\u0041");
399        let serialized = serde_json::to_string(&atom).unwrap();
400        // serde_json escapes the backslash, resulting in 4 backslashes
401        assert_eq!(serialized, "\"\\\\\\\\u0041\"");
402    }
403
404    #[test]
405    fn test_deserialize_escaped_backslash_u() {
406        // Test deserializing the escaped format for unpaired surrogates
407        let json = "\"\\\\uD800\"";
408        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
409        // This should be parsed as an unpaired surrogate
410        assert_eq!(atom.as_str(), None);
411        assert_eq!(atom.to_string_lossy(), "\u{FFFD}");
412    }
413
414    #[test]
415    fn test_deserialize_unpaired_surrogates() {
416        let json = "\"\\\\uD800\""; // Use escaped format that matches serialization
417        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
418        // Should contain an unpaired surrogate, so as_str() returns None
419        assert_eq!(atom.as_str(), None);
420        // But to_string_lossy should work
421        assert_eq!(atom.to_string_lossy(), "\u{FFFD}");
422    }
423
424    #[test]
425    fn test_round_trip_normal_string() {
426        let original = Wtf8Atom::new("Hello, δΈ–η•Œ! 🌍");
427        let serialized = serde_json::to_string(&original).unwrap();
428        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
429        assert_eq!(original.as_str(), deserialized.as_str());
430    }
431
432    #[test]
433    fn test_round_trip_unpaired_surrogates() {
434        // Create a string with unpaired surrogates
435        let mut wtf8 = Wtf8Buf::new();
436        wtf8.push_str("Before ");
437        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
438        wtf8.push_str(" Middle ");
439        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
440        wtf8.push_str(" After");
441        let original = Wtf8Atom::from(wtf8);
442
443        let serialized = serde_json::to_string(&original).unwrap();
444        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
445
446        // Both should be equal when compared as WTF-8
447        assert_eq!(original, deserialized);
448
449        // Both should produce the same lossy string
450        assert_eq!(original.to_string_lossy(), deserialized.to_string_lossy());
451    }
452
453    #[test]
454    fn test_round_trip_mixed_content() {
455        // Create a complex string with normal text, emojis, and unpaired surrogates
456        let mut wtf8 = Wtf8Buf::new();
457        wtf8.push_str("Hello δΈ–η•Œ 🌍 ");
458        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd83d) }); // Unpaired high
459        wtf8.push_str(" test ");
460        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdca9) }); // Unpaired low
461        let original = Wtf8Atom::from(wtf8);
462
463        let serialized = serde_json::to_string(&original).unwrap();
464        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
465
466        assert_eq!(original, deserialized);
467    }
468
469    #[test]
470    fn test_empty_string() {
471        let atom = Wtf8Atom::new("");
472        let serialized = serde_json::to_string(&atom).unwrap();
473        assert_eq!(serialized, "\"\"");
474
475        let deserialized: Wtf8Atom = serde_json::from_str("\"\"").unwrap();
476        assert_eq!(deserialized.as_str(), Some(""));
477    }
478
479    #[test]
480    fn test_special_characters() {
481        let test_cases = vec![
482            ("\"", "\"\\\"\""),
483            ("\n\r\t", "\"\\n\\r\\t\""), // serde_json escapes control characters
484            ("\\", "\"\\\\\""),
485            ("/", "\"/\""),
486        ];
487
488        for (input, expected) in test_cases {
489            let atom = Wtf8Atom::new(input);
490            let serialized = serde_json::to_string(&atom).unwrap();
491            assert_eq!(serialized, expected, "Failed for input: {input:?}");
492
493            let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
494            assert_eq!(deserialized.as_str(), Some(input));
495        }
496    }
497
498    #[test]
499    fn test_consecutive_surrogates_not_paired() {
500        // Test that consecutive surrogates that don't form a valid pair
501        // are handled correctly
502        let mut wtf8 = Wtf8Buf::new();
503        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) }); // High surrogate
504        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) }); // Another high surrogate
505        let atom = Wtf8Atom::from(wtf8);
506
507        let serialized = serde_json::to_string(&atom).unwrap();
508        // The serialized output will have double escaping due to serde_json
509        assert_eq!(serialized, "\"\\\\uD800\\\\uD800\"");
510
511        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
512        assert_eq!(atom, deserialized);
513    }
514
515    #[test]
516    fn test_deserialize_incomplete_escape() {
517        // Test handling of incomplete escape sequences from our custom format
518        let json = "\"\\\\\\\\u123\""; // Escaped backslash + incomplete sequence
519        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
520        // JSON decodes \\\\u123 to \\u123, then our deserializer sees \u123 and treats
521        // it as literal
522        assert_eq!(atom.as_str(), Some("\\u123"));
523    }
524
525    #[test]
526    fn test_deserialize_invalid_hex() {
527        // Test handling of invalid hex in escape sequences from our custom format
528        let json = "\"\\\\\\\\uGGGG\""; // Escaped backslash + invalid hex
529        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
530        // JSON decodes \\\\uGGGG to \\uGGGG, then our deserializer sees \uGGGG and
531        // treats it as literal
532        assert_eq!(atom.as_str(), Some("\\uGGGG"));
533    }
534
535    #[test]
536    fn test_try_into_atom_valid_utf8() {
537        let wtf8_atom = Wtf8Atom::new("Valid UTF-8 string");
538        let result = wtf8_atom.try_into_atom();
539        assert!(result.is_ok());
540        assert_eq!(result.unwrap().as_str(), "Valid UTF-8 string");
541    }
542
543    #[test]
544    fn test_try_into_atom_invalid_utf8() {
545        // Create an atom with unpaired surrogates
546        let mut wtf8 = Wtf8Buf::new();
547        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
548        let wtf8_atom = Wtf8Atom::from(wtf8);
549
550        let result = wtf8_atom.try_into_atom();
551        assert!(result.is_err());
552        // Should return the original Wtf8Atom
553        let err_atom = result.unwrap_err();
554        assert_eq!(err_atom.to_string_lossy(), "\u{FFFD}");
555    }
556}