train_generative_model.py
<#!/usr/bin/env python

"""
Usage:
>> ./server.py
>> ./train_generator.py autoencoder
"""
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import argparse
import time
from keras import callbacks as cbks
import logging
import tensorflow as tf
import numpy as np

from server import client_generator
from models.utils import save_images
mixtures = 1


def old_cleanup(data):
  X = data[0]
  if X.shape[1] == 1:
    X = X[:, -1, :]/127.5 - 1.
  return X


def gen(hwm, host, port):
  for tup in client_generator(hwm=hwm, host=host, port=port):
    X = cleanup(tup)
    yield X


def train_model(name, g_train, d_train, sampler, generator, samples_per_epoch, nb_epoch,
                z_dim=100, verbose=1, callbacks=[],
                validation_data=None, nb_val_samples=None,
                saver=None):
    """
    Main training loop.
    modified from Keras fit_generator
    """
    self = {}
    epoch = 0
    counter = 0
    out_labels = ['g_loss', 'd_loss', 'd_loss_fake', 'd_loss_legit', 'time']  # self.metrics_names
    callback_metrics = out_labels + ['val_' + n for n in out_labels]

    # prepare callbacks
    history = cbks.History()
    callbacks = [cbks.BaseLogger()] + callbacks + [history]
    if verbose:
        callbacks += [cbks.ProgbarLogger()]
    callbacks = cbks.CallbackList(callbacks)

    callbacks._set_params({
        'nb_epoch': nb_epoch,
        'nb_sample': samples_per_epoch,
        'verbose': verbose,
        'metrics': callback_metrics,
    })
    callbacks.on_train_begin()

    while epoch < nb_epoch:
      callbacks.on_epoch_begin(epoch)
      samples_seen = 0
      batch_index = 0
      while samples_seen < samples_per_epoch:
        z, x = next(generator)
        # build batch logs
        batch_logs = {}
        if type(x) is list:
          batch_size = len(x[0])
        elif type(x) is dict:
          batch_size = len(list(x.values())[0])
        else:
          batch_size = len(x)
        batch_logs['batch'] = batch_index
        batch_logs['size'] = batch_size
        callbacks.on_batch_begin(batch_index, batch_logs)

        t1 = time.time()
        d_losses = d_train(x, z, counter)
        z, x = next(generator)
        g_loss, samples, xs = g_train(x, z, counter)
        outs = (g_loss, ) + d_losses + (time.time() - t1, )
        counter += 1

        # save samples
        if batch_index % 100 == 0:
          join_image = np.zeros_like(np.concatenate([samples[:64], xs[:64]], axis=0))
          for j, (i1, i2) in enumerate(zip(samples[:64], xs[:64])):
            join_image[j*2] = i1
            join_image[j*2+1] = i2
          save_images(join_image, [8*2, 8],
                      './outputs/samples_%s/train_%s_%s.png' % (name, epoch, batch_index))

          samples, xs = sampler(z, x)
          join_image = np.zeros_like(np.concatenate([samples[:64], xs[:64]], axis=0))
          for j, (i1, i2) in enumerate(zip(samples[:64], xs[:64])):
            join_image[j*2] = i1
            join_image[j*2+1] = i2
          save_images(join_image, [8*2, 8],
                      './outputs/samples_%s/test_%s_%s.png' % (name, epoch, batch_index))

        for l, o in zip(out_labels, outs):
            batch_logs[l] = o

        callbacks.on_batch_end(batch_index, batch_logs)

        # construct epoch logs
        epoch_logs = {}
        batch_index += 1
        samples_seen += batch_size

      if saver is not None:
        saver(epoch)

      callbacks.on_epoch_end(epoch, epoch_logs)
      epoch += 1

    # _stop.set()
    callbacks.on_train_end()


if __name__ == "__main__":
  parser = argparse.ArgumentParser(description='Generative model trainer')
  parser.add_argument('model', type=str, default="bn_model", help='Model definitnion file')
  parser.add_argument('--name', type=str, default="autoencoder", help='Name of the model.')
  parser.add_argument('--host', type=str, default="localhost", help='Data server ip address.')
  parser.add_argument('--port', type=int, default=5557, help='Port of server.')
  # parser.add_argument('--time', type=int, default=1, help='How many temporal frames in a single input.')
  parser.add_argument('--batch', type=int, default=64, help='Batch size.')
  parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.')
  parser.add_argument('--gpu', type=int, default=0, help='Which gpu to use')
  parser.add_argument('--epochsize', type=int, default=10000, help='How many frames per epoch.')
  parser.add_argument('--loadweights', dest='loadweights', action='store_true', help='Start from checkpoint.')
  parser.set_defaults(skipvalidate=False)
  parser.set_defaults(loadweights=False)
  args = parser.parse_args()

  MODEL_NAME = args.model
  logging.info("Importing get_model from {}".format(args.model))
  exec("from models."+MODEL_NAME+" import get_model")
  # try to import `cleanup` from model file
  try:
    exec("from models."+MODEL_NAME+" import cleanup")
  except:
    cleanup = old_cleanup

  model_code = open('models/'+MODEL_NAME+'.py').read()

  if not os.path.exists("./outputs/results_"+args.name):
      os.makedirs("./outputs/results_"+args.name)
  if not os.path.exists("./outputs/samples_"+args.name):
      os.makedirs("./outputs/samples_"+args.name)

  with tf.Session() as sess:
    g_train, d_train, sampler, saver, loader, extras = get_model(sess=sess, name=args.name, batch_size=args.batch, gpu=args.gpu)

    # start from checkpoint
    if args.loadweights:
      loader()

    train_model(args.name, g_train, d_train, sampler,
                gen(20, args.host, port=args.port),
                samples_per_epoch=args.epochsize,
                nb_epoch=args.epoch, verbose=1, saver=saver
                )/pre>