Skip to content

Commit 5b7f9ad

Browse files
authored
data: expose downsampling preferences to plugins (#3271)
Summary: We add a `sampling_hints` attribute to the `TBContext` magic container, which is populated with the parsed form of the `--samples_per_plugin` flag. Existing plugins’ generic data modes are updated to read from this map instead of using hard-coded thresholds. Test Plan: This change is not actually observable as is, because the multiplexer data provider ignores its downsampling argument. But after patching in a change to make the data provider respect the downsampling argument, this change has the effect that increasing the `--samples_per_plugin` over the default (e.g., `images=20`) now properly increases the number of samples shown in generic data mode, whereas previously it had no effect. wchargin-branch: data-downsampling-flag
1 parent 9bb99e6 commit 5b7f9ad

File tree

6 files changed

+50
-25
lines changed

6 files changed

+50
-25
lines changed

tensorboard/backend/application.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,20 @@
9797
logger = tb_logging.get_logger()
9898

9999

100-
def tensor_size_guidance_from_flags(flags):
101-
"""Apply user per-summary size guidance overrides."""
102-
103-
tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
100+
def _parse_samples_per_plugin(flags):
101+
result = {}
104102
if not flags or not flags.samples_per_plugin:
105-
return tensor_size_guidance
106-
103+
return result
107104
for token in flags.samples_per_plugin.split(","):
108105
k, v = token.strip().split("=")
109-
tensor_size_guidance[k] = int(v)
106+
result[k] = int(v)
107+
return result
110108

109+
110+
def _apply_tensor_size_guidance(sampling_hints):
111+
"""Apply user per-summary size guidance overrides."""
112+
tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
113+
tensor_size_guidance.update(sampling_hints)
111114
return tensor_size_guidance
112115

113116

@@ -151,9 +154,10 @@ def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider):
151154
multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider)
152155
else:
153156
# Regular logdir loading mode.
157+
sampling_hints = _parse_samples_per_plugin(flags)
154158
multiplexer = event_multiplexer.EventMultiplexer(
155159
size_guidance=DEFAULT_SIZE_GUIDANCE,
156-
tensor_size_guidance=tensor_size_guidance_from_flags(flags),
160+
tensor_size_guidance=_apply_tensor_size_guidance(sampling_hints),
157161
purge_orphaned_data=flags.purge_orphaned_data,
158162
max_reload_threads=flags.max_reload_threads,
159163
event_file_active_filter=_get_event_file_active_filter(flags),
@@ -238,6 +242,7 @@ def TensorBoardWSGIApp(
238242
multiplexer=deprecated_multiplexer,
239243
assets_zip_provider=assets_zip_provider,
240244
plugin_name_to_instance=plugin_name_to_instance,
245+
sampling_hints=_parse_samples_per_plugin(flags),
241246
window_title=flags.window_title,
242247
)
243248
tbplugins = []

tensorboard/plugins/base_plugin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def __init__(
254254
logdir=None,
255255
multiplexer=None,
256256
plugin_name_to_instance=None,
257+
sampling_hints=None,
257258
window_title=None,
258259
):
259260
"""Instantiates magic container.
@@ -291,6 +292,10 @@ def __init__(
291292
plugin may be absent from this mapping until it is registered. Plugin
292293
logic should handle cases in which a plugin is absent from this
293294
mapping, lest a KeyError is raised.
295+
sampling_hints: Map from plugin name to `int` or `NoneType`, where
296+
the value represents the user-specified downsampling limit as
297+
given to the `--samples_per_plugin` flag, or `None` if none was
298+
explicitly given for this plugin.
294299
window_title: A string specifying the window title.
295300
"""
296301
self.assets_zip_provider = assets_zip_provider
@@ -301,6 +306,7 @@ def __init__(
301306
self.logdir = logdir
302307
self.multiplexer = multiplexer
303308
self.plugin_name_to_instance = plugin_name_to_instance
309+
self.sampling_hints = sampling_hints
304310
self.window_title = window_title
305311

306312

tensorboard/plugins/histogram/histograms_plugin.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
from tensorboard.util import tensor_util
4040

4141

42+
_DEFAULT_DOWNSAMPLING = 500 # histograms per time series
43+
44+
4245
class HistogramsPlugin(base_plugin.TBPlugin):
4346
"""Histograms Plugin for TensorBoard.
4447
@@ -62,6 +65,9 @@ def __init__(self, context):
6265
"""
6366
self._multiplexer = context.multiplexer
6467
self._db_connection_provider = context.db_connection_provider
68+
self._downsample_to = (context.sampling_hints or {}).get(
69+
self.plugin_name, _DEFAULT_DOWNSAMPLING
70+
)
6571
if context.flags and context.flags.generic_data == "true":
6672
self._data_provider = context.data_provider
6773
else:
@@ -174,20 +180,21 @@ def histograms_impl(self, tag, run, experiment, downsample_to=None):
174180
"""Result of the form `(body, mime_type)`.
175181
176182
At most `downsample_to` events will be returned. If this value is
177-
`None`, then no downsampling will be performed.
183+
`None`, then default downsampling will be performed.
178184
179185
Raises:
180186
tensorboard.errors.PublicError: On invalid request.
181187
"""
182188
if self._data_provider:
183-
# Downsample reads to 500 histograms per time series, which is
184-
# the default size guidance for histograms under the multiplexer
185-
# loading logic.
186-
SAMPLE_COUNT = downsample_to if downsample_to is not None else 500
189+
sample_count = (
190+
downsample_to
191+
if downsample_to is not None
192+
else self._downsample_to
193+
)
187194
all_histograms = self._data_provider.read_tensors(
188195
experiment_id=experiment,
189196
plugin_name=metadata.PLUGIN_NAME,
190-
downsample=SAMPLE_COUNT,
197+
downsample=sample_count,
191198
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
192199
)
193200
histograms = all_histograms.get(run, {}).get(tag, None)

tensorboard/plugins/image/images_plugin.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
}
4444

4545
_DEFAULT_IMAGE_MIMETYPE = "application/octet-stream"
46+
_DEFAULT_DOWNSAMPLING = 10 # images per time series
4647

4748

4849
# Extend imghdr.tests to include svg.
@@ -69,6 +70,9 @@ def __init__(self, context):
6970
"""
7071
self._multiplexer = context.multiplexer
7172
self._db_connection_provider = context.db_connection_provider
73+
self._downsample_to = (context.sampling_hints or {}).get(
74+
self.plugin_name, _DEFAULT_DOWNSAMPLING
75+
)
7276
if context.flags and context.flags.generic_data == "true":
7377
self._data_provider = context.data_provider
7478
else:
@@ -239,14 +243,10 @@ def _image_response_for_run(self, experiment, run, tag, sample):
239243
parameters.
240244
"""
241245
if self._data_provider:
242-
# Downsample reads to 10 images per time series, which is the
243-
# default size guidance for images under the multiplexer loading
244-
# logic.
245-
SAMPLE_COUNT = 10
246246
all_images = self._data_provider.read_blob_sequences(
247247
experiment_id=experiment,
248248
plugin_name=metadata.PLUGIN_NAME,
249-
downsample=SAMPLE_COUNT,
249+
downsample=self._downsample_to,
250250
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
251251
)
252252
images = all_images.get(run, {}).get(tag, None)

tensorboard/plugins/scalar/scalars_plugin.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
from tensorboard.util import tensor_util
4141

4242

43+
_DEFAULT_DOWNSAMPLING = 1000 # scalars per time series
44+
45+
4346
class OutputFormat(object):
4447
"""An enum used to list the valid output formats for API calls."""
4548

@@ -60,6 +63,9 @@ def __init__(self, context):
6063
"""
6164
self._multiplexer = context.multiplexer
6265
self._db_connection_provider = context.db_connection_provider
66+
self._downsample_to = (context.sampling_hints or {}).get(
67+
self.plugin_name, _DEFAULT_DOWNSAMPLING
68+
)
6369
if context.flags and context.flags.generic_data != "false":
6470
self._data_provider = context.data_provider
6571
else:
@@ -169,14 +175,10 @@ def index_impl(self, experiment=None):
169175
def scalars_impl(self, tag, run, experiment, output_format):
170176
"""Result of the form `(body, mime_type)`."""
171177
if self._data_provider:
172-
# Downsample reads to 1000 scalars per time series, which is the
173-
# default size guidance for scalars under the multiplexer loading
174-
# logic.
175-
SAMPLE_COUNT = 1000
176178
all_scalars = self._data_provider.read_scalars(
177179
experiment_id=experiment,
178180
plugin_name=metadata.PLUGIN_NAME,
179-
downsample=SAMPLE_COUNT,
181+
downsample=self._downsample_to,
180182
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
181183
)
182184
scalars = all_scalars.get(run, {}).get(tag, None)

tensorboard/plugins/text/text_plugin.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
2d tables are supported. Showing a 2d slice of the data instead."""
4949
)
5050

51+
_DEFAULT_DOWNSAMPLING = 100 # text tensors per time series
52+
5153

5254
def make_table_row(contents, tag="td"):
5355
"""Given an iterable of string contents, make a table row.
@@ -212,6 +214,9 @@ def __init__(self, context):
212214
context: A base_plugin.TBContext instance.
213215
"""
214216
self._multiplexer = context.multiplexer
217+
self._downsample_to = (context.sampling_hints or {}).get(
218+
self.plugin_name, _DEFAULT_DOWNSAMPLING
219+
)
215220
if context.flags and context.flags.generic_data == "true":
216221
self._data_provider = context.data_provider
217222
else:
@@ -261,7 +266,7 @@ def text_impl(self, run, tag, experiment):
261266
all_text = self._data_provider.read_tensors(
262267
experiment_id=experiment,
263268
plugin_name=metadata.PLUGIN_NAME,
264-
downsample=100,
269+
downsample=self._downsample_to,
265270
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
266271
)
267272
text = all_text.get(run, {}).get(tag, None)

0 commit comments

Comments
 (0)