import pandas as pd import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz
from sklearn.tree import DecisionTreeRegressor from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt import matplotlib as mpl
读取数据
iris = load_iris()
data = pd.DataFrame(iris.data) data.columns = iris.feature_names data['Species'] = load_iris().target print(data)
切分数据
x = data.iloc[:, 2:4] # 花萼长度和宽度y = data.iloc[:, -1]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.75, random_state=42)
训练模型
tree_clf = DecisionTreeClassifier(max_depth=8, criterion='gini') tree_clf.fit(x_train, y_train)
graphviz 强大而便捷的关系图/ 流程图绘制方法, 很容易让我们联想到机器学习中的Decision Tree 的展示方式。幸运的是,scikit-learn 提供了生成.dot 文件的接口export_graphviz(
tree_clf, out_file="./iris_tree.dot",
feature_names=iris.feature_names[:2], class_names=iris.target_names, rounded=True,
filled=True)