Skip to content

Commit

Permalink
fix doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst committed Nov 21, 2024
1 parent 783e6ca commit bcd0ba6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4380,8 +4380,9 @@ def to_tf(
If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns.
A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``.
>>> ds = ds.add_column("sample weights", lambda df: 1)
>>> import pandas as pd
>>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df)))
>>> ds.to_tf(feature_columns="features", label_columns="target", additional_columns="sample weights")
<_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))>
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def to_tf(
A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``.
>>> import pandas as pd
>>> ds = ds.add_column("sample weights", lambda df: pd.DataFrame([1] * len(df)))
>>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df)))
>>> it = ds.iterator()
>>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target", additional_columns="sample weights")
<_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))>
Expand Down

0 comments on commit bcd0ba6

Please sign in to comment.