@@ -2,19 +2,11 @@ use lazy_static::lazy_static;
2
2
use std:: collections:: HashSet ;
3
3
4
4
use squawk_syntax:: ast:: AstNode ;
5
- use squawk_syntax:: { Parse , SourceFile } ;
5
+ use squawk_syntax:: { Parse , SourceFile , SyntaxKind } ;
6
6
use squawk_syntax:: { ast, identifier:: Identifier } ;
7
7
8
8
use crate :: { Linter , Rule , Version , Violation } ;
9
9
10
- fn is_const_expr ( expr : & ast:: Expr ) -> bool {
11
- match expr {
12
- ast:: Expr :: Literal ( _) => true ,
13
- ast:: Expr :: CastExpr ( cast) => matches ! ( cast. expr( ) , Some ( ast:: Expr :: Literal ( _) ) ) ,
14
- _ => false ,
15
- }
16
- }
17
-
18
10
lazy_static ! {
19
11
static ref NON_VOLATILE_FUNCS : HashSet <Identifier > = {
20
12
NON_VOLATILE_BUILT_IN_FUNCTIONS
@@ -26,8 +18,18 @@ lazy_static! {
26
18
} ;
27
19
}
28
20
29
- fn is_non_volatile ( expr : & ast:: Expr ) -> bool {
21
+ fn is_non_volatile_or_const ( expr : & ast:: Expr ) -> bool {
30
22
match expr {
23
+ ast:: Expr :: Literal ( _) => true ,
24
+ ast:: Expr :: ArrayExpr ( _) => true ,
25
+ ast:: Expr :: BinExpr ( bin_expr) => {
26
+ if let Some ( lhs) = bin_expr. lhs ( ) {
27
+ if let Some ( rhs) = bin_expr. rhs ( ) {
28
+ return is_non_volatile_or_const ( & lhs) && is_non_volatile_or_const ( & rhs) ;
29
+ }
30
+ }
31
+ false
32
+ }
31
33
ast:: Expr :: CallExpr ( call_expr) => {
32
34
if let Some ( arglist) = call_expr. arg_list ( ) {
33
35
let no_args = arglist. args ( ) . count ( ) == 0 ;
@@ -45,6 +47,24 @@ fn is_non_volatile(expr: &ast::Expr) -> bool {
45
47
false
46
48
}
47
49
}
50
+ // array[]::t[] is non-volatile. We don't check for a plain array expr
51
+ // since postgres will reject it as a default unless it's cast to a type.
52
+ ast:: Expr :: CastExpr ( cast_expr) => {
53
+ if let Some ( inner_expr) = cast_expr. expr ( ) {
54
+ is_non_volatile_or_const ( & inner_expr)
55
+ } else {
56
+ false
57
+ }
58
+ }
59
+ // current_timestamp is the same as calling now()
60
+ ast:: Expr :: NameRef ( name_ref) => {
61
+ if let Some ( child) = name_ref. syntax ( ) . first_child_or_token ( ) {
62
+ if child. kind ( ) == SyntaxKind :: CURRENT_TIMESTAMP_KW {
63
+ return true ;
64
+ }
65
+ }
66
+ false
67
+ }
48
68
_ => false ,
49
69
}
50
70
}
@@ -69,7 +89,7 @@ pub(crate) fn adding_field_with_default(ctx: &mut Linter, parse: &Parse<SourceFi
69
89
continue ;
70
90
} ;
71
91
if ctx. settings . pg_version > Version :: new ( 11 , None , None )
72
- && ( is_const_expr ( & expr) || is_non_volatile ( & expr ) )
92
+ && is_non_volatile_or_const ( & expr)
73
93
{
74
94
continue ;
75
95
}
@@ -181,6 +201,33 @@ ALTER TABLE "core_recipe" ADD COLUMN "foo" boolean DEFAULT true;
181
201
assert_debug_snapshot ! ( errors) ;
182
202
}
183
203
204
+ #[ test]
205
+ fn default_empty_array_ok ( ) {
206
+ let sql = r#"
207
+ alter table t add column a double precision[] default array[]::double precision[];
208
+
209
+ alter table t add column b bigint[] default cast(array[] as bigint[]);
210
+
211
+ alter table t add column c text[] default array['foo', 'bar']::text[];
212
+ "# ;
213
+
214
+ let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
215
+ assert ! ( errors. is_empty( ) ) ;
216
+ assert_debug_snapshot ! ( errors) ;
217
+ }
218
+
219
+ #[ test]
220
+ fn default_with_const_bin_expr ( ) {
221
+ let sql = r#"
222
+ ALTER TABLE assessments
223
+ ADD COLUMN statistics_last_updated_at timestamptz NOT NULL DEFAULT now() - interval '100 years';
224
+ "# ;
225
+
226
+ let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
227
+ assert ! ( errors. is_empty( ) ) ;
228
+ assert_debug_snapshot ! ( errors) ;
229
+ }
230
+
184
231
#[ test]
185
232
fn default_str_ok ( ) {
186
233
let sql = r#"
@@ -240,6 +287,7 @@ ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now(123);
240
287
assert ! ( !errors. is_empty( ) ) ;
241
288
assert_debug_snapshot ! ( errors) ;
242
289
}
290
+
243
291
#[ test]
244
292
fn default_func_now_ok ( ) {
245
293
let sql = r#"
@@ -252,14 +300,25 @@ ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now();
252
300
assert_debug_snapshot ! ( errors) ;
253
301
}
254
302
303
+ #[ test]
304
+ fn default_func_current_timestamp_ok ( ) {
305
+ let sql = r#"
306
+ alter table t add column c timestamptz default current_timestamp;
307
+ "# ;
308
+
309
+ let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
310
+ assert ! ( errors. is_empty( ) ) ;
311
+ assert_debug_snapshot ! ( errors) ;
312
+ }
313
+
255
314
#[ test]
256
315
fn add_numbers_ok ( ) {
257
- // This should be okay, but we don't handle expressions like this at the moment.
258
316
let sql = r#"
259
317
alter table account_metadata add column blah integer default 2 + 2;
260
318
"# ;
261
319
262
320
let errors = lint ( sql, Rule :: AddingFieldWithDefault ) ;
321
+ assert ! ( errors. is_empty( ) ) ;
263
322
assert_debug_snapshot ! ( errors) ;
264
323
}
265
324
0 commit comments