-
大小: 3KB文件類型: .rar金幣: 1下載: 0 次發布日期: 2021-06-10
- 語言: 其他
- 標簽: CGAN??tensorflow??
資源簡介
條件生成對抗網絡(CGAN), tensorflow實現
代碼片段和文件信息
#?-*-?coding:?utf-8?-*-
from?__future__?import?division?print_function?absolute_import
import?tensorflow?as?tf
from?tensorflow.examples.tutorials.mnist?import?input_data
import?numpy?as?np
import?matplotlib.pyplot?as?plt
import?matplotlib.gridspec?as?gridspec
import?os
import?cv2
import?random
#?os.environ[“CUDA_DEVICE_ORDER“]?=?“PCI_BUS_ID“
#?os.environ[“CUDA_VISIBLE_DEVICES“]?=?“1“
flags?=?tf.app.flags
flags.DEFINE_integer(“iter“?1000000?“Iteration?to?train?[1000000]“)
flags.DEFINE_integer(“batch_size“?64?“The?size?of?batch?images?[64]“)
flags.DEFINE_string(“model_path“?‘./model/cgan.model‘?“Save?model?path?[‘./model/cgan.model‘]“)
flags.DEFINE_boolean(“is_train“?False?“Train?or?test?[False]“)
flags.DEFINE_integer(“test_number“?None?“The?number?that?want?to?generate?if?None?generate?randomly?[None]“)
FLAGS?=?flags.FLAGS
mnist?=?input_data.read_data_sets(‘./MNIST_data‘?one_hot=True)
Z_dim?=?100
X_dim?=?mnist.train.images.shape[1]
y_dim?=?mnist.train.labels.shape[1]
h_dim?=?128
def?xavier_init(size):
????in_dim?=?size[0]
????xavier_stddev?=?1.?/?tf.sqrt(in_dim?/?2.)
????return?tf.random_normal(shape=size?stddev=xavier_stddev)
“““?Discriminator?Net?model?“““
X?=?tf.placeholder(tf.float32?shape=[None?784])
y?=?tf.placeholder(tf.float32?shape=[None?y_dim])
D_W1?=?tf.Variable(xavier_init([X_dim?+?y_dim?h_dim]))
D_b1?=?tf.Variable(tf.zeros(shape=[h_dim]))
D_W2?=?tf.Variable(xavier_init([h_dim?1]))
D_b2?=?tf.Variable(tf.zeros(shape=[1]))
theta_D?=?[D_W1?D_W2?D_b1?D_b2]
def?discriminator(x?y):
????inputs?=?tf.concat(axis=1?values=[x?y])
????D_h1?=?tf.nn.relu(tf.matmul(inputs?D_W1)?+?D_b1)
????D_logit?=?tf.matmul(D_h1?D_W2)?+?D_b2
????D_prob?=?tf.nn.sigmoid(D_logit)
????return?D_prob?D_logit
“““?Generator?Net?model?“““
Z?=?tf.placeholder(tf.float32?shape=[None?Z_dim])
G_W1?=?tf.Variable(xavier_init([Z_dim?+?y_dim?h_dim]))
G_b1?=?tf.Variable(tf.zeros(shape=[h_dim]))
G_W2?=?tf.Variable(xavier_init([h_dim?X_dim]))
G_b2?=?tf.Variable(tf.zeros(shape=[X_dim]))
theta_G?=?[G_W1?G_W2?G_b1?G_b2]
def?generator(z?y):
????inputs?=?tf.concat(axis=1?values=[z?y])
????G_h1?=?tf.nn.relu(tf.matmul(inputs?G_W1)?+?G_b1)
????G_log_prob?=?tf.matmul(G_h1?G_W2)?+?G_b2
????G_prob?=?tf.nn.sigmoid(G_log_prob)
????return?G_prob
def?sample_Z(m?n):
????return?np.random.uniform(-1.?1.?size=[m?n])
def?plot(samples):
????fig?=?plt.figure(figsize=(4?4))
????gs?=?gridspec.GridSpec(4?4)
????gs.update(wspace=0.05?hspace=0.05)
????for?i?sample?in?enumerate(samples):
????????ax?=?plt.subplot(gs[i])
????????plt.axis(‘off‘)
????????ax.set_xticklabels([])
????????ax.set_yticklabels([])
????????ax.set_aspect(‘equal‘)
????????plt.imshow(sample.reshape(28?28)?cmap=‘Greys_r‘)
????return?fig
G_sample?=?generator(Z?y)
D_real?D_logit_real?=?discriminator(X?y)
D_fake?D_logit_fake?=?discriminator(G_sample?y)
D_loss_real
?屬性????????????大小?????日期????時間???名稱
-----------?---------??----------?-----??----
?????文件???????5453??2017-08-16?19:20??cgan_tensorflow.py
?????文件????????973??2017-08-27?12:57??README.md
-----------?---------??----------?-----??----
?????????????????6426????????????????????2
評論
共有 條評論