File size: 7,223 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import functools
from typing import Any, Callable, List, Optional

from lightning.pytorch.callbacks import Callback as PTLCallback
from nemo.lightning.base_callback import BaseCallback
from nemo.lightning.one_logger_callback import OneLoggerNeMoCallback


class CallbackGroup:
    """A singleton registry to host and fan-out lifecycle callbacks.

    Other code should call methods on this group (e.g., `on_model_init_start`).
    The group will iterate all registered callbacks and, if a callback implements
    the method, invoke it with the provided arguments.
    """

    _instance: Optional['CallbackGroup'] = None

    @classmethod
    def get_instance(cls) -> 'CallbackGroup':
        """Get the singleton instance of CallbackGroup.

        Returns:
            CallbackGroup: The singleton instance.
        """
        if cls._instance is None:
            cls._instance = CallbackGroup()
        return cls._instance

    def __init__(self) -> None:
        self._callbacks: List[BaseCallback] = [OneLoggerNeMoCallback()]
        # Ensure application-end is emitted at most once per process
        self._app_end_emitted: bool = False

    def register(self, callback: BaseCallback) -> None:
        """Register a callback to the callback group.

        Args:
            callback: The callback to register.
        """
        self._callbacks.append(callback)

    def update_config(self, nemo_version: str, trainer: Any, **kwargs) -> None:
        """Update configuration across all registered callbacks and attach them to trainer.

        Args:
            nemo_version: Version key (e.g., 'v1' or 'v2') for downstream config builders.
            trainer: Lightning Trainer to which callbacks should be attached if missing.
            **kwargs: Forwarded to each callback's update_config implementation.
        """
        # Forward update to each callback that supports update_config
        sanitized_group_callbacks: List[BaseCallback] = []
        for cb in self._callbacks:
            # Will ignore other callbacks like unittest.mock.MagicMock
            if not isinstance(cb, BaseCallback):
                continue
            if hasattr(cb, 'update_config'):
                method = getattr(cb, 'update_config')
                if callable(method):
                    method(nemo_version=nemo_version, trainer=trainer, **kwargs)
            sanitized_group_callbacks.append(cb)

        # Filter trainer callbacks to avoid leaking MagicMocks from tests
        existing = list(getattr(trainer, 'callbacks', []))
        sanitized_trainer_callbacks = [cb for cb in existing if isinstance(cb, PTLCallback)]

        callbacks = sanitized_group_callbacks + sanitized_trainer_callbacks

        # Sanitize callback state_key for pickling safety
        for cb in callbacks:
            try:
                key = getattr(cb, 'state_key', None)
                if not isinstance(key, str):
                    safe_key = (
                        f"{cb.__class__.__module__}.{getattr(cb.__class__, '__qualname__', cb.__class__.__name__)}"
                    )
                    setattr(cb, 'state_key', safe_key)
            except Exception:
                pass

        trainer.callbacks = callbacks

    @property
    def callbacks(self) -> List['BaseCallback']:
        """Get the list of registered callbacks.

        Returns:
            List[BaseCallback]: List of registered callbacks.
        """
        return self._callbacks

    def __getattr__(self, method_name: str) -> Callable:
        """Dynamically create a dispatcher for unknown attributes.

        Any attribute access is treated as a lifecycle method name.
        When invoked, the dispatcher will call that method on each registered
        callback if it exists.
        """

        def dispatcher(*args, **kwargs):
            for cb in self._callbacks:
                if hasattr(cb, method_name):
                    method = getattr(cb, method_name)
                    if callable(method):
                        method(*args, **kwargs)

        return dispatcher

    # Explicit idempotent app-end to avoid duplicate emissions across multiple callers
    def on_app_end(self, *args, **kwargs) -> None:
        """Emit application-end callbacks exactly once per process.

        Invokes `on_app_end` on each registered callback, if present. Subsequent
        calls are no-ops. All positional and keyword arguments are forwarded.
        """
        if self._app_end_emitted:
            return
        self._app_end_emitted = True
        for cb in self._callbacks:
            if hasattr(cb, 'on_app_end'):
                method = getattr(cb, 'on_app_end')
                if callable(method):
                    method(*args, **kwargs)


def hook_class_init_with_callbacks(cls, start_callback: str, end_callback: str) -> None:
    """Hook a class's __init__ to emit CallbackGroup start/end hooks.

    Args:
        cls (type): Class whose __init__ should be wrapped.
        start_callback (str): CallbackGroup method to call before __init__.
        end_callback (str): CallbackGroup method to call after __init__.
    """
    if not hasattr(cls, '__init__'):
        return

    original_init = cls.__init__

    # Idempotence guard: avoid wrapping the same __init__ multiple times (e.g., in multiple inheritance)
    if getattr(original_init, '_init_wrapped_for_callbacks', False):
        return

    @functools.wraps(original_init)
    def wrapped_init(self, *args, **kwargs):
        # Reentrancy guard: avoid double-emitting hooks across super().__init__ chains
        if getattr(self, '_in_wrapped_init', False):
            # If we're already inside a wrapped __init__, just call the original
            return original_init(self, *args, **kwargs)

        setattr(self, '_in_wrapped_init', True)
        group = CallbackGroup.get_instance()
        if hasattr(group, start_callback):
            getattr(group, start_callback)()
        result = original_init(self, *args, **kwargs)
        if hasattr(group, end_callback):
            getattr(group, end_callback)()
        return result

    wrapped_init._init_wrapped_for_callbacks = True
    cls.__init__ = wrapped_init


# Eagerly create the singleton on import so that early callers can use it
CallbackGroup.get_instance()

# Ensure that a single app-end is emitted at process shutdown (e.g., pytest end-of-session,
# non-Hydra entrypoints). Safe due to idempotent on_app_end.
atexit.register(lambda: CallbackGroup.get_instance().on_app_end())

__all__ = ['CallbackGroup', 'hook_class_init_with_callbacks']