feat: allow multiple columns input for llm models#998
Conversation
bigframes/ml/llm.py
Outdated
| Args: | ||
| X (bigframes.dataframe.DataFrame or bigframes.series.Series): | ||
| Input DataFrame or Series, which contains only one column of prompts. | ||
| Input DataFrame or Series, can contain one or more columns. If multiple columns in the DataFrame, it must contain a "prompt" column for prediction. |
There was a problem hiding this comment.
nit: "If multiple columns are in the DataFrame, they must ..." and for other docs too
There was a problem hiding this comment.
"it" refers to the DataFrame. Can add "are" in "If multiple columns are in the DataFrame"
tests/system/small/ml/test_llm.py
Outdated
| assert "text_embedding" in df.columns | ||
| series = df["text_embedding"] | ||
| value = series[0] | ||
| assert len(value) == 768 |
There was a problem hiding this comment.
nit: maybe we could coalesce line 323 - 325 into a single line?
assert len(df[..][0]) == 768
There was a problem hiding this comment.
sure, actually I'll rewrite the tests. Also some are already removed in a recent PR.
| # BQML identified the column by name | ||
| col_label = cast(blocks.Label, X.columns[0]) | ||
| X = X.rename(columns={col_label: "prompt"}) | ||
| if len(X.columns) == 1: |
There was a problem hiding this comment.
I think we should make another check in the else clause - that the multi-column input does have a "prompt" column. Also add negative test for that scenario
There was a problem hiding this comment.
@tswast had a suggestion that we shouldn't do much client side checks. I'm trying to follow: if the error message is meaningful to the user, then rely on server side checks. Otherwise we have to wrap server error messages or return client side error messages.
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
Fixes #<issue_number_goes_here> 🦕