-
大小: 12KB文件類型: .py金幣: 1下載: 0 次發布日期: 2021-05-13
- 語言: Python
- 標簽: tensorflow??
資源簡介
構建一個四層神經網絡識別手寫體數據集MNIST,然后將注意力模塊CBAM插入到網絡的第一層之后,查看注意力模塊的性能。可以改變CBAM模塊插入的位置,做到任意插入。
代碼片段和文件信息
import?warnings
warnings.filterwarnings(‘ignore‘?category=FutureWarning)
import?tensorflow?as?tf
from?tensorflow.examples.tutorials.mnist?import?input_data
slim?=?tf.contrib.slim
#create?weights?for?each?layer
def?get_weights(shape):
????data?=?tf.truncated_normal(shapestddev=0.1)
????return?tf.Variable(data)
def?get_biases(shape):
????data?=?tf.constant(0.1shape=shape)
????return?tf.Variable(data)
#2d?convolutional?function
def?convolution_2d(xw):
????return?tf.nn.conv2d(xwstrides=[1111]padding=‘SAME‘)
#2*2?max?pooling
def?max_pooling(x):
????return?tf.nn.max_pool(xksize=[1221]strides=[1221]padding=‘SAME‘)
def?combined_static_and_dynamic_shape(tensor):
????“““Returns?a?list?containing?static?and?dynamic?values?for?the?dimensions.??Returns?a?list?of?static?
????and?dynamic?values?for?shape?dimensions.?This?is??useful?to?preserve?static?shapes?when?available?in?reshape?operation.??
????Args:????tensor:?A?tensor?of?any?type.??
????Returns:????A?list?of?size?tensor.shape.ndims?containing?integers?or?a?scalar?tensor.??“““
????static_tensor_shape?=?tensor.shape.as_list()
????dynamic_tensor_shape?=?tf.shape(tensor)
????combined_shape?=?[]
????for?index?dim?in?enumerate(static_tensor_shape):
????????if?dim?is?not?None:
????????????combined_shape.append(dim)
????????else:
????????????combined_shape.append(dynamic_tensor_shape[index])
????return?combined_shape
def?convolutional_block_attention_module(feature_map?index?reduction_ratio?=?0.5):
????“““CBAM:convolutional?block?attention?module
????Args:
????????feature_map:input?feature?map
????????index:the?index?of?the?module
????????reduction_ratio:output?units?number?of?first?MLP?layer:reduction_ratio?*?feature?map
????Return:
????????feature?map?with?channel?and?spatial?attention“““
????with?tf.variable_scope(“cbam_%s“?%?(index)):
????????feature_map_shape?=?combined_static_and_dynamic_shape(feature_map)
????????#?channel?attention?module
????????channel_avg_weights?=?tf.nn.avg_pool(value=feature_map
?????????????????????????????????????????????ksize=[1?feature_map_shape[1]?feature_map_shape[2]?1]
?????????????????????????????????????????????strides=[1?1?1?1]
?????????????????????????????????????????????padding=‘VALID‘)??#?global?average?pool
????????channel_max_weights?=?tf.nn.max_pool(value=feature_map
?????????????????????????????????????????????ksize=[1?feature_map_shape[1]?feature_map_shape[2]?1]
?????????????????????????????????????????????strides=[1?1?1?1]
?????????????????????????????????????????????padding=‘VALID‘)
????????channel_avg_reshape?=?tf.reshape(channel_avg_weights
?????????????????????????????????????????[feature_map_shape[0]?1?feature_map_shape[3]])
????????channel_max_reshape?=?tf.reshape(channel_max_weights
?????????????????????????????????????????[feature_map_shape[0]?1?feature_map_shape[3]])
????????channel_w_reshape?=?tf.concat([channel_avg_reshape?channel_max_re
評論
共有 條評論