Runs training (and validation) of a model on task(s) with the given data.
runner.run(
*,
train_ds_provider: DatasetProvider,
model_fn: Callable[[GraphTensorSpec], tf.keras.Model],
optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
trainer: Trainer,
task: OneOrMappingOf[Task],
loss_weights: Optional[Mapping[str, float]] = None,
gtspec: GraphTensorSpec,
global_batch_size: int,
epochs: int = 1,
drop_remainder: bool = False,
export_dirs: Optional[Sequence[str]] = None,
model_exporters: Optional[Sequence[ModelExporter]] = None,
feature_processors: Optional[Sequence[GraphTensorProcessorFn]] = None,
valid_ds_provider: Optional[DatasetProvider] = None,
train_padding: Optional[GraphTensorPadding] = None,
valid_padding: Optional[GraphTensorPadding] = None,
tf_data_service_config: Optional[TFDataServiceConfig] = None,
steps_per_execution: Optional[int] = None,
run_eagerly: bool = False
)
This includes preprocessing the input data, appending any suitable head(s), and running training (and validation) with the requested distribution strategy.
The input data is processed in multiple stages, starting from the contents of
the datasets provided by train_ds_provider
and valid_ds_provider
:
- Input examples are batched.
- If necessary, input batches are parsed as
GraphTensor
values and merged into components (see:GraphTensor.merge_batch_to_components
). - If set,
train_padding
andvalid_padding
, resp., are applied. - The given
feature_processors
are applied in order for all non-trainable feature transformations on CPU (as part oftf.data.Dataset.map(...)
). - The
Task.preprocess(...)
method is applied to extract training targets (for supervised learning, that means: labels) and optionally transform the value of the preprocessedGraphTensor
into a model input (or multiple model inputs for tasks like self-supervised contrastive losses). - If the resulting
GraphTensor
s have any auxiliary pieces (as indicated bytfgnn.get_aux_type_prefix(...)
): all features (typically: labels) are removed from those graph pieces.
The base GNN (as built by model_fn
) is run on all results from step (6).
Task.predict(...)
is called
on the model outputs that correspond to the one or more graphs requested in step
(5) by
Task.preprocess(...)
.
Trainable transformations of inputs (notably lookups in trainable embedding
tables) are required to happen inside model_fn
.
For supervised learning, training labels enter the pipeline as features on the
GraphTensor
that undergo the feature_processors
(shared by all Task
s) and
are read out of the GraphTensor
by
Task.preprocess(...)
.
Users are strongly encouraged to take one of the following two approaches to prevent the leakage of label information into the training:
- Store labels on the auxiliary
"_readout"
node set and letTask.preprocess(...)
read them from there. (For library-suppliedTask
s, that means initializing withlabel_feature_name="..."
.) If that is not already true for the input datasets, the label feature can be moved there by one of thefeature_processors
, usingtfgnn.structured_readout_into_feature(...)
or a similar helper function. - For single-Task training only: Let
Task.preprocess()
return modifiedGraphTensor
s that no longer contain the separately returned labels. (Library-supplied Tasks delegate this to thelabel_fn="..."
passed in initialization.)
train_ds_provider
|
A DatasetProvider for training. The tf.data.Dataset
is not batched and contains scalar GraphTensor values conforming to
gtspec , possibly serialized as a tf.train.Example proto.
|
model_fn
|
Returns the base GNN tf.keras.Model for use in training and
validation.
|
optimizer_fn
|
Returns a tf.keras.optimizers.Optimizer for use in training.
|
trainer
|
A Trainer .
|
task
|
A Task for single-Task training or a Mapping[str, Task] for
multi-Task training. In multi-Task training, Task.preprocess(...)
must return GraphTensors with the same spec as its inputs, only the
values may change (so that there remains a single spec for model_fn ).
|
loss_weights
|
An optional Mapping[str, float] for multi-Task training. If
given, this structure must match (with tf.nest.assert_same_structure )
the structure of task . The mapping contains, for each task , a scalar
coefficient to weight the loss contributions of that task .
|
gtspec
|
A GraphTensorSpec matching the elements of train and valid
datasets. If train or valid contain tf.string elements, this
GraphTensorSpec is used for parsing; otherwise, train or valid are
expected to contain GraphTensor elements whose relaxed spec matches
gtspec .
|
global_batch_size
|
The tf.data.Dataset global batch size for both training
and validation.
|
epochs
|
The epochs to train. |
drop_remainder
|
Whether to drop a tf.data.Dataset remainder at batching.
|
export_dirs
|
Optional directories for exports (SavedModels); if unset,
default behavior is os.path.join(model_dir, "export") .
|
model_exporters
|
Zero or more ModelExporter for exporting (SavedModels) to
export_dirs . If unset, default behavior is [KerasModelExporter()] .
|
feature_processors
|
A sequence of callables for feature processing with the
Keras functional API. Each callable must accept and return a symbolic
scalar GraphTensor . The callables are composed in order and may change
the GraphTensorSpec (e.g., add/remove features). The resulting Keras
model is executed on CPU as part of a tf.data.Dataset.map operation.
|
valid_ds_provider
|
A DatasetProvider for validation. The tf.data.Dataset
is not batched and contains scalar GraphTensor values conforming to
gtspec , possibly serialized as a tf.train.Example proto.
|
train_padding
|
GraphTensor padding for training. Required if training on
TPU.
|
valid_padding
|
GraphTensor padding for validation. Required if training on
TPU.
|
tf_data_service_config
|
tf.data service speeds-up tf.data input pipeline runtime reducing input bottlenecks for model training. Particularly for training on accelerators consider enabling it. For more info please see: https://www.tensorflow.org/api_docs/python/tf/data/experimental/service. |
steps_per_execution
|
The number of batches to run during each training
iteration. If not set, for TPU strategy default to 100 and to None
otherwise.
|
run_eagerly
|
Whether to compile the model in eager mode, primarily for
debugging purposes. Note that the symbolic model will still be run twice,
so if you use a breakpoint() you will have to Continue twice before you
are in a real eager execution.
|
A RunResult object containing models and information about this run.
|