Spaces:
Sleeping
Sleeping
| # Copyright 2023 The Orbit Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """AbstractTrainer/Evaluator subclasses with added functionality. | |
| The classes in this module provide some additional structure to the bare | |
| `AbstractTrainer`/`AbstractEvaluator` APIs. | |
| Both `StandardTrainer` and `StandardEvaluator` split the train/eval loops into | |
| "begin", "step", and "end" methods, and provide an implementation of the loop | |
| itself that makes calls to the relevant step method. | |
| `StandardTrainer` supports running the loop using the TF while loop construct | |
| for added performance (particularly on TPUs). It additionally provides some | |
| functionality to make writing summaries from inside a model more performant when | |
| running on TPUs. | |
| These classes are intended to work well in common settings, however there may | |
| be use cases these classes don't support (for instance, `StandardEvaluator` in | |
| particular doesn't support running full evaluations over multiple different eval | |
| datasets). Users are encouraged to simply fall back to custom `AbstractTrainer` | |
| and `AbstractEvaluator` subclasses in these cases. | |
| """ | |
| import abc | |
| from typing import Any, Optional | |
| import dataclasses | |
| from orbit import runner | |
| from orbit.utils import loop_fns | |
| import tensorflow as tf, tf_keras | |
| class StandardTrainerOptions: | |
| """Advanced options for `orbit.StandardTrainer`. | |
| Attributes: | |
| use_tf_function: A boolean indicating whether to apply `tf.function` to the | |
| training loop. This will only affect the body of the loop (involving | |
| `train_step`); `train_loop_begin` and `train_loop_end` will always be run | |
| in eager mode. | |
| use_tf_while_loop: A boolean indicating whether to run the training loop | |
| using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`. | |
| use_tpu_summary_optimization: A boolean indicating whether to enable a | |
| performance optimization for summaries in TPUs. Writing summaries | |
| conditionally with outside compilation on TPUs can be extremely slow. If | |
| `True`, this optimization creates two `tf.function`s with two XLA programs | |
| (one with summary calls, and one without). The program with summaries runs | |
| only for one step when summaries should be recorded. | |
| """ | |
| use_tf_function: bool = True | |
| use_tf_while_loop: bool = True | |
| use_tpu_summary_optimization: bool = False | |
| class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): | |
| """Implements standard functionality on top of the AbstractTrainer API. | |
| This class structures the training "inner loop" roughly as follows: | |
| train_loop_begin() | |
| for _ in range(num_steps): | |
| train_step(train_iterator) | |
| return train_loop_end() | |
| Calls to `train_loop_begin` and `train_loop_end` are always done in eager | |
| mode, while the loop/`train_step` may be implemented using `tf.while` and/or | |
| `tf.function`, as determined by the `options` passed to `__init__`. | |
| """ | |
| def __init__(self, | |
| train_dataset, | |
| options: Optional[StandardTrainerOptions] = None): | |
| """Initializes the `StandardTrainer` instance. | |
| Args: | |
| train_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or | |
| `DistributedDataset`. | |
| options: An `orbit.StandardTrainerOptions` instance. | |
| """ | |
| options = options or StandardTrainerOptions() | |
| if options.use_tf_while_loop and not options.use_tf_function: | |
| raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` " | |
| "is not supported") | |
| if options.use_tpu_summary_optimization and not options.use_tf_while_loop: | |
| raise ValueError("`use_tpu_summary_optimization=True` and " | |
| "`use_tf_while_loop=False` is not supported") | |
| self._train_options = options | |
| self._train_dataset = train_dataset | |
| self._train_iter = None | |
| self._train_loop_fn = None | |
| def create_train_loop_fn(self): | |
| """Creates a training loop from the current step function and options. | |
| Returns: | |
| The train loop function, i.e. wrapper of multiple train steps. | |
| """ | |
| train_step_fn = self.train_step | |
| if self._train_options.use_tf_while_loop: | |
| loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn) | |
| if self._train_options.use_tpu_summary_optimization: | |
| loop_fn = loop_fns.LoopFnWithSummaries(loop_fn) | |
| else: | |
| loop_fn = tf.function(loop_fn) | |
| else: | |
| if self._train_options.use_tf_function: | |
| train_step_fn = tf.function(train_step_fn) | |
| loop_fn = loop_fns.create_loop_fn(train_step_fn) | |
| return loop_fn | |
| def train(self, num_steps: tf.Tensor) -> Optional[runner.Output]: | |
| """Implements `num_steps` steps of training. | |
| Args: | |
| num_steps: The number of training steps to run. This corresponds directly | |
| to the number of calls made to `train_step`. | |
| Returns: | |
| The output of `train_loop_end`. | |
| """ | |
| self.train_loop_begin() | |
| if self._train_loop_fn is None: | |
| self._train_loop_fn = self.create_train_loop_fn() | |
| if self._train_iter is None: | |
| self._train_iter = tf.nest.map_structure(iter, self.train_dataset) | |
| self._train_loop_fn(self._train_iter, num_steps) | |
| return self.train_loop_end() | |
| def train_loop_begin(self): | |
| """Called once at the beginning of the training loop. | |
| This method is always called in eager mode, and is a good place to reset | |
| metrics that accumulate values over multiple steps of training. | |
| Note that this method is called before dataset iterator creation. | |
| """ | |
| pass | |
| def train_step(self, iterator): | |
| """Implements one step of training. | |
| What a "step" consists of is up to the implementer. When using distribution | |
| strategies, the call to this method takes place in the "cross-replica | |
| context" for generality, to allow e.g. multiple iterator dequeues and calls | |
| to `strategy.run`. | |
| Note that if `use_tf_function=True`, all the code inside `train_step` should | |
| be compatible with `tf.function` tracing (and in particular, any state | |
| modifications involving `self` should be avoided). In some cases, non- | |
| `tf.function` compatible code can be moved to `train_loop_begin` or | |
| `train_loop_end`, which always execute eagerly. | |
| Args: | |
| iterator: A `tf.nest`-compatible structure of `tf.data.Iterator` or | |
| `DistributedIterator`. The structure of this input matches the structure | |
| of `train_dataset` as passed to `__init__`. | |
| """ | |
| pass | |
| def train_loop_end(self) -> Optional[runner.Output]: | |
| """Called once at the end of the training loop. | |
| This method is always called in eager mode, and is a good place to get | |
| metric results. The value returned from this function will be returned as-is | |
| from the `train` method implementation provided by `StandardTrainer`. | |
| Returns: | |
| The function may return a dictionary of `Tensors`, which will be | |
| written to logs and as TensorBoard summaries. It can also be a | |
| nested dictionary, yielding a hierarchy of summary directories. | |
| """ | |
| pass | |
| def train_dataset(self): | |
| """The current training dataset.""" | |
| return self._train_dataset | |
| def train_dataset(self, train_dataset): | |
| """Sets a new training dataset, replacing the current one. | |
| Any unprocessed examples in the current dataset are discarded. | |
| Args: | |
| train_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or | |
| `DistributedDataset`. | |
| """ | |
| self._train_dataset = train_dataset | |
| self._train_iter = None | |
| class StandardEvaluatorOptions: | |
| """Advanced options for the `orbit.StandardEvaluator`. | |
| Attributes: | |
| use_tf_function: A boolean indicating whether to apply `tf.function` to the | |
| evaluation loop. This will only affect the body of the loop (involving | |
| `eval_step`); `eval_loop_begin` and `eval_loop_end` will always be run | |
| in eager mode. | |
| use_tf_while_loop: A boolean indicating whether to run the evaluation loop | |
| using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`. | |
| recreate_iterator_for_each_eval: A boolean indicating whether to recreate a | |
| new iterator for the evaluation dataset before each round of evaluation, | |
| which implies each round of evaluation starts from the beginning of | |
| the evaluation dataset. For example, the evaluation dataset is | |
| `[1, 2, 3, 4]`, batch size is 1 and evaluation steps is 2. If `True`, the | |
| data to be evaluated is [1, 2] every time. If `False`, the iterator | |
| state is maintained between calls to `StandardEvaluator.evaluate()`. | |
| """ | |
| use_tf_function: bool = True | |
| use_tf_while_loop: bool = False | |
| recreate_iterator_for_each_eval: bool = True | |
| class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): | |
| """Implements the standard functionality of AbstractEvaluator APIs. | |
| This class structures evaluation roughly as follows: | |
| state = eval_begin() | |
| for _ in range(num_steps): | |
| step_outputs = eval_step(eval_iterator) | |
| state = eval_reduce(state, step_outputs) | |
| return eval_end(state) | |
| Calls to `eval_begin` and `eval_end` are always done in eager | |
| mode, while `eval_step` may be compiled with `tf.function` as determined by | |
| the `options` passed to `__init__`. `eval_reduce` is in eager mode if | |
| `use_tf_while_loop=False` in `StandardEvaluatorOptions`, but in graph mode if | |
| `use_tf_while_loop=True`. | |
| This class does not support completely evaluating multiple different datasets | |
| (i.e., where every example of each dataset should be processed, as opposed to | |
| running for a fixed number of evaluation steps). A custom `AbstractEvaluator` | |
| is recommended in this case. | |
| """ | |
| def __init__(self, | |
| eval_dataset, | |
| options: Optional[StandardEvaluatorOptions] = None): | |
| """Initializes the `StandardEvaluator` instance. | |
| Args: | |
| eval_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or | |
| `DistributedDataset`. On TPUs, if users want to exaust the dataset | |
| without specifying number of eval steps, it is recommended to set | |
| `drop_remainder=False` when batching the dataset, so the infrastructure | |
| can handle the last partial batch properly. | |
| options: An `orbit.StandardEvaluatorOptions` instance. | |
| """ | |
| options = options or StandardEvaluatorOptions() | |
| if options.use_tf_while_loop and not options.use_tf_function: | |
| raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` " | |
| "is not supported") | |
| self._eval_options = options | |
| self._eval_dataset = eval_dataset | |
| self._eval_iter = None | |
| self._eval_loop_fn = None | |
| def create_eval_loop_fn(self, has_state: bool): | |
| """Creates an eval loop from the current step function and options. | |
| Args: | |
| has_state: If the step function has state, state will be kept in the loop. | |
| Returns: | |
| The eval loop function, i.e. wrapper of multiple eval steps. | |
| """ | |
| eval_step_fn = self.eval_step | |
| if self._eval_options.use_tf_while_loop: | |
| # TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input | |
| # even when it is not used inside the loop. To workaround this limitation, | |
| # we have to build two tf.functions for it. | |
| if has_state: | |
| loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn) | |
| else: | |
| loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn) | |
| loop_fn = tf.function(loop_fn) | |
| else: | |
| if self._eval_options.use_tf_function: | |
| eval_step_fn = tf.function(eval_step_fn) | |
| loop_fn = loop_fns.create_loop_fn(eval_step_fn) | |
| return loop_fn | |
| def evaluate(self, num_steps: tf.Tensor) -> Optional[runner.Output]: | |
| """Implements `num_steps` steps of evaluation. | |
| Args: | |
| num_steps: The number of evaluation steps to run. When this is -1, | |
| evaluation proceeds until a call to `eval_step` raises a `StopIteration` | |
| or `tf.errors.OutOfRangeError`. | |
| Returns: | |
| The output of `self.eval_end()`. | |
| Raises: | |
| ValueError: If `options.use_tf_while_loop` is `True` and `num_steps` is | |
| unspecified. | |
| """ | |
| if self._eval_options.use_tf_while_loop and num_steps == -1: | |
| raise ValueError("Looping until exhausted is not supported if " | |
| "`options.use_tf_while_loop` is `True`") | |
| outputs = self.eval_begin() # pylint: disable=assignment-from-no-return | |
| has_state = outputs is not None | |
| if self._eval_loop_fn is None: | |
| self._eval_loop_fn = self.create_eval_loop_fn(has_state) | |
| # If `recreate_iterator_for_each_eval` is `True`, `self._eval_iter` is | |
| # always None. | |
| if self._eval_iter is None: | |
| eval_iter = tf.nest.map_structure(iter, self.eval_dataset) | |
| if not self._eval_options.recreate_iterator_for_each_eval: | |
| self._eval_iter = eval_iter | |
| else: | |
| eval_iter = self._eval_iter | |
| if self._eval_options.use_tf_while_loop and not has_state: | |
| self._eval_loop_fn(eval_iter, num_steps) | |
| else: | |
| outputs = self._eval_loop_fn( | |
| eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce) | |
| if outputs is None: | |
| return self.eval_end() | |
| else: | |
| return self.eval_end(outputs) | |
| def eval_begin(self) -> Any: | |
| """Called once at the beginning of the evaluation. | |
| This method is always called in eager mode, and is a good place to reset | |
| metrics that accumulate values over the course of evaluation. | |
| Note that this method is called before dataset iterator creation. | |
| Returns: | |
| A value to pass as the `state` argument to `eval_reduce`. | |
| """ | |
| pass | |
| def eval_step(self, iterator) -> Any: | |
| """Implements one step of evaluation. | |
| What a "step" consists of is up to the implementer. When using distribution | |
| strategies, the call to this method takes place in the "cross-replica | |
| context" for generality, to allow e.g. multiple iterator dequeues and calls | |
| to `strategy.run`. | |
| Note that if `use_tf_function=True`, all the code inside `eval_step` should | |
| be compatible with `tf.function` tracing (and in particular, any state | |
| modifications involving `self` should be avoided). In some cases, non- | |
| `tf.function` compatible code can be moved to `eval_loop_begin`, | |
| `eval_reduce`, or `eval_loop_end`, which always execute eagerly. | |
| Args: | |
| iterator: A `tf.nest`-compatible structure of `tf.data.Iterator` or | |
| `DistributedIterator`. | |
| Returns: | |
| An output which is passed as `step_outputs` argument into `eval_reduce` | |
| function. | |
| """ | |
| pass | |
| def eval_end(self, *args) -> Optional[runner.Output]: | |
| """Called at the end of the evaluation. | |
| Called once at the end of evaluation. | |
| This method is always called in eager mode, and is a good place to get | |
| metric results. The value returned from this function will be returned as-is | |
| from the `evaluate` method implementation provided by `StandardEvaluator`. | |
| Args: | |
| *args: The outputs from `eval_reduce` for the last eval step, if they are | |
| non-`None` (if they are `None`, nothing is passed). | |
| Returns: | |
| The function may return a dictionary of `Tensors`, which will be | |
| written to logs and as TensorBoard summaries. It can also be a | |
| nested dictionary, yielding a hierarchy of summary directories. | |
| """ | |
| pass | |
| def eval_reduce(self, | |
| state: Optional[Any] = None, | |
| step_outputs: Optional[runner.Output] = None) -> Any: | |
| """A function to perform per-step reduction on the evaluation outputs. | |
| This is useful for passing state throughout evaluation, especially in cases | |
| where maintaining or accumulating state is hard to accomplish using | |
| `tf.metrics.Metric` or other `tf.Variable`-based approaches. For instance, | |
| it can be used to easily accumulate all per-example losses from the full | |
| evaluation for subsequent processing in `eval_end()`. | |
| Args: | |
| state: A state being maintained throughout the evaluation. | |
| step_outputs: Outputs from the current evaluation step. | |
| Returns: | |
| An output which is passed as the `state` argument to this function for the | |
| next step. After evaluation is finished, the output from last step will be | |
| passed to `eval_end`. | |
| """ | |
| pass | |
| def eval_dataset(self): | |
| """The current evaluation dataset.""" | |
| return self._eval_dataset | |
| def eval_dataset(self, eval_dataset): | |
| """Sets a new eval dataset, replacing the current one. | |
| Any unprocessed examples in the current dataset are discarded. | |
| Args: | |
| eval_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or | |
| `DistributedDataset`. | |
| """ | |
| self._eval_dataset = eval_dataset | |
| self._eval_iter = None | |