From 0cafbdb74a665412a83aa900b4eb10c00e2498dd Mon Sep 17 00:00:00 2001 From: Shadaj Laddad Date: Thu, 11 Apr 2024 11:49:22 -0700 Subject: [PATCH] fix(hydroflow_plus): handle send_bincode with local structs (#1151) fix(hydroflow_plus): handle send_bincode with local structs Fixes #1144 --- Cargo.lock | 2 + hydroflow_plus/src/stream.rs | 28 ++--------- hydroflow_plus_test/Cargo.toml | 1 + hydroflow_plus_test/src/first_ten.rs | 9 +++- ...rst_ten__tests__first_ten_distributed.snap | 21 ++++---- hydroflow_plus_test_macro/Cargo.toml | 5 +- stageleft/src/lib.rs | 3 ++ stageleft/src/runtime_support.rs | 11 +++++ stageleft/src/type_name.rs | 48 +++++++++++++++++++ stageleft_macro/src/lib.rs | 5 +- 10 files changed, 93 insertions(+), 40 deletions(-) create mode 100644 stageleft/src/type_name.rs diff --git a/Cargo.lock b/Cargo.lock index 1331019a1a39..f5a55f93518c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1584,6 +1584,7 @@ dependencies = [ "hydroflow_plus_test_macro", "insta", "rand 0.8.5", + "serde", "stageleft", "stageleft_tool", "tokio", @@ -1596,6 +1597,7 @@ dependencies = [ "hydroflow_plus", "hydroflow_plus_cli_integration", "rand 0.8.5", + "serde", "stageleft", "stageleft_tool", "tokio", diff --git a/hydroflow_plus/src/stream.rs b/hydroflow_plus/src/stream.rs index c0260cb11229..4161ebb18d7b 100644 --- a/hydroflow_plus/src/stream.rs +++ b/hydroflow_plus/src/stream.rs @@ -13,7 +13,6 @@ use serde::de::DeserializeOwned; use serde::Serialize; use stageleft::{q, IntoQuotedMut, Quoted}; use syn::parse_quote; -use syn::visit_mut::VisitMut; use crate::ir::{HfPlusLeaf, HfPlusNode, HfPlusSource}; use crate::location::{Cluster, HfSend, Location}; @@ -398,29 +397,10 @@ fn get_this_crate() -> TokenStream { } } -// TODO(shadaj): has to be public due to temporary stageleft limitations -/// Rewrites use of alloc::string::* to use std::string::* -pub struct RewriteAlloc {} -impl VisitMut for RewriteAlloc { - fn visit_path_mut(&mut self, i: &mut syn::Path) { - if i.segments.iter().take(2).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("alloc", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("string", Span::call_site())), - ] - { - *i.segments.first_mut().unwrap() = - syn::PathSegment::from(syn::Ident::new("std", Span::call_site())); - } - } -} - fn serialize_bincode(is_demux: bool) -> Pipeline { let root = get_this_crate(); - // This may fail when instantiated in an environment with different deps - let mut t_type: syn::Type = syn::parse_str(std::any::type_name::()).unwrap(); - RewriteAlloc {}.visit_type_mut(&mut t_type); + let t_type: syn::Type = stageleft::quote_type::(); if is_demux { parse_quote! { @@ -437,12 +417,10 @@ fn serialize_bincode(is_demux: bool) -> Pipeline { } } -fn deserialize_bincode(tagged: bool) -> Pipeline { +fn deserialize_bincode(tagged: bool) -> Pipeline { let root = get_this_crate(); - // This may fail when instantiated in an environment with different deps - let mut t_type: syn::Type = syn::parse_str(std::any::type_name::()).unwrap(); - RewriteAlloc {}.visit_type_mut(&mut t_type); + let t_type: syn::Type = stageleft::quote_type::(); if tagged { parse_quote! { diff --git a/hydroflow_plus_test/Cargo.toml b/hydroflow_plus_test/Cargo.toml index b89ef797ffbb..b8a44cf36752 100644 --- a/hydroflow_plus_test/Cargo.toml +++ b/hydroflow_plus_test/Cargo.toml @@ -10,6 +10,7 @@ tokio = { version = "1.16", features = [ "full" ] } stageleft = { path = "../stageleft", version = "^0.2.1" } hydroflow_plus_cli_integration = { path = "../hydro_deploy/hydroflow_plus_cli_integration", version = "^0.6.1" } rand = "0.8.5" +serde = { version = "1", features = [ "derive" ] } hydroflow_plus_test_macro = { path = "../hydroflow_plus_test_macro" } diff --git a/hydroflow_plus_test/src/first_ten.rs b/hydroflow_plus_test/src/first_ten.rs index aeaa78e753fe..5d436d412179 100644 --- a/hydroflow_plus_test/src/first_ten.rs +++ b/hydroflow_plus_test/src/first_ten.rs @@ -1,4 +1,5 @@ use hydroflow_plus::*; +use serde::{Deserialize, Serialize}; use stageleft::*; pub fn first_ten<'a, D: LocalDeploy<'a>>( @@ -18,6 +19,11 @@ pub fn first_ten_runtime<'a>( flow.extract().optimize_default() } +#[derive(Serialize, Deserialize)] +struct SendOverNetwork { + pub n: u32, +} + pub fn first_ten_distributed<'a, D: Deploy<'a>>( flow: &FlowBuilder<'a, D>, process_spec: &impl ProcessSpec<'a, D>, @@ -27,8 +33,9 @@ pub fn first_ten_distributed<'a, D: Deploy<'a>>( let numbers = flow.source_iter(&process, q!(0..10)); numbers + .map(q!(|n| SendOverNetwork { n })) .send_bincode(&second_process) - .for_each(q!(|n| println!("{}", n))); + .for_each(q!(|n: SendOverNetwork| println!("{}", n.n))); // TODO(shadaj): why is the explicit type required here? second_process } diff --git a/hydroflow_plus_test/src/snapshots/hydroflow_plus_test__first_ten__tests__first_ten_distributed.snap b/hydroflow_plus_test/src/snapshots/hydroflow_plus_test__first_ten__tests__first_ten_distributed.snap index ec1a123d1ecf..dee1f4ef73a8 100644 --- a/hydroflow_plus_test/src/snapshots/hydroflow_plus_test__first_ten__tests__first_ten_distributed.snap +++ b/hydroflow_plus_test/src/snapshots/hydroflow_plus_test__first_ten__tests__first_ten_distributed.snap @@ -1,10 +1,10 @@ --- source: hydroflow_plus_test/src/first_ten.rs -expression: builder.build().ir() +expression: built.ir() --- [ ForEach { - f: { use crate :: __staged :: first_ten :: * ; | n | println ! ("{}" , n) }, + f: { use crate :: __staged :: first_ten :: * ; | n : SendOverNetwork | println ! ("{}" , n . n) }, input: Network { to_location: 1, serialize_pipeline: Some( @@ -12,7 +12,7 @@ expression: builder.build().ir() Operator { path: "map", args: [ - "| data | { hydroflow_plus :: runtime_support :: bincode :: serialize :: < i32 > (& data) . unwrap () . into () }", + "| data | { hydroflow_plus :: runtime_support :: bincode :: serialize :: < hydroflow_plus_test :: first_ten :: SendOverNetwork > (& data) . unwrap () . into () }", ], }, ), @@ -24,16 +24,19 @@ expression: builder.build().ir() Operator { path: "map", args: [ - "| res | { hydroflow_plus :: runtime_support :: bincode :: deserialize :: < i32 > (& res . unwrap ()) . unwrap () }", + "| res | { hydroflow_plus :: runtime_support :: bincode :: deserialize :: < hydroflow_plus_test :: first_ten :: SendOverNetwork > (& res . unwrap ()) . unwrap () }", ], }, ), ), - input: Source { - source: Iter( - { use crate :: __staged :: first_ten :: * ; 0 .. 10 }, - ), - location_id: 0, + input: Map { + f: { use crate :: __staged :: first_ten :: * ; | n | SendOverNetwork { n } }, + input: Source { + source: Iter( + { use crate :: __staged :: first_ten :: * ; 0 .. 10 }, + ), + location_id: 0, + }, }, }, }, diff --git a/hydroflow_plus_test_macro/Cargo.toml b/hydroflow_plus_test_macro/Cargo.toml index ece6e63cde08..b6095b68c382 100644 --- a/hydroflow_plus_test_macro/Cargo.toml +++ b/hydroflow_plus_test_macro/Cargo.toml @@ -8,16 +8,13 @@ edition = "2021" proc-macro = true path = "../hydroflow_plus_test/src/lib.rs" -[features] -default = ["macro"] -macro = [] - [dependencies] hydroflow_plus = { path = "../hydroflow_plus", version = "^0.6.1" } tokio = { version = "1.16", features = [ "full" ] } stageleft = { path = "../stageleft", version = "^0.2.1" } hydroflow_plus_cli_integration = { path = "../hydro_deploy/hydroflow_plus_cli_integration", version = "^0.6.1" } rand = "0.8.5" +serde = { version = "1", features = [ "derive" ] } [build-dependencies] stageleft_tool = { path = "../stageleft_tool", version = "^0.1.1" } diff --git a/stageleft/src/lib.rs b/stageleft/src/lib.rs index 9b1eaa74af51..7dc28eccc336 100644 --- a/stageleft/src/lib.rs +++ b/stageleft/src/lib.rs @@ -20,6 +20,9 @@ use runtime_support::FreeVariable; use crate::runtime_support::get_final_crate_name; +mod type_name; +pub use type_name::quote_type; + #[cfg(windows)] #[macro_export] macro_rules! PATH_SEPARATOR { diff --git a/stageleft/src/runtime_support.rs b/stageleft/src/runtime_support.rs index 72b6e4506b28..fdd8d6e0ec01 100644 --- a/stageleft/src/runtime_support.rs +++ b/stageleft/src/runtime_support.rs @@ -1,3 +1,4 @@ +use std::borrow::BorrowMut; use std::marker::PhantomData; use std::mem::MaybeUninit; @@ -27,6 +28,16 @@ pub fn get_final_crate_name(crate_name: &str) -> TokenStream { } } +thread_local! { + pub(crate) static MACRO_TO_CRATE: std::cell::RefCell> = const { std::cell::RefCell::new(None) }; +} + +pub fn set_macro_to_crate(macro_name: &str, crate_name: &str) { + MACRO_TO_CRATE.with_borrow_mut(|cell| { + *cell.borrow_mut() = Some((macro_name.to_string(), crate_name.to_string())); + }); +} + pub trait ParseFromLiteral { fn parse_from_literal(literal: &syn::Expr) -> Self; } diff --git a/stageleft/src/type_name.rs b/stageleft/src/type_name.rs new file mode 100644 index 000000000000..36e81c8f6ec0 --- /dev/null +++ b/stageleft/src/type_name.rs @@ -0,0 +1,48 @@ +use proc_macro2::Span; +use syn::parse_quote; +use syn::visit_mut::VisitMut; + +use crate::runtime_support::get_final_crate_name; + +/// Rewrites use of alloc::string::* to use std::string::* +struct RewriteAlloc { + mapping: Option<(String, String)>, +} + +impl VisitMut for RewriteAlloc { + fn visit_path_mut(&mut self, i: &mut syn::Path) { + if i.segments.iter().take(2).collect::>() + == vec![ + &syn::PathSegment::from(syn::Ident::new("alloc", Span::call_site())), + &syn::PathSegment::from(syn::Ident::new("string", Span::call_site())), + ] + { + *i.segments.first_mut().unwrap() = + syn::PathSegment::from(syn::Ident::new("std", Span::call_site())); + } else if let Some((macro_name, final_name)) = &self.mapping { + if i.segments.first().unwrap().ident == macro_name { + *i.segments.first_mut().unwrap() = + syn::parse2(get_final_crate_name(final_name)).unwrap(); + + i.segments.insert(1, parse_quote!(__staged)); + } + } + } +} + +/// Captures a fully qualified path to a given type, which is useful when +/// the generated code needs to explicitly refer to a type known at staging time. +/// +/// This API is fairly experimental, and comes with caveats. For example, it cannot +/// handle closure types. In addition, when a user refers to a re-exported type, +/// the original type path may be returned here, which could involve private modules. +/// +/// Also, users must be careful to ensure that any crates referred in the type are +/// available where it is spliced. +pub fn quote_type() -> syn::Type { + let mut t_type: syn::Type = syn::parse_str(std::any::type_name::()).unwrap(); + let mapping = super::runtime_support::MACRO_TO_CRATE.with(|m| m.borrow().clone()); + RewriteAlloc { mapping }.visit_type_mut(&mut t_type); + + t_type +} diff --git a/stageleft_macro/src/lib.rs b/stageleft_macro/src/lib.rs index f2de429794e7..b809e618939c 100644 --- a/stageleft_macro/src/lib.rs +++ b/stageleft_macro/src/lib.rs @@ -362,11 +362,14 @@ pub fn entry( #(#param_parsing)* + let macro_crate_name = env!("CARGO_PKG_NAME"); + let final_crate_name = env!("STAGELEFT_FINAL_CRATE_NAME"); + #root::runtime_support::set_macro_to_crate(macro_crate_name, final_crate_name); + let output_core = { #root::Quoted::splice(#input_name #passed_generics(#root::QuotedContext::create(), #(#params_to_pass),*)) }; - let final_crate_name = env!("STAGELEFT_FINAL_CRATE_NAME"); let final_crate_root = #root::runtime_support::get_final_crate_name(final_crate_name); let module_path: #root::internal::syn::Path = #root::internal::syn::parse_str(module_path!()).unwrap();