Skip to content

Commit

Permalink
Merge pull request #2 from junhaideng/main
Browse files Browse the repository at this point in the history
feat: remove async-trait and update deps
  • Loading branch information
Millione authored Nov 6, 2024
2 parents 96b50f1 + c7ddd5c commit 3866af5
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 227 deletions.
362 changes: 169 additions & 193 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "static-graph"
version = "0.2.1"
version = "0.3.0"
edition = "2021"
authors = ["Volo Team <[email protected]>"]
description = "Generate static parallel computation graph from DSL at compile time"
Expand All @@ -12,11 +12,10 @@ repository = "https://github.com/volo-rs/static-graph"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1"
arc-swap = "1"
faststr = "0.2"
fxhash = "0.2"
heck = "0.4"
heck = "0.5"
nom = "7"
proc-macro2 = "1"
quote = "1"
Expand Down
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Add this to your `Cargo.toml`:

```toml
[build-dependencies]
static-graph = "0.2"
static-graph = "0.3"
```

## Example
Expand Down Expand Up @@ -74,6 +74,7 @@ fn main() {
.unwrap();
}
```
> if you want to generate a mermaid file, just set `enable_mermaid(true)`
Finally, in `main.rs` write your own logic for your nodes in the graph. The generated code will be in the `OUT_DIR` directory by default, the graph name is `G`, and the nodes name are `E`, `X`, `Y`, `O`. You should implement the `Runnable` trait for each node, and then you can automatically run the graph in maximum parallel by calling `G::new().run()`.

Expand Down Expand Up @@ -122,7 +123,6 @@ pub struct Request {
#[derive(Clone)]
pub struct EResponse(Duration);

#[async_trait::async_trait]
impl Runnable<Request, ()> for E {
type Resp = EResponse;
type Error = ();
Expand All @@ -136,7 +136,6 @@ impl Runnable<Request, ()> for E {
#[derive(Clone)]
pub struct XResponse(bool);

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for X {
type Resp = XResponse;
type Error = ();
Expand All @@ -150,7 +149,6 @@ impl Runnable<Request, EResponse> for X {
#[derive(Clone)]
pub struct YResponse(bool);

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for Y {
type Resp = YResponse;
type Error = ();
Expand All @@ -164,7 +162,6 @@ impl Runnable<Request, EResponse> for Y {
#[derive(Clone, Debug)]
pub struct OResponse(String);

#[async_trait::async_trait]
impl Runnable<Request, (XResponse, YResponse)> for O {
type Resp = OResponse;
type Error = ();
Expand Down
2 changes: 2 additions & 0 deletions examples/build.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
fn main() {
static_graph::configure()
.file_name("example.rs")
.enable_mermaid(true)
.compile("./graphs/example.graph")
.unwrap();
static_graph::configure()
.file_name("parallel.rs")
.enable_mermaid(true)
.compile("./graphs/parallel.graph")
.unwrap();
}
9 changes: 9 additions & 0 deletions examples/graphs/example.graph
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,13 @@ node O {
o: string,
}

node A -> B{

}

node B {

}

graph G(E)
graph H(A)
4 changes: 0 additions & 4 deletions examples/src/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ pub struct Request {
#[derive(Clone)]
pub struct EResponse(Duration);

#[async_trait::async_trait]
impl Runnable<Request, ()> for E {
type Resp = EResponse;
type Error = ();
Expand All @@ -56,7 +55,6 @@ impl Runnable<Request, ()> for E {
#[derive(Clone)]
pub struct XResponse(bool);

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for X {
type Resp = XResponse;
type Error = ();
Expand All @@ -70,7 +68,6 @@ impl Runnable<Request, EResponse> for X {
#[derive(Clone)]
pub struct YResponse(bool);

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for Y {
type Resp = YResponse;
type Error = ();
Expand All @@ -84,7 +81,6 @@ impl Runnable<Request, EResponse> for Y {
#[derive(Clone, Debug)]
pub struct OResponse(String);

#[async_trait::async_trait]
impl Runnable<Request, (XResponse, YResponse)> for O {
type Resp = OResponse;
type Error = ();
Expand Down
8 changes: 0 additions & 8 deletions examples/src/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ pub struct Request;
#[derive(Clone)]
pub struct EResponse;

#[async_trait::async_trait]
impl Runnable<Request, ()> for E {
type Resp = EResponse;
type Error = ();
Expand All @@ -48,7 +47,6 @@ impl Runnable<Request, ()> for E {
#[derive(Clone)]
pub struct XResponse;

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for X {
type Resp = XResponse;
type Error = ();
Expand All @@ -62,7 +60,6 @@ impl Runnable<Request, EResponse> for X {
#[derive(Clone)]
pub struct YResponse;

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for Y {
type Resp = YResponse;
type Error = ();
Expand All @@ -76,7 +73,6 @@ impl Runnable<Request, EResponse> for Y {
#[derive(Clone)]
pub struct WResponse;

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for W {
type Resp = WResponse;
type Error = ();
Expand All @@ -89,7 +85,6 @@ impl Runnable<Request, EResponse> for W {
#[derive(Clone)]
pub struct ZResponse;

#[async_trait::async_trait]
impl Runnable<Request, EResponse> for Z {
type Resp = ZResponse;
type Error = ();
Expand All @@ -102,7 +97,6 @@ impl Runnable<Request, EResponse> for Z {
#[derive(Clone)]
pub struct QResponse;

#[async_trait::async_trait]
impl Runnable<Request, (XResponse, YResponse)> for Q {
type Resp = QResponse;
type Error = ();
Expand All @@ -119,7 +113,6 @@ impl Runnable<Request, (XResponse, YResponse)> for Q {
#[derive(Clone)]
pub struct RResponse;

#[async_trait::async_trait]
impl Runnable<Request, (WResponse, ZResponse)> for R {
type Resp = RResponse;
type Error = ();
Expand All @@ -137,7 +130,6 @@ impl Runnable<Request, (WResponse, ZResponse)> for R {
#[derive(Clone, Debug)]
pub struct OResponse;

#[async_trait::async_trait]
impl Runnable<Request, (QResponse, RResponse)> for O {
type Resp = OResponse;
type Error = ();
Expand Down
58 changes: 55 additions & 3 deletions src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,59 @@ impl Codegen {
stream
}

/// draw a mermaid graph , go [mermaid.js.org](https://mermaid.js.org) for more detail
/// ```mermaid
///graph TD;
/// subgraph G
/// E --> X;
/// E --> Y;
/// X --> O;
/// Y --> O;
/// O
/// end
/// ```
pub fn mermaid(&self, def_ids: &[DefId]) -> String {
let mut ret = String::from("graph TD;\n");
for def_id in def_ids.iter() {
if let Some(graph) = self.graph(*def_id) {
let mut visited = FxHashSet::default();
let mut bytes = format!("subgraph {}\n", graph.name);
let mut node_queue = VecDeque::new();
node_queue.push_back(graph.entry_node);
visited.insert(graph.entry_node);
while !node_queue.is_empty() {
let node = node_queue.pop_front().unwrap();
if let Some(node) = self.node(node) {
if !node.to_nodes.is_empty() {
for to in node.to_nodes.iter() {
if !visited.contains(to) {
node_queue.push_back(*to);
visited.insert(*to);
}
if let Some(to) = self.node(*to) {
bytes.push_str(" ");
bytes.push_str(&node.name);
bytes.push_str("-->");
bytes.push_str(&to.name);
bytes.push_str(";\n");
}
}
} else {
bytes.push_str(" ");
bytes.push_str(&node.name);
bytes.push('\n');
}
}
}

bytes.push_str("end\n\n");
ret.push_str(&bytes);
}
}

ret
}

pub fn write_graph(&mut self, def_id: DefId, stream: &mut TokenStream) {
let graph = self.graph(def_id).unwrap();
let graph_name = self.upper_camel_name(&graph.name).as_syn_ident();
Expand Down Expand Up @@ -166,11 +219,10 @@ impl Codegen {
#[inline]
fn write_trait(&mut self, stream: &mut TokenStream) {
stream.extend(quote::quote! {
#[static_graph::async_trait]
pub trait Runnable<Req, PrevResp> {
type Resp;
type Error;
async fn run(&self, req: Req, prev_resp: PrevResp) -> ::std::result::Result<Self::Resp, Self::Error>;
fn run(&self, req: Req, prev_resp: PrevResp) -> impl std::future::Future<Output = ::std::result::Result<Self::Resp, Self::Error>> + Send;
}
});
}
Expand All @@ -179,7 +231,7 @@ impl Codegen {
let name = self.upper_camel_name(&graph.name).as_syn_ident();
let mut queue = VecDeque::new();

assert!(self.in_degrees.get(&graph.entry_node).is_none());
assert!(!self.in_degrees.contains_key(&graph.entry_node));

queue.push_back(graph.entry_node);
let mut bounds = TokenStream::new();
Expand Down
36 changes: 25 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@
//! #[derive(Clone)]
//! pub struct EResponse(Duration);
//! #[async_trait::async_trait]
//! impl Runnable<Request, ()> for E {
//! //! impl Runnable<Request, ()> for E {
//! type Resp = EResponse;
//! type Error = ();
Expand All @@ -121,8 +120,7 @@
//! #[derive(Clone)]
//! pub struct XResponse(bool);
//! #[async_trait::async_trait]
//! impl Runnable<Request, EResponse> for X {
//! //! impl Runnable<Request, EResponse> for X {
//! type Resp = XResponse;
//! type Error = ();
Expand All @@ -135,8 +133,7 @@
//! #[derive(Clone)]
//! pub struct YResponse(bool);
//! #[async_trait::async_trait]
//! impl Runnable<Request, EResponse> for Y {
//! //! impl Runnable<Request, EResponse> for Y {
//! type Resp = YResponse;
//! type Error = ();
Expand All @@ -149,8 +146,7 @@
//! #[derive(Clone, Debug)]
//! pub struct OResponse(String);
//! #[async_trait::async_trait]
//! impl Runnable<Request, (XResponse, YResponse)> for O {
//! //! impl Runnable<Request, (XResponse, YResponse)> for O {
//! type Resp = OResponse;
//! type Error = ();
Expand Down Expand Up @@ -179,7 +175,6 @@ pub mod symbol;
pub mod tags;

pub use arc_swap::*;
pub use async_trait::*;
pub use tokio::*;

use crate::{
Expand All @@ -188,8 +183,9 @@ use crate::{
parser::{document::Document, Parser},
resolver::{ResolveResult, Resolver},
};

use std::{
io::{self, Write},
io::Write,
path::{Path, PathBuf},
process::{exit, Command},
};
Expand All @@ -207,6 +203,7 @@ pub fn configure() -> Builder {
emit_rerun_if_changed: std::env::var_os("CARGO").is_some(),
out_dir: None,
file_name: "gen_graph.rs".into(),
enable_mermaid: false,
}
}

Expand All @@ -215,6 +212,7 @@ pub struct Builder {
emit_rerun_if_changed: bool,
out_dir: Option<PathBuf>,
file_name: PathBuf,
enable_mermaid: bool, // generate mermaid file
}

impl Builder {
Expand All @@ -236,7 +234,13 @@ impl Builder {
self
}

pub fn compile(self, graph: impl AsRef<Path>) -> io::Result<()> {
#[must_use]
pub fn enable_mermaid(mut self, enable: bool) -> Self {
self.enable_mermaid = enable;
self
}

pub fn compile(self, graph: impl AsRef<Path>) -> std::io::Result<()> {
let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
out_dir.clone()
} else {
Expand Down Expand Up @@ -265,6 +269,16 @@ impl Builder {
cx.set_tags(tags);

let mut cg = Codegen::new(cx);

if self.enable_mermaid {
let ret = cg.mermaid(&entrys);
let mut name = self.file_name.file_stem().unwrap().to_os_string();
name.push(".mermaid");
let out = out_dir.join(name);
let mut file = std::io::BufWriter::new(std::fs::File::create(&out).unwrap());
file.write_all(ret.trim().as_bytes()).unwrap();
file.flush().unwrap();
}
let stream = cg.write_document(entrys);
let out = out_dir.join(self.file_name);
let mut file = std::io::BufWriter::new(std::fs::File::create(&out).unwrap());
Expand Down

0 comments on commit 3866af5

Please sign in to comment.