From c2dd7dbccd58fb262c9146d52b5d8693e815a88a Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Mon, 22 Apr 2024 15:58:27 +0000 Subject: [PATCH] fix: llm fine tuning tests --- tests/system/load/test_llm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 62ef7d5c72..d56f6100c1 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -45,9 +45,12 @@ def llm_remote_text_pandas_df(): ) -def test_llm_palm_configure_fit( - llm_fine_tune_df_default_index, llm_remote_text_pandas_df -): +@pytest.fixture(scope="session") +def llm_remote_text_df(session, llm_remote_text_pandas_df): + return session.read_pandas(llm_remote_text_pandas_df) + + +def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_df): model = bigframes.ml.llm.PaLM2TextGenerator( model_name="text-bison", max_iterations=1 ) @@ -59,7 +62,7 @@ def test_llm_palm_configure_fit( assert model is not None - df = model.predict(llm_remote_text_pandas_df).to_pandas() + df = model.predict(llm_remote_text_df["prompt"]).to_pandas() assert df.shape == (3, 4) assert "ml_generate_text_llm_result" in df.columns series = df["ml_generate_text_llm_result"]