Skip to content

Commit c450124

Browse files
committed
Batch events
1 parent 876979d commit c450124

File tree

16 files changed

+307
-77
lines changed

16 files changed

+307
-77
lines changed

plugins/snowflake/superduper_snowflake/data_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from snowflake.snowpark.types import BooleanType, StringType, VariantType
1313
from superduper import CFG, logging
1414
from superduper.backends.base.data_backend import BaseDataBackend
15+
from superduper.base.event import Create, CreateTable
1516
from superduper.base.query import Query
1617
from superduper.base.schema import Schema
1718
from watchdog.events import FileSystemEventHandler
@@ -115,6 +116,13 @@ def db(self, value):
115116
"""
116117
self._db = value
117118

119+
def create_tables_and_schemas(self, events: t.List[CreateTable]):
120+
"""Create tables and schemas in the data-backend.
121+
122+
:param events: The events to create.
123+
"""
124+
raise NotImplementedError
125+
118126
def create_table_and_schema(self, identifier: str, schema: Schema, primary_id: str):
119127
"""Create a schema in the data-backend.
120128

superduper/backends/base/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ def get_tool(self, uuid: str):
4545
tool_id = self.uuid_tool_mapping[uuid]
4646
return self.tools[tool_id]
4747

48-
def put_component(self, component: 'Component', **kwargs):
48+
def put_component(self, component: 'Component', uuid: str):
4949
"""Put a component to the backend.
5050
5151
:param component: Component to put.
5252
:param kwargs: kwargs dictionary.
5353
"""
54+
component = self.db.load(component=component, uuid=uuid)
5455
logging.info(
5556
f'Putting component: {component.huuid} on to {self.__class__.__name__}'
5657
)
@@ -68,7 +69,7 @@ def put_component(self, component: 'Component', **kwargs):
6869
return
6970
self.tool_uuid_mapping[tool.identifier].add(component.uuid)
7071
self.tools[tool.identifier] = tool
71-
tool.initialize(**kwargs)
72+
tool.initialize()
7273

7374
def drop_component(self, component: str, identifier: str):
7475
"""Drop the component from backend.

superduper/backends/base/compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def initialize(self):
6565
self.put_component(m)
6666

6767
@abstractmethod
68-
def put_component(self, component: 'Component'):
68+
def put_component(self, component: 'Component', uuid: str):
6969
"""Create handler on component declare.
7070
7171
:param component: Component to put.

superduper/backends/base/data_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
if t.TYPE_CHECKING:
1414
from superduper.base.schema import Schema
15+
from superduper.base.event import CreateTable
1516

1617

1718
class BaseDataBackend(ABC):
@@ -85,6 +86,19 @@ def db(self, value):
8586
"""
8687
self._db = value
8788

89+
def create_tables_and_schemas(self, events: t.List['CreateTable']):
90+
"""Create a schema in the data-backend.
91+
92+
:param events: List of `CreateTable` events.
93+
"""
94+
from superduper.base.schema import Schema
95+
for event in events:
96+
self.create_table_and_schema(
97+
event.identifier,
98+
schema=Schema.build(**event.fields),
99+
primary_id=event.primary_id,
100+
)
101+
88102
@abstractmethod
89103
def create_table_and_schema(
90104
self, identifier: str, schema: 'Schema', primary_id: str

superduper/backends/base/scheduler.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from superduper import CFG, logging
1010
from superduper.backends.base.backends import BaseBackend
1111
from superduper.base.base import Base
12+
from superduper.base.event import Create, CreateTable, PutComponent, Update
1213

1314
DependencyType = t.Union[t.Dict[str, str], t.Sequence[t.Dict[str, str]]]
1415

@@ -140,6 +141,30 @@ def _consume_event_type(event_type, ids, table, db: 'Datalayer'):
140141
db.cluster.compute.release_futures(context)
141142

142143

144+
def cluster_events(events: t.List[Base]) -> t.Tuple[t.List['CreateTable'], t.List['Create'], t.List['Job']]:
145+
"""
146+
Cluster events into table, create and job events.
147+
148+
:param events: List of events to be clustered.
149+
:return: Tuple of table events, create events and job events.
150+
"""
151+
from superduper.base.metadata import Job
152+
table_events = []
153+
create_events = []
154+
job_events = []
155+
put_events = []
156+
for event in events:
157+
if isinstance(event, CreateTable):
158+
table_events.append(event)
159+
elif isinstance(event, (Update, Create)):
160+
create_events.append(event)
161+
elif isinstance(event, Job):
162+
job_events.append(event)
163+
elif isinstance(event, PutComponent):
164+
put_events.append(event)
165+
return table_events, create_events, put_events, job_events
166+
167+
143168
def consume_events(events: t.List[Base], table: str, db: 'Datalayer'):
144169
"""
145170
Consume events from table queue.
@@ -152,7 +177,32 @@ def consume_events(events: t.List[Base], table: str, db: 'Datalayer'):
152177
logging.info(f'Consuming {len(events)} events on {table}.')
153178
consume_streaming_events(events=events, table=table, db=db)
154179
else:
155-
logging.info(f'Consuming {len(events)} _apply events')
156-
for event in events:
157-
event.execute(db)
180+
table_events, create_events, put_events, job_events = cluster_events(events)
181+
182+
if table_events:
183+
logging.info(f'Consuming {len(events)} `CreateTable` events')
184+
CreateTable.batch_execute(
185+
events=table_events,
186+
db=db,
187+
)
188+
189+
if create_events:
190+
logging.info(f'Consuming {len(events)} `Create` events')
191+
Create.batch_execute(
192+
events=create_events,
193+
db=db,
194+
)
195+
196+
if put_events:
197+
logging.info(f'Consuming {len(events)} `PutComponent` events')
198+
PutComponent.batch_execute(
199+
events=put_events,
200+
db=db,
201+
)
202+
203+
if job_events:
204+
logging.info(f'Consuming {len(events)} jobs (`Job`)')
205+
for job in job_events:
206+
job.execute(db)
207+
158208
return

superduper/backends/local/cdc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ def list_uuids(self):
3232
"""List UUIDs of components."""
3333
return list(self._trigger_uuid_mapping.values())
3434

35-
def put_component(self, component):
36-
assert isinstance(component, CDC)
37-
self.triggers.add((component.component, component.identifier))
35+
def put_component(self, component: str, uuid: str):
36+
self.triggers.add((component, uuid))
3837

3938
def drop_component(self, component, identifier):
4039
c = self.db.load(component=component, identifier=identifier)

superduper/backends/local/compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def list_uuids(self):
7777
def drop_component(self, component: str, identifier: str):
7878
"""Drop a component from the compute."""
7979

80-
def put_component(self, component: 'Component'):
80+
def put_component(self, component: 'Component', uuid: str):
8181
"""Create a handler on compute."""
8282

8383
def initialize(self):

superduper/base/apply.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from superduper import CFG, Component, logging
99
from superduper.base import exceptions
1010
from superduper.base.document import Document
11-
from superduper.base.event import Create, Signal, Update
11+
from superduper.base.event import Create, PutComponent, Signal, Update
1212
from superduper.components.component import running_status
1313
from superduper.misc.tree import dict_to_tree
1414

@@ -62,9 +62,10 @@ def apply(
6262
force = db.cfg.force_apply
6363

6464
# This holds a record of the changes
65-
diff: t.Dict = {}
65+
diff = {}
66+
6667
# context allows us to track the origin of the component creation
67-
create_events, job_events = _apply(
68+
table_events, create_events, put_events, job_events = _apply(
6869
db=db,
6970
object=object,
7071
context=object.uuid,
@@ -87,6 +88,14 @@ def apply(
8788

8889
logging.info('Found these changes and/ or additions that need to be made:')
8990

91+
logging.info('-' * 100)
92+
logging.info('TABLE EVENTS:')
93+
logging.info('-' * 100)
94+
steps = {t.huuid: str(i) for i, t in enumerate(table_events.values())}
95+
96+
for i, t in enumerate(table_events.values()):
97+
logging.info(f'[{i}]: {t.huuid}')
98+
9099
logging.info('-' * 100)
91100
logging.info('METADATA EVENTS:')
92101
logging.info('-' * 100)
@@ -104,6 +113,14 @@ def apply(
104113
else:
105114
logging.info(f'[{i}]: {c.huuid}: {c.__class__.__name__}')
106115

116+
logging.info('-' * 100)
117+
logging.info('PUT EVENTS:')
118+
logging.info('-' * 100)
119+
steps = {p.huuid: str(i) for i, p in enumerate(put_events.values())}
120+
121+
for i, t in enumerate(put_events.values()):
122+
logging.info(f'[{i}]: {t.huuid}')
123+
107124
logging.info('-' * 100)
108125
logging.info('JOBS EVENTS:')
109126
logging.info('-' * 100)
@@ -124,7 +141,9 @@ def apply(
124141
logging.info('-' * 100)
125142

126143
events = [
144+
*list(table_events.values()),
127145
*list(create_events.values()),
146+
*list(put_events.values()),
128147
*list(job_events.values()),
129148
Signal(context=object.uuid, msg='done'),
130149
]
@@ -192,12 +211,14 @@ def _apply(
192211
object.db = db
193212

194213
create_events = {}
214+
table_events = {}
215+
put_events = {}
195216
children = []
196217

197218
def wrapper(child):
198219
nonlocal create_events
199220

200-
c, j = _apply(
221+
t, c, p, j = _apply(
201222
db=db,
202223
object=child,
203224
context=context,
@@ -209,6 +230,8 @@ def wrapper(child):
209230

210231
job_events.update(j)
211232
create_events.update(c)
233+
table_events.update(t)
234+
put_events.update(p)
212235
children.append((child.component, child.identifier, child.uuid))
213236
return f'&:component:{child.huuid}'
214237

@@ -252,7 +275,7 @@ def wrapper(child):
252275
serialized = db._save_artifact(serialized.encode())
253276

254277
if apply_status == 'same':
255-
return create_events, job_events
278+
return table_events, create_events, put_events, job_events
256279

257280
elif apply_status == 'new':
258281

@@ -264,11 +287,23 @@ def wrapper(child):
264287
children=children,
265288
)
266289

290+
table_events.update(object.create_table_events())
291+
267292
these_job_events = object.create_jobs(
268293
event_type='apply',
269294
jobs=list(job_events.values()),
270295
context=context,
271296
)
297+
298+
for service in object.services:
299+
put_events[f'{object.huuid}/{service}'] = PutComponent(
300+
component=object.component,
301+
identifier=object.identifier,
302+
uuid=object.uuid,
303+
context=context,
304+
service=service,
305+
)
306+
272307
elif apply_status == 'breaking':
273308

274309
metadata_event = Create(
@@ -278,12 +313,23 @@ def wrapper(child):
278313
parent=parent,
279314
children=children,
280315
)
316+
317+
table_events.update(object.create_table_events())
281318

282319
these_job_events = object.create_jobs(
283320
event_type='apply',
284321
jobs=list(job_events.values()),
285322
context=context,
286323
)
324+
325+
for service in object.services:
326+
put_events[f'{object.huuid}/{service}'] = PutComponent(
327+
component=object.component,
328+
identifier=object.identifier,
329+
uuid=object.uuid,
330+
context=context,
331+
service=service,
332+
)
287333
else:
288334
assert apply_status == 'update'
289335

@@ -318,4 +364,4 @@ def wrapper(child):
318364

319365
create_events[metadata_event.huuid] = metadata_event
320366
job_events.update({jj.huuid: jj for jj in these_job_events})
321-
return create_events, job_events
367+
return table_events, create_events, put_events, job_events

0 commit comments

Comments
 (0)