Skip to content

Commit 4f6f6a2

Browse files
authored
Add last_value to ScalarTimeSeries interface. (#6579)
This allows `list_scalars` of some of the data provider implementations to return the last scalar value, which will improve the performance when loading experiments with Hparams data. Googlers, see b/292102513 for context. Tested internally: cl/563163418 #hparams
1 parent 1dde0cb commit 4f6f6a2

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

tensorboard/data/provider.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ class _TimeSeries:
777777
"_plugin_content",
778778
"_description",
779779
"_display_name",
780+
"_last_value",
780781
)
781782

782783
def __init__(
@@ -787,12 +788,14 @@ def __init__(
787788
plugin_content,
788789
description,
789790
display_name,
791+
last_value=None,
790792
):
791793
self._max_step = max_step
792794
self._max_wall_time = max_wall_time
793795
self._plugin_content = plugin_content
794796
self._description = description
795797
self._display_name = display_name
798+
self._last_value = last_value
796799

797800
@property
798801
def max_step(self):
@@ -814,6 +817,10 @@ def description(self):
814817
def display_name(self):
815818
return self._display_name
816819

820+
@property
821+
def last_value(self):
822+
return self._last_value
823+
817824

818825
class ScalarTimeSeries(_TimeSeries):
819826
"""Metadata about a scalar time series for a particular run and tag.
@@ -830,6 +837,9 @@ class ScalarTimeSeries(_TimeSeries):
830837
empty if no description was specified.
831838
display_name: An optional long-form Markdown description, as a `str` that is
832839
empty if no description was specified. Deprecated; may be removed soon.
840+
last_value: An optional value for the latest scalar in the time series,
841+
corresponding to the scalar at `max_step`. Note that this field might NOT
842+
be populated by all data provider implementations.
833843
"""
834844

835845
def __eq__(self, other):
@@ -845,6 +855,8 @@ def __eq__(self, other):
845855
return False
846856
if self._display_name != other._display_name:
847857
return False
858+
if self._last_value != other._last_value:
859+
return False
848860
return True
849861

850862
def __hash__(self):
@@ -855,6 +867,7 @@ def __hash__(self):
855867
self._plugin_content,
856868
self._description,
857869
self._display_name,
870+
self._last_value,
858871
)
859872
)
860873

@@ -866,6 +879,7 @@ def __repr__(self):
866879
"plugin_content=%r" % (self._plugin_content,),
867880
"description=%r" % (self._description,),
868881
"display_name=%r" % (self._display_name,),
882+
"last_value=%r" % (self._last_value,),
869883
)
870884
)
871885

tensorboard/data/provider_test.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,13 @@ def test_repr(self):
9191

9292
class ScalarTimeSeriesTest(tb_test.TestCase):
9393
def _scalar_time_series(
94-
self, max_step, max_wall_time, plugin_content, description, display_name
94+
self,
95+
max_step,
96+
max_wall_time,
97+
plugin_content,
98+
description,
99+
display_name,
100+
last_value,
95101
):
96102
# Helper to use explicit kwargs.
97103
return provider.ScalarTimeSeries(
@@ -100,6 +106,7 @@ def _scalar_time_series(
100106
plugin_content=plugin_content,
101107
description=description,
102108
display_name=display_name,
109+
last_value=last_value,
103110
)
104111

105112
def test_repr(self):
@@ -109,26 +116,28 @@ def test_repr(self):
109116
plugin_content=b"AB\xCD\xEF!\x00",
110117
description="test test",
111118
display_name="one two",
119+
last_value=0.0001,
112120
)
113121
repr_ = repr(x)
114122
self.assertIn(repr(x.max_step), repr_)
115123
self.assertIn(repr(x.max_wall_time), repr_)
116124
self.assertIn(repr(x.plugin_content), repr_)
117125
self.assertIn(repr(x.description), repr_)
118126
self.assertIn(repr(x.display_name), repr_)
127+
self.assertIn(repr(x.last_value), repr_)
119128

120129
def test_eq(self):
121-
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
122-
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
123-
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum")
130+
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
131+
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
132+
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum", 1024)
124133
self.assertEqual(x1, x2)
125134
self.assertNotEqual(x1, x3)
126135
self.assertNotEqual(x1, object())
127136

128137
def test_hash(self):
129-
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
130-
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
131-
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum")
138+
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
139+
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
140+
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum", 1024)
132141
self.assertEqual(hash(x1), hash(x2))
133142
# The next check is technically not required by the `__hash__`
134143
# contract, but _should_ pass; failure on this assertion would at

0 commit comments

Comments
 (0)