Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion sea-orm-macros/src/derives/model_ex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ pub fn expand_sea_orm_model(input: ItemStruct, compact: bool) -> syn::Result<Tok
};

let mut model_fields = Vec::new();
let mut required_model_fields = Vec::new();

for field in all_fields.iter_mut() {
let field_type = &field.ty;
let field_type = quote! { #field_type }
.to_string() // e.g.: "Option < String >"
.replace(' ', ""); // Remove spaces

if !(field_type.starts_with("Option<")
|| field_type.starts_with("HasMany<")
|| field_type.starts_with("HasOne<"))
{
required_model_fields.push(field.clone());
}
if is_compound_field(&field_type) {
let entity_path = extract_compound_entity(&field_type);
if field_type.starts_with("Option<") || field_type.starts_with("HasOne<") {
Expand All @@ -69,6 +75,31 @@ pub fn expand_sea_orm_model(input: ItemStruct, compact: bool) -> syn::Result<Tok
}
}

let non_required_model_field_names = model_fields
.iter()
.filter(|item| !required_model_fields.contains(item))
.map(|item| &item.ident)
.collect::<Vec<_>>();

let required_field_names = required_model_fields
.iter()
.map(|item| &item.ident)
.collect::<Vec<_>>();
let required_field_types = required_model_fields
.iter()
.map(|item| &item.ty)
.collect::<Vec<_>>();
let from_required_fields = quote! {
impl #model_ex {
pub fn model_from_required(#(#required_field_names: #required_field_types),*) -> #model {
#model {
#(#required_field_names),*,
#(#non_required_model_field_names : None),*
}
}
}
};

Ok(quote! {
#(#model_attrs)*
#[sea_orm(model_ex)]
Expand All @@ -79,6 +110,7 @@ pub fn expand_sea_orm_model(input: ItemStruct, compact: bool) -> syn::Result<Tok
#(#model_ex_attrs)*
#compact_model
#vis struct #model_ex #all_fields
#from_required_fields
})
}

Expand Down Expand Up @@ -785,6 +817,8 @@ fn to_upper_camel_case(i: &Ident) -> Ident {
#[cfg(test)]
mod test {
use super::format_tuple;
use crate::{DbErr, entity::*, tests_cfg::*};
use pretty_assertions::assert_eq;

#[test]
fn test_format_tuple() {
Expand All @@ -799,4 +833,16 @@ mod test {
"(super::Column::A, super::Column::B)"
);
}

#[test]
fn test_new_from_required() {
assert_eq!(
fruit::ModelEx::model_from_required(1, "Example Fruit"),
fruit::Model {
id: 1,
name: "Example Fruit",
cake_id: None,
}
);
}
}
Loading