-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtrainer.py
85 lines (59 loc) · 2.49 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pandas
from sklearn import preprocessing
from sklearn.ensemble import RandomForestClassifier
import numpy
from sklearn import svm
from sklearn import cross_validation as cv
import matplotlib.pylab as plt
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning,
module="pandas", lineno=570)
def return_nonstring_col(data_cols):
cols_to_keep=[]
train_cols=[]
for col in data_cols:
if col!='URL' and col!='host' and col!='path':
cols_to_keep.append(col)
if col!='malicious' and col!='result':
train_cols.append(col)
return [cols_to_keep,train_cols]
def svm_classifier(train,query,train_cols):
clf = svm.SVC()
train[train_cols] = preprocessing.scale(train[train_cols])
query[train_cols] = preprocessing.scale(query[train_cols])
print clf.fit(train[train_cols], train['malicious'])
scores = cv.cross_val_score(clf, train[train_cols], train['malicious'], cv=30)
print('Estimated score SVM: %0.5f (+/- %0.5f)' % (scores.mean(), scores.std() / 2))
query['result']=clf.predict(query[train_cols])
print query[['URL','result']]
# Called from gui
def forest_classifier_gui(train,query,train_cols):
rf = RandomForestClassifier(n_estimators=150)
print rf.fit(train[train_cols], train['malicious'])
query['result']=rf.predict(query[train_cols])
print query[['URL','result']].head(2)
return query['result']
def forest_classifier(train,query,train_cols):
rf = RandomForestClassifier(n_estimators=150)
print rf.fit(train[train_cols], train['malicious'])
scores = cv.cross_val_score(rf, train[train_cols], train['malicious'], cv=30)
print('Estimated score RandomForestClassifier: %0.5f (+/- %0.5f)' % (scores.mean(), scores.std() / 2))
query['result']=rf.predict(query[train_cols])
print query[['URL','result']]
def train(db,test_db):
query_csv = pandas.read_csv(test_db)
cols_to_keep,train_cols=return_nonstring_col(query_csv.columns)
#query=query_csv[cols_to_keep]
train_csv = pandas.read_csv(db)
cols_to_keep,train_cols=return_nonstring_col(train_csv.columns)
train=train_csv[cols_to_keep]
svm_classifier(train_csv,query_csv,train_cols)
forest_classifier(train_csv,query_csv,train_cols)
def gui_caller(db,test_db):
query_csv = pandas.read_csv(test_db)
cols_to_keep,train_cols=return_nonstring_col(query_csv.columns)
#query=query_csv[cols_to_keep]
train_csv = pandas.read_csv(db)
cols_to_keep,train_cols=return_nonstring_col(train_csv.columns)
train=train_csv[cols_to_keep]
return forest_classifier_gui(train_csv,query_csv,train_cols)