Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions bigframes/core/local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@

import geopandas # type: ignore
import numpy as np
import pandas
import pandas as pd
import pyarrow as pa
import pyarrow.parquet # type: ignore

import bigframes.core.schema as schemata
import bigframes.core.utils as utils
import bigframes.dtypes


Expand Down Expand Up @@ -58,15 +59,12 @@ class ManagedArrowTable:
schema: schemata.ArraySchema = dataclasses.field(hash=False)
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)

def __post_init__(self):
self.validate()

@functools.cached_property
def metadata(self) -> LocalTableMetadata:
return LocalTableMetadata.from_arrow(self.data)

@classmethod
def from_pandas(cls, dataframe: pandas.DataFrame) -> ManagedArrowTable:
def from_pandas(cls, dataframe: pd.DataFrame) -> ManagedArrowTable:
"""Creates managed table from pandas. Ignores index, col names must be unique strings"""
columns: list[pa.ChunkedArray] = []
fields: list[schemata.SchemaItem] = []
Expand All @@ -78,9 +76,11 @@ def from_pandas(cls, dataframe: pandas.DataFrame) -> ManagedArrowTable:
columns.append(new_arr)
fields.append(schemata.SchemaItem(str(name), bf_type))

return ManagedArrowTable(
mat = ManagedArrowTable(
pa.table(columns, names=column_names), schemata.ArraySchema(tuple(fields))
)
mat.validate(include_content=True)
return mat

@classmethod
def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
Expand All @@ -91,10 +91,12 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
columns.append(new_arr)
fields.append(schemata.SchemaItem(name, bf_type))

return ManagedArrowTable(
mat = ManagedArrowTable(
pa.table(columns, names=table.column_names),
schemata.ArraySchema(tuple(fields)),
)
mat.validate()
return mat

def to_parquet(
self,
Expand Down Expand Up @@ -140,8 +142,7 @@ def itertuples(
):
yield tuple(row_dict.values())

def validate(self):
# TODO: Content-based validation for some datatypes (eg json, wkt, list) where logical domain is smaller than pyarrow type
def validate(self, include_content: bool = False):
for bf_field, arrow_field in zip(self.schema.items, self.data.schema):
expected_arrow_type = _get_managed_storage_type(bf_field.dtype)
arrow_type = arrow_field.type
Expand All @@ -150,6 +151,38 @@ def validate(self):
f"Field {bf_field} has arrow array type: {arrow_type}, expected type: {expected_arrow_type}"
)

if include_content:
for batch in self.data.to_batches():
for field in self.schema.items:
_validate_content(batch.column(field.column), field.dtype)


def _validate_content(array: pa.Array, dtype: bigframes.dtypes.Dtype):
"""
Recursively validates the content of a PyArrow Array based on the
expected BigFrames dtype, focusing on complex types like JSON, structs,
and arrays where the Arrow type alone isn't sufficient.
"""
# TODO: validate GEO data context.
if dtype == bigframes.dtypes.JSON_DTYPE:
values = array.to_pandas()
for data in values:
# Skip scalar null values to avoid `TypeError` from json.load.
if not utils.is_list_like(data) and pd.isna(data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't everything a string or a null? how do you get a list-like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, the data is formatted using pyarrow storage, so these checks are no longer necessary. They are only required to handle values like None, or [] in pandas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-adding null value checks to ensure compatibility with Python 3.9 tests.

continue
try:
# Attempts JSON parsing.
json.loads(data)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format found: {data!r}") from e
elif bigframes.dtypes.is_struct_like(dtype):
for field_name, dtype in bigframes.dtypes.get_struct_fields(dtype).items():
_validate_content(array.field(field_name), dtype)
elif bigframes.dtypes.is_array_like(dtype):
return _validate_content(
array.flatten(), bigframes.dtypes.get_array_inner_type(dtype)
)


# Sequential iterator, but could split into batches and leverage parallelism for speed
def _iter_table(
Expand Down Expand Up @@ -226,7 +259,7 @@ def _(


def _adapt_pandas_series(
series: pandas.Series,
series: pd.Series,
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]:
# Mostly rely on pyarrow conversions, but have to convert geo without its help.
if series.dtype == bigframes.dtypes.GEO_DTYPE:
Expand Down
53 changes: 51 additions & 2 deletions tests/system/small/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,8 @@ def test_read_pandas_json_series(session, write_engine):
json_data = [
"1",
None,
'["1","3","5"]',
'{"a":1,"b":["x","y"],"c":{"x":[],"z":false}}',
'[1,"3",null,{"a":null}]',
'{"a":1,"b":["x","y"],"c":{"x":[],"y":null,"z":false}}',
]
expected_series = pd.Series(json_data, dtype=bigframes.dtypes.JSON_DTYPE)

Expand All @@ -975,6 +975,28 @@ def test_read_pandas_json_series(session, write_engine):
)


@pytest.mark.parametrize(
("write_engine"),
[
pytest.param("default"),
pytest.param("bigquery_inline"),
pytest.param("bigquery_load"),
pytest.param("bigquery_streaming"),
],
)
def test_read_pandas_json_series_w_invalid_json(session, write_engine):
json_data = [
"False", # Should be "false"
]
pd_s = pd.Series(json_data, dtype=bigframes.dtypes.JSON_DTYPE)

with pytest.raises(
ValueError,
match="Invalid JSON format found",
):
session.read_pandas(pd_s, write_engine=write_engine)


@pytest.mark.parametrize(
("write_engine"),
[
Expand Down Expand Up @@ -1056,6 +1078,33 @@ def test_read_pandas_w_nested_json(session, write_engine):
pd.testing.assert_series_equal(bq_s, pd_s)


@pytest.mark.parametrize(
("write_engine"),
[
pytest.param("default"),
pytest.param("bigquery_inline"),
pytest.param("bigquery_load"),
pytest.param("bigquery_streaming"),
],
)
def test_read_pandas_w_nested_invalid_json(session, write_engine):
# TODO: supply a reason why this isn't compatible with pandas 1.x
pytest.importorskip("pandas", minversion="2.0.0")
data = [
[{"json_field": "NULL"}], # Should be "null"
]
pa_array = pa.array(data, type=pa.list_(pa.struct([("json_field", pa.string())])))
pd_s = pd.Series(
arrays.ArrowExtensionArray(pa_array), # type: ignore
dtype=pd.ArrowDtype(
pa.list_(pa.struct([("json_field", bigframes.dtypes.JSON_ARROW_TYPE)]))
),
)

with pytest.raises(ValueError, match="Invalid JSON format found"):
session.read_pandas(pd_s, write_engine=write_engine)


@pytest.mark.parametrize(
("write_engine"),
[
Expand Down