# 绘制树状图
# 树状图亦称为树枝状图,是一种通过树状结构描述父子成员层次结构的图形。树形图的形状一般是一个上下颠倒的树,其根部是一个没有父成员的根节点,之后从根节点开始用线连接子成员,使子成员变为子节点,直至线的末端为没有子成员的树叶节点为止。树形图用于说明成员之间的关系和连接,常见于分类学、进化科学、企业组织管理等领域。例如,frog 技术专家 Paul Adams 设计的人工智能树状图 (部分) 如图 8-19 所示。
从图 8-19 可以看出,树状图的树叶节点经过第一层聚类形成两个类簇,即自然语言处理和机器学习,之后经过第二层聚类形成一个类簇————人工智能。
树状图的绘制需要准备聚类数据。单独使用 matplotlib 较为烦琐,因此这里可以结合 scipy 包的功能完成。scipy 是一款基于 numpy 的、专为科学和工程设计的、易于使用的 Python 包,它提供了线性代数、傅里叶变换、信号处理等丰富的功能。
scipy.cluster 模块中包含众多聚类算法,主要包括矢量量化和层次聚类两种,并分别封装到 vq 和 hierarchy 模块中。hierarchy 模块中提供了一系列聚类的功能,可以轻松生成聚类数据并绘制树状图。下面介绍 hierarchy 模块的常用函数。
# dendrogram() 函数
# dendrogram() 函数用于将层次聚类数据绘制为树状图,其语法格式如下所示 :
dendrogram(Z, p=30, truncate_mode=None, color_threshold=None,
get_leaves=True, orientation='top', latbels=None,
count_sort=False, distance_sort=False,
show_leaf_counts=True, **kwargs)
# 该函数常用参数的含义如下。
# · Z :表示编码层次聚类的链接矩阵。
· truncate_mode :表示截断的模式,用于压缩因观测矩阵过大而难以阅读的树状图,可以取值为 None (不执行截断,默认)、'lastp'、'level'。
· color_threshold :表示颜色阈值。
· labels :表示节点对应的文本标签。
# linkage() 函数
# linkage() 函数用于将一维压缩距离矩阵或二维观测向量阵列进行层次聚类或凝聚聚类,其语法格式如下所示 :
linkage(y, method='single', metric='euclidean', optimal_ordering=False)
# 该函数常用参数的含义如下。
( 1 ) y :可以是一维距离向量或二维的坐标矩阵。
( 2 )method :表示计算类簇之间距离的方法, 常用的取值可以为 'single'、'complete'、'averag' 和 'ward',各取值具体含义如下。
# · 'single' :表示将类簇与类簇之间最近的距离作为类簇间距。
· 'complete' :表示将类簇与类簇之间最远的距离作为类簇间距。
· 'average' :表示将类簇与类簇之间的平均距离作为类簇间距。
· 'ward' :表示将每个类簇的方差最小化作为类簇间距。
# inkage() 函数会返回编码层次聚类的链接矩阵。
美国对各州的谋杀、暴力、爆炸等犯罪案件的数量进行了统计,并将统计后的结果整理到 USArrests.xlsx 文件中。下面使用 pandas 读取 USArrests.xlsx 文件的数据,并将犯罪案例数量相似度高的州进行聚类后绘制一个树状图,示例代码如下。
In [9]:
import pandas as pd
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as shc
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
df = pd.read_excel(r'C:\Users\admin\Desktop\USArrests.xlsx')
plt.figure(figsize=(10, 6), dpi= 80)
plt.title("美国各州犯罪案件的树状图", fontsize=12)
# 绘制树状图
dend = shc.dendrogram(shc.linkage(df[['Murder', 'Assault', 'UrbanPop']],
method='ward'), labels=df.State.values, color_threshold=l00)
plt.xticks(fontsize=10.5)
plt.ylabeK('案例数量')
plt.show()