train_steering_model.py
#!/usr/bin/env python """ Steering angle prediction model """ import os import argparse import json from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten, Lambda, ELU from keras.layers.convolutional import Convolution2D from server import client_generator def gen(hwm, host, port): for tup in client_generator(hwm=hwm, host=host, port=port): X, Y, _ = tup Y = Y[:, -1] if X.shape[1] == 1: # no temporal context X = X[:, -1] yield X, Y def get_model(time_len=1): ch, row, col = 3, 160, 320 # camera format model = Sequential() model.add(Lambda(lambda x: x/127.5 - 1., input_shape=(ch, row, col), output_shape=(ch, row, col))) model.add(Convolution2D(16, 8, 8, subsample=(4, 4), border_mode="same")) model.add(ELU()) model.add(Convolution2D(32, 5, 5, subsample=(2, 2), border_mode="same")) model.add(ELU()) model.add(Convolution2D(64, 5, 5, subsample=(2, 2), border_mode="same")) model.add(Flatten()) model.add(Dropout(.2)) model.add(ELU()) model.add(Dense(512)) model.add(Dropout(.5)) model.add(ELU()) model.add(Dense(1)) model.compile(optimizer="adam", loss="mse") return model if __name__ == "__main__": parser = argparse.ArgumentParser(description='Steering angle model trainer') parser.add_argument('--host', type=str, default="localhost", help='Data server ip address.') parser.add_argument('--port', type=int, default=5557, help='Port of server.') parser.add_argument('--val_port', type=int, default=5556, help='Port of server for validation dataset.') parser.add_argument('--batch', type=int, default=64, help='Batch size.') parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.') parser.add_argument('--epochsize', type=int, default=10000, help='How many frames per epoch.') parser.add_argument('--skipvalidate', dest='skipvalidate', action='store_true', help='Multiple path output.') parser.set_defaults(skipvalidate=False) parser.set_defaults(loadweights=False) args = parser.parse_args() model = get_model() model.fit_generator( gen(20, args.host, port=args.port), samples_per_epoch=10000, nb_epoch=args.epoch, validation_data=gen(20, args.host, port=args.val_port), nb_val_samples=1000 ) print("Saving model weights and configuration file.") if not os.path.exists("./outputs/steering_model"): os.makedirs("./outputs/steering_model") model.save_weights("./outputs/steering_model/steering_angle.keras", True) with open('./outputs/steering_model/steering_angle.json', 'w') as outfile: json.dump(model.to_json(), outfile)