Skip to content

Commit

Permalink
Implement templates for IN
Browse files Browse the repository at this point in the history
  • Loading branch information
aqrln committed Jan 14, 2025
1 parent 159c4a0 commit 0e6a694
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 2 deletions.
2 changes: 2 additions & 0 deletions quaint/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod compare;
mod conditions;
mod conjunctive;
mod cte;
mod decorated;
mod delete;
mod enums;
mod expression;
Expand All @@ -35,6 +36,7 @@ pub use compare::{Comparable, Compare, JsonCompare, JsonType};
pub use conditions::ConditionTree;
pub use conjunctive::Conjunctive;
pub use cte::{CommonTableExpression, IntoCommonTableExpression};
pub use decorated::{Decoratable, Decorated};
pub use delete::Delete;
pub use enums::{EnumName, EnumVariant};
pub use expression::*;
Expand Down
46 changes: 46 additions & 0 deletions quaint/src/ast/decorated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::borrow::Cow;

use super::{Expression, ExpressionKind};

#[derive(Debug, Clone, PartialEq)]
pub struct Decorated<'a> {
pub(crate) expr: Box<Expression<'a>>,
pub(crate) prefix: Option<Cow<'a, str>>,
pub(crate) suffix: Option<Cow<'a, str>>,
}

impl<'a> Decorated<'a> {
pub fn new<L, R>(expr: Expression<'a>, prefix: Option<L>, suffix: Option<R>) -> Self
where
L: Into<Cow<'a, str>>,
R: Into<Cow<'a, str>>,
{
Decorated {
expr: Box::new(expr),
prefix: prefix.map(<_>::into),
suffix: suffix.map(<_>::into),
}
}
}

expression!(Decorated, Decorated);

pub trait Decoratable<'a> {
fn decorate<L, R>(self, left: Option<L>, right: Option<R>) -> Decorated<'a>
where
L: Into<Cow<'a, str>>,
R: Into<Cow<'a, str>>;
}

impl<'a, T> Decoratable<'a> for T
where
T: Into<Expression<'a>>,
{
fn decorate<L, R>(self, left: Option<L>, right: Option<R>) -> Decorated<'a>
where
L: Into<Cow<'a, str>>,
R: Into<Cow<'a, str>>,
{
Decorated::new(self.into(), left, right)
}
}
2 changes: 2 additions & 0 deletions quaint/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ pub enum ExpressionKind<'a> {
Value(Box<Expression<'a>>),
/// DEFAULT keyword, e.g. for `INSERT INTO ... VALUES (..., DEFAULT, ...)`
Default,
/// An expression wrapped with comments on each side
Decorated(Decorated<'a>),
}

impl ExpressionKind<'_> {
Expand Down
17 changes: 17 additions & 0 deletions quaint/src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ pub trait Visitor<'a> {
None => self.write("*")?,
},
ExpressionKind::Default => self.write("DEFAULT")?,
ExpressionKind::Decorated(decorated) => self.visit_decorated(decorated)?,
}

if let Some(alias) = value.alias {
Expand Down Expand Up @@ -1207,4 +1208,20 @@ pub trait Visitor<'a> {
fn visit_comment(&mut self, comment: Cow<'a, str>) -> Result {
self.surround_with("/* ", " */", |ref mut s| s.write(comment))
}

fn visit_decorated(&mut self, decorated: Decorated<'a>) -> Result {
let Decorated { prefix, suffix, expr } = decorated;

if let Some(prefix) = prefix {
self.visit_comment(prefix)?;
}

self.visit_expression(*expr)?;

if let Some(suffix) = suffix {
self.visit_comment(suffix)?;
}

Ok(())
}
}
4 changes: 4 additions & 0 deletions query-engine/connectors/mongodb-query-connector/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ impl MongoFilterVisitor {
doc! { "$not": { "$in": [&field_name, coerce_as_array(self.prefixed_field_ref(&field_ref)?)] } }
}
},
ScalarCondition::InTemplate(_) => unimplemented!("query compiler not supported with mongodb yet"),
ScalarCondition::NotInTemplate(_) => unimplemented!("query compiler not supported with mongodb yet"),
ScalarCondition::JsonCompare(jc) => match *jc.condition {
ScalarCondition::Equals(value) => {
let bson = match value {
Expand Down Expand Up @@ -400,6 +402,8 @@ impl MongoFilterVisitor {
true,
)),
},
ScalarCondition::InTemplate(_) => unimplemented!("query compiler not supported with mongodb yet"),
ScalarCondition::NotInTemplate(_) => unimplemented!("query compiler not supported with mongodb yet"),
ScalarCondition::IsSet(is_set) => Ok(render_is_set(&field_name, is_set)),
ScalarCondition::JsonCompare(_) => Err(MongoError::Unsupported(
"JSON filtering is not yet supported on MongoDB".to_string(),
Expand Down
36 changes: 36 additions & 0 deletions query-engine/connectors/sql-query-connector/src/filter/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,22 @@ fn default_scalar_filter(
// This code path is only reachable for connectors with `ScalarLists` capability
comparable.not_equals(Expression::from(field_ref.aliased_col(alias, ctx)).all())
}
ScalarCondition::InTemplate(ConditionValue::Value(value)) => {
let sql_value = convert_first_value(fields, value, alias, ctx);
comparable.in_selection(sql_value.decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
))
}
ScalarCondition::InTemplate(ConditionValue::FieldRef(_)) => todo!(),
ScalarCondition::NotInTemplate(ConditionValue::Value(value)) => {
let sql_value = convert_first_value(fields, value, alias, ctx);
comparable.not_in_selection(sql_value.decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
))
}
ScalarCondition::NotInTemplate(ConditionValue::FieldRef(_)) => todo!(),
ScalarCondition::Search(value, _) => {
reachable_only_with_capability!(ConnectorCapability::NativeFullTextSearch);
let query: String = value
Expand Down Expand Up @@ -1139,6 +1155,26 @@ fn insensitive_scalar_filter(
// This code path is only reachable for connectors with `ScalarLists` capability
comparable.compare_raw("NOT ILIKE", Expression::from(field_ref.aliased_col(alias, ctx)).all())
}
ScalarCondition::InTemplate(ConditionValue::Value(value)) => {
let comparable = Expression::from(lower(comparable));
let sql_value = convert_first_value(fields, value, alias, ctx);

comparable.in_selection(sql_value.decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
))
}
ScalarCondition::InTemplate(ConditionValue::FieldRef(_)) => todo!(),
ScalarCondition::NotInTemplate(ConditionValue::Value(value)) => {
let comparable = Expression::from(lower(comparable));
let sql_value = convert_first_value(fields, value, alias, ctx);

comparable.in_selection(sql_value.decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
))
}
ScalarCondition::NotInTemplate(ConditionValue::FieldRef(_)) => todo!(),
ScalarCondition::Search(value, _) => {
reachable_only_with_capability!(ConnectorCapability::NativeFullTextSearch);
let query: String = value
Expand Down
2 changes: 1 addition & 1 deletion query-engine/core/src/compiler/translate/query/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ fn add_inmemory_join(parent: Expression, nested: Vec<ReadQuery>, ctx: &Context<'
);

let condition = if parent.r#type().is_list() {
ScalarCondition::In(ConditionListValue::list(vec![placeholder]))
ScalarCondition::InTemplate(ConditionValue::value(placeholder))
} else {
ScalarCondition::Equals(ConditionValue::value(placeholder))
};
Expand Down
2 changes: 1 addition & 1 deletion query-engine/query-engine/examples/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub fn main() -> anyhow::Result<()> {
// })
let query: JsonSingleQuery = serde_json::from_value(json!({
"modelName": "User",
"action": "findUnique",
"action": "findMany",
"query": {
"arguments": {
"where": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub enum ScalarCondition {
GreaterThanOrEquals(ConditionValue),
In(ConditionListValue),
NotIn(ConditionListValue),
InTemplate(ConditionValue),
NotInTemplate(ConditionValue),
JsonCompare(JsonCondition),
Search(ConditionValue, Vec<ScalarProjection>),
NotSearch(ConditionValue, Vec<ScalarProjection>),
Expand Down Expand Up @@ -52,6 +54,8 @@ impl ScalarCondition {
Self::GreaterThanOrEquals(v) => Self::LessThan(v),
Self::In(v) => Self::NotIn(v),
Self::NotIn(v) => Self::In(v),
Self::InTemplate(v) => Self::NotInTemplate(v),
Self::NotInTemplate(v) => Self::InTemplate(v),
Self::JsonCompare(json_compare) => {
let inverted_cond = json_compare.condition.invert(true);

Expand Down Expand Up @@ -86,6 +90,8 @@ impl ScalarCondition {
ScalarCondition::GreaterThanOrEquals(v) => v.as_field_ref(),
ScalarCondition::In(v) => v.as_field_ref(),
ScalarCondition::NotIn(v) => v.as_field_ref(),
ScalarCondition::InTemplate(v) => v.as_field_ref(),
ScalarCondition::NotInTemplate(v) => v.as_field_ref(),
ScalarCondition::JsonCompare(json_cond) => json_cond.condition.as_field_ref(),
ScalarCondition::Search(v, _) => v.as_field_ref(),
ScalarCondition::NotSearch(v, _) => v.as_field_ref(),
Expand Down

0 comments on commit 0e6a694

Please sign in to comment.