@@ -2,6 +2,7 @@ package golden
2
2
3
3
import (
4
4
"fmt"
5
+ "path/filepath"
5
6
"sync"
6
7
"testing"
7
8
)
@@ -75,12 +76,22 @@ func DagTest(t *testing.T, cases []DagTestCase) {
75
76
if nextCase .Config != nil {
76
77
config = * nextCase .Config
77
78
}
79
+ if len (config .ScriptExtensions ) == 0 {
80
+ // Default script extension if none is provided.
81
+ config .ScriptExtensions = []ScriptExtension {{Extension : ".sh" , Command : "bash" }}
82
+ }
83
+
84
+ // Get the script extension for the test case.
85
+ ext , err := dagGetScriptExtension (nextCase .Path , config )
86
+ if err != nil {
87
+ t .Fatal (err )
88
+ }
78
89
79
- nextCase := nextCase
90
+ nextCase := nextCase // Capture the variable for the goroutine.
80
91
go func () {
92
+ defer wg .Done ()
81
93
// Run the test case.
82
- BashTestFile (t , nextCase .Path , config )
83
- wg .Done ()
94
+ ScriptTestFile (t , ext .Command , nextCase .Path , config )
84
95
}()
85
96
}
86
97
@@ -100,6 +111,18 @@ func DagTest(t *testing.T, cases []DagTestCase) {
100
111
}
101
112
}
102
113
114
+ func dagGetScriptExtension (path string , config ScriptConfig ) (ScriptExtension , error ) {
115
+ // Get extension from the path.
116
+ ext := filepath .Ext (path )
117
+ // Search for fitting script definition among config.ScriptExtensions.
118
+ for _ , def := range config .ScriptExtensions {
119
+ if def .Extension == ext {
120
+ return def , nil
121
+ }
122
+ }
123
+ return ScriptExtension {}, fmt .Errorf ("no script definition found for path %s with extension %s" , path , ext )
124
+ }
125
+
103
126
func validate (cases []DagTestCase ) error {
104
127
// Ensure that all cases have unique names.
105
128
names := make (map [string ]bool )
0 commit comments