diff --git a/bigframes/ml/__init__.py b/bigframes/ml/__init__.py index 55c8709d8d..b2c62ff961 100644 --- a/bigframes/ml/__init__.py +++ b/bigframes/ml/__init__.py @@ -26,4 +26,5 @@ "llm", "forecasting", "imported", + "remote", ] diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index d8135f7085..5aad77a394 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -294,6 +294,8 @@ def create_remote_model( self, session: bigframes.Session, connection_name: str, + input: Mapping[str, str] = {}, + output: Mapping[str, str] = {}, options: Mapping[str, Union[str, int, float, Iterable[str]]] = {}, ) -> BqmlModel: """Create a session-temporary BQML remote model with the CREATE OR REPLACE MODEL statement @@ -301,6 +303,10 @@ def create_remote_model( Args: connection_name: a BQ connection to talk with Vertex AI, of the format ... https://cloud.google.com/bigquery/docs/create-cloud-resource-connection + input: + input schema for general remote models + output: + output schema for general remote models options: a dict of options to configure the model. Generates a BQML OPTIONS clause @@ -311,6 +317,8 @@ def create_remote_model( sql = self._model_creation_sql_generator.create_remote_model( connection_name=connection_name, model_ref=model_ref, + input=input, + output=output, options=options, ) diff --git a/bigframes/ml/remote.py b/bigframes/ml/remote.py new file mode 100644 index 0000000000..d4c34bbd0d --- /dev/null +++ b/bigframes/ml/remote.py @@ -0,0 +1,157 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BigFrames general remote models.""" + +from __future__ import annotations + +from typing import Mapping, Optional, Union +import warnings + +import bigframes +from bigframes import clients +from bigframes.core import log_adapter +from bigframes.ml import base, core, globals, utils +import bigframes.pandas as bpd + +_SUPPORTED_DTYPES = ( + "bool", + "string", + "int64", + "float64", + "array", + "array", + "array", + "array", +) + +_REMOTE_MODEL_STATUS = "remote_model_status" + + +@log_adapter.class_logger +class VertexAIModel(base.BaseEstimator): + """Remote model from a Vertex AI https endpoint. User must specify https endpoint, input schema and output schema. + How to deploy a model in Vertex AI https://cloud.google.com/bigquery/docs/bigquery-ml-remote-model-tutorial#Deploy-Model-on-Vertex-AI. + + Args: + endpoint (str): + Vertex AI https endpoint. + input ({column_name: column_type}): + Input schema. Supported types are "bool", "string", "int64", "float64", "array", "array", "array", "array". + output ({column_name: column_type}): + Output label schema. Supported the same types as the input. + session (bigframes.Session or None): + BQ session to create the model. If None, use the global default session. + connection_name (str or None): + Connection to connect with remote service. str of the format ... + if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach + permission if the connection isn't fully setup. + """ + + def __init__( + self, + endpoint: str, + input: Mapping[str, str], + output: Mapping[str, str], + session: Optional[bigframes.Session] = None, + connection_name: Optional[str] = None, + ): + self.endpoint = endpoint + self.input = input + self.output = output + self.session = session or bpd.get_global_session() + + self._bq_connection_manager = clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) + connection_name = connection_name or self.session._bq_connection + self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + connection_name, + default_project=self.session._project, + default_location=self.session._location, + ) + + self._bqml_model_factory = globals.bqml_model_factory() + self._bqml_model: core.BqmlModel = self._create_bqml_model() + + def _create_bqml_model(self): + # Parse and create connection if needed. + if not self.connection_name: + raise ValueError( + "Must provide connection_name, either in constructor or through session options." + ) + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", + ) + + options = { + "endpoint": self.endpoint, + } + + def standardize_type(v: str): + v = v.lower() + v = v.replace("boolean", "bool") + + if v not in _SUPPORTED_DTYPES: + raise ValueError( + f"Data type {v} is not supported. We only support {', '.join(_SUPPORTED_DTYPES)}." + ) + + return v + + self.input = {k: standardize_type(v) for k, v in self.input.items()} + self.output = {k: standardize_type(v) for k, v in self.output.items()} + + return self._bqml_model_factory.create_remote_model( + session=self.session, + connection_name=self.connection_name, + input=self.input, + output=self.output, + options=options, + ) + + def predict( + self, + X: Union[bpd.DataFrame, bpd.Series], + ) -> bpd.DataFrame: + """Predict the result from the input DataFrame. + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series): + Input DataFrame or Series, which needs to comply with the input parameter of the model. + + Returns: + bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. + """ + + (X,) = utils.convert_to_dataframe(X) + + df = self._bqml_model.predict(X) + + # unlike LLM models, the general remote model status is null for successful runs. + if (df[_REMOTE_MODEL_STATUS].notna()).any(): + warnings.warn( + f"Some predictions failed. Check column {_REMOTE_MODEL_STATUS} for detailed status. You may want to filter the failed rows and retry.", + RuntimeWarning, + ) + + return df diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index ab051231fb..1c88eda4ab 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -57,6 +57,12 @@ def build_expressions(self, *expr_sqls: str) -> str: indent_str = " " return "\n" + indent_str + f",\n{indent_str}".join(expr_sqls) + def build_schema(self, **kwargs: str) -> str: + """Encode a dict of values into a formatted schema type items for SQL""" + indent_str = " " + param_strs = [f"{k} {v}" for k, v in kwargs.items()] + return "\n" + indent_str + f",\n{indent_str}".join(param_strs) + def options(self, **kwargs: Union[str, int, float, Iterable[str]]) -> str: """Encode the OPTIONS clause for BQML""" return f"OPTIONS({self.build_parameters(**kwargs)})" @@ -65,6 +71,14 @@ def struct_options(self, **kwargs: Union[int, float]) -> str: """Encode a BQ STRUCT as options.""" return f"STRUCT({self.build_structs(**kwargs)})" + def input(self, **kwargs: str) -> str: + """Encode a BQML INPUT clause.""" + return f"INPUT({self.build_schema(**kwargs)})" + + def output(self, **kwargs: str) -> str: + """Encode a BQML OUTPUT clause.""" + return f"OUTPUT({self.build_schema(**kwargs)})" + # Connection def connection(self, conn_name: str) -> str: """Encode the REMOTE WITH CONNECTION clause for BQML. conn_name is of the format ...""" @@ -154,15 +168,19 @@ def create_remote_model( self, connection_name: str, model_ref: google.cloud.bigquery.ModelReference, + input: Mapping[str, str] = {}, + output: Mapping[str, str] = {}, options: Mapping[str, Union[str, int, float, Iterable[str]]] = {}, ) -> str: """Encode the CREATE OR REPLACE MODEL statement for BQML remote model.""" - options_sql = self.options(**options) - parts = [f"CREATE OR REPLACE MODEL {self._model_id_sql(model_ref)}"] + if input: + parts.append(self.input(**input)) + if output: + parts.append(self.output(**output)) parts.append(self.connection(connection_name)) - if options_sql: - parts.append(options_sql) + if options: + parts.append(self.options(**options)) return "\n".join(parts) def create_imported_model( diff --git a/docs/reference/bigframes.ml/index.rst b/docs/reference/bigframes.ml/index.rst index f3cbe1174a..1975d62e6d 100644 --- a/docs/reference/bigframes.ml/index.rst +++ b/docs/reference/bigframes.ml/index.rst @@ -30,3 +30,5 @@ API Reference pipeline preprocessing + + remote diff --git a/docs/reference/bigframes.ml/remote.rst b/docs/reference/bigframes.ml/remote.rst new file mode 100644 index 0000000000..7827acfe92 --- /dev/null +++ b/docs/reference/bigframes.ml/remote.rst @@ -0,0 +1,7 @@ +bigframes.ml.remote +=================== + +.. automodule:: bigframes.ml.remote + :members: + :inherited-members: + :undoc-members: diff --git a/docs/templates/toc.yml b/docs/templates/toc.yml index 9879721d28..58ac1c0efe 100644 --- a/docs/templates/toc.yml +++ b/docs/templates/toc.yml @@ -108,6 +108,12 @@ - name: PaLM2TextEmbeddingGenerator uid: bigframes.ml.llm.PaLM2TextEmbeddingGenerator name: llm + - items: + - name: Overview + uid: bigframes.ml.remote + - name: VertexAIModel + uid: bigframes.ml.remote.VertexAIModel + name: remote - items: - name: metrics uid: bigframes.ml.metrics diff --git a/tests/system/small/ml/conftest.py b/tests/system/small/ml/conftest.py index c11445b79a..c4a1272e44 100644 --- a/tests/system/small/ml/conftest.py +++ b/tests/system/small/ml/conftest.py @@ -29,6 +29,7 @@ imported, linear_model, llm, + remote, ) @@ -247,6 +248,46 @@ def palm2_embedding_generator_multilingual_model( ) +@pytest.fixture(scope="session") +def linear_remote_model_params() -> dict: + # Pre-deployed endpoint of linear reg model in Vertex. + # bigframes-test-linreg2 -> bigframes-test-linreg-endpoint2 + return { + "input": {"culmen_length_mm": "float64"}, + "output": {"predicted_body_mass_g": "array"}, + "endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/1084210331973/locations/us-central1/endpoints/3193318217619603456", + } + + +@pytest.fixture(scope="session") +def bqml_linear_remote_model( + session, bq_connection, linear_remote_model_params +) -> core.BqmlModel: + options = { + "endpoint": linear_remote_model_params["endpoint"], + } + return globals.bqml_model_factory().create_remote_model( + session=session, + input=linear_remote_model_params["input"], + output=linear_remote_model_params["output"], + connection_name=bq_connection, + options=options, + ) + + +@pytest.fixture(scope="session") +def linear_remote_vertex_model( + session, bq_connection, linear_remote_model_params +) -> remote.VertexAIModel: + return remote.VertexAIModel( + endpoint=linear_remote_model_params["endpoint"], + input=linear_remote_model_params["input"], + output=linear_remote_model_params["output"], + session=session, + connection_name=bq_connection, + ) + + @pytest.fixture(scope="session") def time_series_bqml_arima_plus_model( session, time_series_arima_plus_model_name diff --git a/tests/system/small/ml/test_core.py b/tests/system/small/ml/test_core.py index be34a4871c..22cbbb1932 100644 --- a/tests/system/small/ml/test_core.py +++ b/tests/system/small/ml/test_core.py @@ -289,6 +289,22 @@ def test_model_predict_with_unnamed_index( ) +def test_remote_model_predict( + bqml_linear_remote_model: core.BqmlModel, new_penguins_df +): + predictions = bqml_linear_remote_model.predict(new_penguins_df).to_pandas() + expected = pd.DataFrame( + {"predicted_body_mass_g": [[3739.54], [3675.79], [3619.54]]}, + index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"), + ) + pd.testing.assert_frame_equal( + predictions[["predicted_body_mass_g"]].sort_index(), + expected, + check_exact=False, + rtol=0.1, + ) + + @pytest.mark.flaky(retries=2, delay=120) def test_model_generate_text( bqml_palm2_text_generator_model: core.BqmlModel, llm_text_df diff --git a/tests/system/small/ml/test_remote.py b/tests/system/small/ml/test_remote.py new file mode 100644 index 0000000000..e8eb1c85e8 --- /dev/null +++ b/tests/system/small/ml/test_remote.py @@ -0,0 +1,33 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from bigframes.ml import remote + + +def test_remote_linear_vertex_model_predict( + linear_remote_vertex_model: remote.VertexAIModel, new_penguins_df +): + predictions = linear_remote_vertex_model.predict(new_penguins_df).to_pandas() + expected = pd.DataFrame( + {"predicted_body_mass_g": [[3739.54], [3675.79], [3619.54]]}, + index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"), + ) + pd.testing.assert_frame_equal( + predictions[["predicted_body_mass_g"]].sort_index(), + expected, + check_exact=False, + rtol=0.1, + ) diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index ea16722393..9223058540 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -190,6 +190,32 @@ def test_create_remote_model_produces_correct_sql( ) +def test_create_remote_model_with_params_produces_correct_sql( + model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator, +): + sql = model_creation_sql_generator.create_remote_model( + connection_name="my_project.us.my_connection", + model_ref=bigquery.ModelReference.from_string( + "test-proj._anonXYZ.create_remote_model" + ), + input={"column1": "int64"}, + output={"result": "array"}, + options={"option_key1": "option_value1", "option_key2": 2}, + ) + assert ( + sql + == """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_remote_model` +INPUT( + column1 int64) +OUTPUT( + result array) +REMOTE WITH CONNECTION `my_project.us.my_connection` +OPTIONS( + option_key1="option_value1", + option_key2=2)""" + ) + + def test_create_imported_model_produces_correct_sql( model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator, ):