diff --git a/bigframes/session/_io/bigquery/read_gbq_table.py b/bigframes/session/_io/bigquery/read_gbq_table.py index 2dff16933f..6322040428 100644 --- a/bigframes/session/_io/bigquery/read_gbq_table.py +++ b/bigframes/session/_io/bigquery/read_gbq_table.py @@ -243,25 +243,17 @@ def get_index_cols( | int | bigframes.enums.DefaultIndexKind, *, - names: Optional[Iterable[str]] = None, + rename_to_schema: Optional[Dict[str, str]] = None, ) -> List[str]: """ If we can get a total ordering from the table, such as via primary key column(s), then return those too so that ordering generation can be avoided. """ - # Transform index_col -> index_cols so we have a variable that is # always a list of column names (possibly empty). schema_len = len(table.schema) - # If the `names` is provided, the index_col provided by the user is the new - # name, so we need to rename it to the original name in the table schema. - renamed_schema: Optional[Dict[str, str]] = None - if names is not None: - assert len(list(names)) == schema_len - renamed_schema = {name: field.name for name, field in zip(names, table.schema)} - index_cols: List[str] = [] if isinstance(index_col, bigframes.enums.DefaultIndexKind): if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64: @@ -278,8 +270,8 @@ def get_index_cols( f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}" ) elif isinstance(index_col, str): - if renamed_schema is not None: - index_col = renamed_schema.get(index_col, index_col) + if rename_to_schema is not None: + index_col = rename_to_schema.get(index_col, index_col) index_cols = [index_col] elif isinstance(index_col, int): if not 0 <= index_col < schema_len: @@ -291,8 +283,8 @@ def get_index_cols( elif isinstance(index_col, Iterable): for item in index_col: if isinstance(item, str): - if renamed_schema is not None: - item = renamed_schema.get(item, item) + if rename_to_schema is not None: + item = rename_to_schema.get(item, item) index_cols.append(item) elif isinstance(item, int): if not 0 <= item < schema_len: diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 814d44292e..add4efb6ab 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -96,22 +96,35 @@ def _to_index_cols( return index_cols -def _check_column_duplicates( - index_cols: Iterable[str], columns: Iterable[str], index_col_in_columns: bool -) -> Iterable[str]: - """Validates and processes index and data columns for duplicates and overlap. +def _check_duplicates(name: str, columns: Optional[Iterable[str]] = None): + """Check for duplicate column names in the provided iterable.""" + if columns is None: + return + columns_list = list(columns) + set_columns = set(columns_list) + if len(columns_list) > len(set_columns): + raise ValueError( + f"The '{name}' argument contains duplicate names. " + f"All column names specified in '{name}' must be unique." + ) - This function performs two main tasks: - 1. Ensures there are no duplicate column names within the `index_cols` list - or within the `columns` list. - 2. Based on the `index_col_in_columns` flag, it validates the relationship - between `index_cols` and `columns`. + +def _check_index_col_param( + index_cols: Iterable[str], + columns: Iterable[str], + *, + table_columns: Optional[Iterable[str]] = None, + index_col_in_columns: Optional[bool] = False, +): + """Checks for duplicates in `index_cols` and resolves overlap with `columns`. Args: index_cols (Iterable[str]): - An iterable of column names designated as the index. + Column names designated as the index columns. columns (Iterable[str]): - An iterable of column names designated as the data columns. + Used column names from table_columns. + table_columns (Iterable[str]): + A full list of column names in the table schema. index_col_in_columns (bool): A flag indicating how to handle overlap between `index_cols` and `columns`. @@ -121,40 +134,97 @@ def _check_column_duplicates( `columns`. An error is raised if an index column is not found in the `columns` list. """ - index_cols_list = list(index_cols) if index_cols is not None else [] - columns_list = list(columns) if columns is not None else [] - set_index = set(index_cols_list) - set_columns = set(columns_list) + _check_duplicates("index_col", index_cols) - if len(index_cols_list) > len(set_index): - raise ValueError( - "The 'index_col' argument contains duplicate names. " - "All column names specified in 'index_col' must be unique." - ) + if columns is not None and len(list(columns)) > 0: + set_index = set(list(index_cols) if index_cols is not None else []) + set_columns = set(list(columns) if columns is not None else []) - if len(columns_list) == 0: - return columns + if index_col_in_columns: + if not set_index.issubset(set_columns): + raise ValueError( + f"The specified index column(s) were not found: {set_index - set_columns}. " + f"Available columns are: {set_columns}" + ) + else: + if not set_index.isdisjoint(set_columns): + raise ValueError( + "Found column names that exist in both 'index_col' and 'columns' arguments. " + "These arguments must specify distinct sets of columns." + ) - if len(columns_list) > len(set_columns): - raise ValueError( - "The 'columns' argument contains duplicate names. " - "All column names specified in 'columns' must be unique." - ) + if not index_col_in_columns and table_columns is not None: + for key in index_cols: + if key not in table_columns: + possibility = min( + table_columns, + key=lambda item: bigframes._tools.strings.levenshtein_distance( + key, item + ), + ) + raise ValueError( + f"Column '{key}' of `index_col` not found in this table. Did you mean '{possibility}'?" + ) - if index_col_in_columns: - if not set_index.issubset(set_columns): - raise ValueError( - f"The specified index column(s) were not found: {set_index - set_columns}. " - f"Available columns are: {set_columns}" + +def _check_columns_param(columns: Iterable[str], table_columns: Iterable[str]): + """Validates that the specified columns are present in the table columns. + + Args: + columns (Iterable[str]): + Used column names from table_columns. + table_columns (Iterable[str]): + A full list of column names in the table schema. + Raises: + ValueError: If any column in `columns` is not found in the table columns. + """ + for column_name in columns: + if column_name not in table_columns: + possibility = min( + table_columns, + key=lambda item: bigframes._tools.strings.levenshtein_distance( + column_name, item + ), ) - return [col for col in columns if col not in set_index] - else: - if not set_index.isdisjoint(set_columns): raise ValueError( - "Found column names that exist in both 'index_col' and 'columns' arguments. " - "These arguments must specify distinct sets of columns." + f"Column '{column_name}' is not found. Did you mean '{possibility}'?" ) - return columns + + +def _check_names_param( + names: Iterable[str], + index_col: Iterable[str] + | str + | Iterable[int] + | int + | bigframes.enums.DefaultIndexKind, + columns: Iterable[str], + table_columns: Iterable[str], +): + len_names = len(list(names)) + len_table_columns = len(list(table_columns)) + len_columns = len(list(columns)) + if len_names > len_table_columns: + raise ValueError( + f"Too many columns specified: expected {len_table_columns}" + f" and found {len_names}" + ) + elif len_names < len_table_columns: + if isinstance(index_col, bigframes.enums.DefaultIndexKind) or index_col != (): + raise KeyError( + "When providing both `index_col` and `names`, ensure the " + "number of `names` matches the number of columns in your " + "data." + ) + if len_columns != 0: + # The 'columns' must be identical to the 'names'. If not, raise an error. + if len_columns != len_names: + raise ValueError( + "Number of passed names did not match number of header " + "fields in the file" + ) + if set(list(names)) != set(list(columns)): + raise ValueError("Usecols do not match columns") @dataclasses.dataclass @@ -545,11 +615,14 @@ def read_gbq_table( f"`max_results` should be a positive number, got {max_results}." ) + _check_duplicates("columns", columns) + table_ref = google.cloud.bigquery.table.TableReference.from_string( table_id, default_project=self._bqclient.project ) columns = list(columns) + include_all_columns = columns is None or len(columns) == 0 filters = typing.cast(list, list(filters)) # --------------------------------- @@ -563,72 +636,58 @@ def read_gbq_table( cache=self._df_snapshot, use_cache=use_cache, ) - table_column_names = {field.name for field in table.schema} if table.location.casefold() != self._storage_manager.location.casefold(): raise ValueError( f"Current session is in {self._storage_manager.location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}" ) - for key in columns: - if key not in table_column_names: - possibility = min( - table_column_names, - key=lambda item: bigframes._tools.strings.levenshtein_distance( - key, item - ), - ) - raise ValueError( - f"Column '{key}' of `columns` not found in this table. Did you mean '{possibility}'?" - ) - - # TODO(b/408499371): check `names` work with `use_cols` for read_csv method. + table_column_names = [field.name for field in table.schema] + rename_to_schema: Optional[Dict[str, str]] = None if names is not None: + _check_names_param(names, index_col, columns, table_column_names) + + # Additional unnamed columns is going to set as index columns len_names = len(list(names)) - len_columns = len(table.schema) - if len_names > len_columns: - raise ValueError( - f"Too many columns specified: expected {len_columns}" - f" and found {len_names}" - ) - elif len_names < len_columns: - if ( - isinstance(index_col, bigframes.enums.DefaultIndexKind) - or index_col != () - ): - raise KeyError( - "When providing both `index_col` and `names`, ensure the " - "number of `names` matches the number of columns in your " - "data." - ) - index_col = range(len_columns - len_names) + len_schema = len(table.schema) + if len(columns) == 0 and len_names < len_schema: + index_col = range(len_schema - len_names) names = [ - field.name for field in table.schema[: len_columns - len_names] + field.name for field in table.schema[: len_schema - len_names] ] + list(names) + assert len_schema >= len_names + assert len_names >= len(columns) + + table_column_names = table_column_names[: len(list(names))] + rename_to_schema = dict(zip(list(names), table_column_names)) + + if len(columns) != 0: + if names is None: + _check_columns_param(columns, table_column_names) + else: + _check_columns_param(columns, names) + names = columns + assert rename_to_schema is not None + columns = [rename_to_schema[renamed_name] for renamed_name in columns] + # Converting index_col into a list of column names requires # the table metadata because we might use the primary keys # when constructing the index. index_cols = bf_read_gbq_table.get_index_cols( table=table, index_col=index_col, - names=names, + rename_to_schema=rename_to_schema, ) - columns = list( - _check_column_duplicates(index_cols, columns, index_col_in_columns) + _check_index_col_param( + index_cols, + columns, + table_columns=table_column_names, + index_col_in_columns=index_col_in_columns, ) - - for key in index_cols: - if key not in table_column_names: - possibility = min( - table_column_names, - key=lambda item: bigframes._tools.strings.levenshtein_distance( - key, item - ), - ) - raise ValueError( - f"Column '{key}' of `index_col` not found in this table. Did you mean '{possibility}'?" - ) + if index_col_in_columns and not include_all_columns: + set_index = set(list(index_cols) if index_cols is not None else []) + columns = [col for col in columns if col not in set_index] # ----------------------------- # Optionally, execute the query @@ -715,7 +774,7 @@ def read_gbq_table( metadata_only=not self._scan_index_uniqueness, ) schema = schemata.ArraySchema.from_bq_table(table) - if columns: + if not include_all_columns: schema = schema.select(index_cols + columns) array_value = core.ArrayValue.from_table( table, @@ -767,14 +826,14 @@ def read_gbq_table( value_columns = [col for col in array_value.column_ids if col not in index_cols] if names is not None: - renamed_cols: Dict[str, str] = { - col: new_name for col, new_name in zip(array_value.column_ids, names) - } + assert rename_to_schema is not None + schema_to_rename = {value: key for key, value in rename_to_schema.items()} if index_col != bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64: index_names = [ - renamed_cols.get(index_col, index_col) for index_col in index_cols + schema_to_rename.get(index_col, index_col) + for index_col in index_cols ] - value_columns = [renamed_cols.get(col, col) for col in value_columns] + value_columns = [schema_to_rename.get(col, col) for col in value_columns] block = blocks.Block( array_value, @@ -898,9 +957,7 @@ def read_gbq_query( ) index_cols = _to_index_cols(index_col) - columns = _check_column_duplicates( - index_cols, columns, index_col_in_columns=False - ) + _check_index_col_param(index_cols, columns) filters_copy1, filters_copy2 = itertools.tee(filters) has_filters = len(list(filters_copy1)) != 0 diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index cbb441e5aa..809d08c6c1 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -146,9 +146,7 @@ def test_read_gbq_w_unknown_column( ): with pytest.raises( ValueError, - match=re.escape( - "Column 'int63_col' of `columns` not found in this table. Did you mean 'int64_col'?" - ), + match=re.escape("Column 'int63_col' is not found. Did you mean 'int64_col'?"), ): session.read_gbq( scalars_table_id, @@ -1365,6 +1363,132 @@ def test_read_csv_for_names_and_index_col( ) +@pytest.mark.parametrize( + "usecols", + [ + pytest.param(["a", "b", "c"], id="same"), + pytest.param(["a", "c"], id="less_than_names"), + ], +) +def test_read_csv_for_names_and_usecols( + session, usecols, df_and_gcs_csv_for_two_columns +): + _, path = df_and_gcs_csv_for_two_columns + + names = ["a", "b", "c"] + bf_df = session.read_csv(path, engine="bigquery", names=names, usecols=usecols) + + # Convert default pandas dtypes to match BigQuery DataFrames dtypes. + pd_df = session.read_csv( + path, names=names, usecols=usecols, dtype=bf_df.dtypes.to_dict() + ) + + assert bf_df.shape == pd_df.shape + assert bf_df.columns.tolist() == pd_df.columns.tolist() + + # BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs + # (b/280889935) or guarantee row ordering. + bf_df = bf_df.set_index(names[0]).sort_index() + pd_df = pd_df.set_index(names[0]) + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas()) + + +def test_read_csv_for_names_and_invalid_usecols( + session, df_and_gcs_csv_for_two_columns +): + _, path = df_and_gcs_csv_for_two_columns + + names = ["a", "b", "c"] + usecols = ["a", "X"] + with pytest.raises( + ValueError, + match=re.escape("Column 'X' is not found. "), + ): + session.read_csv(path, engine="bigquery", names=names, usecols=usecols) + + +@pytest.mark.parametrize( + ("usecols", "index_col"), + [ + pytest.param(["a", "b", "c"], "a", id="same"), + pytest.param(["a", "b", "c"], ["a", "b"], id="same_two_index"), + pytest.param(["a", "c"], 0, id="less_than_names"), + ], +) +def test_read_csv_for_names_and_usecols_and_indexcol( + session, usecols, index_col, df_and_gcs_csv_for_two_columns +): + _, path = df_and_gcs_csv_for_two_columns + + names = ["a", "b", "c"] + bf_df = session.read_csv( + path, engine="bigquery", names=names, usecols=usecols, index_col=index_col + ) + + # Convert default pandas dtypes to match BigQuery DataFrames dtypes. + pd_df = session.read_csv( + path, + names=names, + usecols=usecols, + index_col=index_col, + dtype=bf_df.reset_index().dtypes.to_dict(), + ) + + assert bf_df.shape == pd_df.shape + assert bf_df.columns.tolist() == pd_df.columns.tolist() + + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas()) + + +def test_read_csv_for_names_less_than_columns_and_same_usecols( + session, df_and_gcs_csv_for_two_columns +): + _, path = df_and_gcs_csv_for_two_columns + names = ["a", "c"] + usecols = ["a", "c"] + bf_df = session.read_csv(path, engine="bigquery", names=names, usecols=usecols) + + # Convert default pandas dtypes to match BigQuery DataFrames dtypes. + pd_df = session.read_csv( + path, names=names, usecols=usecols, dtype=bf_df.dtypes.to_dict() + ) + + assert bf_df.shape == pd_df.shape + assert bf_df.columns.tolist() == pd_df.columns.tolist() + + # BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs + # (b/280889935) or guarantee row ordering. + bf_df = bf_df.set_index(names[0]).sort_index() + pd_df = pd_df.set_index(names[0]) + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas()) + + +def test_read_csv_for_names_less_than_columns_and_mismatched_usecols( + session, df_and_gcs_csv_for_two_columns +): + _, path = df_and_gcs_csv_for_two_columns + names = ["a", "b"] + usecols = ["a"] + with pytest.raises( + ValueError, + match=re.escape("Number of passed names did not match number"), + ): + session.read_csv(path, engine="bigquery", names=names, usecols=usecols) + + +def test_read_csv_for_names_less_than_columns_and_different_usecols( + session, df_and_gcs_csv_for_two_columns +): + _, path = df_and_gcs_csv_for_two_columns + names = ["a", "b"] + usecols = ["a", "c"] + with pytest.raises( + ValueError, + match=re.escape("Usecols do not match columns"), + ): + session.read_csv(path, engine="bigquery", names=names, usecols=usecols) + + def test_read_csv_for_dtype(session, df_and_gcs_csv_for_two_columns): _, path = df_and_gcs_csv_for_two_columns