From a1e8509d38ae1a72a2d45334ef687b8307e596ea Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Thu, 7 Dec 2023 01:48:09 +0000 Subject: [PATCH] fix: ml.sql logic --- bigframes/ml/sql.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 1c88eda4ab..5fb40624dd 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -153,14 +153,12 @@ def create_model( ) -> str: """Encode the CREATE OR REPLACE MODEL statement for BQML""" source_sql = source_df.sql - transform_sql = self.transform(*transforms) if transforms is not None else None - options_sql = self.options(**options) parts = [f"CREATE OR REPLACE MODEL {self._model_id_sql(model_ref)}"] - if transform_sql: - parts.append(transform_sql) - if options_sql: - parts.append(options_sql) + if transforms: + parts.append(self.transform(*transforms)) + if options: + parts.append(self.options(**options)) parts.append(f"AS {source_sql}") return "\n".join(parts) @@ -189,11 +187,10 @@ def create_imported_model( options: Mapping[str, Union[str, int, float, Iterable[str]]] = {}, ) -> str: """Encode the CREATE OR REPLACE MODEL statement for BQML remote model.""" - options_sql = self.options(**options) parts = [f"CREATE OR REPLACE MODEL {self._model_id_sql(model_ref)}"] - if options_sql: - parts.append(options_sql) + if options: + parts.append(self.options(**options)) return "\n".join(parts)