Spaces:
Sleeping
Sleeping
| # Copyright 2023 The Orbit Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Provides a utility class for managing summary writing.""" | |
| import os | |
| from orbit.utils.summary_manager_interface import SummaryManagerInterface | |
| import tensorflow as tf, tf_keras | |
| class SummaryManager(SummaryManagerInterface): | |
| """A utility class for managing summary writing.""" | |
| def __init__(self, summary_dir, summary_fn, global_step=None): | |
| """Initializes the `SummaryManager` instance. | |
| Args: | |
| summary_dir: The directory in which to write summaries. If `None`, all | |
| summary writing operations provided by this class are no-ops. | |
| summary_fn: A callable defined accepting `name`, `value`, and `step` | |
| parameters, making calls to `tf.summary` functions to write summaries. | |
| global_step: A `tf.Variable` containing the global step value. | |
| """ | |
| self._enabled = summary_dir is not None | |
| self._summary_dir = summary_dir | |
| self._summary_fn = summary_fn | |
| self._summary_writers = {} | |
| if global_step is None: | |
| self._global_step = tf.summary.experimental.get_step() | |
| else: | |
| self._global_step = global_step | |
| def summary_writer(self, relative_path=""): | |
| """Returns the underlying summary writer for a specific subdirectory. | |
| Args: | |
| relative_path: The current path in which to write summaries, relative to | |
| the summary directory. By default it is empty, which corresponds to the | |
| root directory. | |
| """ | |
| if self._summary_writers and relative_path in self._summary_writers: | |
| return self._summary_writers[relative_path] | |
| if self._enabled: | |
| self._summary_writers[relative_path] = tf.summary.create_file_writer( | |
| os.path.join(self._summary_dir, relative_path)) | |
| else: | |
| self._summary_writers[relative_path] = tf.summary.create_noop_writer() | |
| return self._summary_writers[relative_path] | |
| def flush(self): | |
| """Flushes the underlying summary writers.""" | |
| if self._enabled: | |
| tf.nest.map_structure(tf.summary.flush, self._summary_writers) | |
| def write_summaries(self, summary_dict): | |
| """Writes summaries for the given dictionary of values. | |
| This recursively creates subdirectories for any nested dictionaries | |
| provided in `summary_dict`, yielding a hierarchy of directories which will | |
| then be reflected in the TensorBoard UI as different colored curves. | |
| For example, users may evaluate on multiple datasets and return | |
| `summary_dict` as a nested dictionary: | |
| { | |
| "dataset1": { | |
| "loss": loss1, | |
| "accuracy": accuracy1 | |
| }, | |
| "dataset2": { | |
| "loss": loss2, | |
| "accuracy": accuracy2 | |
| }, | |
| } | |
| This will create two subdirectories, "dataset1" and "dataset2", inside the | |
| summary root directory. Each directory will contain event files including | |
| both "loss" and "accuracy" summaries. | |
| Args: | |
| summary_dict: A dictionary of values. If any value in `summary_dict` is | |
| itself a dictionary, then the function will create a subdirectory with | |
| name given by the corresponding key. This is performed recursively. Leaf | |
| values are then summarized using the summary writer instance specific to | |
| the parent relative path. | |
| """ | |
| if not self._enabled: | |
| return | |
| self._write_summaries(summary_dict) | |
| def _write_summaries(self, summary_dict, relative_path=""): | |
| for name, value in summary_dict.items(): | |
| if isinstance(value, dict): | |
| self._write_summaries( | |
| value, relative_path=os.path.join(relative_path, name)) | |
| else: | |
| with self.summary_writer(relative_path).as_default(): | |
| self._summary_fn(name, value, step=self._global_step) | |