reth_codecs_derive/compact/
generator.rs

1//! Code generator for the `Compact` trait.
2
3use super::*;
4use crate::ZstdConfig;
5use convert_case::{Case, Casing};
6use syn::{Attribute, LitStr};
7
8/// Generates code to implement the `Compact` trait for a data type.
9pub fn generate_from_to(
10    ident: &Ident,
11    attrs: &[Attribute],
12    has_lifetime: bool,
13    fields: &FieldList,
14    zstd: Option<ZstdConfig>,
15) -> TokenStream2 {
16    let flags = format_ident!("{ident}Flags");
17
18    let reth_codecs = parse_reth_codecs_path(attrs).unwrap();
19
20    let to_compact = generate_to_compact(fields, ident, zstd.clone(), &reth_codecs);
21    let from_compact = generate_from_compact(fields, ident, zstd);
22
23    let snake_case_ident = ident.to_string().to_case(Case::Snake);
24
25    let fuzz = format_ident!("fuzz_test_{snake_case_ident}");
26    let test = format_ident!("fuzz_{snake_case_ident}");
27
28    let lifetime = if has_lifetime {
29        quote! { 'a }
30    } else {
31        quote! {}
32    };
33
34    let impl_compact = if has_lifetime {
35        quote! {
36           impl<#lifetime> #reth_codecs::Compact for #ident<#lifetime>
37        }
38    } else {
39        quote! {
40           impl #reth_codecs::Compact for #ident
41        }
42    };
43
44    let has_ref_fields = fields.iter().any(|field| {
45        if let FieldTypes::StructField(field) = field {
46            field.is_reference
47        } else {
48            false
49        }
50    });
51
52    let fn_from_compact = if has_ref_fields {
53        quote! { unimplemented!("from_compact not supported with ref structs") }
54    } else {
55        quote! {
56            let (flags, mut buf) = #flags::from(buf);
57            #from_compact
58        }
59    };
60
61    let fuzz_tests = if has_lifetime {
62        quote! {}
63    } else {
64        quote! {
65            #[cfg(test)]
66            #[expect(dead_code)]
67            #[test_fuzz::test_fuzz]
68            fn #fuzz(obj: #ident)  {
69                use #reth_codecs::Compact;
70                let mut buf = vec![];
71                let len = obj.clone().to_compact(&mut buf);
72                let (same_obj, buf) = #ident::from_compact(buf.as_ref(), len);
73                assert_eq!(obj, same_obj);
74            }
75
76            #[test]
77            #[expect(missing_docs)]
78            pub fn #test() {
79                #fuzz(#ident::default())
80            }
81        }
82    };
83
84    // Build function
85    quote! {
86        #fuzz_tests
87
88        #impl_compact {
89            fn to_compact<B>(&self, buf: &mut B) -> usize where B: #reth_codecs::__private::bytes::BufMut + AsMut<[u8]> {
90                let mut flags = #flags::default();
91                let mut total_length = 0;
92                #(#to_compact)*
93                total_length
94            }
95
96            fn from_compact(mut buf: &[u8], len: usize) -> (Self, &[u8]) {
97                #fn_from_compact
98            }
99        }
100    }
101}
102
103/// Generates code to implement the `Compact` trait method `to_compact`.
104fn generate_from_compact(
105    fields: &FieldList,
106    ident: &Ident,
107    zstd: Option<ZstdConfig>,
108) -> TokenStream2 {
109    let mut lines = vec![];
110    let mut known_types = vec![
111        "B256",
112        "Address",
113        "Bloom",
114        "Vec",
115        "TxHash",
116        "BlockHash",
117        "FixedBytes",
118        "Cow",
119        "TxSeismicElements",
120    ];
121
122    // Only types without `Bytes` should be added here. It's currently manually added, since
123    // it's hard to figure out with derive_macro which types have Bytes fields.
124    //
125    // This removes the requirement of the field to be placed last in the struct.
126    known_types.extend_from_slice(&["TxKind", "AccessList", "Signature", "CheckpointBlockRange"]);
127
128    // let mut handle = FieldListHandler::new(fields);
129    let is_enum = fields.iter().any(|field| matches!(field, FieldTypes::EnumVariant(_)));
130
131    if is_enum {
132        let enum_lines = EnumHandler::new(fields).generate_from(ident);
133
134        // Builds the object instantiation.
135        lines.push(quote! {
136            let obj = match flags.variant() {
137                #(#enum_lines)*
138                _ => unreachable!()
139            };
140        });
141    } else {
142        let mut struct_handler = StructHandler::new(fields);
143        lines.append(&mut struct_handler.generate_from(known_types.as_slice()));
144
145        // Builds the object instantiation.
146        if struct_handler.is_wrapper {
147            lines.push(quote! {
148                let obj = #ident(placeholder);
149            });
150        } else {
151            let fields = fields.iter().filter_map(|field| {
152                if let FieldTypes::StructField(field) = field {
153                    let ident = format_ident!("{}", field.name);
154                    return Some(quote! {
155                        #ident: #ident,
156                    })
157                }
158                None
159            });
160
161            lines.push(quote! {
162                let obj = #ident {
163                    #(#fields)*
164                };
165            });
166        }
167    }
168
169    // If the type has compression support, then check the `__zstd` flag. Otherwise, use the default
170    // code branch. However, even if it's a type with compression support, not all values are
171    // to be compressed (thus the zstd flag). Ideally only the bigger ones.
172    if let Some(zstd) = zstd {
173        let decompressor = zstd.decompressor;
174        quote! {
175            if flags.__zstd() != 0 {
176                #decompressor.with(|decompressor| {
177                    let decompressor = &mut decompressor.borrow_mut();
178                    let decompressed = decompressor.decompress(buf);
179                    let mut original_buf = buf;
180
181                    let mut buf: &[u8] = decompressed;
182                    #(#lines)*
183                    (obj, original_buf)
184                })
185            } else {
186                #(#lines)*
187                (obj, buf)
188            }
189        }
190    } else {
191        quote! {
192            #(#lines)*
193            (obj, buf)
194        }
195    }
196}
197
198/// Generates code to implement the `Compact` trait method `from_compact`.
199fn generate_to_compact(
200    fields: &FieldList,
201    ident: &Ident,
202    zstd: Option<ZstdConfig>,
203    reth_codecs: &syn::Path,
204) -> Vec<TokenStream2> {
205    let mut lines = vec![quote! {
206        let mut buffer = #reth_codecs::__private::bytes::BytesMut::new();
207    }];
208
209    let is_enum = fields.iter().any(|field| matches!(field, FieldTypes::EnumVariant(_)));
210
211    if is_enum {
212        let enum_lines = EnumHandler::new(fields).generate_to(ident);
213
214        lines.push(quote! {
215            flags.set_variant(match self {
216                #(#enum_lines)*
217            });
218        })
219    } else {
220        lines.append(&mut StructHandler::new(fields).generate_to());
221    }
222
223    // Just because a type supports compression, doesn't mean all its values are to be compressed.
224    // We skip the smaller ones, and thus require a flag` __zstd` to specify if this value is
225    // compressed or not.
226    if zstd.is_some() {
227        lines.push(quote! {
228            let mut zstd = buffer.len() > 7;
229            if zstd {
230                flags.set___zstd(1);
231            }
232        });
233    }
234
235    // Places the flag bits.
236    lines.push(quote! {
237        let flags = flags.into_bytes();
238        total_length += flags.len() + buffer.len();
239        buf.put_slice(&flags);
240    });
241
242    if let Some(zstd) = zstd {
243        let compressor = zstd.compressor;
244        lines.push(quote! {
245            if zstd {
246                #compressor.with(|compressor| {
247                    let mut compressor = compressor.borrow_mut();
248
249                    let compressed = compressor.compress(&buffer).expect("Failed to compress.");
250                    buf.put(compressed.as_slice());
251                });
252            } else {
253                buf.put(buffer);
254            }
255        });
256    } else {
257        lines.push(quote! {
258            buf.put(buffer);
259        })
260    }
261
262    lines
263}
264
265/// Function to extract the crate path from `reth_codecs(crate = "...")` attribute.
266pub(crate) fn parse_reth_codecs_path(attrs: &[Attribute]) -> syn::Result<syn::Path> {
267    // let default_crate_path: syn::Path = syn::parse_str("reth-codecs").unwrap();
268    let mut reth_codecs_path: syn::Path = syn::parse_quote!(reth_codecs);
269    for attr in attrs {
270        if attr.path().is_ident("reth_codecs") {
271            attr.parse_nested_meta(|meta| {
272                if meta.path.is_ident("crate") {
273                    let value = meta.value()?;
274                    let lit: LitStr = value.parse()?;
275                    reth_codecs_path = syn::parse_str(&lit.value())?;
276                    Ok(())
277                } else {
278                    Err(meta.error("unsupported attribute"))
279                }
280            })?;
281        }
282    }
283
284    Ok(reth_codecs_path)
285}