Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Helper functions for running models in a distributed setting.""" | |
| import json | |
| import os | |
| import tensorflow as tf, tf_keras | |
| def _collective_communication(all_reduce_alg): | |
| """Return a CollectiveCommunication based on all_reduce_alg. | |
| Args: | |
| all_reduce_alg: a string specifying which collective communication to pick, | |
| or None. | |
| Returns: | |
| tf.distribute.experimental.CollectiveCommunication object | |
| Raises: | |
| ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"] | |
| """ | |
| collective_communication_options = { | |
| None: tf.distribute.experimental.CollectiveCommunication.AUTO, | |
| "ring": tf.distribute.experimental.CollectiveCommunication.RING, | |
| "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL | |
| } | |
| if all_reduce_alg not in collective_communication_options: | |
| raise ValueError( | |
| "When used with `multi_worker_mirrored`, valid values for " | |
| "all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format( | |
| all_reduce_alg)) | |
| return collective_communication_options[all_reduce_alg] | |
| def _mirrored_cross_device_ops(all_reduce_alg, num_packs): | |
| """Return a CrossDeviceOps based on all_reduce_alg and num_packs. | |
| Args: | |
| all_reduce_alg: a string specifying which cross device op to pick, or None. | |
| num_packs: an integer specifying number of packs for the cross device op. | |
| Returns: | |
| tf.distribute.CrossDeviceOps object or None. | |
| Raises: | |
| ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"]. | |
| """ | |
| if all_reduce_alg is None: | |
| return None | |
| mirrored_all_reduce_options = { | |
| "nccl": tf.distribute.NcclAllReduce, | |
| "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce | |
| } | |
| if all_reduce_alg not in mirrored_all_reduce_options: | |
| raise ValueError( | |
| "When used with `mirrored`, valid values for all_reduce_alg are " | |
| "[`nccl`, `hierarchical_copy`]. Supplied value: {}".format( | |
| all_reduce_alg)) | |
| cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg] | |
| return cross_device_ops_class(num_packs=num_packs) | |
| def tpu_initialize(tpu_address): | |
| """Initializes TPU for TF 2.x training. | |
| Args: | |
| tpu_address: string, bns address of master TPU worker. | |
| Returns: | |
| A TPUClusterResolver. | |
| """ | |
| cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( | |
| tpu=tpu_address) | |
| if tpu_address not in ("", "local"): | |
| tf.config.experimental_connect_to_cluster(cluster_resolver) | |
| tf.tpu.experimental.initialize_tpu_system(cluster_resolver) | |
| return cluster_resolver | |
| def get_distribution_strategy(distribution_strategy="mirrored", | |
| num_gpus=0, | |
| all_reduce_alg=None, | |
| num_packs=1, | |
| tpu_address=None, | |
| **kwargs): | |
| """Return a Strategy for running the model. | |
| Args: | |
| distribution_strategy: a string specifying which distribution strategy to | |
| use. Accepted values are "off", "one_device", "mirrored", | |
| "parameter_server", "multi_worker_mirrored", and "tpu" -- case | |
| insensitive. "tpu" means to use TPUStrategy using `tpu_address`. | |
| "off" means to use the default strategy which is obtained from | |
| tf.distribute.get_strategy (for details on the default strategy, see | |
| https://www.tensorflow.org/guide/distributed_training#default_strategy). | |
| num_gpus: Number of GPUs to run this model. | |
| all_reduce_alg: Optional. Specifies which algorithm to use when performing | |
| all-reduce. For `MirroredStrategy`, valid values are "nccl" and | |
| "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are | |
| "ring" and "nccl". If None, DistributionStrategy will choose based on | |
| device topology. | |
| num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce` | |
| or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`. | |
| tpu_address: Optional. String that represents TPU to connect to. Must not be | |
| None if `distribution_strategy` is set to `tpu`. | |
| **kwargs: Additional kwargs for internal usages. | |
| Returns: | |
| tf.distribute.Strategy object. | |
| Raises: | |
| ValueError: if `distribution_strategy` is "off" or "one_device" and | |
| `num_gpus` is larger than 1; or `num_gpus` is negative or if | |
| `distribution_strategy` is `tpu` but `tpu_address` is not specified. | |
| """ | |
| del kwargs | |
| if num_gpus < 0: | |
| raise ValueError("`num_gpus` can not be negative.") | |
| if not isinstance(distribution_strategy, str): | |
| msg = ("distribution_strategy must be a string but got: %s." % | |
| (distribution_strategy,)) | |
| if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison | |
| msg += (" If you meant to pass the string 'off', make sure you add " | |
| "quotes around 'off' so that yaml interprets it as a string " | |
| "instead of a bool.") | |
| raise ValueError(msg) | |
| distribution_strategy = distribution_strategy.lower() | |
| if distribution_strategy == "off": | |
| if num_gpus > 1: | |
| raise ValueError(f"When {num_gpus} GPUs are specified, " | |
| "distribution_strategy flag cannot be set to `off`.") | |
| # Return the default distribution strategy. | |
| return tf.distribute.get_strategy() | |
| if distribution_strategy == "tpu": | |
| # When tpu_address is an empty string, we communicate with local TPUs. | |
| cluster_resolver = tpu_initialize(tpu_address) | |
| return tf.distribute.TPUStrategy(cluster_resolver) | |
| if distribution_strategy == "multi_worker_mirrored": | |
| return tf.distribute.experimental.MultiWorkerMirroredStrategy( | |
| communication=_collective_communication(all_reduce_alg)) | |
| if distribution_strategy == "one_device": | |
| if num_gpus == 0: | |
| return tf.distribute.OneDeviceStrategy("device:CPU:0") | |
| if num_gpus > 1: | |
| raise ValueError("`OneDeviceStrategy` can not be used for more than " | |
| "one device.") | |
| return tf.distribute.OneDeviceStrategy("device:GPU:0") | |
| if distribution_strategy == "mirrored": | |
| if num_gpus == 0: | |
| devices = ["device:CPU:0"] | |
| else: | |
| devices = ["device:GPU:%d" % i for i in range(num_gpus)] | |
| return tf.distribute.MirroredStrategy( | |
| devices=devices, | |
| cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs)) | |
| if distribution_strategy == "parameter_server": | |
| cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() | |
| return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver) | |
| raise ValueError("Unrecognized Distribution Strategy: %r" % | |
| distribution_strategy) | |
| def configure_cluster(worker_hosts=None, task_index=-1): | |
| """Set multi-worker cluster spec in TF_CONFIG environment variable. | |
| Args: | |
| worker_hosts: comma-separated list of worker ip:port pairs. | |
| task_index: index of the worker. | |
| Returns: | |
| Number of workers in the cluster. | |
| """ | |
| tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) | |
| if tf_config: | |
| num_workers = ( | |
| len(tf_config["cluster"].get("chief", [])) + | |
| len(tf_config["cluster"].get("worker", []))) | |
| elif worker_hosts: | |
| workers = worker_hosts.split(",") | |
| num_workers = len(workers) | |
| if num_workers > 1 and task_index < 0: | |
| raise ValueError("Must specify task_index when number of workers > 1") | |
| task_index = 0 if num_workers == 1 else task_index | |
| os.environ["TF_CONFIG"] = json.dumps({ | |
| "cluster": { | |
| "worker": workers | |
| }, | |
| "task": { | |
| "type": "worker", | |
| "index": task_index | |
| } | |
| }) | |
| else: | |
| num_workers = 1 | |
| return num_workers | |
| def get_strategy_scope(strategy): | |
| if strategy: | |
| strategy_scope = strategy.scope() | |
| else: | |
| strategy_scope = DummyContextManager() | |
| return strategy_scope | |
| class DummyContextManager(object): | |
| def __enter__(self): | |
| pass | |
| def __exit__(self, *args): | |
| pass | |