博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow实现batch_normalization的例子代码
阅读量:4212 次
发布时间:2019-05-26

本文共 6014 字,大约阅读时间需要 20 分钟。

话不多说,直接上代码,看不懂的朋友,欢迎留言。

import numpy as npimport tensorflow as tfimport matplotlib.pyplot as pltACTIVATION = tf.nn.relu  #激励函数N_LAYERS = 7         #层数N_HIDDEN_UNITS = 30  #隐藏层的神经元个数def fix_seed(seed=1):    # reproducible    np.random.seed(seed)    tf.set_random_seed(seed)def plot_his(inputs, inputs_norm):    # plot histogram for the inputs of every layer    for j, all_inputs in enumerate([inputs, inputs_norm]):        for i, input in enumerate(all_inputs):            plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1))            plt.cla()            if i == 0:                the_range = (-7, 10)            else:                the_range = (-1, 1)            plt.hist(input.ravel(), bins=15, range=the_range, color='#FF5733')            plt.yticks(())            if j == 1:                plt.xticks(the_range)            else:                plt.xticks(())            ax = plt.gca()            ax.spines['right'].set_color('none')            ax.spines['top'].set_color('none')        plt.title("%s normalizing" % ("Without" if j == 0 else "With"))    plt.draw()    plt.pause(0.01)def built_net(xs, ys, norm):    def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):        # weights and biases (bad initialization for this case)        Weights = tf.Variable(tf.random_normal([in_size, out_size], mean=0., stddev=1.))        biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)        # fully connected product        Wx_plus_b = tf.matmul(inputs, Weights) + biases        # normalize fully connected product        if norm:            # Batch Normalize            fc_mean, fc_var = tf.nn.moments(                Wx_plus_b,                axes=[0],   # the dimension you wanna normalize, here [0] for batch                            # for image, you wanna do [0, 1, 2] for [batch, height, width] but not channel            )            scale = tf.Variable(tf.ones([out_size]))  #扩展参数。集体参考公式            shift = tf.Variable(tf.zeros([out_size]))  #平移参数,具体参考公式            epsilon = 0.001            # apply moving average for mean and var when train on batch            ema = tf.train.ExponentialMovingAverage(decay=0.5)            def mean_var_with_update():                ema_apply_op = ema.apply([fc_mean, fc_var])                with tf.control_dependencies([ema_apply_op]):                    return tf.identity(fc_mean), tf.identity(fc_var)            mean, var = mean_var_with_update()            Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, shift, scale, epsilon)            # similar with this two steps:            # Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)            # Wx_plus_b = Wx_plus_b * scale + shift        # activation        if activation_function is None:            outputs = Wx_plus_b        else:            outputs = activation_function(Wx_plus_b)        return outputs    fix_seed(1)    if norm:        # BN for the first input        fc_mean, fc_var = tf.nn.moments(            xs,            axes=[0],        )        scale = tf.Variable(tf.ones([1]))        shift = tf.Variable(tf.zeros([1]))        epsilon = 0.001        # apply moving average for mean and var when train on batch        ema = tf.train.ExponentialMovingAverage(decay=0.5)        def mean_var_with_update():            ema_apply_op = ema.apply([fc_mean, fc_var])            with tf.control_dependencies([ema_apply_op]):                return tf.identity(fc_mean), tf.identity(fc_var)        mean, var = mean_var_with_update()        xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon)    # record inputs for every layer    layers_inputs = [xs]    # build hidden layers    for l_n in range(N_LAYERS):        layer_input = layers_inputs[l_n]        in_size = layers_inputs[l_n].get_shape()[1].value        output = add_layer(            layer_input,    # input            in_size,        # input size            N_HIDDEN_UNITS, # output size            ACTIVATION,     # activation function            norm,           # normalize before activation        )        layers_inputs.append(output)    # add output for next run    # build output layer    prediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None)    cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))    train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost)    return [train_op, cost, layers_inputs]# make up datafix_seed(1)x_data = np.linspace(-7, 10, 2500)[:, np.newaxis]np.random.shuffle(x_data)noise = np.random.normal(0, 8, x_data.shape)y_data = np.square(x_data) - 5 + noise# plot input dataplt.scatter(x_data, y_data)plt.show()xs = tf.placeholder(tf.float32, [None, 1])  # [num_samples, num_features]ys = tf.placeholder(tf.float32, [None, 1])train_op, cost, layers_inputs = built_net(xs, ys, norm=False)   # without BNtrain_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True) # with BNsess = tf.Session()if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:    init = tf.initialize_all_variables()else:    init = tf.global_variables_initializer()sess.run(init)# record costcost_his = []cost_his_norm = []record_step = 5plt.ion()plt.figure(figsize=(7, 3))for i in range(250):    if i % 50 == 0:        # plot histogram        all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], feed_dict={xs: x_data, ys: y_data})        plot_his(all_inputs, all_inputs_norm)    # train on batch    sess.run([train_op, train_op_norm], feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]})    if i % record_step == 0:        # record cost        cost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data}))        cost_his_norm.append(sess.run(cost_norm, feed_dict={xs: x_data, ys: y_data}))plt.ioff()plt.figure()plt.plot(np.arange(len(cost_his))*record_step, np.array(cost_his), label='no BN')     # no normplt.plot(np.arange(len(cost_his))*record_step, np.array(cost_his_norm), label='BN')   # normplt.legend()plt.show()

转载地址:http://eukmi.baihongyu.com/

你可能感兴趣的文章
KMP求前缀函数(next数组)
查看>>
KMP
查看>>
poj 3863Business Center
查看>>
Android编译系统简要介绍和学习计划
查看>>
Android编译系统环境初始化过程分析
查看>>
user2eng 笔记
查看>>
DRM in Android
查看>>
ARC MRC 变换
查看>>
Swift cell的自适应高度
查看>>
【linux】.fuse_hiddenXXXX 文件是如何生成的?
查看>>
【LKM】整合多个LKM为1个
查看>>
【Windows C++】调用powershell上传指定目录下所有文件
查看>>
Java图形界面中单选按钮JRadioButton和按钮Button事件处理
查看>>
小练习 - 排序:冒泡、选择、快排
查看>>
SparkStreaming 如何保证消费Kafka的数据不丢失不重复
查看>>
Spark Shuffle及其调优
查看>>
数据仓库分层
查看>>
常见数据结构-TrieTree/线段树/TreeSet
查看>>
Hive数据倾斜
查看>>
TopK问题
查看>>