0

I would like to use an alternating update rule with keras. I.e. per-batch I would like to call a regular gradient-based step, and next call a custom step.

I thought about implementing it by either inheriting an optimizer or a callback (and use the on-batch calls). However, neither would do, because they both lack the batch-data and batch-labels (and I need both).

Any idea on how to implement a custom alternating update with keras?

If required, I don't mind directly calling tensorflow specific methods, as long as I can keep use the project wrapped with the keras framework (with model.fit, model.predict .. )

Yuval Atzmon
  • 5,645
  • 3
  • 41
  • 74

1 Answers1

-1

try create a custom callback

import keras.callbacks as callbacks

class JSONMetrics(callbacks.Callback):

_model      = None
_each_epoch = None
_metrics    = None
_epoch      = None
_file_json  = None 

def __init__(self,model,each_epoch,logger=None):

    self._file_json = "file_log.json"
    self._model     = model
    self._each_epoch= each_epoch
    self._epoch     = 0
    self._metrics   = {'loss':[], 'acc':[]}

def on_epoch_begin(self, epoch, logs):
    # print('Epoch {0} begin'.format(epoch))
    try:
        with open(self._file_json, 'r') as f:   
            self._metrics = json.load(f)

def on_epoch_end(self, epoch, logs):
    self._logger.info('Nemesis: Epoch {0} end'.format(epoch))

    self._metrics['loss'].append(logs.get('loss'))
    self._metrics['acc'].append(logs.get('acc'))
    with open(self._file_json, 'w') as f:
        data = json.dump(self._metrics, f)

    if self._epoch % self._each_epoch == 0:

        file_name = 'weights%08d.h5' % self._epoch
        #print('Saving weights at {0} file'.format(file_name))
        self._model.save_weights(file_name)

    self._epoch += 1

You can evoke the self.model to solve your problem and save the acc and loss for example.