資源簡介
BP算法實現iris數據集分類
代碼片段和文件信息
import?torch
import?torch.nn?as?nn
import?torch.nn.functional?as?F
from?torch.utils.data?import?Dataset
from?torch.utils.data?import?DataLoader
import?numpy?as?np
import?random
import?pandas?as?pd
from?sklearn.model_selection?import?train_test_split
import?matplotlib.pyplot?as?plt
print(‘=======?Data?load?&?Data?process?=======‘)
#?def?load_data():
‘‘‘
???Data?Information:
????1.?sepal?length?in?cm
????2.?sepal?width?in?cm
????3.?petal?length?in?cm
????4.?petal?width?in?cm
????5.?Label:
????--?Iris?Setosa
????--?Iris?Versicolour
????--?Iris?Virginica
‘‘‘
raw_data?=?pd.read_csv(‘iris.data‘?header=None
???????????????????????names=[‘sepal_length‘?‘sepal_width‘?‘petal_length‘?‘petal_width‘?‘label‘])
data?=?raw_data[[‘sepal_length‘?‘sepal_width‘?‘petal_length‘?‘petal_width‘]].values
class_mapping?=?{label:?idx?for?idx?label?in?enumerate(set(raw_data[‘label‘]))}
label?=?raw_data[‘label‘].map(class_mapping).values.reshape(-1?1)
print(“data?num?is:?{}?labels?num?is:?{}“.format(len(data)?len(label)))
X_train?X_test?Y_train?Y_test?=?train_test_split(data?label?test_size=0.2?shuffle=True)
X_valid?X_test?Y_valid?Y_test?=?train_test_split(X_test?Y_test?test_size=0.5?shuffle=True)
print(‘train?data_num:{}\nvalid?data_num:{}\ntest?data_num:{}‘.format(len(X_train)?len(X_valid)?len(X_test)))
def?set_random_seed(seed):
????random.seed(seed)
????np.random.seed(seed)
????torch.manual_seed(seed)
????if?torch.cuda.is_available():
????????torch.cuda.manual_seed(seed)
class?IrisDataset(Dataset):
????def?__init__(self?x_data?y_data):
????????super(IrisDataset?self).__init__()
????????self.x_data?=?torch.Tensor(x_data)
????????self.y_data?=?torch.Tensor(y_data)
????def?__getitem__(self?index):
????????return?self.x_data[index]?self.y_data[index]
????def?__len__(self):
????????return?len(self.x_data)
def?evaluate(XYw1w2w3):
????out?=?F.softmax(F.sigmoid(F.sigmoid(torch.Tensor(X).to(device).mm(w1)).mm(w2)).mm(w3)?dim=1)
????pred_y?=?torch.max(out?1)[1].cpu().data.numpy()
????target_y=Y.reshape(1-1)
????accuracy?=?float((pred_y==target_y).astype(int).sum())/?float(target_y.size)
????return?accuracy
#?init?parameters
set_random_seed(1)
device?=?torch.device(“cuda“?if?torch.cuda.is_available()?else?“cpu“)
D_in?H_1?H_2?D_out?=?4?20?20?3
LEARNING_RATE?=?0.1
MAX_EPOCH?=?100
BATCH_SIZE?=?40
w1?=?torch.randn(D_in?H_1?requires_grad=True?device=device)
w2?=?torch.randn(H_1?H_2?requires_grad=True?device=device)
w3?=?torch.randn(H_2?D_out?requires_grad=True?device=device)
#?data_loader
train_dataset?=?IrisDataset(X_train?Y_train)
train_loader?=?DataLoader(train_dataset?batch_size=BATCH_SIZE?shuffle=True)
train_losses?val_losses?=?[]?[]
for?epoch?in?range(MAX_EPOCH?+?1):
????batch_losses?=?[]
????for?step?(x?y)?in?enumerate(train_loader):
????????x?y?=?x.to(device)?y.to(device)
????????#?x:?batch_size*4??y:?b
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件????????5822??2021-01-04?07:20??main.py
?????文件?????????670??2021-01-04?07:20??__MACOSX\._main.py
?????文件????????4551??2021-01-04?07:23??iris.data
?????文件?????????654??2021-01-04?07:23??__MACOSX\._iris.data
評論
共有 條評論