Skip to content

Improve the efficiency of some SQL queries #3263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 8, 2025
11 changes: 5 additions & 6 deletions src/zenml/models/v2/base/scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,15 +490,14 @@ def apply_filter(
Returns:
The query with filter applied.
"""
from zenml.zen_stores.schemas import TagResourceSchema
from zenml.zen_stores.schemas import TagResourceSchema, TagSchema

query = super().apply_filter(query=query, table=table)
if self.tag:
query = (
query.join(getattr(table, "tags"))
.join(TagResourceSchema.tag)
.distinct()
)
query = query.join(
TagResourceSchema,
TagResourceSchema.resource_id == getattr(table, "id"),
).join(TagSchema, TagSchema.id == TagResourceSchema.tag_id)

return query

Expand Down
38 changes: 19 additions & 19 deletions src/zenml/zen_stores/schemas/artifact_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@
from zenml.zen_stores.schemas.model_schemas import (
ModelVersionArtifactSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import (
RunMetadataResourceSchema,
)
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema


class ArtifactSchema(NamedSchema, table=True):
Expand All @@ -82,11 +80,12 @@ class ArtifactSchema(NamedSchema, table=True):
back_populates="artifact",
sa_relationship_kwargs={"cascade": "delete"},
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="artifact",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
cascade="delete",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
overlaps="tags",
),
)
Expand Down Expand Up @@ -136,7 +135,7 @@ def to_model(
body = ArtifactResponseBody(
created=self.created,
updated=self.updated,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
latest_version_name=latest_name,
latest_version_id=latest_id,
)
Expand Down Expand Up @@ -192,11 +191,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
uri: str = Field(sa_column=Column(TEXT, nullable=False))
materializer: str = Field(sa_column=Column(TEXT, nullable=False))
data_type: str = Field(sa_column=Column(TEXT, nullable=False))
tags: List["TagResourceSchema"] = Relationship(
back_populates="artifact_version",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
cascade="delete",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
overlaps="tags",
),
)
Expand Down Expand Up @@ -244,12 +244,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
workspace: "WorkspaceSchema" = Relationship(
back_populates="artifact_versions"
)
run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
back_populates="artifact_versions",
run_metadata: List["RunMetadataSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
cascade="delete",
overlaps="run_metadata_resources",
secondary="run_metadata_resource",
primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
overlaps="run_metadata",
),
)
output_of_step_runs: List["StepRunOutputArtifactSchema"] = Relationship(
Expand Down Expand Up @@ -365,7 +365,7 @@ def to_model(
data_type=data_type,
created=self.created,
updated=self.updated,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
producer_pipeline_run_id=producer_pipeline_run_id,
save_type=ArtifactSaveType(self.save_type),
artifact_store_id=self.artifact_store_id,
Expand Down
38 changes: 19 additions & 19 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.run_metadata_schemas import (
RunMetadataResourceSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
from zenml.zen_stores.schemas.utils import (
RunMetadataInterface,
Expand Down Expand Up @@ -114,11 +112,12 @@ class ModelSchema(NamedSchema, table=True):
save_models_to_registry: bool = Field(
sa_column=Column(BOOLEAN, nullable=False)
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="model",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
cascade="delete",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
overlaps="tags",
),
)
Expand Down Expand Up @@ -168,7 +167,7 @@ def to_model(
Returns:
The created `ModelResponse`.
"""
tags = [t.tag.to_model() for t in self.tags]
tags = [tag.to_model() for tag in self.tags]

if self.model_versions:
version_numbers = [mv.number for mv in self.model_versions]
Expand Down Expand Up @@ -299,11 +298,12 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
back_populates="model_version",
sa_relationship_kwargs={"cascade": "delete"},
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="model_version",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
cascade="delete",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
overlaps="tags",
),
)
Expand All @@ -316,12 +316,12 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
description: str = Field(sa_column=Column(TEXT, nullable=True))
stage: str = Field(sa_column=Column(TEXT, nullable=True))

run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
back_populates="model_versions",
run_metadata: List["RunMetadataSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
cascade="delete",
overlaps="run_metadata_resources",
secondary="run_metadata_resource",
primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
overlaps="run_metadata",
),
)
pipeline_runs: List["PipelineRunSchema"] = Relationship(
Expand Down Expand Up @@ -471,7 +471,7 @@ def to_model(
data_artifact_ids=data_artifact_ids,
deployment_artifact_ids=deployment_artifact_ids,
pipeline_run_ids=pipeline_run_ids,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
)

return ModelVersionResponse(
Expand Down
37 changes: 18 additions & 19 deletions src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@
ModelVersionPipelineRunSchema,
ModelVersionSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import (
RunMetadataResourceSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.service_schemas import ServiceSchema
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema


class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
Expand Down Expand Up @@ -140,12 +138,12 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
)
workspace: "WorkspaceSchema" = Relationship(back_populates="runs")
user: Optional["UserSchema"] = Relationship(back_populates="runs")
run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
back_populates="pipeline_runs",
run_metadata: List["RunMetadataSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
cascade="delete",
overlaps="run_metadata_resources",
secondary="run_metadata_resource",
primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
overlaps="run_metadata",
),
)
logs: Optional["LogsSchema"] = Relationship(
Expand Down Expand Up @@ -215,10 +213,12 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
services: List["ServiceSchema"] = Relationship(
back_populates="pipeline_run",
)
tags: List["TagResourceSchema"] = Relationship(
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)",
cascade="delete",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
overlaps="tags",
),
)
Expand Down Expand Up @@ -291,12 +291,6 @@ def to_model(
Raises:
RuntimeError: if the model creation fails.
"""
orchestrator_environment = (
json.loads(self.orchestrator_environment)
if self.orchestrator_environment
else {}
)

if self.deployment is not None:
deployment = self.deployment.to_model(include_metadata=True)

Expand Down Expand Up @@ -377,6 +371,11 @@ def to_model(
# in the response -> We need to reset the metadata here
step.metadata = None

orchestrator_environment = (
json.loads(self.orchestrator_environment)
if self.orchestrator_environment
else {}
)
metadata = PipelineRunResponseMetadata(
workspace=self.workspace.to_model(),
run_metadata=self.fetch_metadata(),
Expand Down Expand Up @@ -405,7 +404,7 @@ def to_model(

resources = PipelineRunResponseResources(
model_version=model_version,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
)

return PipelineRunResponse(
Expand Down
12 changes: 7 additions & 5 deletions src/zenml/zen_stores/schemas/pipeline_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema


class PipelineSchema(NamedSchema, table=True):
Expand Down Expand Up @@ -95,10 +95,12 @@ class PipelineSchema(NamedSchema, table=True):
deployments: List["PipelineDeploymentSchema"] = Relationship(
back_populates="pipeline",
)
tags: List["TagResourceSchema"] = Relationship(
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)",
cascade="delete",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
overlaps="tags",
),
)
Expand Down Expand Up @@ -162,7 +164,7 @@ def to_model(
latest_run_user=latest_run_user.to_model()
if latest_run_user
else None,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
)

return PipelineResponse(
Expand Down
46 changes: 1 addition & 45 deletions src/zenml/zen_stores/schemas/run_metadata_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,25 @@
# permissions and limitations under the License.
"""SQLModel implementation of pipeline run metadata tables."""

from typing import TYPE_CHECKING, List, Optional
from typing import Optional
from uuid import UUID, uuid4

from sqlalchemy import TEXT, VARCHAR, Column
from sqlmodel import Field, Relationship, SQLModel

from zenml.enums import MetadataResourceTypes
from zenml.zen_stores.schemas.base_schemas import BaseSchema
from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema

if TYPE_CHECKING:
from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema


class RunMetadataSchema(BaseSchema, table=True):
"""SQL Model for run metadata."""

__tablename__ = "run_metadata"

# Relationship to link to resources
resources: List["RunMetadataResourceSchema"] = Relationship(
back_populates="run_metadata",
sa_relationship_kwargs={"cascade": "delete"},
)
stack_component_id: Optional[UUID] = build_foreign_key_field(
source=__tablename__,
target=StackComponentSchema.__tablename__,
Expand Down Expand Up @@ -105,36 +94,3 @@ class RunMetadataResourceSchema(SQLModel, table=True):
ondelete="CASCADE",
nullable=False,
)

# Relationship back to the base metadata table
run_metadata: RunMetadataSchema = Relationship(back_populates="resources")

# Relationship to link specific resource types
pipeline_runs: List["PipelineRunSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
overlaps="run_metadata_resources,step_runs,artifact_versions,model_versions",
),
)
step_runs: List["StepRunSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)",
overlaps="run_metadata_resources,pipeline_runs,artifact_versions,model_versions",
),
)
artifact_versions: List["ArtifactVersionSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
overlaps="run_metadata_resources,pipeline_runs,step_runs,model_versions",
),
)
model_versions: List["ModelVersionSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
overlaps="run_metadata_resources,pipeline_runs,step_runs,artifact_versions",
),
)
Loading
Loading