From 8d550c58ccab3597edff6fe1abc7178c3ed98d9a Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Mon, 11 Dec 2017 14:32:39 +0530 Subject: [PATCH] fixed batch normalization layer before activation --- segment_model.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/segment_model.py b/segment_model.py index 4c7a85b..8e52f72 100644 --- a/segment_model.py +++ b/segment_model.py @@ -4,7 +4,7 @@ import numpy as np from keras.models import Model,load_model,model_from_yaml from keras.layers import Input,Concatenate,Lambda, Reshape, Dropout from keras.layers import Dense,Conv2D, LSTM, Bidirectional, GRU -from keras.layers import BatchNormalization +from keras.layers import BatchNormalization,Activation from keras.losses import categorical_crossentropy from keras.utils import to_categorical from keras.optimizers import RMSprop @@ -51,15 +51,19 @@ def segment_model(input_dim): def simple_segment_model(input_dim): inp = Input(shape=input_dim) b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp) - bn_b_gr1 = BatchNormalization(momentum=0.98)(b_gr1) - b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(bn_b_gr1) - bn_b_gr2 = BatchNormalization(momentum=0.98)(b_gr2) - d1 = Dense(32, activation='relu')(bn_b_gr2) + # bn_b_gr1 = BatchNormalization(momentum=0.98)(b_gr1) + b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1) + # bn_b_gr2 = BatchNormalization(momentum=0.98)(b_gr2) + d1 = Dense(32)(b_gr2) bn_d1 = BatchNormalization(momentum=0.98)(d1) - d2 = Dense(8, activation='relu')(bn_d1) + bn_da1 = Activation('relu')(bn_d1) + d2 = Dense(8)(bn_da1) bn_d2 = BatchNormalization(momentum=0.98)(d2) - d3 = Dense(1, activation='softmax')(bn_d2) - oup = Reshape(target_shape=(input_dim[0],))(d3) + bn_da2 = Activation('relu')(bn_d2) + d3 = Dense(1)(bn_da2) + bn_d3 = BatchNormalization(momentum=0.98)(d3) + bn_da3 = Activation('softmax')(bn_d3) + oup = Reshape(target_shape=(input_dim[0],))(bn_da3) return Model(inp, oup) def write_model_arch(mod,mod_file):