Skip to content

MVP serialize #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ default = []
testing = []

[dependencies]
bincode = "1.3.3"
comemo-macros = { workspace = true }
once_cell = { workspace = true }
parking_lot = { workspace = true }
serde = { version = "1.0.197", features = ["serde_derive"] }
siphasher = { workspace = true }

[dev-dependencies]
Expand Down
15 changes: 11 additions & 4 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use syn::{parse_quote, Error, Result};
/// arguments.
///
/// # Example
/// ```
/// ```ignore
/// /// Evaluate a `.calc` script.
/// #[comemo::memoize]
/// fn evaluate(script: &str, files: comemo::Tracked<Files>) -> i32 {
Expand All @@ -77,9 +77,16 @@ use syn::{parse_quote, Error, Result};
/// ```
///
#[proc_macro_attribute]
pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
pub fn memoize(attr: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
let func = syn::parse_macro_input!(stream as syn::Item);
memoize::expand(&func)
let serializable = match attr.to_string().as_str() {
"serialize" => true,
"" => false,
invalid => {
panic!("invalid attribute: {invalid}\nvalid attributes is `serialize`")
}
};
memoize::expand(&func, serializable)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
Expand Down Expand Up @@ -133,7 +140,7 @@ pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
/// - They cannot use destructuring patterns in their arguments.
///
/// # Example
/// ```
/// ```ignore
/// /// File storage.
/// struct Files(HashMap<PathBuf, String>);
///
Expand Down
36 changes: 32 additions & 4 deletions macros/src/memoize.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use super::*;

/// Memoize a function.
pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
pub fn expand(item: &syn::Item, serializable: bool) -> Result<proc_macro2::TokenStream> {
let syn::Item::Fn(item) = item else {
bail!(item, "`memoize` can only be applied to functions and methods");
};

// Preprocess and validate the function.
let function = prepare(item)?;
let function = prepare(item, serializable)?;

// Rewrite the function's body to memoize it.
process(&function)
Expand All @@ -18,6 +18,7 @@ struct Function {
item: syn::ItemFn,
args: Vec<Argument>,
output: syn::Type,
serializable: bool,
}

/// An argument to a memoized function.
Expand All @@ -27,7 +28,7 @@ enum Argument {
}

/// Preprocess and validate a function.
fn prepare(function: &syn::ItemFn) -> Result<Function> {
fn prepare(function: &syn::ItemFn, serializable: bool) -> Result<Function> {
let mut args = vec![];

for input in &function.sig.inputs {
Expand All @@ -39,7 +40,7 @@ fn prepare(function: &syn::ItemFn) -> Result<Function> {
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
};

Ok(Function { item: function.clone(), args, output })
Ok(Function { item: function.clone(), args, output, serializable })
}

/// Preprocess a function argument.
Expand Down Expand Up @@ -79,6 +80,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result<Argument> {
/// Rewrite a function's body to memoize it.
fn process(function: &Function) -> Result<TokenStream> {
// Construct assertions that the arguments fulfill the necessary bounds.
// todo: assert serializable/deserializable somehow?
let bounds = function.args.iter().map(|arg| {
let val = match arg {
Argument::Receiver(token) => quote! { #token },
Expand Down Expand Up @@ -124,12 +126,38 @@ fn process(function: &Function) -> Result<TokenStream> {
ident.mutability = None;
}

let serialization = if function.serializable {
Some(quote! {
static __UNIQUE_PATH: ::comemo::internal::once_cell::sync::Lazy<u128> =
once_cell::sync::Lazy::new(|| {
::comemo::internal::hash(
&format!("{}-{}-{}-{}", module_path!(), file!(), line!(), column!())
)
});

::comemo::internal::register_serializer(|| {
(
__UNIQUE_PATH.clone(),
::comemo::internal::bincode::serialize(&*__CACHE.inner().read())
.unwrap_or_default(),
)
});

::comemo::internal::register_loader(__UNIQUE_PATH.clone(), |data| {
*__CACHE.inner().write() =
::comemo::internal::bincode::deserialize(data).unwrap_or_default();
});
})
} else {
None
};
wrapped.block = parse_quote! { {
static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
#serialization
::core::default::Default::default()
});

Expand Down
7 changes: 7 additions & 0 deletions src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};

use once_cell::sync::Lazy;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use siphasher::sip128::{Hasher128, SipHasher13};

use crate::accelerate;
Expand Down Expand Up @@ -115,8 +116,13 @@ impl<C: 'static, Out: 'static> Cache<C, Out> {
pub fn evict(&self, max_age: usize) {
self.0.write().evict(max_age)
}
#[doc(hidden)]
pub fn inner(&self) -> &Lazy<RwLock<CacheData<C, Out>>> {
&self.0
}
}

#[derive(Serialize, Deserialize)]
/// The internal data for a cache.
pub struct CacheData<C, Out> {
/// Maps from hashes to memoized results.
Expand Down Expand Up @@ -166,6 +172,7 @@ impl<C, Out> Default for CacheData<C, Out> {
}
}

#[derive(Serialize, Deserialize)]
/// A memoized result.
struct CacheEntry<C, Out> {
/// The memoized function's constraint.
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ mod cache;
mod constraint;
mod input;
mod prehashed;
mod serialization;
mod track;

pub use crate::cache::evict;
pub use crate::prehashed::Prehashed;
pub use crate::track::{Track, Tracked, TrackedMut, Validate};
pub use comemo_macros::{memoize, track};
pub use serialization::{deserialize, serialize};

/// These are implementation details. Do not rely on them!
#[doc(hidden)]
Expand All @@ -102,6 +104,9 @@ pub mod internal {
pub use crate::cache::{memoized, register_evictor, Cache, CacheData};
pub use crate::constraint::{hash, Call, ImmutableConstraint, MutableConstraint};
pub use crate::input::{assert_hashable_or_trackable, Args, Input};
pub use crate::serialization::{
bincode, once_cell, register_loader, register_serializer,
};
pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces};

#[cfg(feature = "testing")]
Expand Down
45 changes: 45 additions & 0 deletions src/serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
pub use {bincode, once_cell};

use once_cell::sync::Lazy;
use parking_lot::RwLock;
use std::{collections::HashMap, ops::Deref};

static SERIALIZERS: RwLock<Vec<fn() -> (u128, Vec<u8>)>> = RwLock::new(vec![]);
static LOADERS: RwLock<Lazy<HashMap<u128, fn(&[u8])>>> =
RwLock::new(Lazy::new(|| HashMap::new()));

pub fn register_serializer(serializer: fn() -> (u128, Vec<u8>)) {
SERIALIZERS.write().push(serializer)
}

pub fn register_loader(unique_id: u128, loader: fn(&[u8])) {
let conflict = LOADERS.write().insert(unique_id, loader);
debug_assert!(conflict.is_none())
}

/// returns a blob of binary data that you may load with [comemo::serialization::deserialize](deserialize).
pub fn serialize() -> bincode::Result<Vec<u8>> {
bincode::serialize(
&SERIALIZERS
.read()
.deref()
.iter()
.map(|f| f())
.collect::<HashMap<u128, Vec<u8>>>(),
)
}

/// Errors if data is invalid, in which case the function simply returns without filling the caches.
pub fn deserialize(data: Vec<u8>) -> Result<(), ()> {
let Ok(data) = bincode::deserialize::<HashMap<u128, Vec<u8>>>(&data) else {
return Err(());
};

for (function_name, load_function) in LOADERS.read().deref().iter() {
let Some(data) = data.get(function_name) else {
continue;
};
load_function(data)
}
Ok(())
}
30 changes: 30 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::collections::HashMap;
use std::hash::Hash;
use std::ops::Deref;
use std::path::{Path, PathBuf};

use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate};
Expand Down Expand Up @@ -452,3 +453,32 @@ impl Impure {
VAL.fetch_add(1, Ordering::SeqCst)
}
}

#[test]
#[serial]
fn feature_store_all() {
#[memoize(serialize)]
fn empty() -> String {
format!("The world is {}", "big")
}

#[memoize(serialize)]
fn sum(a: i32, b: i32) -> i32 {
a + b
}

test!(miss: empty(), "The world is big");
test!(hit: empty(), "The world is big");

test!(miss: sum(1, 2), 3);
test!(hit: sum(1, 2), 3);

let data = comemo::serialize().unwrap();
comemo::evict(0);
test!(miss: empty(), "The world is big");
test!(miss: sum(1, 2), 3);

comemo::deserialize(data).unwrap();
test!(hit: empty() , "The world is big");
test!(hit: sum(1, 2), 3);
}
Loading