1
1
use async_trait:: async_trait;
2
- use casbin:: { error:: AdapterError , Adapter , Error as CasbinError , Model , Result } ;
2
+ use casbin:: { error:: AdapterError , Adapter , Error as CasbinError , Filter , Model , Result } ;
3
3
use diesel:: {
4
4
self ,
5
5
r2d2:: { ConnectionManager , Pool } ,
@@ -11,6 +11,7 @@ use std::time::Duration;
11
11
12
12
pub struct DieselAdapter {
13
13
pool : Pool < ConnectionManager < adapter:: Connection > > ,
14
+ is_filtered : bool ,
14
15
}
15
16
16
17
pub const TABLE_NAME : & str = "casbin_rules" ;
@@ -27,7 +28,10 @@ impl<'a> DieselAdapter {
27
28
. get ( )
28
29
. map_err ( |err| CasbinError :: from ( AdapterError ( Box :: new ( Error :: PoolError ( err) ) ) ) ) ;
29
30
30
- adapter:: new ( conn) . map ( |_| Self { pool } )
31
+ adapter:: new ( conn) . map ( |_| Self {
32
+ pool,
33
+ is_filtered : false ,
34
+ } )
31
35
}
32
36
33
37
pub ( crate ) fn save_policy_line (
@@ -82,6 +86,39 @@ impl<'a> DieselAdapter {
82
86
None
83
87
}
84
88
89
+ pub ( crate ) fn load_filtered_policy_line (
90
+ & self ,
91
+ casbin_rule : & CasbinRule ,
92
+ f : & Filter ,
93
+ ) -> Option < Vec < String > > {
94
+ if let Some ( sec) = casbin_rule. ptype . chars ( ) . next ( ) {
95
+ if let Some ( policy) = self . normalize_policy ( casbin_rule) {
96
+ let mut is_filtered = false ;
97
+ if sec == 'p' {
98
+ for ( i, rule) in f. p . iter ( ) . enumerate ( ) {
99
+ if !rule. is_empty ( ) && rule != & policy[ i] {
100
+ is_filtered = true
101
+ }
102
+ }
103
+ } else if sec == 'g' {
104
+ for ( i, rule) in f. g . iter ( ) . enumerate ( ) {
105
+ if !rule. is_empty ( ) && rule != & policy[ i] {
106
+ is_filtered = true
107
+ }
108
+ }
109
+ } else {
110
+ return None ;
111
+ }
112
+
113
+ if !is_filtered {
114
+ return Some ( policy) ;
115
+ }
116
+ }
117
+ }
118
+
119
+ None
120
+ }
121
+
85
122
fn normalize_policy ( & self , casbin_rule : & CasbinRule ) -> Option < Vec < String > > {
86
123
let mut result = vec ! [
87
124
& casbin_rule. v0,
@@ -135,6 +172,33 @@ impl Adapter for DieselAdapter {
135
172
Ok ( ( ) )
136
173
}
137
174
175
+ async fn load_filtered_policy ( & mut self , m : & mut dyn Model , f : Filter ) -> Result < ( ) > {
176
+ let conn = self
177
+ . pool
178
+ . get ( )
179
+ . map_err ( |err| CasbinError :: from ( AdapterError ( Box :: new ( Error :: PoolError ( err) ) ) ) ) ?;
180
+
181
+ let rules = adapter:: load_policy ( conn) ?;
182
+
183
+ for casbin_rule in & rules {
184
+ let rule = self . load_filtered_policy_line ( casbin_rule, & f) ;
185
+
186
+ if let Some ( rule) = rule {
187
+ if let Some ( ref sec) = casbin_rule. ptype . chars ( ) . next ( ) . map ( |x| x. to_string ( ) ) {
188
+ if let Some ( t1) = m. get_mut_model ( ) . get_mut ( sec) {
189
+ if let Some ( t2) = t1. get_mut ( & casbin_rule. ptype ) {
190
+ t2. get_mut_policy ( ) . insert ( rule) ;
191
+ }
192
+ }
193
+ }
194
+ } else {
195
+ self . is_filtered = true ;
196
+ }
197
+ }
198
+
199
+ Ok ( ( ) )
200
+ }
201
+
138
202
async fn save_policy ( & mut self , m : & mut dyn Model ) -> Result < ( ) > {
139
203
let conn = self
140
204
. pool
@@ -241,6 +305,10 @@ impl Adapter for DieselAdapter {
241
305
Ok ( false )
242
306
}
243
307
}
308
+
309
+ fn is_filtered ( & self ) -> bool {
310
+ self . is_filtered
311
+ }
244
312
}
245
313
246
314
#[ cfg( test) ]
@@ -411,5 +479,47 @@ mod tests {
411
479
)
412
480
. await
413
481
. is_ok( ) ) ;
482
+
483
+ // shadow the previous enforcer
484
+ let mut e = Enforcer :: new (
485
+ "examples/rbac_with_domains_model.conf" ,
486
+ "examples/rbac_with_domains_policy.csv" ,
487
+ )
488
+ . await
489
+ . unwrap ( ) ;
490
+
491
+ assert ! ( adapter. save_policy( e. get_mut_model( ) ) . await . is_ok( ) ) ;
492
+ e. set_adapter ( adapter) . await . unwrap ( ) ;
493
+
494
+ let filter = Filter {
495
+ p : vec ! [ "" , "domain1" ] ,
496
+ g : vec ! [ "" , "" , "domain1" ] ,
497
+ } ;
498
+
499
+ e. load_filtered_policy ( filter) . await . unwrap ( ) ;
500
+ assert ! ( e
501
+ . enforce( & [ "alice" , "domain1" , "data1" , "read" ] )
502
+ . await
503
+ . unwrap( ) ) ;
504
+ assert ! ( e
505
+ . enforce( & [ "alice" , "domain1" , "data1" , "write" ] )
506
+ . await
507
+ . unwrap( ) ) ;
508
+ assert ! ( !e
509
+ . enforce( & [ "alice" , "domain1" , "data2" , "read" ] )
510
+ . await
511
+ . unwrap( ) ) ;
512
+ assert ! ( !e
513
+ . enforce( & [ "alice" , "domain1" , "data2" , "write" ] )
514
+ . await
515
+ . unwrap( ) ) ;
516
+ assert ! ( !e
517
+ . enforce( & [ "bob" , "domain2" , "data2" , "read" ] )
518
+ . await
519
+ . unwrap( ) ) ;
520
+ assert ! ( !e
521
+ . enforce( & [ "bob" , "domain2" , "data2" , "write" ] )
522
+ . await
523
+ . unwrap( ) ) ;
414
524
}
415
525
}
0 commit comments