From d339b9a28d0ad6336884ae1dabaef6609465c32a Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Thu, 29 May 2025 21:22:12 +0000 Subject: [PATCH 1/2] feat: implement ai.classify() --- bigframes/operations/ai.py | 93 +++++++++++++++++++++++- tests/system/large/operations/test_ai.py | 27 +++++++ tests/system/small/operations/test_ai.py | 52 +++++++++++++ 3 files changed, 171 insertions(+), 1 deletion(-) diff --git a/bigframes/operations/ai.py b/bigframes/operations/ai.py index c65947f53f..1ff2694299 100644 --- a/bigframes/operations/ai.py +++ b/bigframes/operations/ai.py @@ -16,7 +16,7 @@ import re import typing -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Sequence import warnings import numpy as np @@ -258,6 +258,97 @@ def extract_logprob(s: bigframes.series.Series) -> bigframes.series.Series: return concat([self._df, *attach_columns], axis=1) + def classify( + self, + instruction: str, + model, + labels: Sequence[str], + output_column: str = "result", + ground_with_google_search: bool = False, + attach_logprobs=False, + ): + """ + Classifies the rows of dataframes based on user instruction into the provided labels. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> bpd.options.display.progress_bar = None + >>> bpd.options.experiments.ai_operators = True + >>> bpd.options.compute.ai_ops_confirmation_threshold = 25 + + >>> import bigframes.ml.llm as llm + >>> model = llm.GeminiTextGenerator(model_name="gemini-2.0-flash-001") + + >>> df = bpd.DataFrame({ + ... "feedback_text": [ + ... "The product is amazing, but the shipping was slow.", + ... "I had an issue with my recent bill.", + ... "The user interface is very intuitive." + ... ], + ... }) + >>> df.ai.classify("{feedback_text}", model=model, labels=["Shipping", "Billing", "UI"]) + feedback_text result + 0 The product is amazing, but the shipping was s... Shipping + 1 I had an issue with my recent bill. Billing + 2 The user interface is very intuitive. UI + + [3 rows x 2 columns] + + Args: + instruction (str): + An instruction on how to classify the data. This value must contain + column references by name, which should be wrapped in a pair of braces. + For example, if you have a column "feedback", you can refer to this column + with"{food}". + + model (bigframes.ml.llm.GeminiTextGenerator): + A GeminiTextGenerator provided by Bigframes ML package. + + labels (Sequence[str]): + A collection of labels (categories). It must contain at least two and at most 20 elements. + + output_column (str, default "result"): + The name of column for the output. + + ground_with_google_search (bool, default False): + Enables Grounding with Google Search for the GeminiTextGenerator model. + When set to True, the model incorporates relevant information from Google + Search results into its responses, enhancing their accuracy and factualness. + Note: Using this feature may impact billing costs. Refer to the pricing + page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models + The default is `False`. + + attach_logprobs (bool, default False): + Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level + of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0. + + + Returns: + bigframes.pandas.DataFrame: DataFrame with classification result. + + Raises: + NotImplementedError: when the AI operator experiment is off. + ValueError: when the instruction refers to a non-existing column, when no + columns are referred to, or when the count of labels does not meet the + requirement. + """ + + if len(labels) < 2 or len(labels) > 20: + raise ValueError( + f"The number of labels should be between 2 and 20 (inclusive), but {len(labels)} labels are provided." + ) + + updated_instruction = f"Based on the user instruction {instruction}, you must provide an answer that must exist in the following list of labels: {labels}" + + return self.map( + updated_instruction, + model, + output_schema={output_column: "string"}, + ground_with_google_search=ground_with_google_search, + attach_logprobs=attach_logprobs, + ) + def join( self, other, diff --git a/tests/system/large/operations/test_ai.py b/tests/system/large/operations/test_ai.py index 1b1d3a3376..c0716220b1 100644 --- a/tests/system/large/operations/test_ai.py +++ b/tests/system/large/operations/test_ai.py @@ -398,6 +398,33 @@ def test_map_invalid_model_raise_error(): ) +def test_classify(gemini_flash_model, session): + df = dataframe.DataFrame(data={"creature": ["dog", "rose"]}, session=session) + + with bigframes.option_context( + AI_OP_EXP_OPTION, + True, + THRESHOLD_OPTION, + 10, + ): + actual_result = df.ai.classify( + "{creature}", + gemini_flash_model, + labels=["animal", "plant"], + output_column="result", + ).to_pandas() + + expected_result = pd.DataFrame( + { + "creature": ["dog", "rose"], + "result": ["animal", "plant"], + } + ) + pandas.testing.assert_frame_equal( + actual_result, expected_result, check_index_type=False, check_dtype=False + ) + + @pytest.mark.parametrize( "instruction", [ diff --git a/tests/system/small/operations/test_ai.py b/tests/system/small/operations/test_ai.py index 25d411bef8..b6927dbf8c 100644 --- a/tests/system/small/operations/test_ai.py +++ b/tests/system/small/operations/test_ai.py @@ -108,6 +108,58 @@ def test_map(session): ) +def test_classify(session): + df = dataframe.DataFrame({"col": ["A", "B"]}, session=session) + model = FakeGeminiTextGenerator( + dataframe.DataFrame( + { + "result": ["A", "B"], + "full_response": _create_dummy_full_response(2), + }, + session=session, + ), + ) + + with bigframes.option_context( + AI_OP_EXP_OPTION, + True, + THRESHOLD_OPTION, + 50, + ): + result = df.ai.classify( + "classify {col}", model=model, labels=["A", "B"] + ).to_pandas() + + pandas.testing.assert_frame_equal( + result, + pd.DataFrame( + {"col": ["A", "B"], "result": ["A", "B"]}, dtype=dtypes.STRING_DTYPE + ), + check_index_type=False, + ) + + +def test_classify_invalid_labels_raise_error(session): + df = dataframe.DataFrame({"col": ["A", "B"]}, session=session) + model = FakeGeminiTextGenerator( + dataframe.DataFrame( + { + "result": ["A", "B"], + "full_response": _create_dummy_full_response(2), + }, + session=session, + ), + ) + + with bigframes.option_context( + AI_OP_EXP_OPTION, + True, + THRESHOLD_OPTION, + 50, + ), pytest.raises(ValueError): + df.ai.classify("classify {col}", model=model, labels=[]) + + def test_join(session): left_df = dataframe.DataFrame({"col_A": ["A"]}, session=session) right_df = dataframe.DataFrame({"col_B": ["B"]}, session=session) From 6a5caec9862a5b9c425015d536e4c21834566e1b Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Mon, 2 Jun 2025 20:57:26 +0000 Subject: [PATCH 2/2] check label duplicity --- bigframes/operations/ai.py | 4 ++++ tests/system/small/operations/test_ai.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/bigframes/operations/ai.py b/bigframes/operations/ai.py index 1ff2694299..87245d104e 100644 --- a/bigframes/operations/ai.py +++ b/bigframes/operations/ai.py @@ -307,6 +307,7 @@ def classify( labels (Sequence[str]): A collection of labels (categories). It must contain at least two and at most 20 elements. + Labels are case sensitive. Duplicated labels are not allowed. output_column (str, default "result"): The name of column for the output. @@ -339,6 +340,9 @@ def classify( f"The number of labels should be between 2 and 20 (inclusive), but {len(labels)} labels are provided." ) + if len(set(labels)) != len(labels): + raise ValueError("There are duplicate labels.") + updated_instruction = f"Based on the user instruction {instruction}, you must provide an answer that must exist in the following list of labels: {labels}" return self.map( diff --git a/tests/system/small/operations/test_ai.py b/tests/system/small/operations/test_ai.py index b6927dbf8c..83aca8b5b1 100644 --- a/tests/system/small/operations/test_ai.py +++ b/tests/system/small/operations/test_ai.py @@ -139,7 +139,14 @@ def test_classify(session): ) -def test_classify_invalid_labels_raise_error(session): +@pytest.mark.parametrize( + "labels", + [ + pytest.param([], id="empty-label"), + pytest.param(["A", "A", "B"], id="duplicate-labels"), + ], +) +def test_classify_invalid_labels_raise_error(session, labels): df = dataframe.DataFrame({"col": ["A", "B"]}, session=session) model = FakeGeminiTextGenerator( dataframe.DataFrame( @@ -157,7 +164,7 @@ def test_classify_invalid_labels_raise_error(session): THRESHOLD_OPTION, 50, ), pytest.raises(ValueError): - df.ai.classify("classify {col}", model=model, labels=[]) + df.ai.classify("classify {col}", model=model, labels=labels) def test_join(session):