-
大小: 10KB文件類型: .py金幣: 1下載: 0 次發(fā)布日期: 2021-05-20
- 語言: Python
- 標(biāo)簽: vae??autoencoder??python??實(shí)現(xiàn)??
資源簡介
AutoEncoder是深度學(xué)習(xí)的另外一個重要內(nèi)容,并且非常有意思,神經(jīng)網(wǎng)絡(luò)通過大量數(shù)據(jù)集,進(jìn)行end-to-end的訓(xùn)練,不斷提高其準(zhǔn)確率,而AutoEncoder通過設(shè)計encode和decode過程使輸入和輸出越來越接近,是一種無監(jiān)督學(xué)習(xí)過程。
代碼片段和文件信息
import?itertools
import?matplotlib?as?mpl
import?numpy?as?np
import?os
import?tensorflow?as?tf
import?tensorflow.contrib.slim?as?slim
import?time
import?seaborn?as?sns
from?matplotlib?import?pyplot?as?plt
from?scipy.misc?import?imsave
from?tensorflow.contrib.learn.python.learn.datasets.mnist?import?read_data_sets
sns.set_style(‘whitegrid‘)
distributions?=?tf.distributions
flags?=?tf.app.flags
flags.DEFINE_string(‘data_dir‘?‘/tmp/dat/‘?‘Directory?for?data‘)
flags.DEFINE_string(‘logdir‘?‘/tmp/log/‘?‘Directory?for?logs‘)
#?For?making?plots:
#?flags.DEFINE_integer(‘latent_dim‘?2?‘Latent?dimensionality?of?model‘)
#?flags.DEFINE_integer(‘batch_size‘?64?‘Minibatch?size‘)
#?flags.DEFINE_integer(‘n_samples‘?10?‘Number?of?samples?to?save‘)
#?flags.DEFINE_integer(‘print_every‘?10?‘Print?every?n?iterations‘)
#?flags.DEFINE_integer(‘hidden_size‘?200?‘Hidden?size?for?neural?networks‘)
#?flags.DEFINE_integer(‘n_iterations‘?1000?‘number?of?iterations‘)
#?For?bigger?model:
flags.DEFINE_integer(‘latent_dim‘?100?‘Latent?dimensionality?of?model‘)
flags.DEFINE_integer(‘batch_size‘?64?‘Minibatch?size‘)
flags.DEFINE_integer(‘n_samples‘?1?‘Number?of?samples?to?save‘)
flags.DEFINE_integer(‘print_every‘?1000?‘Print?every?n?iterations‘)
flags.DEFINE_integer(‘hidden_size‘?200?‘Hidden?size?for?neural?networks‘)
flags.DEFINE_integer(‘n_iterations‘?100000?‘number?of?iterations‘)
FLAGS?=?flags.FLAGS
def?inference_network(x?latent_dim?hidden_size):
??“““Construct?an?inference?network?parametrizing?a?Gaussian.
??Args:
????x:?A?batch?of?MNIST?digits.
????latent_dim:?The?latent?dimensionality.
????hidden_size:?The?size?of?the?neural?net?hidden?layers.
??Returns:
????mu:?Mean?parameters?for?the?variational?family?Normal
????sigma:?Standard?deviation?parameters?for?the?variational?family?Normal
??“““
??with?slim.arg_scope([slim.fully_connected]?activation_fn=tf.nn.relu):
????net?=?slim.flatten(x)
????net?=?slim.fully_connected(net?hidden_size)
????net?=?slim.fully_connected(net?hidden_size)
????gaussian_params?=?slim.fully_connected(
????????net?latent_dim?*?2?activation_fn=None)
??#?The?mean?parameter?is?unconstrained
??mu?=?gaussian_params[:?:latent_dim]
??#?The?standard?deviation?must?be?positive.?Parametrize?with?a?softplus
??sigma?=?tf.nn.softplus(gaussian_params[:?latent_dim:])
??return?mu?sigma
def?generative_network(z?hidden_size):
??“““Build?a?generative?network?parametrizing?the?likelihood?of?the?data
??Args:
????z:?Samples?of?latent?variables
????hidden_size:?Size?of?the?hidden?state?of?the?neural?net
??Returns:
????bernoulli_logits:?logits?for?the?Bernoulli?likelihood?of?the?data
??“““
??with?slim.arg_scope([slim.fully_connected]?activation_fn=tf.nn.relu):
????net?=?slim.fully_connected(z?hidden_size)
????net?=?slim.fully_connected(net?hidden_size)
????bernoulli_logits?=?slim.fully_connected(net?784?activation_fn=None)
????bernoulli_logits?=?tf.reshape(bernoulli_logits?[-1?28?28?1])
??return?bernoulli_logits
評論
共有 條評論