Skip to content

Commit 972e300

Browse files
authored
Added enabled attribute (#8)
1 parent 1275982 commit 972e300

File tree

5 files changed

+147
-19
lines changed

5 files changed

+147
-19
lines changed

macros/src/lib.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ macro_rules! bail {
1212

1313
mod memoize;
1414
mod track;
15+
mod utils;
1516

1617
use proc_macro::TokenStream as BoundaryStream;
1718
use proc_macro2::TokenStream;
@@ -76,10 +77,35 @@ use syn::{parse_quote, Error, Result};
7677
/// }
7778
/// ```
7879
///
80+
/// # Disabling memoization conditionally
81+
/// If you want to enable or disable memoization for a function conditionally,
82+
/// you can use the `enabled` attribute. This is useful for cheap function calls
83+
/// where dealing with the caching is more expensive than recomputing the
84+
/// result. This allows you to bypass hashing and constraint validation while
85+
/// still dealing with the same function signature. And allows saving memory and
86+
/// time.
87+
///
88+
/// By default, all functions are unconditionally memoized. To disable
89+
/// memoization conditionally, you must specify an `enabled = <expr>` attribute.
90+
/// The expression can use the parameters and must evaluate to a boolean value.
91+
/// If the expression is `false`, the function will be executed without hashing
92+
/// and caching.
93+
///
94+
/// ## Example
95+
/// ```
96+
/// /// Compute the sum of a slice of floats, but only memoize if the slice is
97+
/// /// longer than 1024 elements.
98+
/// #[comemo::memoize(enabled = add.len() > 1024)]
99+
/// fn evaluate(add: &[f32]) -> f32 {
100+
/// add.iter().copied().sum()
101+
/// }
102+
/// ```
103+
///
79104
#[proc_macro_attribute]
80-
pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
105+
pub fn memoize(args: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
106+
let args = syn::parse_macro_input!(args as TokenStream);
81107
let func = syn::parse_macro_input!(stream as syn::Item);
82-
memoize::expand(&func)
108+
memoize::expand(args, &func)
83109
.unwrap_or_else(|err| err.to_compile_error())
84110
.into()
85111
}

macros/src/memoize.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
use utils::parse_key_value;
2+
13
use super::*;
24

35
/// Memoize a function.
4-
pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
6+
pub fn expand(attrs: TokenStream, item: &syn::Item) -> Result<proc_macro2::TokenStream> {
57
let syn::Item::Fn(item) = item else {
68
bail!(item, "`memoize` can only be applied to functions and methods");
79
};
810

911
// Preprocess and validate the function.
10-
let function = prepare(item)?;
12+
let function = prepare(attrs, item)?;
1113

1214
// Rewrite the function's body to memoize it.
1315
process(&function)
@@ -18,6 +20,18 @@ struct Function {
1820
item: syn::ItemFn,
1921
args: Vec<Argument>,
2022
output: syn::Type,
23+
enabled: Option<syn::Expr>,
24+
}
25+
26+
/// Additional metadata for a memoized function.
27+
struct Meta {
28+
enabled: Option<syn::Expr>,
29+
}
30+
31+
impl syn::parse::Parse for Meta {
32+
fn parse(input: syn::parse::ParseStream) -> Result<Self> {
33+
Ok(Self { enabled: parse_key_value::<kw::enabled, _>(input)? })
34+
}
2135
}
2236

2337
/// An argument to a memoized function.
@@ -27,9 +41,10 @@ enum Argument {
2741
}
2842

2943
/// Preprocess and validate a function.
30-
fn prepare(function: &syn::ItemFn) -> Result<Function> {
31-
let mut args = vec![];
44+
fn prepare(attrs: TokenStream, function: &syn::ItemFn) -> Result<Function> {
45+
let meta = syn::parse2::<Meta>(attrs.clone())?;
3246

47+
let mut args = vec![];
3348
for input in &function.sig.inputs {
3449
args.push(prepare_arg(input)?);
3550
}
@@ -39,7 +54,12 @@ fn prepare(function: &syn::ItemFn) -> Result<Function> {
3954
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
4055
};
4156

42-
Ok(Function { item: function.clone(), args, output })
57+
Ok(Function {
58+
item: function.clone(),
59+
args,
60+
output,
61+
enabled: meta.enabled,
62+
})
4363
}
4464

4565
/// Preprocess a function argument.
@@ -124,6 +144,8 @@ fn process(function: &Function) -> Result<TokenStream> {
124144
ident.mutability = None;
125145
}
126146

147+
let enabled = function.enabled.clone().unwrap_or(parse_quote! { true });
148+
127149
wrapped.block = parse_quote! { {
128150
static __CACHE: ::comemo::internal::Cache<
129151
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
@@ -134,13 +156,19 @@ fn process(function: &Function) -> Result<TokenStream> {
134156
});
135157

136158
#(#bounds;)*
159+
137160
::comemo::internal::memoized(
138161
::comemo::internal::Args(#arg_tuple),
139162
&::core::default::Default::default(),
140163
&__CACHE,
164+
#enabled,
141165
#closure,
142166
)
143167
} };
144168

145169
Ok(quote! { #wrapped })
146170
}
171+
172+
pub mod kw {
173+
syn::custom_keyword!(enabled);
174+
}

macros/src/utils.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use syn::{
2+
parse::{Parse, ParseStream},
3+
token::Token,
4+
};
5+
6+
use super::*;
7+
8+
/// Parse a metadata key-value pair, separated by `=`.
9+
pub fn parse_key_value<K: Token + Default + Parse, V: Parse>(
10+
input: ParseStream,
11+
) -> Result<Option<V>> {
12+
if !input.peek(|_| K::default()) {
13+
return Ok(None);
14+
}
15+
16+
let _: K = input.parse()?;
17+
let _: syn::Token![=] = input.parse()?;
18+
let value: V = input.parse::<V>()?;
19+
eat_comma(input);
20+
Ok(Some(value))
21+
}
22+
23+
/// Parse a comma if there is one.
24+
pub fn eat_comma(input: ParseStream) {
25+
if input.peek(syn::Token![,]) {
26+
let _: syn::Token![,] = input.parse().unwrap();
27+
}
28+
}

src/cache.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,20 @@ pub fn memoized<'c, In, Out, F>(
2323
mut input: In,
2424
constraint: &'c In::Constraint,
2525
cache: &Cache<In::Constraint, Out>,
26+
enabled: bool,
2627
func: F,
2728
) -> Out
2829
where
2930
In: Input + 'c,
3031
Out: Clone + 'static,
3132
F: FnOnce(In::Tracked<'c>) -> Out,
3233
{
34+
// Early bypass if memoization is disabled.
35+
// Hopefully the compiler will optimize this away, if the condition is constant.
36+
if !enabled {
37+
return memoized_disabled(input, constraint, func);
38+
}
39+
3340
// Compute the hash of the input's key part.
3441
let key = {
3542
let mut state = SipHasher13::new();
@@ -73,6 +80,30 @@ where
7380
output
7481
}
7582

83+
fn memoized_disabled<'c, In, Out, F>(
84+
input: In,
85+
constraint: &'c In::Constraint,
86+
func: F,
87+
) -> Out
88+
where
89+
In: Input + 'c,
90+
Out: Clone + 'static,
91+
F: FnOnce(In::Tracked<'c>) -> Out,
92+
{
93+
// Execute the function with the new constraints hooked in.
94+
let (input, outer) = input.retrack(constraint);
95+
let output = func(input);
96+
97+
// Add the new constraints to the outer ones.
98+
outer.join(constraint);
99+
100+
// Ensure that the last call was a miss during testing.
101+
#[cfg(feature = "testing")]
102+
LAST_WAS_HIT.with(|cell| cell.set(false));
103+
104+
output
105+
}
106+
76107
/// Evict the global cache.
77108
///
78109
/// This removes all memoized results from the cache whose age is larger than or

tests/tests.rs

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,21 @@ fn test_basic() {
7373
test!(hit: sum_iter(1000), 499500);
7474
}
7575

76+
#[memoize]
77+
fn evaluate(script: &str, files: Tracked<Files>) -> i32 {
78+
script
79+
.split('+')
80+
.map(str::trim)
81+
.map(|part| match part.strip_prefix("eval ") {
82+
Some(path) => evaluate(&files.read(path), files),
83+
None => part.parse::<i32>().unwrap(),
84+
})
85+
.sum()
86+
}
7687
/// Test the calc language.
7788
#[test]
7889
#[serial]
7990
fn test_calc() {
80-
#[memoize]
81-
fn evaluate(script: &str, files: Tracked<Files>) -> i32 {
82-
script
83-
.split('+')
84-
.map(str::trim)
85-
.map(|part| match part.strip_prefix("eval ") {
86-
Some(path) => evaluate(&files.read(path), files),
87-
None => part.parse::<i32>().unwrap(),
88-
})
89-
.sum()
90-
}
91-
9291
let mut files = Files(HashMap::new());
9392
files.write("alpha.calc", "2 + eval beta.calc");
9493
files.write("beta.calc", "2 + 3");
@@ -452,3 +451,19 @@ impl Impure {
452451
VAL.fetch_add(1, Ordering::SeqCst)
453452
}
454453
}
454+
455+
#[test]
456+
#[serial]
457+
#[cfg(debug_assertions)]
458+
fn test_with_disabled() {
459+
#[comemo::memoize(enabled = size >= 1000)]
460+
fn disabled(size: usize) -> usize {
461+
size
462+
}
463+
464+
test!(miss: disabled(0), 0);
465+
test!(miss: disabled(0), 0);
466+
467+
test!(miss: disabled(2000), 2000);
468+
test!(hit: disabled(2000), 2000);
469+
}

0 commit comments

Comments
 (0)