-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
196 lines (164 loc) · 4.4 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
package main
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"runtime"
"github.com/davecheney/profile"
"github.com/wlattner/rf/tree"
flag "github.com/docker/docker/pkg/mflag"
)
var (
// model/prediction files
dataFile = flag.String([]string{"d", "-data"}, "", "example data")
predictFile = flag.String([]string{"p", "-predictions"}, "", "file to output predictions")
modelFile = flag.String([]string{"f", "-final_model"}, "rf.model", "file to output fitted model")
impFile = flag.String([]string{"-var_importance"}, "", "file to output variable importance estimates")
// model params
nTree = flag.Int([]string{"-trees"}, 10, "number of trees")
minSplit = flag.Int([]string{"-min_split"}, 2, "minimum number of samples required to split an internal node")
minLeaf = flag.Int([]string{"-min_leaf"}, 1, "minimum number of samples in newly created leaves")
maxFeatures = flag.Int([]string{"-max_features"}, -1, "number of features to consider when looking for the best split, -1 will default to √(# features)")
impurity = flag.String([]string{"-impurity"}, "gini", "impurity measure for evaluating splits")
// force classification
forceClf = flag.Bool([]string{"c", "-classification"}, false, "force parser to use integer targets/labels for classification")
// runtime params
nWorkers = flag.Int([]string{"-workers"}, 1, "number of workers for fitting trees")
runProfile = flag.Bool([]string{"-profile"}, false, "cpu profile")
)
type modelOptions struct {
nTree int
minSplit int
minLeaf int
maxFeatures int
impurity tree.ImpurityMeasure
nWorkers int
}
// lookup table for impurity measure
var impurityCode = map[string]tree.ImpurityMeasure{
"gini": tree.Gini,
"entropy": tree.Entropy,
}
func parseModelOpts() (modelOptions, error) {
o := modelOptions{
nTree: *nTree,
minSplit: *minSplit,
minLeaf: *minLeaf,
maxFeatures: *maxFeatures,
nWorkers: *nWorkers,
}
imp, ok := impurityCode[*impurity]
if !ok {
return o, errors.New("invalid impurity option, choices are gini or entropy")
}
o.impurity = imp
return o, nil
}
func main() {
flag.Parse()
if *nWorkers > 1 {
runtime.GOMAXPROCS(runtime.NumCPU())
}
if *runProfile {
defer profile.Start(profile.CPUProfile).Stop()
}
// make sure user specified csv file w/ data
if *dataFile == "" {
fmt.Fprintf(os.Stderr, "Usage of rf:\n\n")
flag.PrintDefaults()
os.Exit(1)
}
f, err := os.Open(*dataFile)
if err != nil {
fatal("error opening data file", err.Error())
}
defer f.Close()
d, err := parseCSV(f, *forceClf)
if err != nil {
fatal("error parsing input data", err.Error())
}
// consider non-blank *predictFile as prediction, fit otherwise
if *predictFile != "" {
m, err := loadModel(*modelFile)
if err != nil {
fatal("error opening model file", err.Error())
}
pred, err := m.Predict(d)
if err != nil {
fatal(err.Error())
}
// write the predictions to file
o, err := os.Create(*predictFile)
if err != nil {
fatal("error creating", *predictFile, err.Error())
}
defer o.Close()
err = writePred(o, pred)
if err != nil {
fatal("error writing predictions", err.Error())
}
os.Exit(0)
} else {
// must be model fitting
opt, err := parseModelOpts()
if err != nil {
fatal("invalid model option", err.Error())
}
// fit model
m := new(Model)
m.Fit(d, opt)
// save model to disk
o, err := os.Create(*modelFile)
if err != nil {
fatal("error saving model", err.Error())
}
defer o.Close()
err = m.Save(o)
if err != nil {
fatal("error saving model", err.Error())
}
// write var importance to file
if *impFile != "" {
f, err := os.Create(*impFile)
if err != nil {
fatal("error saving variable importance", err.Error())
}
defer f.Close()
err = m.SaveVarImp(f)
if err != nil {
fatal("error saving variable importance", err.Error())
}
}
m.Report(os.Stderr)
}
}
func loadModel(fName string) (*Model, error) {
f, err := os.Open(*modelFile)
if err != nil {
return nil, err
}
defer f.Close()
m := new(Model)
err = m.Load(f)
return m, err
}
func fatal(a ...interface{}) {
fmt.Fprintln(os.Stderr, a...)
os.Exit(1)
}
func writePred(w io.Writer, prediction []string) error {
wtr := bufio.NewWriter(w)
for _, pred := range prediction {
_, err := wtr.WriteString(pred)
if err != nil {
return err
}
err = wtr.WriteByte('\n')
if err != nil {
return err
}
}
return wtr.Flush()
}