Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tests for tensorflow_models.core.trainers.trainer.""" | |
| # pylint: disable=g-direct-tensorflow-import | |
| import gc | |
| import multiprocessing | |
| import os | |
| import sys | |
| from absl.testing import parameterized | |
| import orbit | |
| import portpicker | |
| import tensorflow as tf, tf_keras | |
| from tensorflow.python.distribute import combinations | |
| from tensorflow.python.distribute import strategy_combinations | |
| from official.core import base_trainer as trainer_lib | |
| from official.core import config_definitions as cfg | |
| from official.core import train_lib | |
| from official.utils.testing import mock_task | |
| TPU_TEST = 'test_tpu' in sys.argv[0] | |
| GPU_TEST = 'test_gpu' in sys.argv[0] | |
| def all_strategy_combinations(): | |
| return combinations.combine( | |
| distribution=[ | |
| strategy_combinations.default_strategy, | |
| strategy_combinations.cloud_tpu_strategy, | |
| strategy_combinations.one_device_strategy_gpu, | |
| ],) | |
| def create_in_process_cluster(num_workers, num_ps): | |
| """Creates and starts local servers and returns the cluster_resolver.""" | |
| worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] | |
| ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] | |
| cluster_dict = {} | |
| cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports] | |
| if num_ps > 0: | |
| cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports] | |
| cluster_spec = tf.train.ClusterSpec(cluster_dict) | |
| # Workers need some inter_ops threads to work properly. | |
| worker_config = tf.compat.v1.ConfigProto() | |
| if multiprocessing.cpu_count() < num_workers + 1: | |
| worker_config.inter_op_parallelism_threads = num_workers + 1 | |
| for i in range(num_workers): | |
| tf.distribute.Server( | |
| cluster_spec, | |
| job_name='worker', | |
| task_index=i, | |
| config=worker_config, | |
| protocol='grpc') | |
| for i in range(num_ps): | |
| tf.distribute.Server( | |
| cluster_spec, job_name='ps', task_index=i, protocol='grpc') | |
| cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver( | |
| cluster_spec, rpc_layer='grpc') | |
| return cluster_resolver | |
| def dataset_fn(input_context=None): | |
| del input_context | |
| def dummy_data(_): | |
| return tf.zeros((1, 1), dtype=tf.float32) | |
| dataset = tf.data.Dataset.range(1) | |
| dataset = dataset.repeat() | |
| dataset = dataset.map( | |
| dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
| return dataset | |
| class MockAsyncTrainer(trainer_lib._AsyncTrainer): | |
| """Mock AsyncTrainer to test the _AsyncTrainer class.""" | |
| def __init__(self): | |
| self._strategy = tf.distribute.get_strategy() | |
| self.init_async() | |
| self.global_step = tf.Variable( | |
| 0, | |
| dtype=tf.int64, | |
| name='global_step', | |
| trainable=False, | |
| aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) | |
| self.eval_global_step = tf.Variable( | |
| 0, | |
| dtype=tf.int64, | |
| name='eval_global_step', | |
| trainable=False, | |
| aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) | |
| train_dataset = self.distribute_dataset(dataset_fn) | |
| orbit.StandardTrainer.__init__( | |
| self, train_dataset, options=orbit.StandardTrainerOptions()) | |
| validation_dataset = self.distribute_dataset(dataset_fn) | |
| orbit.StandardEvaluator.__init__( | |
| self, | |
| validation_dataset, | |
| options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True)) | |
| def train_loop_begin(self): | |
| self.global_step.assign(0) | |
| def train_step(self, iterator): | |
| def replica_step(_): | |
| self.global_step.assign_add(1) | |
| self._strategy.run(replica_step, args=(next(iterator),)) | |
| def train_loop_end(self): | |
| self.join() | |
| return self.global_step.numpy() | |
| def eval_begin(self): | |
| self.eval_global_step.assign(0) | |
| def eval_step(self, iterator): | |
| def replica_step(_): | |
| self.eval_global_step.assign_add(1) | |
| self._strategy.run(replica_step, args=(next(iterator),)) | |
| def eval_end(self): | |
| self.join() | |
| return self.eval_global_step.numpy() | |
| class TrainerTest(tf.test.TestCase, parameterized.TestCase): | |
| def setUp(self): | |
| super().setUp() | |
| self._config = cfg.ExperimentConfig( | |
| trainer=cfg.TrainerConfig( | |
| optimizer_config=cfg.OptimizationConfig({ | |
| 'optimizer': { | |
| 'type': 'sgd' | |
| }, | |
| 'learning_rate': { | |
| 'type': 'constant' | |
| } | |
| }))) | |
| def tearDown(self): | |
| gc.collect() | |
| # This will only contain uncollectable garbage, i.e. reference cycles | |
| # involving objects with __del__ defined. | |
| self.assertEmpty(gc.garbage) | |
| super().tearDown() | |
| def create_test_trainer(self, config, model_dir=None, task=None): | |
| task = task or mock_task.MockTask(config.task, logging_dir=model_dir) | |
| ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir) | |
| trainer = trainer_lib.Trainer( | |
| config, | |
| task, | |
| model=task.build_model(), | |
| optimizer=task.create_optimizer(config.trainer.optimizer_config, | |
| config.runtime), | |
| checkpoint_exporter=ckpt_exporter) | |
| return trainer | |
| def test_trainer_train(self, distribution): | |
| with distribution.scope(): | |
| trainer = self.create_test_trainer(self._config) | |
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertIn('training_loss', logs) | |
| self.assertIn('learning_rate', logs) | |
| def test_trainer_passing_datasets(self, distribution): | |
| with distribution.scope(): | |
| task = mock_task.MockTask(self._config) | |
| train_dataset = orbit.utils.make_distributed_dataset( | |
| distribution, task.build_inputs, self._config.task.train_data) | |
| validation_dataset = orbit.utils.make_distributed_dataset( | |
| distribution, task.build_inputs, self._config.task.validation_data) | |
| self._config.task.train_data = None | |
| self._config.task.validation_data = None | |
| trainer = trainer_lib.Trainer( | |
| self._config, | |
| task, | |
| model=task.build_model(), | |
| optimizer=task.create_optimizer(self._config.trainer.optimizer_config, | |
| self._config.runtime), | |
| train_dataset=train_dataset, | |
| validation_dataset=validation_dataset) | |
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertIn('training_loss', logs) | |
| self.assertIn('learning_rate', logs) | |
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertIn('validation_loss', logs) | |
| def test_base_async_trainer(self): | |
| if TPU_TEST or GPU_TEST: | |
| self.skipTest('Aysnc training is not available on GPU/GPU.') | |
| num_workers = 3 | |
| num_ps = 2 | |
| cluster_resolver = create_in_process_cluster(num_workers, num_ps) | |
| distribution = tf.distribute.experimental.ParameterServerStrategy( | |
| cluster_resolver) | |
| with distribution.scope(): | |
| trainer = MockAsyncTrainer() | |
| trainer.init_async() | |
| self.assertIsInstance( | |
| trainer._coordinator, | |
| tf.distribute.experimental.coordinator.ClusterCoordinator) | |
| self.assertEqual(trainer.train(tf.constant(10)), 10) | |
| self.assertEqual(trainer.evaluate(tf.constant(11)), 11) | |
| def test_async_trainer_train(self): | |
| if TPU_TEST or GPU_TEST: | |
| self.skipTest('Aysnc training is not available on GPU/TPU.') | |
| num_workers = 3 | |
| num_ps = 2 | |
| cluster_resolver = create_in_process_cluster(num_workers, num_ps) | |
| distribution = tf.distribute.experimental.ParameterServerStrategy( | |
| cluster_resolver) | |
| with distribution.scope(): | |
| config = cfg.ExperimentConfig(**self._config.as_dict()) | |
| config.trainer.eval_tf_while_loop = True | |
| trainer = self.create_test_trainer(config) | |
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertIn('training_loss', logs) | |
| self.assertIn('learning_rate', logs) | |
| def test_async_trainer_validate(self): | |
| if TPU_TEST or GPU_TEST: | |
| self.skipTest('Aysnc training is not available on GPU/GPU.') | |
| num_workers = 3 | |
| num_ps = 2 | |
| cluster_resolver = create_in_process_cluster(num_workers, num_ps) | |
| distribution = tf.distribute.experimental.ParameterServerStrategy( | |
| cluster_resolver) | |
| with distribution.scope(): | |
| config = cfg.ExperimentConfig(**self._config.as_dict()) | |
| config.trainer.eval_tf_while_loop = True | |
| trainer = self.create_test_trainer(config) | |
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertIn('acc', logs) | |
| self.assertIn('validation_loss', logs) | |
| def test_trainer_validate(self, distribution): | |
| with distribution.scope(): | |
| trainer = self.create_test_trainer(self._config) | |
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync) | |
| self.assertIn('validation_loss', logs) | |
| def test_trainer_validate_without_loss(self, distribution): | |
| class MockTaskWithoutValidationLoss(mock_task.MockTask): | |
| def validation_step(self, inputs, model, metrics=None): | |
| # Disable validation loss. | |
| logs = super().validation_step(inputs, model) | |
| del logs[self.loss] | |
| return logs | |
| with distribution.scope(): | |
| task = MockTaskWithoutValidationLoss() | |
| trainer = self.create_test_trainer(self._config, task=task) | |
| logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync) | |
| self.assertNotIn('validation_loss', logs) | |
| def test_configure_optimizer(self, mixed_precision_dtype, loss_scale): | |
| config = cfg.ExperimentConfig( | |
| runtime=cfg.RuntimeConfig( | |
| mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale), | |
| trainer=cfg.TrainerConfig( | |
| optimizer_config=cfg.OptimizationConfig({ | |
| 'optimizer': { | |
| 'type': 'sgd' | |
| }, | |
| 'learning_rate': { | |
| 'type': 'constant' | |
| }, | |
| }))) | |
| trainer = self.create_test_trainer(config) | |
| if mixed_precision_dtype == 'float16': | |
| self.assertIsInstance(trainer.optimizer, | |
| tf_keras.mixed_precision.LossScaleOptimizer) | |
| if loss_scale in (None, 'dynamic'): | |
| self.assertTrue(trainer.optimizer.dynamic) | |
| else: | |
| self.assertFalse(trainer.optimizer.dynamic) | |
| self.assertEqual(trainer.optimizer.initial_scale, loss_scale) | |
| else: | |
| self.assertIsInstance( | |
| trainer.optimizer, | |
| (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD)) | |
| metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertIn('training_loss', metrics) | |
| def test_export_best_ckpt(self): | |
| config = cfg.ExperimentConfig( | |
| trainer=cfg.TrainerConfig( | |
| best_checkpoint_export_subdir='best_ckpt', | |
| best_checkpoint_eval_metric='acc', | |
| optimizer_config=cfg.OptimizationConfig({ | |
| 'optimizer': { | |
| 'type': 'sgd' | |
| }, | |
| 'learning_rate': { | |
| 'type': 'constant' | |
| } | |
| }))) | |
| model_dir = self.get_temp_dir() | |
| trainer = self.create_test_trainer(config, model_dir=model_dir) | |
| trainer.train(tf.convert_to_tensor(1, dtype=tf.int32)) | |
| trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) | |
| self.assertTrue( | |
| tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json'))) | |
| def test_model_with_compiled_loss(self): | |
| task = mock_task.MockTask() | |
| model = task.build_model() | |
| model.compile(loss=tf_keras.losses.CategoricalCrossentropy()) | |
| trainer = trainer_lib.Trainer( | |
| self._config, | |
| task, | |
| model=model, | |
| optimizer=task.create_optimizer(self._config.trainer.optimizer_config)) | |
| logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) | |
| self.assertIn('training_loss', logs) | |
| if __name__ == '__main__': | |
| tf.test.main() | |