Skip to content

Commit

Permalink
feat(queryEngine): add limit argument to updateMany (#5110)
Browse files Browse the repository at this point in the history
  • Loading branch information
FGoessler authored Jan 13, 2025
1 parent a046f87 commit 8d21d25
Show file tree
Hide file tree
Showing 25 changed files with 327 additions and 75 deletions.
48 changes: 48 additions & 0 deletions prisma-fmt/src/get_dmmf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5523,6 +5523,18 @@ mod tests {
"isList": false
}
]
},
{
"name": "limit",
"isRequired": false,
"isNullable": false,
"inputTypes": [
{
"type": "Int",
"location": "scalar",
"isList": false
}
]
}
],
"isNullable": false,
Expand Down Expand Up @@ -5567,6 +5579,18 @@ mod tests {
"isList": false
}
]
},
{
"name": "limit",
"isRequired": false,
"isNullable": false,
"inputTypes": [
{
"type": "Int",
"location": "scalar",
"isList": false
}
]
}
],
"isNullable": false,
Expand Down Expand Up @@ -5897,6 +5921,18 @@ mod tests {
"isList": false
}
]
},
{
"name": "limit",
"isRequired": false,
"isNullable": false,
"inputTypes": [
{
"type": "Int",
"location": "scalar",
"isList": false
}
]
}
],
"isNullable": false,
Expand Down Expand Up @@ -5941,6 +5977,18 @@ mod tests {
"isList": false
}
]
},
{
"name": "limit",
"isRequired": false,
"isNullable": false,
"inputTypes": [
{
"type": "Int",
"location": "scalar",
"isList": false
}
]
}
],
"isNullable": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,24 @@ mod delete_many {
Ok(())
}

// "The delete many Mutation" should "fail if limit param is negative"
#[connector_test]
async fn should_fail_with_negative_limit(runner: Runner) -> TestResult<()> {
create_row(&runner, r#"{ id: 1, title: "title1" }"#).await?;
create_row(&runner, r#"{ id: 2, title: "title2" }"#).await?;
create_row(&runner, r#"{ id: 3, title: "title3" }"#).await?;
create_row(&runner, r#"{ id: 4, title: "title4" }"#).await?;

assert_error!(
&runner,
r#"mutation { deleteManyTodo(limit: -3){ count }}"#,
2019,
"Provided limit (-3) must be a positive integer."
);

Ok(())
}

fn nested_del_many() -> String {
let schema = indoc! {
r#"model ZChild{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,65 @@ mod update_many {
Ok(())
}

// "An updateMany mutation" should "update max limit number of items"
#[connector_test]
async fn update_max_limit_items(runner: Runner) -> TestResult<()> {
create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?;
create_row(&runner, r#"{ id: 2, optStr: "str2" }"#).await?;
create_row(&runner, r#"{ id: 3, optStr: "str3" }"#).await?;

insta::assert_snapshot!(
run_query!(&runner, r#"mutation {
updateManyTestModel(
where: { }
data: { optStr: { set: "updated" } }
limit: 2
){
count
}
}"#),
@r###"{"data":{"updateManyTestModel":{"count":2}}}"###
);

insta::assert_snapshot!(
run_query!(
&runner,
r#"{
findManyTestModel(orderBy: { id: asc }) {
optStr
}
}"#),
@r###"{"data":{"findManyTestModel":[{"optStr":"updated"},{"optStr":"updated"},{"optStr":"str3"}]}}"###
);

Ok(())
}

// "An updateMany mutation" should "fail if limit param is negative"
#[connector_test]
async fn should_fail_with_negative_limit(runner: Runner) -> TestResult<()> {
create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?;
create_row(&runner, r#"{ id: 2, optStr: "str2" }"#).await?;
create_row(&runner, r#"{ id: 3, optStr: "str3" }"#).await?;

assert_error!(
&runner,
r#"mutation {
updateManyTestModel(
where: { }
data: { optStr: { set: "updated" } }
limit: -2
){
count
}
}"#,
2019,
"Provided limit (-2) must be a positive integer."
);

Ok(())
}

// "An updateMany mutation" should "correctly apply all number operations for Int"
#[connector_test(exclude(CockroachDb))]
async fn apply_number_ops_for_int(runner: Runner) -> TestResult<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl WriteOperations for MongoDbConnection {
model: &Model,
record_filter: connector_interface::RecordFilter,
args: WriteArgs,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(async move {
Expand All @@ -105,7 +106,7 @@ impl WriteOperations for MongoDbConnection {
model,
record_filter,
args,
UpdateType::Many,
UpdateType::Many { limit },
)
.await?;

Expand All @@ -120,6 +121,7 @@ impl WriteOperations for MongoDbConnection {
_record_filter: connector_interface::RecordFilter,
_args: WriteArgs,
_selected_fields: FieldSelection,
_limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<ManyRecords> {
unimplemented!()
Expand Down Expand Up @@ -162,7 +164,7 @@ impl WriteOperations for MongoDbConnection {
&mut self,
model: &Model,
record_filter: connector_interface::RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(write::delete_records(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ impl WriteOperations for MongoDbTransaction<'_> {
model: &Model,
record_filter: connector_interface::RecordFilter,
args: connector_interface::WriteArgs,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(async move {
Expand All @@ -136,7 +137,7 @@ impl WriteOperations for MongoDbTransaction<'_> {
model,
record_filter,
args,
UpdateType::Many,
UpdateType::Many { limit },
)
.await?;
Ok(result.len())
Expand All @@ -150,6 +151,7 @@ impl WriteOperations for MongoDbTransaction<'_> {
_record_filter: connector_interface::RecordFilter,
_args: connector_interface::WriteArgs,
_selected_fields: FieldSelection,
_limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<ManyRecords> {
unimplemented!()
Expand Down Expand Up @@ -191,7 +193,7 @@ impl WriteOperations for MongoDbTransaction<'_> {
&mut self,
model: &Model,
record_filter: connector_interface::RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(write::delete_records(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::error::MongoError::ConversionError;
use crate::{
error::{DecorateErrorWithFieldInformationExtension, MongoError},
filter::{FilterPrefix, MongoFilter, MongoFilterVisitor},
Expand Down Expand Up @@ -160,6 +161,10 @@ pub async fn update_records<'conn>(
let ids: Vec<Bson> = if let Some(selectors) = record_filter.selectors {
selectors
.into_iter()
.take(match update_type {
UpdateType::Many { limit } => limit.unwrap_or(usize::MAX),
UpdateType::One => 1,
})
.map(|p| {
(&id_field, p.values().next().unwrap())
.into_bson()
Expand Down Expand Up @@ -205,7 +210,7 @@ pub async fn update_records<'conn>(
// It's important we check the `matched_count` and not the `modified_count` here.
// MongoDB returns `modified_count: 0` when performing a noop update, which breaks
// nested connect mutations as it rely on the returned count to know whether the update happened.
if update_type == UpdateType::Many && res.matched_count == 0 {
if matches!(update_type, UpdateType::Many { limit: _ }) && res.matched_count == 0 {
return Ok(Vec::new());
}
}
Expand All @@ -228,15 +233,15 @@ pub async fn delete_records<'conn>(
session: &mut ClientSession,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
) -> crate::Result<usize> {
let coll = database.collection::<Document>(model.db_name());
let id_field = pick_singular_id(model);

let ids = if let Some(selectors) = record_filter.selectors {
selectors
.into_iter()
.take(limit.unwrap_or(i64::MAX) as usize)
.take(limit.unwrap_or(usize::MAX))
.map(|p| {
(&id_field, p.values().next().unwrap())
.into_bson()
Expand Down Expand Up @@ -305,7 +310,7 @@ async fn find_ids(
session: &mut ClientSession,
model: &Model,
filter: MongoFilter,
limit: Option<i64>,
limit: Option<usize>,
) -> crate::Result<Vec<Bson>> {
let id_field = model.primary_identifier();
let mut builder = MongoReadQueryBuilder::new(model.clone());
Expand All @@ -321,7 +326,17 @@ async fn find_ids(

let mut builder = builder.with_model_projection(id_field)?;

builder.limit = limit;
if let Some(limit) = limit {
builder.limit = match i64::try_from(limit) {
Ok(limit) => Some(limit),
Err(_) => {
return Err(ConversionError {
from: "usize".to_owned(),
to: "i64".to_owned(),
})
}
}
}

let query = builder.build()?;
let docs = query.execute(collection, session).await?;
Expand Down
4 changes: 3 additions & 1 deletion query-engine/connectors/query-connector/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ pub trait WriteOperations {
model: &Model,
record_filter: RecordFilter,
args: WriteArgs,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> crate::Result<usize>;

Expand All @@ -299,6 +300,7 @@ pub trait WriteOperations {
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: FieldSelection,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> crate::Result<ManyRecords>;

Expand Down Expand Up @@ -326,7 +328,7 @@ pub trait WriteOperations {
&mut self,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> crate::Result<usize>;

Expand Down
2 changes: 1 addition & 1 deletion query-engine/connectors/query-connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ pub type Result<T> = std::result::Result<T, error::ConnectorError>;
/// However when we updating any records we want to return an empty array if zero items were updated
#[derive(PartialEq)]
pub enum UpdateType {
Many,
Many { limit: Option<usize> },
One,
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,13 @@ where
model: &Model,
record_filter: RecordFilter,
args: WriteArgs,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<usize> {
let ctx = Context::new(&self.connection_info, traceparent);
catch(
&self.connection_info,
write::update_records(&self.inner, model, record_filter, args, &ctx),
write::update_records(&self.inner, model, record_filter, args, limit, &ctx),
)
.await
}
Expand All @@ -242,12 +243,13 @@ where
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: FieldSelection,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<ManyRecords> {
let ctx = Context::new(&self.connection_info, traceparent);
catch(
&self.connection_info,
write::update_records_returning(&self.inner, model, record_filter, args, selected_fields, &ctx),
write::update_records_returning(&self.inner, model, record_filter, args, selected_fields, limit, &ctx),
)
.await
}
Expand All @@ -272,7 +274,7 @@ where
&mut self,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<usize> {
let ctx = Context::new(&self.connection_info, traceparent);
Expand Down
Loading

0 comments on commit 8d21d25

Please sign in to comment.