資源簡介
利用gan神經網絡生成一維數據,代碼有備注,真實數據為一個列表。利用gan神經網絡生成一維數據,代碼有備注,真實數據為一個列表。
代碼片段和文件信息
#coding=utf-8
#?實現簡單的DCGAN(深度卷積生成對抗網絡)
import?warnings
from?keras.layers?import?Dense?LeakyReLU?BatchNormalization?Input?Dropout
from?keras.layers?import?Activation
import?matplotlib.pyplot?as?plt
from?keras.models?import?Sequential?Model?load_model
from?keras.optimizers?import?adam
from?keras.utils?import?plot_model
warnings.filterwarnings(“ignore“)
import?numpy?as?np
import?csv
from?sklearn.metrics?import?mean_squared_error
def?uniform_sampling(n_sample?dim):
????#?均勻分布采樣
????return?np.random.uniform(0?1?size=(n_sample?dim))
def?normal_sampling(n_sample?dim):
????#?均勻分布采樣
????return?np.random.randn(n_sample?dim)
#?構建判別網絡
d_model?=?Sequential()
d_model.add(Dense(units=64?input_dim=1))
d_model.add(LeakyReLU(alpha=0.2))
d_model.add(BatchNormalization(momentum=0.8))
d_model.add(Dense(256))
d_model.add(LeakyReLU(alpha=0.2))
d_model.add(BatchNormalization(momentum=0.8))
d_model.add(Activation(‘tanh‘))
d_model.add(Dense(1?activation=‘sigmoid‘))??#?輸出樣本標記為1,即假樣本的概率
d_model.summary()
#?構建生成網絡
g_model?=?Sequential()
g_model.add(Dense(units=64?input_dim=1))
g_model.add(LeakyReLU(alpha=0.2))
g_model.add(BatchNormalization(momentum=0.8))
g_model.add(Dense(256))
g_model.add(LeakyReLU(alpha=0.2))
g_model.add(BatchNormalization(momentum=0.8))
g_model.add(Activation(‘tanh‘))
g_model.add(Dense(1?activation=‘tanh‘))
g_model.summary()
class?DCGAN:
????def?__init__(self?d_model?g_model
?????????????????input_dim=1?g_dim=1
?????????????????max_step=200?sample_size=256?d_iter=3?kind=‘normal‘):
????????self.input_dim?=?input_dim??#?圖像的展開維度,即判別網絡的輸入維度
????????self.g_dim?=?g_dim??#?隨機噪聲維度,即生成網絡的輸入維度
????????self.max_step?=?max_step??#?整個模型的迭代次數
????????self.sample_size?=?sample_size??#?訓練過程中小批量采樣的個數的一半
????????self.d_iter?=?d_iter??#?每次迭代,判別網絡訓練的次數
????????self.kind?=?kind??#?隨機噪聲分布類型
????????self.d_model?=?d_model??#?判別模型
????????self.g_model?=?g_model??#?生成模型
????????self.m_model?=?self.merge_model()??#?合并模型
????????self.optimizer?=?adam(lr=0.0002?beta_1=0.5)
????????self.d_model.compile(optimizer=self.optimizer?loss=‘binary_crossentropy‘)
????def?merge_model(self):
????????#?合并生成網絡與判別網絡
????????noise?=?Input(shape=(self.g_dim))
????????gen_sample?=?self.g_model(noise)
????????self.d_model.trainable?=?False??#?固定判別網絡,訓練合并網絡等同與訓練生成網絡
????????d_output?=?self.d_model(gen_sample)
????????m_model?=?Model(noise?d_output)??#?模型輸出生成樣本的預測結果,越接近0越好
????????m_model.compile(optimizer=‘adam‘?loss=‘binary_crossentropy‘)
????????return?m_model
????def?gen_noise(self?num_sample):
????????#?生成隨機噪聲數據
????????if?self.kind?==?‘normal‘:
????????????f?=?normal_sampling
????????elif?self.kind?==?‘uniform‘:
????????????f?=?uniform_sampling
????????else:
????????????raise?ValueError(‘暫不支持分布{}‘.format(self.kind))
????????return?f(num_sample?self.g_dim)
????def?gen_real_data(self?train_data):
????????#?真實樣本采樣
????????n_samples?=?train_data.shape[0]
????????i
評論
共有 條評論