diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index f5267f4840d..16ca14f4b5c 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -490,15 +490,14 @@ def apply_filter( Returns: The query with filter applied. """ - from zenml.zen_stores.schemas import TagResourceSchema + from zenml.zen_stores.schemas import TagResourceSchema, TagSchema query = super().apply_filter(query=query, table=table) if self.tag: - query = ( - query.join(getattr(table, "tags")) - .join(TagResourceSchema.tag) - .distinct() - ) + query = query.join( + TagResourceSchema, + TagResourceSchema.resource_id == getattr(table, "id"), + ).join(TagSchema, TagSchema.id == TagResourceSchema.tag_id) return query diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 02e842a5fb5..487ad7c7946 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -59,10 +59,8 @@ from zenml.zen_stores.schemas.model_schemas import ( ModelVersionArtifactSchema, ) - from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceSchema, - ) - from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema + from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.tag_schemas import TagSchema class ArtifactSchema(NamedSchema, table=True): @@ -82,11 +80,12 @@ class ArtifactSchema(NamedSchema, table=True): back_populates="artifact", sa_relationship_kwargs={"cascade": "delete"}, ) - tags: List["TagResourceSchema"] = Relationship( - back_populates="artifact", + tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)", - cascade="delete", + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", overlaps="tags", ), ) @@ -136,7 +135,7 @@ def to_model( body = ArtifactResponseBody( created=self.created, updated=self.updated, - tags=[t.tag.to_model() for t in self.tags], + tags=[tag.to_model() for tag in self.tags], latest_version_name=latest_name, latest_version_id=latest_id, ) @@ -192,11 +191,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): uri: str = Field(sa_column=Column(TEXT, nullable=False)) materializer: str = Field(sa_column=Column(TEXT, nullable=False)) data_type: str = Field(sa_column=Column(TEXT, nullable=False)) - tags: List["TagResourceSchema"] = Relationship( - back_populates="artifact_version", + tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)", - cascade="delete", + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", overlaps="tags", ), ) @@ -244,12 +244,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): workspace: "WorkspaceSchema" = Relationship( back_populates="artifact_versions" ) - run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="artifact_versions", + run_metadata: List["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", - cascade="delete", - overlaps="run_metadata_resources", + secondary="run_metadata_resource", + primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", + secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)", + overlaps="run_metadata", ), ) output_of_step_runs: List["StepRunOutputArtifactSchema"] = Relationship( @@ -365,7 +365,7 @@ def to_model( data_type=data_type, created=self.created, updated=self.updated, - tags=[t.tag.to_model() for t in self.tags], + tags=[tag.to_model() for tag in self.tags], producer_pipeline_run_id=producer_pipeline_run_id, save_type=ArtifactSaveType(self.save_type), artifact_store_id=self.artifact_store_id, diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 41c186c75ca..03551f36ba8 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -56,11 +56,9 @@ from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema -from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceSchema, -) +from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field -from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema +from zenml.zen_stores.schemas.tag_schemas import TagSchema from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.utils import ( RunMetadataInterface, @@ -114,11 +112,12 @@ class ModelSchema(NamedSchema, table=True): save_models_to_registry: bool = Field( sa_column=Column(BOOLEAN, nullable=False) ) - tags: List["TagResourceSchema"] = Relationship( - back_populates="model", + tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)", - cascade="delete", + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", overlaps="tags", ), ) @@ -168,7 +167,7 @@ def to_model( Returns: The created `ModelResponse`. """ - tags = [t.tag.to_model() for t in self.tags] + tags = [tag.to_model() for tag in self.tags] if self.model_versions: version_numbers = [mv.number for mv in self.model_versions] @@ -299,11 +298,12 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): back_populates="model_version", sa_relationship_kwargs={"cascade": "delete"}, ) - tags: List["TagResourceSchema"] = Relationship( - back_populates="model_version", + tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)", - cascade="delete", + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", overlaps="tags", ), ) @@ -316,12 +316,12 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) - run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="model_versions", + run_metadata: List["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", - cascade="delete", - overlaps="run_metadata_resources", + secondary="run_metadata_resource", + primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", + secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)", + overlaps="run_metadata", ), ) pipeline_runs: List["PipelineRunSchema"] = Relationship( @@ -471,7 +471,7 @@ def to_model( data_artifact_ids=data_artifact_ids, deployment_artifact_ids=deployment_artifact_ids, pipeline_run_ids=pipeline_run_ids, - tags=[t.tag.to_model() for t in self.tags], + tags=[tag.to_model() for tag in self.tags], ) return ModelVersionResponse( diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 67236f0ab7d..1481a90d2de 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -58,12 +58,10 @@ ModelVersionPipelineRunSchema, ModelVersionSchema, ) - from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceSchema, - ) + from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema - from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema + from zenml.zen_stores.schemas.tag_schemas import TagSchema class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): @@ -140,12 +138,12 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ) workspace: "WorkspaceSchema" = Relationship(back_populates="runs") user: Optional["UserSchema"] = Relationship(back_populates="runs") - run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="pipeline_runs", + run_metadata: List["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", - cascade="delete", - overlaps="run_metadata_resources", + secondary="run_metadata_resource", + primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", + secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)", + overlaps="run_metadata", ), ) logs: Optional["LogsSchema"] = Relationship( @@ -215,10 +213,12 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): services: List["ServiceSchema"] = Relationship( back_populates="pipeline_run", ) - tags: List["TagResourceSchema"] = Relationship( + tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)", - cascade="delete", + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", overlaps="tags", ), ) @@ -291,12 +291,6 @@ def to_model( Raises: RuntimeError: if the model creation fails. """ - orchestrator_environment = ( - json.loads(self.orchestrator_environment) - if self.orchestrator_environment - else {} - ) - if self.deployment is not None: deployment = self.deployment.to_model(include_metadata=True) @@ -377,6 +371,11 @@ def to_model( # in the response -> We need to reset the metadata here step.metadata = None + orchestrator_environment = ( + json.loads(self.orchestrator_environment) + if self.orchestrator_environment + else {} + ) metadata = PipelineRunResponseMetadata( workspace=self.workspace.to_model(), run_metadata=self.fetch_metadata(), @@ -405,7 +404,7 @@ def to_model( resources = PipelineRunResponseResources( model_version=model_version, - tags=[t.tag.to_model() for t in self.tags], + tags=[tag.to_model() for tag in self.tags], ) return PipelineRunResponse( diff --git a/src/zenml/zen_stores/schemas/pipeline_schemas.py b/src/zenml/zen_stores/schemas/pipeline_schemas.py index 3719a64b207..0513b4b5a7b 100644 --- a/src/zenml/zen_stores/schemas/pipeline_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_schemas.py @@ -43,7 +43,7 @@ ) from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema - from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema + from zenml.zen_stores.schemas.tag_schemas import TagSchema class PipelineSchema(NamedSchema, table=True): @@ -95,10 +95,12 @@ class PipelineSchema(NamedSchema, table=True): deployments: List["PipelineDeploymentSchema"] = Relationship( back_populates="pipeline", ) - tags: List["TagResourceSchema"] = Relationship( + tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)", - cascade="delete", + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", overlaps="tags", ), ) @@ -162,7 +164,7 @@ def to_model( latest_run_user=latest_run_user.to_model() if latest_run_user else None, - tags=[t.tag.to_model() for t in self.tags], + tags=[tag.to_model() for tag in self.tags], ) return PipelineResponse( diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index f4465b13e66..8ef2426a1a5 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -13,13 +13,12 @@ # permissions and limitations under the License. """SQLModel implementation of pipeline run metadata tables.""" -from typing import TYPE_CHECKING, List, Optional +from typing import Optional from uuid import UUID, uuid4 from sqlalchemy import TEXT, VARCHAR, Column from sqlmodel import Field, Relationship, SQLModel -from zenml.enums import MetadataResourceTypes from zenml.zen_stores.schemas.base_schemas import BaseSchema from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field @@ -27,22 +26,12 @@ from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema -if TYPE_CHECKING: - from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema - from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema - from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema - class RunMetadataSchema(BaseSchema, table=True): """SQL Model for run metadata.""" __tablename__ = "run_metadata" - # Relationship to link to resources - resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs={"cascade": "delete"}, - ) stack_component_id: Optional[UUID] = build_foreign_key_field( source=__tablename__, target=StackComponentSchema.__tablename__, @@ -105,36 +94,3 @@ class RunMetadataResourceSchema(SQLModel, table=True): ondelete="CASCADE", nullable=False, ) - - # Relationship back to the base metadata table - run_metadata: RunMetadataSchema = Relationship(back_populates="resources") - - # Relationship to link specific resource types - pipeline_runs: List["PipelineRunSchema"] = Relationship( - back_populates="run_metadata_resources", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", - overlaps="run_metadata_resources,step_runs,artifact_versions,model_versions", - ), - ) - step_runs: List["StepRunSchema"] = Relationship( - back_populates="run_metadata_resources", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", - overlaps="run_metadata_resources,pipeline_runs,artifact_versions,model_versions", - ), - ) - artifact_versions: List["ArtifactVersionSchema"] = Relationship( - back_populates="run_metadata_resources", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", - overlaps="run_metadata_resources,pipeline_runs,step_runs,model_versions", - ), - ) - model_versions: List["ModelVersionSchema"] = Relationship( - back_populates="run_metadata_resources", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", - overlaps="run_metadata_resources,pipeline_runs,step_runs,artifact_versions", - ), - ) diff --git a/src/zenml/zen_stores/schemas/run_template_schemas.py b/src/zenml/zen_stores/schemas/run_template_schemas.py index c2869e099f4..05021801569 100644 --- a/src/zenml/zen_stores/schemas/run_template_schemas.py +++ b/src/zenml/zen_stores/schemas/run_template_schemas.py @@ -41,7 +41,7 @@ PipelineDeploymentSchema, ) from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema - from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema + from zenml.zen_stores.schemas.tag_schemas import TagSchema class RunTemplateSchema(BaseSchema, table=True): @@ -110,10 +110,12 @@ class RunTemplateSchema(BaseSchema, table=True): } ) - tags: List["TagResourceSchema"] = Relationship( + tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.RUN_TEMPLATE.value}', foreign(TagResourceSchema.resource_id)==RunTemplateSchema.id)", - cascade="delete", + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.RUN_TEMPLATE.value}', foreign(TagResourceSchema.resource_id)==RunTemplateSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", overlaps="tags", ), ) @@ -253,7 +255,7 @@ def to_model( pipeline=pipeline, build=build, code_reference=code_reference, - tags=[t.tag.to_model() for t in self.tags], + tags=[tag.to_model() for tag in self.tags], ) return RunTemplateResponse( diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index ea01de1ab24..33acd07522d 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -58,9 +58,7 @@ from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.logs_schemas import LogsSchema from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema - from zenml.zen_stores.schemas.run_metadata_schemas import ( - RunMetadataResourceSchema, - ) + from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): @@ -150,12 +148,12 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): deployment: Optional["PipelineDeploymentSchema"] = Relationship( back_populates="step_runs" ) - run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( - back_populates="step_runs", + run_metadata: List["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", - cascade="delete", - overlaps="run_metadata_resources", + secondary="run_metadata_resource", + primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", + secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)", + overlaps="run_metadata", ), ) input_artifacts: List["StepRunInputArtifactSchema"] = Relationship( diff --git a/src/zenml/zen_stores/schemas/tag_schemas.py b/src/zenml/zen_stores/schemas/tag_schemas.py index d1e6a0483b4..6e3767ebb38 100644 --- a/src/zenml/zen_stores/schemas/tag_schemas.py +++ b/src/zenml/zen_stores/schemas/tag_schemas.py @@ -14,7 +14,7 @@ """SQLModel implementation of tag tables.""" from datetime import datetime -from typing import TYPE_CHECKING, Any, List +from typing import Any, List from uuid import UUID from sqlalchemy import VARCHAR, Column @@ -33,16 +33,6 @@ from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field -if TYPE_CHECKING: - from zenml.zen_stores.schemas.artifact_schemas import ( - ArtifactSchema, - ArtifactVersionSchema, - ) - from zenml.zen_stores.schemas.model_schemas import ( - ModelSchema, - ModelVersionSchema, - ) - class TagSchema(NamedSchema, table=True): """SQL Model for tag.""" @@ -52,7 +42,7 @@ class TagSchema(NamedSchema, table=True): color: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) links: List["TagResourceSchema"] = Relationship( back_populates="tag", - sa_relationship_kwargs={"cascade": "delete"}, + sa_relationship_kwargs={"overlaps": "tags", "cascade": "delete"}, ) @classmethod @@ -130,37 +120,11 @@ class TagResourceSchema(BaseSchema, table=True): ondelete="CASCADE", nullable=False, ) - tag: "TagSchema" = Relationship(back_populates="links") + tag: "TagSchema" = Relationship( + back_populates="links", sa_relationship_kwargs={"overlaps": "tags"} + ) resource_id: UUID resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) - artifact: List["ArtifactSchema"] = Relationship( - back_populates="tags", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)", - overlaps="tags,model,artifact_version,model_version", - ), - ) - artifact_version: List["ArtifactVersionSchema"] = Relationship( - back_populates="tags", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)", - overlaps="tags,model,artifact,model_version", - ), - ) - model: List["ModelSchema"] = Relationship( - back_populates="tags", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)", - overlaps="tags,artifact,artifact_version,model_version", - ), - ) - model_version: List["ModelVersionSchema"] = Relationship( - back_populates="tags", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)", - overlaps="tags,model,artifact,artifact_version", - ), - ) @classmethod def from_request(cls, request: TagResourceRequest) -> "TagResourceSchema": diff --git a/src/zenml/zen_stores/schemas/utils.py b/src/zenml/zen_stores/schemas/utils.py index 5484a6a9cc8..c2a3d24478c 100644 --- a/src/zenml/zen_stores/schemas/utils.py +++ b/src/zenml/zen_stores/schemas/utils.py @@ -75,35 +75,34 @@ def get_page_from_list( class RunMetadataInterface: """The interface for entities with run metadata.""" - run_metadata_resources = Relationship() + run_metadata = Relationship() def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: - """Fetches all the metadata entries related to the artifact version. + """Fetches all the metadata entries related to the entity. Returns: - a dictionary, where the key is the key of the metadata entry + A dictionary, where the key is the key of the metadata entry and the values represent the list of entries with this key. """ metadata_collection: Dict[str, List[RunMetadataEntry]] = {} - # Fetch the metadata related to this step - for rm in self.run_metadata_resources: - if rm.run_metadata.key not in metadata_collection: - metadata_collection[rm.run_metadata.key] = [] - metadata_collection[rm.run_metadata.key].append( + for rm in self.run_metadata: + if rm.key not in metadata_collection: + metadata_collection[rm.key] = [] + metadata_collection[rm.key].append( RunMetadataEntry( - value=json.loads(rm.run_metadata.value), - created=rm.run_metadata.created, + value=json.loads(rm.value), + created=rm.created, ) ) return metadata_collection def fetch_metadata(self) -> Dict[str, MetadataType]: - """Fetches the latest metadata entry related to the artifact version. + """Fetches the latest metadata entry related to the entity. Returns: - a dictionary, where the key is the key of the metadata entry + A dictionary, where the key is the key of the metadata entry and the values represent the latest entry with this key. """ metadata_collection = self.fetch_metadata_collection()