Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Provides TFM orbit actions and associated helper functions/classes.""" | |
| import os | |
| from typing import List | |
| from absl import logging | |
| import gin | |
| import orbit | |
| import tensorflow as tf, tf_keras | |
| from official.core import base_trainer | |
| from official.core import config_definitions | |
| from official.modeling import optimization | |
| class PruningAction: | |
| """Train action to updates pruning related information. | |
| This action updates pruning steps at the end of trainig loop, and log | |
| pruning metrics to tensorboard. | |
| This action must be used when training a pruned model to avoid pruning error. | |
| """ | |
| def __init__( | |
| self, | |
| export_dir: str, | |
| model: tf_keras.Model, | |
| optimizer: tf_keras.optimizers.Optimizer, | |
| ): | |
| """Initializes the instance. | |
| Args: | |
| export_dir: `str` for the export directory of the pruning summaries. | |
| model: `tf_keras.Model` model instance used for training. This will be | |
| used to assign a pruning step to each prunable weight. | |
| optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for | |
| training. This will be used to find the current training steps. | |
| """ | |
| # TODO(b/221490190): Avoid local import when the bug is fixed. | |
| import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top | |
| self._optimizer = optimizer | |
| self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep() | |
| self.update_pruning_step.set_model(model) | |
| self.update_pruning_step.on_train_begin() | |
| self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries( | |
| log_dir=export_dir) | |
| model.optimizer = optimizer | |
| self.pruning_summaries.set_model(model) | |
| def __call__(self, output: orbit.runner.Output): | |
| """Update pruning step and log pruning summaries. | |
| Args: | |
| output: The train output. | |
| """ | |
| self.update_pruning_step.on_epoch_end(batch=None) | |
| self.pruning_summaries.on_epoch_begin(epoch=None) | |
| class EMACheckpointing: | |
| """Eval action to save checkpoint with average weights when EMA is used. | |
| This action swaps the weights of the model with the average weights, then it | |
| saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is | |
| expensive for large models, so doing this action in eval is more efficient | |
| than training. | |
| """ | |
| def __init__(self, | |
| export_dir: str, | |
| optimizer: tf_keras.optimizers.Optimizer, | |
| checkpoint: tf.train.Checkpoint, | |
| max_to_keep: int = 1): | |
| """Initializes the instance. | |
| Args: | |
| export_dir: `str` for the export directory of the EMA average weights. | |
| optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for | |
| training. This will be used to swap the model weights with the average | |
| weigths. | |
| checkpoint: `tf.train.Checkpoint` instance. | |
| max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir. | |
| """ | |
| if not isinstance(optimizer, optimization.ExponentialMovingAverage): | |
| raise ValueError('Optimizer has to be instance of' | |
| 'optimization.ExponentialMovingAverage for' | |
| 'EMACheckpointing action') | |
| export_dir = os.path.join(export_dir, 'ema_checkpoints') | |
| tf.io.gfile.makedirs(os.path.dirname(export_dir)) | |
| self._optimizer = optimizer | |
| self._checkpoint = checkpoint | |
| self._checkpoint_manager = tf.train.CheckpointManager( | |
| checkpoint, | |
| directory=export_dir, | |
| max_to_keep=max_to_keep, | |
| checkpoint_name='average_weights') | |
| def __call__(self, output: orbit.runner.Output): | |
| """Swaps model weights, and saves the checkpoint. | |
| Args: | |
| output: The train or eval output. | |
| """ | |
| self._optimizer.swap_weights() | |
| self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations) | |
| self._optimizer.swap_weights() | |
| class RecoveryAction: | |
| """Train action to recover from loss blowup. | |
| Checks the loss value by the given threshold. If applicable, recover the | |
| model by reading the checkpoint on disk. | |
| """ | |
| def __init__(self, checkpoint_manager: tf.train.CheckpointManager): | |
| self.checkpoint_manager = checkpoint_manager | |
| def __call__(self, _): | |
| """Recovers the training by triggering checkpoint restoration.""" | |
| # Loads the previous good checkpoint. | |
| checkpoint_path = self.checkpoint_manager.restore_or_initialize() | |
| logging.warning('Recovering the model from checkpoint: %s.', | |
| checkpoint_path) | |
| class RecoveryCondition: | |
| """Recovery Condition.""" | |
| def __init__(self, | |
| global_step: tf.Variable, | |
| loss_upper_bound: float, | |
| recovery_begin_steps: int = 0, | |
| recovery_max_trials: int = 3): | |
| self.recover_counter = 0 | |
| self.recovery_begin_steps = recovery_begin_steps | |
| self.recovery_max_trials = recovery_max_trials | |
| self.loss_upper_bound = loss_upper_bound | |
| self.global_step = global_step | |
| def __call__(self, outputs: orbit.runner.Output): | |
| loss_value = outputs['training_loss'] | |
| if tf.math.is_nan(loss_value): | |
| self.recover_counter += 1 | |
| if self.recover_counter > self.recovery_max_trials: | |
| raise RuntimeError( | |
| 'The loss value is NaN after training loop and it happens %d times.' | |
| % self.recover_counter) | |
| return True | |
| if (self.global_step >= self.recovery_begin_steps and | |
| loss_value > self.loss_upper_bound): | |
| self.recover_counter += 1 | |
| if self.recover_counter > self.recovery_max_trials: | |
| raise RuntimeError( | |
| f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.' | |
| ) | |
| return True | |
| return False | |
| def get_eval_actions(params: config_definitions.ExperimentConfig, | |
| trainer: base_trainer.Trainer, | |
| model_dir: str) -> List[orbit.Action]: | |
| """Gets eval actions for TFM trainer.""" | |
| eval_actions = [] | |
| # Adds ema checkpointing action to save the average weights under | |
| # ema_checkpoints subdir. | |
| if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage): | |
| eval_actions.append( | |
| EMACheckpointing( | |
| export_dir=model_dir, | |
| optimizer=trainer.optimizer, | |
| checkpoint=trainer.checkpoint, | |
| max_to_keep=params.trainer.max_to_keep)) | |
| return eval_actions | |
| def get_train_actions( | |
| params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer, | |
| model_dir: str, | |
| checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]: | |
| """Gets train actions for TFM trainer.""" | |
| train_actions = [] | |
| # Adds pruning callback actions. | |
| if hasattr(params.task, 'pruning') and params.task.pruning: | |
| train_actions.append( | |
| PruningAction( | |
| export_dir=model_dir, | |
| model=trainer.model, | |
| optimizer=trainer.optimizer)) | |
| if params.trainer.recovery_max_trials >= 0: | |
| recovery_condition = RecoveryCondition( | |
| global_step=trainer.global_step, | |
| loss_upper_bound=params.trainer.loss_upper_bound, | |
| recovery_begin_steps=params.trainer.recovery_begin_steps, | |
| recovery_max_trials=params.trainer.recovery_max_trials, | |
| ) | |
| recover_action = orbit.actions.ConditionalAction( | |
| condition=recovery_condition, | |
| action=RecoveryAction(checkpoint_manager), | |
| ) | |
| train_actions.append(recover_action) | |
| if ( | |
| params.trainer.preemption_on_demand_checkpoint | |
| and trainer.strategy.cluster_resolver | |
| ): | |
| on_demand_checkpoint_action = orbit.actions.SaveCheckpointIfPreempted( | |
| trainer.strategy.cluster_resolver, | |
| checkpoint_manager, | |
| trainer.global_step, | |
| keep_running_after_save=True, | |
| ) | |
| train_actions.append(on_demand_checkpoint_action) | |
| return train_actions | |