-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
42 lines (36 loc) · 1006 Bytes
/
train.py
File metadata and controls
42 lines (36 loc) · 1006 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
import pandas as pd
from sklearn.linear_model import ElasticNet
import pickle
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input_data",
type=str,
dest="input_data",
required=True,
help="input data file.",
)
parser.add_argument(
"-m",
"--model_path",
type=str,
dest="model_path",
required=True,
help="path for saving trained model.",
)
args = parser.parse_args()
input_data = args.input_data
model_path = args.model_path
if __name__ == "__main__":
# Training a simple model to fit the training data
alpha = 0.6
l1_ratio = 0.4
data = pd.read_csv(input_data, sep=",")
# The predicted column is "quality" which is a scalar [3, 9]
train_x = data.drop(["quality"], axis=1)
train_y = data[["quality"]]
lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
lr.fit(train_x, train_y)
pickle.dump(lr, open(model_path, 'wb'))
print("#INFO: Model is succefully processed!")