Spaces:
Sleeping
Sleeping
| # Copyright 2023 The Orbit Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utilities for creating loop functions.""" | |
| from absl import logging | |
| from orbit.utils import tpu_summaries | |
| import tensorflow as tf, tf_keras | |
| def create_loop_fn(step_fn): | |
| """Creates a loop function driven by a Python `while` loop. | |
| Args: | |
| step_fn: A function taking a nested structure of `tf.data.Iterator` or | |
| `DistributedIterator`. There are no constraints on the return value of the | |
| function (except that it must be compatible with any `reduce_fn` provided | |
| to the returned `loop_fn`). | |
| Returns: | |
| A loop function taking required `iterator` and `num_steps` parameters, as | |
| well as optional `state` and `reduce_fn` parameters for accumulating state | |
| over multiple iterations of the loop. See the `loop_fn` definition below for | |
| additional details. | |
| """ | |
| def loop_fn(iterator, num_steps, state=None, reduce_fn=None): | |
| """Makes `num_steps` calls to `step_fn(iterator)`. | |
| Additionally, state may be accumulated across iterations of the loop. | |
| Conceptually, state accumulation is handled roughly as follows: | |
| for _ in range(num_steps): | |
| step_outputs = step_fn(iterator) | |
| state = reduce_fn(state, step_outputs) | |
| return state | |
| However, the implementation is slightly more complicated in order to support | |
| looping until the iterator is exhausted (when `num_steps == -1`) and to | |
| properly catch exceptions when running under async remote eager (as is the | |
| case in TPU training setups involving separate coordinator/worker machines). | |
| Args: | |
| iterator: A nested structure of `tf.data.Iterator` or | |
| `DistributedIterator`. | |
| num_steps: The number of steps in the loop. If `num_steps == -1`, will | |
| iterate until exausting the iterator. | |
| state: An optional initial state before running the loop. | |
| reduce_fn: A callable taking two inputs, `state` and `value`, where | |
| `state` is the previous output from `reduce_fn`, and `value` is the | |
| output from `step_fn`. | |
| Returns: | |
| The final state returned by `reduce_fn`, or `None` if `state` and | |
| `reduce_fn` are not provided. | |
| """ | |
| step = 0 | |
| try: | |
| # To make sure the OutOfRangeError exception can be handled well under | |
| # async remote eager, we need to wrap the loop body in `async_scope`. | |
| with tf.experimental.async_scope(): | |
| while num_steps == -1 or step < num_steps: | |
| outputs = step_fn(iterator) | |
| if reduce_fn is not None: | |
| state = reduce_fn(state, outputs) | |
| step += 1 | |
| return state | |
| except (StopIteration, tf.errors.OutOfRangeError): | |
| logging.info("The dataset iterator is exhausted after %d steps.", step) | |
| tf.experimental.async_clear_error() | |
| return state | |
| return loop_fn | |
| def create_tf_while_loop_fn(step_fn): | |
| """Creates a loop function compatible with TF's AutoGraph loop conversion. | |
| Args: | |
| step_fn: A function taking a nested structure of `tf.data.Iterator` or | |
| `DistributedIterator`. Currently, any return values are ignored. | |
| Returns: | |
| A loop function taking required `iterator` and `num_steps` parameters. If | |
| called inside a `tf.function`, the loop will be converted by AutoGraph into | |
| a `tf.while_loop` construct. See the `loop_fn` definition below for | |
| additional details. | |
| """ | |
| def loop_fn(iterator, num_steps): | |
| """Makes `num_steps` calls to `step_fn(iterator)`. | |
| Args: | |
| iterator: A nested structure of `tf.data.Iterator` or | |
| `DistributedIterator`. | |
| num_steps: The number of steps in the loop. Should be passed as a | |
| `tf.Tensor`. Iterating until iterator exhaustion is not supported. | |
| """ | |
| if not isinstance(num_steps, tf.Tensor): | |
| raise ValueError( | |
| "`num_steps` should be a `tf.Tensor`. Passing a Python value can " | |
| "cause unnecessary retracing when wrapped by `tf.function`.") | |
| for _ in tf.range(num_steps): | |
| # Clear out the outer name scope so the ops created inside `tf.while_loop` | |
| # don't get "while/" as name prefix. | |
| with tf.name_scope(""): | |
| step_fn(iterator) | |
| return loop_fn | |
| def create_tf_while_loop_fn_with_state(step_fn): | |
| """Creates a TF while loop function with state. | |
| This function is similar to `create_tf_while_loop_fn`, but allowing a `state` | |
| to be accumulated over multiple iterations of the loop. Note that the | |
| structure of the `state` cannot be changed across iterations. | |
| Args: | |
| step_fn: A function taking a nested structure of `tf.data.Iterator` or | |
| `DistributedIterator`. Currently, any return values are ignored. | |
| Returns: | |
| A loop function taking required `iterator`, `num_steps`, `state` and | |
| `reduce_fn` parameters. If called inside a `tf.function`, the loop will be | |
| converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn` | |
| definition below for additional details. | |
| """ | |
| def loop_fn_with_state(iterator, num_steps, state, reduce_fn): | |
| """Makes `num_steps` calls to `step_fn(iterator)`. | |
| Args: | |
| iterator: A nested structure of `tf.data.Iterator` or | |
| `DistributedIterator`. | |
| num_steps: The number of steps in the loop. Should be passed as a | |
| `tf.Tensor`. Iterating until iterator exhaustion is not supported. | |
| state: An initial state before running the loop. | |
| reduce_fn: A callable taking two inputs, `state` and `value`, where | |
| `state` is the previous output from `reduce_fn`, and `value` is the | |
| output from `step_fn`. | |
| Returns: | |
| The final state returned by `reduce_fn`. | |
| """ | |
| if not isinstance(num_steps, tf.Tensor): | |
| raise ValueError( | |
| "`num_steps` should be a `tf.Tensor`. Passing a Python value can " | |
| "cause unnecessary retracing when wrapped by `tf.function`.") | |
| def _get_relaxed_tensor_shape(t): | |
| """Returns a `TensorShape` with all `None` dimensions.""" | |
| if not tf.is_tensor(t): | |
| return None | |
| shape = t.shape | |
| if shape.rank is not None and shape.rank > 0: | |
| return tf.TensorShape([None] * shape.rank) | |
| return shape | |
| def _get_relaxed_shape_structure(s): | |
| """Returns the relaxed shape of the input nested structure `s`.""" | |
| return tf.nest.pack_sequence_as( | |
| state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)]) | |
| for _ in tf.range(num_steps): | |
| # Clear out the outer name scope so the ops created inside `tf.while_loop` | |
| # don't get "while/" as name prefix. | |
| with tf.name_scope(""): | |
| # Relax the shapes within the loop, so the shape of `state` can change | |
| # across iterations. This is useful to aggregate outputs from each step | |
| # and concat to `state`. | |
| tf.autograph.experimental.set_loop_options( | |
| shape_invariants=[(state, _get_relaxed_shape_structure(state))]) | |
| outputs = step_fn(iterator) | |
| state = reduce_fn(state, outputs) | |
| return state | |
| return loop_fn_with_state | |
| class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction): | |
| """Implements a two-program approach for optimizing summaries on TPU. | |
| This version works with the result of `create_tf_while_loop_fn`. | |
| """ | |
| def __call__(self, iterator, num_steps): | |
| if tf.summary.should_record_summaries(): | |
| output = self.with_summaries(iterator, tf.constant(1)) | |
| num_steps -= 1 | |
| if num_steps >= 1: | |
| output = self.without_summaries(iterator, num_steps) | |
| return output | |