資源簡(jiǎn)介
用Python自己寫的線性回歸。其方法包括用最小二乘法直接求解和梯度下降法求解。理解代碼能讓你更好東其原理

代碼片段和文件信息
“““直接法和梯度法求解線性回歸“““
import?numpy?as?np
import?matplotlib.pyplot?as?plt
import?time
from?numpy.linalg?import?linalg?as?la
data?=??np.array([[?0.35291809??0.16468428??0.35774628]
???????[-0.55106013?-0.10981663??0.25468008]
???????[-0.65439632?-0.71406955??0.1061582?]
???????[-0.19790689??0.61536205??0.43122894]
???????[-0.00171825??0.66827656??0.44198075]
???????[-0.2739687??-1.16342739??0.01195186]
???????[?0.11592071?-0.18320789??0.29397728]
???????[-0.02707248?-0.53269863??0.21784183]
???????[?0.7321352???0.27868019??0.42643361]
???????[-0.76680149?-0.89838545??0.06411818]])
X?=?data[::2]
y?=?data[:-1]
“““直接求解“““
b?=?np.array([1])#偏移量?b?shape=(101)
b=b.repeat(10)
“““將偏移量與2個(gè)特征值組合?shape?=?(103)“““
X?=?np.column_stack((bX))
xtx?=?X.transpose().dot(X)
xtx?=?la.inv(xtx)
theta?=?xtx.dot(X.transpose()).dot(y)
“““梯度求解“““
#model
def?model(thetaX):
????theta?=?np.array(theta)
????return?X.dot(theta)
#cost
def?cost(mthetaXy):
????#print(theta)
????ele?=?y?-?model(thetaX)
????item?=?ele**2
????item_sum?=?np.sum(item)
????return?item_sum/2/m
#gradient
def?gradient(mthetaXycols):
????grad_theta?=?[]
????for?j?in?range(cols):
????????grad?=?(y-model(thetaX)).dot(X[:j])
????????grad_sum?=?np.sum(grad)????
????????grad_theta.append(-grad_sum/m)
????return?np.array(grad_theta)
#theta?update
def?theta_update(grad_thetathetasigma):
????return?theta?-?sigma?*?grad_theta
‘stop?stratege‘
def?stop_stratege(costcost_updatethreshold):
????return?cost-cost_update?
#?OLS?algorithm
def?OLS(Xythreshold):
????start?=?time.clock()
????#?樣本個(gè)數(shù)
????m=10
????#?設(shè)置權(quán)重參數(shù)的初始值
????theta?=?[000]
????#?迭代步數(shù)
????iters?=?0
????#?記錄代價(jià)函數(shù)的值
????cost_record=[]
????#?學(xué)習(xí)率
????sigma?=?0.0001
????cost_val?=?cost(mthetaXy)#代價(jià)函數(shù)
????cost_record.append(cost_val)
????while?True:
????????grad?=?gradient(mthetaXy3)#求梯度
????????#?參數(shù)更新
????????theta?=?theta_update(gradthetasigma)
????????cost_update?=?cost(mthetaXy)
????????if?stop_stratege(cost_valcost_updatethreshold):
????????????break
????????iters=iters+1
????????cost_val?=?cost_update
????????cost_record.append(cost_val)
????end?=?time.clock()
????print(“OLS?convergence?duration:?%f?s“?%?(end?-?start))
????return?cost_record?iterstheta
if?__name__==“__main__“:
????cost_record?iterstheta=OLS(Xy1e-10)
##????x?=?range(iters)
????x?=?np.arange(0iters+11)
????plt.figure()
????plt.plot(xcost_record)
????plt.show()
?屬性????????????大小?????日期????時(shí)間???名稱
-----------?---------??----------?-----??----
?????文件????????2710??2020-03-14?22:09??線性回歸的最小二乘法與梯度下降法代碼.py
評(píng)論
共有 條評(píng)論