資源簡介
Python3寫的EM算法,包含兩個程序,一個是em分類,一個是GMM應(yīng)用
(EM算法推導(dǎo)(收斂性證明和在GMM中的應(yīng)用))我的博客:https://blog.csdn.net/kevinoop/article/details/80522477

代碼片段和文件信息
#?!/usr/bin/python
#?-*-?coding:utf-8?-*-
import?numpy?as?np
from?scipy.stats?import?multivariate_normal
from?sklearn.mixture?import?GaussianMixture
from?mpl_toolkits.mplot3d?import?Axes3D
import?matplotlib?as?mpl
import?matplotlib.pyplot?as?plt
from?sklearn.metrics.pairwise?import?pairwise_distances_argmin
mpl.rcParams[‘font.sans-serif‘]?=?[‘SimHei‘]
mpl.rcParams[‘a(chǎn)xes.unicode_minus‘]?=?False
if?__name__?==?‘__main__‘:
????style?=?‘myself‘
????np.random.seed(0)
????mu1_fact?=?(0?0?0)
????cov1_fact?=?np.diag((1?2?3))
????data1?=?np.random.multivariate_normal(mu1_fact?cov1_fact?400)
????mu2_fact?=?(2?2?1)
????cov2_fact?=?np.array(((1?1?3)?(1?2?1)?(0?0?1)))
????data2?=?np.random.multivariate_normal(mu2_fact?cov2_fact?100)
????data?=?np.vstack((data1?data2))
????y?=?np.array([True]?*?400?+?[False]?*?100)
????if?style?==?‘sklearn‘:
????????g?=?GaussianMixture(n_components=2?covariance_type=‘full‘?tol=1e-6?max_iter=1000)
????????g.fit(data)
????????print(‘類別概率:\t‘?g.weights_[0])
????????print(‘均值:\n‘?g.means_?‘\n‘)
????????print(‘方差:\n‘?g.covariances_?‘\n‘)
????????mu1?mu2?=?g.means_
????????sigma1?sigma2?=?g.covariances_
????else:
????????num_iter?=?100
????????n?d?=?data.shape
????????#?隨機指定
????????#?mu1?=?np.random.standard_normal(d)
????????#?print?mu1
????????#?mu2?=?np.random.standard_normal(d)
????????#?print?mu2
????????mu1?=?data.min(axis=0)
????????mu2?=?data.max(axis=0)
????????sigma1?=?np.identity(d)
????????sigma2?=?np.identity(d)
????????pi?=?0.5
????????#?EM
????????for?i?in?range(num_iter):
????????????#?E?Step
????????????norm1?=?multivariate_normal(mu1?sigma1)
????????????norm2?=?multivariate_normal(mu2?sigma2)
????????????tau1?=?pi?*?norm1.pdf(data)
????????????tau2?=?(1?-?pi)?*?norm2.pdf(data)
????????????gamma?=?tau1?/?(tau1?+?tau2)
????????????#?M?Step
????????????mu1?=?np.dot(gamma?data)?/?np.sum(gamma)
????????????mu2?=?np.dot((1?-?gamma)?data)?/?np.sum((1?-?gamma))
????????????sigma1?=?np.dot(gamma?*?(data?-?mu1).T?data?-?mu1)?/?np.sum(gamma)
????????????sigma2?=?np.dot((1?-?gamma)?*?(data?-?mu2).T?data?-?mu2)?/?np.sum(1?-?gamma)
????????????pi?=?np.sum(gamma)?/?n
????????????print(i?“:\t“?mu1?mu2)
????????print(‘類別概率:\t‘?pi)
????????print(‘均值:\t‘?mu1?mu2)
????????print(‘方差:\n‘?sigma1?‘\n\n‘?sigma2?‘\n‘)
????#?預(yù)測分類
????norm1?=?multivariate_normal(mu1?sigma1)
????norm2?=?multivariate_normal(mu2?sigma2)
????tau1?=?norm1.pdf(data)
????tau2?=?norm2.pdf(data)
????fig?=?plt.figure(figsize=(10?5)?facecolor=‘w‘)
????ax?=?fig.add_subplot(121?projection=‘3d‘)
????ax.scatter(data[:?0]?data[:?1]?data[:?2]?c=‘b‘?s=30?marker=‘o‘?edgecolors=‘k‘?depthshade=True)
????ax.set_xlabel(‘X‘)
????ax.set_ylabel(‘Y‘)
????ax.set_zlabel(‘Z‘)
????ax.set_title(‘原始數(shù)據(jù)‘?fontsize=15)
????ax?=?fig.add_subplot(122?projection=‘3d‘)
????order?=?pairwise_distances_argmin([mu1_fact?mu2_fact]?[mu1?mu2]?metri
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件????????2784??2017-07-04?10:01??20.6.GMM_pdf.py
?????文件????????3737??2017-07-04?09:33??20.1.EM.py
評論
共有 條評論