Skip to content

Commit

Permalink
fix(hydroflow_plus): handle send_bincode with local structs (#1151)
Browse files Browse the repository at this point in the history
fix(hydroflow_plus): handle send_bincode with local structs

Fixes #1144
  • Loading branch information
shadaj committed Apr 11, 2024
1 parent 9a89595 commit 0cafbdb
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 40 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 3 additions & 25 deletions hydroflow_plus/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use serde::de::DeserializeOwned;
use serde::Serialize;
use stageleft::{q, IntoQuotedMut, Quoted};
use syn::parse_quote;
use syn::visit_mut::VisitMut;

use crate::ir::{HfPlusLeaf, HfPlusNode, HfPlusSource};
use crate::location::{Cluster, HfSend, Location};
Expand Down Expand Up @@ -398,29 +397,10 @@ fn get_this_crate() -> TokenStream {
}
}

// TODO(shadaj): has to be public due to temporary stageleft limitations
/// Rewrites use of alloc::string::* to use std::string::*
pub struct RewriteAlloc {}
impl VisitMut for RewriteAlloc {
fn visit_path_mut(&mut self, i: &mut syn::Path) {
if i.segments.iter().take(2).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("alloc", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("string", Span::call_site())),
]
{
*i.segments.first_mut().unwrap() =
syn::PathSegment::from(syn::Ident::new("std", Span::call_site()));
}
}
}

fn serialize_bincode<T: Serialize>(is_demux: bool) -> Pipeline {
let root = get_this_crate();

// This may fail when instantiated in an environment with different deps
let mut t_type: syn::Type = syn::parse_str(std::any::type_name::<T>()).unwrap();
RewriteAlloc {}.visit_type_mut(&mut t_type);
let t_type: syn::Type = stageleft::quote_type::<T>();

if is_demux {
parse_quote! {
Expand All @@ -437,12 +417,10 @@ fn serialize_bincode<T: Serialize>(is_demux: bool) -> Pipeline {
}
}

fn deserialize_bincode<T2: DeserializeOwned>(tagged: bool) -> Pipeline {
fn deserialize_bincode<T: DeserializeOwned>(tagged: bool) -> Pipeline {
let root = get_this_crate();

// This may fail when instantiated in an environment with different deps
let mut t_type: syn::Type = syn::parse_str(std::any::type_name::<T2>()).unwrap();
RewriteAlloc {}.visit_type_mut(&mut t_type);
let t_type: syn::Type = stageleft::quote_type::<T>();

if tagged {
parse_quote! {
Expand Down
1 change: 1 addition & 0 deletions hydroflow_plus_test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tokio = { version = "1.16", features = [ "full" ] }
stageleft = { path = "../stageleft", version = "^0.2.1" }
hydroflow_plus_cli_integration = { path = "../hydro_deploy/hydroflow_plus_cli_integration", version = "^0.6.1" }
rand = "0.8.5"
serde = { version = "1", features = [ "derive" ] }

hydroflow_plus_test_macro = { path = "../hydroflow_plus_test_macro" }

Expand Down
9 changes: 8 additions & 1 deletion hydroflow_plus_test/src/first_ten.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use hydroflow_plus::*;
use serde::{Deserialize, Serialize};
use stageleft::*;

pub fn first_ten<'a, D: LocalDeploy<'a>>(
Expand All @@ -18,6 +19,11 @@ pub fn first_ten_runtime<'a>(
flow.extract().optimize_default()
}

#[derive(Serialize, Deserialize)]
struct SendOverNetwork {
pub n: u32,
}

pub fn first_ten_distributed<'a, D: Deploy<'a>>(
flow: &FlowBuilder<'a, D>,
process_spec: &impl ProcessSpec<'a, D>,
Expand All @@ -27,8 +33,9 @@ pub fn first_ten_distributed<'a, D: Deploy<'a>>(

let numbers = flow.source_iter(&process, q!(0..10));
numbers
.map(q!(|n| SendOverNetwork { n }))
.send_bincode(&second_process)
.for_each(q!(|n| println!("{}", n)));
.for_each(q!(|n: SendOverNetwork| println!("{}", n.n))); // TODO(shadaj): why is the explicit type required here?

second_process
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
---
source: hydroflow_plus_test/src/first_ten.rs
expression: builder.build().ir()
expression: built.ir()
---
[
ForEach {
f: { use crate :: __staged :: first_ten :: * ; | n | println ! ("{}" , n) },
f: { use crate :: __staged :: first_ten :: * ; | n : SendOverNetwork | println ! ("{}" , n . n) },
input: Network {
to_location: 1,
serialize_pipeline: Some(
Operator(
Operator {
path: "map",
args: [
"| data | { hydroflow_plus :: runtime_support :: bincode :: serialize :: < i32 > (& data) . unwrap () . into () }",
"| data | { hydroflow_plus :: runtime_support :: bincode :: serialize :: < hydroflow_plus_test :: first_ten :: SendOverNetwork > (& data) . unwrap () . into () }",
],
},
),
Expand All @@ -24,16 +24,19 @@ expression: builder.build().ir()
Operator {
path: "map",
args: [
"| res | { hydroflow_plus :: runtime_support :: bincode :: deserialize :: < i32 > (& res . unwrap ()) . unwrap () }",
"| res | { hydroflow_plus :: runtime_support :: bincode :: deserialize :: < hydroflow_plus_test :: first_ten :: SendOverNetwork > (& res . unwrap ()) . unwrap () }",
],
},
),
),
input: Source {
source: Iter(
{ use crate :: __staged :: first_ten :: * ; 0 .. 10 },
),
location_id: 0,
input: Map {
f: { use crate :: __staged :: first_ten :: * ; | n | SendOverNetwork { n } },
input: Source {
source: Iter(
{ use crate :: __staged :: first_ten :: * ; 0 .. 10 },
),
location_id: 0,
},
},
},
},
Expand Down
5 changes: 1 addition & 4 deletions hydroflow_plus_test_macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ edition = "2021"
proc-macro = true
path = "../hydroflow_plus_test/src/lib.rs"

[features]
default = ["macro"]
macro = []

[dependencies]
hydroflow_plus = { path = "../hydroflow_plus", version = "^0.6.1" }
tokio = { version = "1.16", features = [ "full" ] }
stageleft = { path = "../stageleft", version = "^0.2.1" }
hydroflow_plus_cli_integration = { path = "../hydro_deploy/hydroflow_plus_cli_integration", version = "^0.6.1" }
rand = "0.8.5"
serde = { version = "1", features = [ "derive" ] }

[build-dependencies]
stageleft_tool = { path = "../stageleft_tool", version = "^0.1.1" }
3 changes: 3 additions & 0 deletions stageleft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ use runtime_support::FreeVariable;

use crate::runtime_support::get_final_crate_name;

mod type_name;
pub use type_name::quote_type;

#[cfg(windows)]
#[macro_export]
macro_rules! PATH_SEPARATOR {
Expand Down
11 changes: 11 additions & 0 deletions stageleft/src/runtime_support.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::BorrowMut;
use std::marker::PhantomData;
use std::mem::MaybeUninit;

Expand Down Expand Up @@ -27,6 +28,16 @@ pub fn get_final_crate_name(crate_name: &str) -> TokenStream {
}
}

thread_local! {
pub(crate) static MACRO_TO_CRATE: std::cell::RefCell<Option<(String, String)>> = const { std::cell::RefCell::new(None) };
}

pub fn set_macro_to_crate(macro_name: &str, crate_name: &str) {
MACRO_TO_CRATE.with_borrow_mut(|cell| {
*cell.borrow_mut() = Some((macro_name.to_string(), crate_name.to_string()));
});
}

pub trait ParseFromLiteral {
fn parse_from_literal(literal: &syn::Expr) -> Self;
}
Expand Down
48 changes: 48 additions & 0 deletions stageleft/src/type_name.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use proc_macro2::Span;
use syn::parse_quote;
use syn::visit_mut::VisitMut;

use crate::runtime_support::get_final_crate_name;

/// Rewrites use of alloc::string::* to use std::string::*
struct RewriteAlloc {
mapping: Option<(String, String)>,
}

impl VisitMut for RewriteAlloc {
fn visit_path_mut(&mut self, i: &mut syn::Path) {
if i.segments.iter().take(2).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("alloc", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("string", Span::call_site())),
]
{
*i.segments.first_mut().unwrap() =
syn::PathSegment::from(syn::Ident::new("std", Span::call_site()));
} else if let Some((macro_name, final_name)) = &self.mapping {
if i.segments.first().unwrap().ident == macro_name {
*i.segments.first_mut().unwrap() =
syn::parse2(get_final_crate_name(final_name)).unwrap();

i.segments.insert(1, parse_quote!(__staged));
}
}
}
}

/// Captures a fully qualified path to a given type, which is useful when
/// the generated code needs to explicitly refer to a type known at staging time.
///
/// This API is fairly experimental, and comes with caveats. For example, it cannot
/// handle closure types. In addition, when a user refers to a re-exported type,
/// the original type path may be returned here, which could involve private modules.
///
/// Also, users must be careful to ensure that any crates referred in the type are
/// available where it is spliced.
pub fn quote_type<T>() -> syn::Type {
let mut t_type: syn::Type = syn::parse_str(std::any::type_name::<T>()).unwrap();
let mapping = super::runtime_support::MACRO_TO_CRATE.with(|m| m.borrow().clone());
RewriteAlloc { mapping }.visit_type_mut(&mut t_type);

t_type
}
5 changes: 4 additions & 1 deletion stageleft_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,14 @@ pub fn entry(

#(#param_parsing)*

let macro_crate_name = env!("CARGO_PKG_NAME");
let final_crate_name = env!("STAGELEFT_FINAL_CRATE_NAME");
#root::runtime_support::set_macro_to_crate(macro_crate_name, final_crate_name);

let output_core = {
#root::Quoted::splice(#input_name #passed_generics(#root::QuotedContext::create(), #(#params_to_pass),*))
};

let final_crate_name = env!("STAGELEFT_FINAL_CRATE_NAME");
let final_crate_root = #root::runtime_support::get_final_crate_name(final_crate_name);

let module_path: #root::internal::syn::Path = #root::internal::syn::parse_str(module_path!()).unwrap();
Expand Down

0 comments on commit 0cafbdb

Please sign in to comment.