diff --git a/docs/book/.gitbook/assets/argilla_annotator.png b/docs/book/.gitbook/assets/argilla_annotator.png index 4cde7bac206..62327909f6d 100644 Binary files a/docs/book/.gitbook/assets/argilla_annotator.png and b/docs/book/.gitbook/assets/argilla_annotator.png differ diff --git a/docs/book/component-guide/annotators/annotators.md b/docs/book/component-guide/annotators/annotators.md index f0592a1eb3f..eed8c494e28 100644 --- a/docs/book/component-guide/annotators/annotators.md +++ b/docs/book/component-guide/annotators/annotators.md @@ -55,7 +55,7 @@ The core parts of the annotation workflow include: ### List of available annotators For production use cases, some more flavors can be found in specific `integrations` modules. In terms of annotators, -ZenML features integrations with `label_studio` and `pigeon`. +ZenML features integrations with the following tools. | Annotator | Flavor | Integration | Notes | |-----------------------------------------|----------------|----------------|----------------------------------------------------------------------| diff --git a/docs/book/component-guide/annotators/argilla.md b/docs/book/component-guide/annotators/argilla.md index b0ed6f92a53..b136e0a4cd9 100644 --- a/docs/book/component-guide/annotators/argilla.md +++ b/docs/book/component-guide/annotators/argilla.md @@ -4,12 +4,7 @@ description: Annotating data using Argilla. # Argilla -[Argilla](https://github.com/argilla-io/argilla) is an open-source data curation -platform designed to enhance the development of both small and large language -models (LLMs) and NLP tasks in general. It enables users to build robust -language models through faster data curation using both human and machine -feedback, providing support for each step in the MLOps cycle, from data labeling -to model monitoring. +[Argilla](https://github.com/argilla-io/argilla) is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. It enables users to build robust language models through faster data curation using both human and machine feedback, providing support for each step in the MLOps cycle, from data labeling to model monitoring. ![Argilla Annotator](../../.gitbook/assets/argilla_annotator.png) @@ -31,7 +26,7 @@ of Argilla as well as a deployed instance of Argilla. There is an easy way to deploy Argilla as a [Hugging Face Space](https://huggingface.co/docs/hub/spaces-sdks-docker-argilla), for instance, which is documented in the [Argilla -documentation](https://docs.argilla.io/en/latest/getting_started/installation/deployments/huggingface-spaces.html). +documentation](https://docs.argilla.io/latest/getting_started/quickstart/). ### How to deploy it? @@ -59,16 +54,16 @@ zenml secret create argilla_secrets --api_key="" Then register your annotator with ZenML: ```shell -zenml annotator register argilla --flavor argilla --authentication_secret=argilla_secrets +zenml annotator register argilla --flavor argilla --authentication_secret=argilla_secrets --port=6900 ``` When using a deployed instance of Argilla, the instance URL must be specified without any trailing `/` at the end. If you are using a Hugging Face Spaces instance and its visibility is set to private, you must also set the -`extra_headers` parameter which would include a Hugging Face token. For example: +`headers` parameter which would include a Hugging Face token. For example: ```shell -zenml annotator register argilla --flavor argilla --authentication_secret=argilla_secrets --instance_url="https://[your-owner-name]-[your_space_name].hf.space" --extra_headers="{"Authorization": f"Bearer {}"}" +zenml annotator register argilla --flavor argilla --authentication_secret=argilla_secrets --instance_url="https://[your-owner-name]-[your_space_name].hf.space" --headers='{"Authorization": "Bearer {[your_hugging_face_token]}"}' ``` Finally, add all these components to a stack and set it as your active stack. @@ -95,9 +90,8 @@ functionality via the ZenML SDK. You can access information about the datasets you're using with the `zenml annotator dataset list`. To work on annotation for a particular dataset, you can -run `zenml annotator dataset annotate `. What follows is an -overview of some key components to the Argilla integration and how it can be -used. +run `zenml annotator dataset annotate `. This will open the Argilla +web interface for you to start annotating the dataset. #### Argilla Annotator Stack Component diff --git a/docs/mocked_libs.json b/docs/mocked_libs.json index 796dbeea0a1..605258569b9 100644 --- a/docs/mocked_libs.json +++ b/docs/mocked_libs.json @@ -229,10 +229,7 @@ "xgboost", "argilla", "argilla.client", - "argilla.client.client", - "argilla.client.sdk", - "argilla.client.sdk.commons", - "argilla.client.sdk.commons.errors", + "argilla._exceptions._api", "peewee", "prodigy", "prodigy.components", diff --git a/scripts/install-zenml-dev.sh b/scripts/install-zenml-dev.sh index c3ca2f61f57..7d0e9521c93 100755 --- a/scripts/install-zenml-dev.sh +++ b/scripts/install-zenml-dev.sh @@ -36,7 +36,7 @@ install_integrations() { # figure out the python version python_version=$(python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))") - ignore_integrations="feast label_studio bentoml seldon pycaret skypilot_aws skypilot_gcp skypilot_azure pigeon prodigy" + ignore_integrations="feast label_studio bentoml seldon pycaret skypilot_aws skypilot_gcp skypilot_azure pigeon prodigy argilla" # Ignore tensorflow and deepchecks only on Python 3.12 if [ "$python_version" = "3.12" ]; then diff --git a/src/zenml/integrations/argilla/__init__.py b/src/zenml/integrations/argilla/__init__.py index 3953cf3863b..9d87666f673 100644 --- a/src/zenml/integrations/argilla/__init__.py +++ b/src/zenml/integrations/argilla/__init__.py @@ -26,7 +26,7 @@ class ArgillaIntegration(Integration): NAME = ARGILLA REQUIREMENTS = [ - "argilla>=1.20.0,<2", + "argilla>=2.0.0", ] @classmethod diff --git a/src/zenml/integrations/argilla/annotators/argilla_annotator.py b/src/zenml/integrations/argilla/annotators/argilla_annotator.py index 09e0790c56b..fe04ce40dc5 100644 --- a/src/zenml/integrations/argilla/annotators/argilla_annotator.py +++ b/src/zenml/integrations/argilla/annotators/argilla_annotator.py @@ -14,11 +14,12 @@ """Implementation of the Argilla annotation integration.""" import json -from typing import Any, List, Tuple, Type, cast +import webbrowser +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast import argilla as rg -from argilla.client.client import Argilla as ArgillaClient -from argilla.client.sdk.commons.errors import BaseClientError, NotFoundApiError +from argilla._exceptions._api import ArgillaAPIError +from argilla.client import Argilla as ArgillaClient from zenml.annotators.base_annotator import BaseAnnotator from zenml.integrations.argilla.flavors import ( @@ -67,7 +68,7 @@ def get_url(self) -> str: ) def _get_client(self) -> ArgillaClient: - """Gets Argilla client. + """Gets the Argilla client. Returns: Argilla client. @@ -75,7 +76,7 @@ def _get_client(self) -> ArgillaClient: config = self.config init_kwargs = {"api_url": self.get_url()} - # set the API key from the secret or using settings + # Set the API key from the secret or using settings authentication_secret = self.get_authentication_secret() if config.api_key and authentication_secret: api_key = config.api_key @@ -92,194 +93,333 @@ def _get_client(self) -> ArgillaClient: if api_key: init_kwargs["api_key"] = api_key - if config.workspace is not None: - init_kwargs["workspace"] = config.workspace - if config.extra_headers is not None: - init_kwargs["extra_headers"] = json.loads(config.extra_headers) + if config.headers is not None: + init_kwargs["headers"] = json.loads(config.headers) if config.httpx_extra_kwargs is not None: init_kwargs["httpx_extra_kwargs"] = json.loads( config.httpx_extra_kwargs ) try: - _ = rg.active_client() - except BaseClientError: - rg.init(**init_kwargs) - return rg.active_client() + _ = rg.Argilla(**init_kwargs).me + except ArgillaAPIError as e: + logger.error(f"Failed to verify the Argilla instance: {str(e)}") + return rg.Argilla(**init_kwargs) - def get_url_for_dataset(self, dataset_name: str) -> str: + def get_url_for_dataset(self, dataset_name: str, **kwargs: Any) -> str: """Gets the URL of the annotation interface for the given dataset. Args: dataset_name: The name of the dataset. + **kwargs: Additional keyword arguments to pass to the Argilla client. + -workspace: The name of the workspace. By default, the first available. Returns: - The URL of the annotation interface. + The URL of of the dataset annotation interface. """ - dataset_id = self.get_dataset(dataset_name=dataset_name).id + workspace = kwargs.get("workspace") + + dataset_id = self.get_dataset( + dataset_name=dataset_name, workspace=workspace + ).id return f"{self.get_url()}/dataset/{dataset_id}/annotation-mode" - def get_datasets(self) -> List[Any]: + def get_datasets(self, **kwargs: Any) -> List[Any]: """Gets the datasets currently available for annotation. + Args: + **kwargs: Additional keyword arguments to pass to the Argilla client. + -workspace: The name of the workspace. By default, the first available. + If set, only the datasets in the workspace will be returned. + Returns: A list of datasets. """ - old_datasets = self._get_client().list_datasets() - new_datasets = rg.FeedbackDataset.list() + workspace = kwargs.get("workspace") + + if workspace is None: + datasets = list(self._get_client().datasets) + else: + datasets = list(self._get_client().workspaces(workspace).datasets) + + return datasets + + def get_dataset_names(self, **kwargs: Any) -> List[str]: + """Gets the names of the datasets. + + Args: + **kwargs: Additional keyword arguments to pass to the Argilla client. + -workspace: The name of the workspace. By default, the first available. + If set, only the dataset names in the workspace will be returned. + + Returns: + A list of dataset names. + """ + workspace = kwargs.get("workspace") + + if workspace is None: + dataset_names = [dataset.name for dataset in self.get_datasets()] + else: + dataset_names = [ + dataset.name + for dataset in self.get_datasets(workspace=workspace) + ] + + return dataset_names + + def _get_data_by_status( + self, dataset_name: str, status: str, workspace: Optional[str] + ) -> Any: + """Gets the dataset containing the data with the specified status. + + Args: + dataset_name: The name of the dataset. + status: The response status to filter by ('completed' for labeled, + 'pending' for unlabeled). + workspace: The name of the workspace. By default, the first available. + + Returns: + The list of records with the specified status. + """ + dataset = self.get_dataset( + dataset_name=dataset_name, workspace=workspace + ) - # Deduplicate datasets based on their names - dataset_names = set() - deduplicated_datasets = [] - for dataset in new_datasets + old_datasets: - if dataset.name not in dataset_names: - dataset_names.add(dataset.name) - deduplicated_datasets.append(dataset) + query = rg.Query(filter=rg.Filter([("status", "==", status)])) - return deduplicated_datasets + return dataset.records( + query=query, + with_suggestions=True, + with_vectors=True, + with_responses=True, + ).to_list() - def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]: + def get_dataset_stats( + self, dataset_name: str, **kwargs: Any + ) -> Tuple[int, int]: """Gets the statistics of the given dataset. Args: dataset_name: The name of the dataset. + **kwargs: Additional keyword arguments to pass to the Argilla client. + -workspace: The name of the workspace. By default, the first available. Returns: A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset. """ - dataset = self.get_dataset(dataset_name=dataset_name) + workspace = kwargs.get("workspace") + labeled_task_count = len( - dataset.filter_by(response_status="submitted") + self._get_data_by_status( + dataset_name=dataset_name, + status="completed", + workspace=workspace, + ) ) unlabeled_task_count = len( - dataset.filter_by(response_status="pending") + self._get_data_by_status( + dataset_name=dataset_name, + status="pending", + workspace=workspace, + ) ) + return (labeled_task_count, unlabeled_task_count) - def add_dataset(self, **kwargs: Any) -> Any: - """Registers a dataset for annotation. + def launch(self, **kwargs: Any) -> None: + """Launches the annotation interface. + + Args: + **kwargs: Additional keyword arguments to pass to the Argilla client. + """ + url = kwargs.get("api_url") or self.get_url() - You must pass a `dataset_name` and a `dataset` object to this method. + if self._get_client(): + webbrowser.open(url, new=1, autoraise=True) + else: + logger.warning( + "Could not launch annotation interface" + "because the connection could not be established." + ) + + def add_dataset(self, **kwargs: Any) -> Any: + """Create a dataset for annotation. Args: - **kwargs: Additional keyword arguments to pass to the Argilla - client. + **kwargs: Additional keyword arguments to pass to the Argilla client. + -dataset_name: The name of the dataset. + -settings: The settings for the dataset. + -workspace: The name of the workspace. By default, the first available. Returns: An Argilla dataset object. Raises: - ValueError: if 'dataset_name' and 'dataset' aren't provided. + ValueError: if `dataset_name` or `settings` aren't provided. + RuntimeError: if the workspace creation fails. + RuntimeError: if the dataset creation fails. """ dataset_name = kwargs.get("dataset_name") - dataset = kwargs.get("dataset") + settings = kwargs.get("settings") + workspace = kwargs.get("workspace") - if not dataset_name: - raise ValueError("`dataset_name` keyword argument is required.") - elif dataset is None: - raise ValueError("`dataset` keyword argument is required.") + if dataset_name is None or settings is None: + raise ValueError( + "`dataset_name` and `settings` keyword arguments are required." + ) + + if workspace is None and not self._get_client().workspaces: + workspace_to_create = rg.Workspace(name="argilla") + try: + workspace = workspace_to_create.create() + except Exception as e: + raise RuntimeError( + "Failed to create the `argilla` workspace." + ) from e try: - logger.info(f"Pushing dataset '{dataset_name}' to Argilla...") - dataset.push_to_argilla(name=dataset_name) - logger.info(f"Dataset '{dataset_name}' pushed successfully.") + dataset = rg.Dataset( + name=dataset_name, workspace=workspace, settings=settings + ) + logger.info(f"Creating the dataset '{dataset_name}' in Argilla...") + dataset.create() + logger.info(f"Dataset '{dataset_name}' successfully created.") + return self.get_dataset( + dataset_name=dataset_name, workspace=workspace + ) except Exception as e: logger.error( - f"Failed to push dataset '{dataset_name}' to Argilla: {str(e)}" + f"Failed to create dataset '{dataset_name}' in Argilla: {str(e)}" ) - raise ValueError( - f"Failed to push dataset to Argilla: {str(e)}" + raise RuntimeError( + f"Failed to create the dataset '{dataset_name}' in Argilla: {str(e)}" ) from e - return self.get_dataset(dataset_name=dataset_name) - def delete_dataset(self, **kwargs: Any) -> None: - """Deletes a dataset from the annotation interface. + def add_records( + self, + dataset_name: str, + records: Union[Any, List[Dict[str, Any]]], + workspace: Optional[str] = None, + mapping: Optional[Dict[str, str]] = None, + ) -> Any: + """Add records to an Argilla dataset for annotation. Args: - **kwargs: Additional keyword arguments to pass to the Argilla - client. + dataset_name: The name of the dataset. + records: The records to add to the dataset. + workspace: The name of the workspace. By default, the first available. + mapping: The mapping of the records to the dataset fields. By default, None. Raises: - ValueError: If the dataset name is not provided. + RuntimeError: If the records cannot be loaded to Argilla. """ - dataset_name = kwargs.get("dataset_name") - if not dataset_name: - raise ValueError("`dataset_name` keyword argument is required.") + dataset = self.get_dataset( + dataset_name=dataset_name, workspace=workspace + ) try: - self._get_client().delete(name=dataset_name) - self.get_dataset(dataset_name=dataset_name).delete() - logger.info(f"Dataset '{dataset_name}' deleted successfully.") - except ValueError: - logger.warning( - f"Dataset '{dataset_name}' not found. Skipping deletion." + logger.info( + f"Loading the records to '{dataset_name}' in Argilla..." ) + dataset.records.log(records=records, mapping=mapping) + logger.info( + f"Records loaded successfully to Argilla for '{dataset_name}'." + ) + except Exception as e: + logger.error( + f"Failed to load the records to Argilla for '{dataset_name}': {str(e)}" + ) + raise RuntimeError( + f"Failed to load the records to Argilla: {str(e)}" + ) from e def get_dataset(self, **kwargs: Any) -> Any: """Gets the dataset with the given name. Args: **kwargs: Additional keyword arguments to pass to the Argilla client. + -dataset_name: The name of the dataset. + -workspace: The name of the workspace. By default, the first available. Returns: - The Argilla DatasetModel object for the given name. + The Argilla Dataset for the given name and workspace, if specified. Raises: ValueError: If the dataset name is not provided or if the dataset does not exist. """ dataset_name = kwargs.get("dataset_name") + workspace = kwargs.get("workspace") + if not dataset_name: raise ValueError("`dataset_name` keyword argument is required.") try: - if rg.FeedbackDataset.from_argilla(name=dataset_name) is not None: - return rg.FeedbackDataset.from_argilla(name=dataset_name) + dataset = self._get_client().datasets( + name=dataset_name, workspace=workspace + ) + if dataset is None: + logger.error(f"Dataset '{dataset_name}' not found.") else: - return self._get_client().get_dataset(name=dataset_name) - except (NotFoundApiError, ValueError) as e: + return dataset + except ValueError as e: logger.error(f"Dataset '{dataset_name}' not found.") raise ValueError(f"Dataset '{dataset_name}' not found.") from e - def get_data_by_status(self, dataset_name: str, status: str) -> Any: - """Gets the dataset containing the data with the specified status. + def delete_dataset(self, **kwargs: Any) -> None: + """Deletes a dataset from the annotation interface. Args: - dataset_name: The name of the dataset. - status: The response status to filter by ('submitted' for labeled, - 'pending' for unlabeled). - - Returns: - The dataset containing the data with the specified status. + **kwargs: Additional keyword arguments to pass to the Argilla client. + -dataset_name: The name of the dataset. + -workspace: The name of the workspace. By default, the first available Raises: - ValueError: If the dataset name is not provided. + ValueError: If the dataset name is not provided or if the datasets + is not found. """ + dataset_name = kwargs.get("dataset_name") + workspace = kwargs.get("workspace") + if not dataset_name: - raise ValueError("`dataset_name` argument is required.") + raise ValueError("`dataset_name` keyword argument is required.") - return self.get_dataset(dataset_name=dataset_name).filter_by( - response_status=status - ) + try: + dataset = self.get_dataset( + dataset_name=dataset_name, workspace=workspace + ) + dataset.delete() + logger.info(f"Dataset '{dataset_name}' deleted successfully.") + except ValueError: + logger.warning( + f"Dataset '{dataset_name}' not found. Skipping deletion." + ) def get_labeled_data(self, **kwargs: Any) -> Any: """Gets the dataset containing the labeled data. Args: **kwargs: Additional keyword arguments to pass to the Argilla client. + -dataset_name: The name of the dataset. + -workspace: The name of the workspace. By default, the first available. Returns: - The dataset containing the labeled data. + The list of annotated records. Raises: ValueError: If the dataset name is not provided. """ - if dataset_name := kwargs.get("dataset_name"): - return self.get_data_by_status(dataset_name, status="submitted") - else: + dataset_name = kwargs.get("dataset_name") + workspace = kwargs.get("workspace") + + if not dataset_name: raise ValueError("`dataset_name` keyword argument is required.") + return self._get_data_by_status( + dataset_name, workspace=workspace, status="completed" + ) + def get_unlabeled_data(self, **kwargs: str) -> Any: """Gets the dataset containing the unlabeled data. @@ -287,12 +427,17 @@ def get_unlabeled_data(self, **kwargs: str) -> Any: **kwargs: Additional keyword arguments to pass to the Argilla client. Returns: - The dataset containing the unlabeled data. + The list of pending records for annotation. Raises: ValueError: If the dataset name is not provided. """ - if dataset_name := kwargs.get("dataset_name"): - return self.get_data_by_status(dataset_name, status="pending") - else: + dataset_name = kwargs.get("dataset_name") + workspace = kwargs.get("workspace") + + if not dataset_name: raise ValueError("`dataset_name` keyword argument is required.") + + return self._get_data_by_status( + dataset_name, workspace=workspace, status="pending" + ) diff --git a/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py b/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py index f14d5a86c4d..649c9eb4cd7 100644 --- a/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +++ b/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py @@ -24,6 +24,7 @@ from zenml.config.base_settings import BaseSettings from zenml.integrations.argilla import ARGILLA_ANNOTATOR_FLAVOR from zenml.stack.authentication_mixin import AuthenticationConfigMixin +from zenml.utils import deprecation_utils from zenml.utils.secret_utils import SecretField if TYPE_CHECKING: @@ -43,19 +44,23 @@ class ArgillaAnnotatorSettings(BaseSettings): Attributes: instance_url: URL of the Argilla instance. api_key: The api_key for Argilla - workspace: The workspace to use for the annotation interface. port: The port to use for the annotation interface. - extra_headers: Extra headers to include in the request. + headers: Extra headers to include in the request. httpx_extra_kwargs: Extra kwargs to pass to the client. """ instance_url: str = DEFAULT_LOCAL_INSTANCE_URL api_key: Optional[str] = SecretField(default=None) - workspace: Optional[str] = "admin" - port: Optional[int] - extra_headers: Optional[str] = None + port: Optional[int] = DEFAULT_LOCAL_ARGILLA_PORT + headers: Optional[str] = None httpx_extra_kwargs: Optional[str] = None + extra_headers: Optional[str] = None + + _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( + ("extra_headers", "headers"), + ) + @field_validator("instance_url") @classmethod def ensure_instance_url_ends_without_slash(cls, instance_url: str) -> str: