forked from cystanford/breast_cancer_data
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbreast_svm.py
74 lines (64 loc) · 2.11 KB
/
breast_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
# 乳腺癌诊断分类
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
# 加载数据集,你需要把数据放到目录中
data = pd.read_csv("./data.csv")
# 数据探索
# 因为数据集中列比较多,我们需要把 dataframe 中的列全部显示出来
pd.set_option("display.max_columns", None)
print(data.columns)
print(data.head(5))
print(data.describe())
# 将特征字段分成 3 组
features_mean = list(data.columns[2:12])
features_se = list(data.columns[12:22])
features_worst = list(data.columns[22:32])
# 数据清洗
# ID 列没有用,删除该列
data.drop("id", axis=1, inplace=True)
# 将 B 良性替换为 0,M 恶性替换为 1
data["diagnosis"] = data["diagnosis"].map({"M": 1, "B": 0})
# 将肿瘤诊断结果可视化
sns.countplot(data["diagnosis"], label="Count")
plt.show()
# 用热力图呈现 features_mean 字段之间的相关性
corr = data[features_mean].corr()
plt.figure(figsize=(14, 14))
# annot=True 显示每个方格的数据
sns.heatmap(corr, annot=True)
plt.show()
# 特征选择
features_remain = [
"radius_mean",
"texture_mean",
"smoothness_mean",
"compactness_mean",
"symmetry_mean",
"fractal_dimension_mean",
]
# 抽取 30% 的数据作为测试集,其余作为训练集
train, test = train_test_split(
data, test_size=0.3
) # in this our main data is splitted into train and test
# 抽取特征选择的数值作为训练和测试数据
train_X = train[features_remain]
train_y = train["diagnosis"]
test_X = test[features_remain]
test_y = test["diagnosis"]
# 采用 Z-Score 规范化数据,保证每个特征维度的数据均值为 0,方差为 1
ss = StandardScaler()
train_X = ss.fit_transform(train_X)
test_X = ss.transform(test_X)
# 创建 SVM 分类器
model = svm.SVC()
# 用训练集做训练
model.fit(train_X, train_y)
# 用测试集做预测
prediction = model.predict(test_X)
print("准确率:", metrics.accuracy_score(prediction, test_y))