Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Flags for managing compute devices. Currently only contains TPU flags.""" | |
| from absl import flags | |
| from absl import logging | |
| from official.utils.flags._conventions import help_wrap | |
| def require_cloud_storage(flag_names): | |
| """Register a validator to check directory flags. | |
| Args: | |
| flag_names: An iterable of strings containing the names of flags to be | |
| checked. | |
| """ | |
| msg = "TPU requires GCS path for {}".format(", ".join(flag_names)) | |
| def _path_check(flag_values): # pylint: disable=missing-docstring | |
| if flag_values["tpu"] is None: | |
| return True | |
| valid_flags = True | |
| for key in flag_names: | |
| if not flag_values[key].startswith("gs://"): | |
| logging.error("%s must be a GCS path.", key) | |
| valid_flags = False | |
| return valid_flags | |
| def define_device(tpu=True): | |
| """Register device specific flags. | |
| Args: | |
| tpu: Create flags to specify TPU operation. | |
| Returns: | |
| A list of flags for core.py to marks as key flags. | |
| """ | |
| key_flags = [] | |
| if tpu: | |
| flags.DEFINE_string( | |
| name="tpu", | |
| default=None, | |
| help=help_wrap( | |
| "The Cloud TPU to use for training. This should be either the name " | |
| "used when creating the Cloud TPU, or a " | |
| "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the" | |
| "CPU of the local instance instead. (Good for debugging.)")) | |
| key_flags.append("tpu") | |
| flags.DEFINE_string( | |
| name="tpu_zone", | |
| default=None, | |
| help=help_wrap( | |
| "[Optional] GCE zone where the Cloud TPU is located in. If not " | |
| "specified, we will attempt to automatically detect the GCE " | |
| "project from metadata.")) | |
| flags.DEFINE_string( | |
| name="tpu_gcp_project", | |
| default=None, | |
| help=help_wrap( | |
| "[Optional] Project name for the Cloud TPU-enabled project. If not " | |
| "specified, we will attempt to automatically detect the GCE " | |
| "project from metadata.")) | |
| flags.DEFINE_integer( | |
| name="num_tpu_shards", | |
| default=8, | |
| help=help_wrap("Number of shards (TPU chips).")) | |
| return key_flags | |