Skip to content

Commit 9e3da9c

Browse files
Various QoL Updates (#22)
* Added 1000 limit for dlo/dmo read, dataframe and output dlo schema validation and replaced static list of std python libs --------- Co-authored-by: Chandresh Patel <[email protected]>
1 parent eb79b53 commit 9e3da9c

File tree

5 files changed

+120
-64
lines changed

5 files changed

+120
-64
lines changed

src/datacustomcode/io/reader/query_api.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
logger = logging.getLogger(__name__)
4444

4545

46-
SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {}"
46+
SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {} LIMIT {}"
4747
PANDAS_TYPE_MAPPING = {
4848
"object": StringType(),
4949
"int64": LongType(),
@@ -85,29 +85,40 @@ def __init__(self, spark: SparkSession) -> None:
8585
)
8686

8787
def read_dlo(
88-
self, name: str, schema: Union[AtomicType, StructType, str, None] = None
88+
self,
89+
name: str,
90+
schema: Union[AtomicType, StructType, str, None] = None,
91+
row_limit: int = 1000,
8992
) -> PySparkDataFrame:
9093
"""
91-
Read a Data Lake Object (DLO) from the Data Cloud.
94+
Read a Data Lake Object (DLO) from the Data Cloud, limited to a number of rows.
9295
9396
Args:
9497
name (str): The name of the DLO.
9598
schema (Optional[Union[AtomicType, StructType, str]]): Schema of the DLO.
99+
row_limit (int): Maximum number of rows to fetch.
96100
97101
Returns:
98102
PySparkDataFrame: The PySpark DataFrame.
99103
"""
100-
pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name))
104+
pandas_df = self._conn.get_pandas_dataframe(
105+
SQL_QUERY_TEMPLATE.format(name, row_limit)
106+
)
101107
if not schema:
102108
# auto infer schema
103109
schema = _pandas_to_spark_schema(pandas_df)
104110
spark_dataframe = self.spark.createDataFrame(pandas_df, schema)
105111
return spark_dataframe
106112

107113
def read_dmo(
108-
self, name: str, schema: Union[AtomicType, StructType, str, None] = None
114+
self,
115+
name: str,
116+
schema: Union[AtomicType, StructType, str, None] = None,
117+
row_limit: int = 1000,
109118
) -> PySparkDataFrame:
110-
pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name))
119+
pandas_df = self._conn.get_pandas_dataframe(
120+
SQL_QUERY_TEMPLATE.format(name, row_limit)
121+
)
111122
if not schema:
112123
# auto infer schema
113124
schema = _pandas_to_spark_schema(pandas_df)

src/datacustomcode/io/writer/print.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,73 @@
1414
# limitations under the License.
1515

1616

17-
from pyspark.sql import DataFrame as PySparkDataFrame
17+
from typing import Optional
1818

19+
from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession
20+
21+
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
1922
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
2023

2124

2225
class PrintDataCloudWriter(BaseDataCloudWriter):
2326
CONFIG_NAME = "PrintDataCloudWriter"
2427

28+
def __init__(
29+
self, spark: SparkSession, reader: Optional[QueryAPIDataCloudReader] = None
30+
) -> None:
31+
super().__init__(spark)
32+
self.reader = QueryAPIDataCloudReader(self.spark) if reader is None else reader
33+
34+
def validate_dataframe_columns_against_dlo(
35+
self,
36+
dataframe: PySparkDataFrame,
37+
dlo_name: str,
38+
) -> None:
39+
"""
40+
Validates that all columns in the given dataframe exist in the DLO schema.
41+
42+
Args:
43+
dataframe (PySparkDataFrame): The DataFrame to validate.
44+
dlo_name (str): The name of the DLO to check against.
45+
reader (QueryAPIDataCloudReader): The reader to use for schema retrieval.
46+
47+
Raises:
48+
ValueError: If any columns in the dataframe are not present in the DLO
49+
schema.
50+
"""
51+
# Get DLO schema (no data, just schema)
52+
dlo_df = self.reader.read_dlo(dlo_name, row_limit=0)
53+
dlo_columns = set(dlo_df.columns)
54+
df_columns = set(dataframe.columns)
55+
56+
# Find columns in dataframe not present in DLO
57+
extra_columns = df_columns - dlo_columns
58+
if extra_columns:
59+
raise ValueError(
60+
"The following columns are not present in the \n"
61+
f"DLO '{dlo_name}': {sorted(extra_columns)}.\n"
62+
"To fix this error, you can either:\n"
63+
" - Drop these columns from your DataFrame before writing, e.g.,\n"
64+
" dataframe = dataframe.drop({cols})\n"
65+
" - Or, add these columns to the DLO schema in Data Cloud.".format(
66+
cols=sorted(extra_columns)
67+
)
68+
)
69+
2570
def write_to_dlo(
2671
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
2772
) -> None:
73+
74+
# Validate columns before proceeding
75+
self.validate_dataframe_columns_against_dlo(dataframe, name)
76+
2877
dataframe.show()
2978

3079
def write_to_dmo(
3180
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
3281
) -> None:
82+
# The way its validating for DLO and dataframes columns,
83+
# its not going to work for DMO because DMO may not exists,
84+
# so just show the dataframe.
85+
3386
dataframe.show()

src/datacustomcode/scan.py

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import ast
1818
import os
19+
import sys
1920
from typing import (
2021
Any,
2122
ClassVar,
@@ -40,6 +41,8 @@
4041
},
4142
}
4243

44+
STANDARD_LIBS = set(sys.stdlib_module_names)
45+
4346

4447
class DataAccessLayerCalls(pydantic.BaseModel):
4548
read_dlo: frozenset[str]
@@ -137,54 +140,6 @@ def found(self) -> DataAccessLayerCalls:
137140
class ImportVisitor(ast.NodeVisitor):
138141
"""AST Visitor that extracts external package imports from Python code."""
139142

140-
# Standard library modules that should be excluded from requirements
141-
STANDARD_LIBS: ClassVar[set[str]] = {
142-
"abc",
143-
"argparse",
144-
"ast",
145-
"asyncio",
146-
"base64",
147-
"collections",
148-
"configparser",
149-
"contextlib",
150-
"copy",
151-
"csv",
152-
"datetime",
153-
"enum",
154-
"functools",
155-
"glob",
156-
"hashlib",
157-
"http",
158-
"importlib",
159-
"inspect",
160-
"io",
161-
"itertools",
162-
"json",
163-
"logging",
164-
"math",
165-
"os",
166-
"pathlib",
167-
"pickle",
168-
"random",
169-
"re",
170-
"shutil",
171-
"site",
172-
"socket",
173-
"sqlite3",
174-
"string",
175-
"subprocess",
176-
"sys",
177-
"tempfile",
178-
"threading",
179-
"time",
180-
"traceback",
181-
"typing",
182-
"uuid",
183-
"warnings",
184-
"xml",
185-
"zipfile",
186-
}
187-
188143
# Additional packages to exclude from requirements.txt
189144
EXCLUDED_PACKAGES: ClassVar[set[str]] = {
190145
"datacustomcode", # Internal package
@@ -200,7 +155,7 @@ def visit_Import(self, node: ast.Import) -> None:
200155
# Get the top-level package name
201156
package = name.name.split(".")[0]
202157
if (
203-
package not in self.STANDARD_LIBS
158+
package not in STANDARD_LIBS
204159
and package not in self.EXCLUDED_PACKAGES
205160
and not package.startswith("_")
206161
):
@@ -213,7 +168,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
213168
# Get the top-level package
214169
package = node.module.split(".")[0]
215170
if (
216-
package not in self.STANDARD_LIBS
171+
package not in STANDARD_LIBS
217172
and package not in self.EXCLUDED_PACKAGES
218173
and not package.startswith("_")
219174
):

tests/io/reader/test_query_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_read_dlo(
143143

144144
# Verify get_pandas_dataframe was called with the right SQL
145145
mock_connection.get_pandas_dataframe.assert_called_once_with(
146-
SQL_QUERY_TEMPLATE.format("test_dlo")
146+
SQL_QUERY_TEMPLATE.format("test_dlo", 1000)
147147
)
148148

149149
# Verify DataFrame was created with auto-inferred schema
@@ -172,7 +172,7 @@ def test_read_dlo_with_schema(
172172

173173
# Verify get_pandas_dataframe was called with the right SQL
174174
mock_connection.get_pandas_dataframe.assert_called_once_with(
175-
SQL_QUERY_TEMPLATE.format("test_dlo")
175+
SQL_QUERY_TEMPLATE.format("test_dlo", 1000)
176176
)
177177

178178
# Verify DataFrame was created with provided schema
@@ -192,7 +192,7 @@ def test_read_dmo(
192192

193193
# Verify get_pandas_dataframe was called with the right SQL
194194
mock_connection.get_pandas_dataframe.assert_called_once_with(
195-
SQL_QUERY_TEMPLATE.format("test_dmo")
195+
SQL_QUERY_TEMPLATE.format("test_dmo", 1000)
196196
)
197197

198198
# Verify DataFrame was created
@@ -220,7 +220,7 @@ def test_read_dmo_with_schema(
220220

221221
# Verify get_pandas_dataframe was called with the right SQL
222222
mock_connection.get_pandas_dataframe.assert_called_once_with(
223-
SQL_QUERY_TEMPLATE.format("test_dmo")
223+
SQL_QUERY_TEMPLATE.format("test_dmo", 1000)
224224
)
225225

226226
# Verify DataFrame was created with provided schema

tests/io/writer/test_print.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,33 @@ def mock_dataframe(self):
2323
return df
2424

2525
@pytest.fixture
26-
def print_writer(self, mock_spark_session):
26+
def mock_reader(self):
27+
"""Create a mock QueryAPIDataCloudReader."""
28+
reader = MagicMock()
29+
mock_dlo_df = MagicMock()
30+
mock_dlo_df.columns = ["col1", "col2"]
31+
reader.read_dlo.return_value = mock_dlo_df
32+
return reader
33+
34+
@pytest.fixture
35+
def print_writer(self, mock_spark_session, mock_reader):
2736
"""Create a PrintDataCloudWriter instance."""
28-
return PrintDataCloudWriter(mock_spark_session)
37+
return PrintDataCloudWriter(mock_spark_session, mock_reader)
2938

3039
def test_write_to_dlo(self, print_writer, mock_dataframe):
3140
"""Test write_to_dlo method calls dataframe.show()."""
41+
# Mock the validate_dataframe_columns_against_dlo method
42+
print_writer.validate_dataframe_columns_against_dlo = MagicMock()
43+
3244
# Call the method
3345
print_writer.write_to_dlo("test_dlo", mock_dataframe, WriteMode.OVERWRITE)
3446

3547
# Verify show() was called
3648
mock_dataframe.show.assert_called_once()
3749

50+
# Verify validate_dataframe_columns_against_dlo was called
51+
print_writer.validate_dataframe_columns_against_dlo.assert_called_once()
52+
3853
def test_write_to_dmo(self, print_writer, mock_dataframe):
3954
"""Test write_to_dmo method calls dataframe.show()."""
4055
# Call the method
@@ -59,9 +74,31 @@ def test_ignores_name_and_write_mode(self, print_writer, mock_dataframe):
5974
for name, write_mode in test_cases:
6075
# Reset mock before each call
6176
mock_dataframe.show.reset_mock()
62-
77+
# Mock the validate_dataframe_columns_against_dlo method
78+
print_writer.validate_dataframe_columns_against_dlo = MagicMock()
6379
# Call method
6480
print_writer.write_to_dlo(name, mock_dataframe, write_mode)
6581

6682
# Verify show() was called with no arguments
6783
mock_dataframe.show.assert_called_once_with()
84+
85+
print_writer.validate_dataframe_columns_against_dlo.assert_called_once()
86+
87+
def test_validate_dataframe_columns_against_dlo(self, print_writer, mock_dataframe):
88+
"""Test validate_dataframe_columns_against_dlo method."""
89+
# Mock the QueryAPIDataCloudReader
90+
91+
# Set up mock dataframe columns
92+
mock_dataframe.columns = ["col1", "col2", "col3"]
93+
94+
# Test that validation raises ValueError for extra columns
95+
with pytest.raises(ValueError) as exc_info:
96+
print_writer.validate_dataframe_columns_against_dlo(
97+
mock_dataframe, "test_dlo"
98+
)
99+
100+
assert "col3" in str(exc_info.value)
101+
102+
# Test successful validation with matching columns
103+
mock_dataframe.columns = ["col1", "col2"]
104+
print_writer.validate_dataframe_columns_against_dlo(mock_dataframe, "test_dlo")

0 commit comments

Comments
 (0)