File size: 3,071 Bytes
5887903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import shutil
from pathlib import Path

import numpy as np
from PIL import Image as PILImage

try:
    from trackio.media.media import TrackioMedia
except ImportError:
    from media.media import TrackioMedia


TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image


class TrackioImage(TrackioMedia):
    """
    Initializes an Image object.

    Example:
        ```python
        import trackio
        import numpy as np
        from PIL import Image

        # Create an image from numpy array
        image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
        image = trackio.Image(image_data, caption="Random image")
        trackio.log({"my_image": image})

        # Create an image from PIL Image
        pil_image = Image.new('RGB', (100, 100), color='red')
        image = trackio.Image(pil_image, caption="Red square")
        trackio.log({"red_image": image})

        # Create an image from file path
        image = trackio.Image("path/to/image.jpg", caption="Photo from file")
        trackio.log({"file_image": image})
        ```

    Args:
        value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*):
            A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
            If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
        caption (`str`, *optional*):
            A string caption for the image.
    """

    TYPE = "trackio.image"

    def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
        super().__init__(value, caption)
        self._format: str | None = None

        if not isinstance(self._value, TrackioImageSourceType):
            raise ValueError(
                f"Invalid value type, expected {TrackioImageSourceType}, got {type(self._value)}"
            )
        if isinstance(self._value, np.ndarray) and self._value.dtype != np.uint8:
            raise ValueError(
                f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
            )
        if (
            isinstance(self._value, np.ndarray | PILImage.Image)
            and self._format is None
        ):
            self._format = "png"

    def _as_pil(self) -> PILImage.Image | None:
        try:
            if isinstance(self._value, np.ndarray):
                arr = np.asarray(self._value).astype("uint8")
                return PILImage.fromarray(arr).convert("RGBA")
            if isinstance(self._value, PILImage.Image):
                return self._value.convert("RGBA")
        except Exception as e:
            raise ValueError(f"Failed to process image data: {self._value}") from e
        return None

    def _save_media(self, file_path: Path):
        if pil := self._as_pil():
            pil.save(file_path, format=self._format)
        elif isinstance(self._value, str | Path):
            if os.path.isfile(self._value):
                shutil.copy(self._value, file_path)
            else:
                raise ValueError(f"File not found: {self._value}")