Skip to content

Latest commit

 

History

History
275 lines (256 loc) · 9.58 KB

File metadata and controls

275 lines (256 loc) · 9.58 KB

runner.run

View source on GitHub

Runs training (and validation) of a model on task(s) with the given data.

runner.run(
    *,
    train_ds_provider: DatasetProvider,
    model_fn: Callable[[GraphTensorSpec], tf.keras.Model],
    optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
    trainer: Trainer,
    task: OneOrMappingOf[Task],
    loss_weights: Optional[Mapping[str, float]] = None,
    gtspec: GraphTensorSpec,
    global_batch_size: int,
    epochs: int = 1,
    drop_remainder: bool = False,
    export_dirs: Optional[Sequence[str]] = None,
    model_exporters: Optional[Sequence[ModelExporter]] = None,
    feature_processors: Optional[Sequence[GraphTensorProcessorFn]] = None,
    valid_ds_provider: Optional[DatasetProvider] = None,
    train_padding: Optional[GraphTensorPadding] = None,
    valid_padding: Optional[GraphTensorPadding] = None,
    tf_data_service_config: Optional[TFDataServiceConfig] = None,
    steps_per_execution: Optional[int] = None,
    run_eagerly: bool = False
)

This includes preprocessing the input data, appending any suitable head(s), and running training (and validation) with the requested distribution strategy.

The input data is processed in multiple stages, starting from the contents of the datasets provided by train_ds_provider and valid_ds_provider:

  1. Input examples are batched.
  2. If necessary, input batches are parsed as GraphTensor values and merged into components (see: GraphTensor.merge_batch_to_components).
  3. If set, train_padding and valid_padding, resp., are applied.
  4. The given feature_processors are applied in order for all non-trainable feature transformations on CPU (as part of tf.data.Dataset.map(...)).
  5. The Task.preprocess(...) method is applied to extract training targets (for supervised learning, that means: labels) and optionally transform the value of the preprocessed GraphTensor into a model input (or multiple model inputs for tasks like self-supervised contrastive losses).
  6. If the resulting GraphTensors have any auxiliary pieces (as indicated by tfgnn.get_aux_type_prefix(...)): all features (typically: labels) are removed from those graph pieces.

The base GNN (as built by model_fn) is run on all results from step (6). Task.predict(...) is called on the model outputs that correspond to the one or more graphs requested in step (5) by Task.preprocess(...).

Trainable transformations of inputs (notably lookups in trainable embedding tables) are required to happen inside model_fn.

For supervised learning, training labels enter the pipeline as features on the GraphTensor that undergo the feature_processors (shared by all Tasks) and are read out of the GraphTensor by Task.preprocess(...).

Users are strongly encouraged to take one of the following two approaches to prevent the leakage of label information into the training:

  • Store labels on the auxiliary "_readout" node set and let Task.preprocess(...) read them from there. (For library-supplied Tasks, that means initializing with label_feature_name="...".) If that is not already true for the input datasets, the label feature can be moved there by one of the feature_processors, using tfgnn.structured_readout_into_feature(...) or a similar helper function.
  • For single-Task training only: Let Task.preprocess() return modified GraphTensors that no longer contain the separately returned labels. (Library-supplied Tasks delegate this to the label_fn="..." passed in initialization.)

Args

train_ds_provider A DatasetProvider for training. The tf.data.Dataset is not batched and contains scalar GraphTensor values conforming to gtspec, possibly serialized as a tf.train.Example proto.
model_fn Returns the base GNN tf.keras.Model for use in training and validation.
optimizer_fn Returns a tf.keras.optimizers.Optimizer for use in training.
trainer A Trainer.
task A Task for single-Task training or a Mapping[str, Task] for multi-Task training. In multi-Task training, Task.preprocess(...) must return GraphTensors with the same spec as its inputs, only the values may change (so that there remains a single spec for model_fn).
loss_weights An optional Mapping[str, float] for multi-Task training. If given, this structure must match (with tf.nest.assert_same_structure) the structure of task. The mapping contains, for each task, a scalar coefficient to weight the loss contributions of that task.
gtspec A GraphTensorSpec matching the elements of train and valid datasets. If train or valid contain tf.string elements, this GraphTensorSpec is used for parsing; otherwise, train or valid are expected to contain GraphTensor elements whose relaxed spec matches gtspec.
global_batch_size The tf.data.Dataset global batch size for both training and validation.
epochs The epochs to train.
drop_remainder Whether to drop a tf.data.Dataset remainder at batching.
export_dirs Optional directories for exports (SavedModels); if unset, default behavior is os.path.join(model_dir, "export").
model_exporters Zero or more ModelExporter for exporting (SavedModels) to export_dirs. If unset, default behavior is [KerasModelExporter()].
feature_processors A sequence of callables for feature processing with the Keras functional API. Each callable must accept and return a symbolic scalar GraphTensor. The callables are composed in order and may change the GraphTensorSpec (e.g., add/remove features). The resulting Keras model is executed on CPU as part of a tf.data.Dataset.map operation.
valid_ds_provider A DatasetProvider for validation. The tf.data.Dataset is not batched and contains scalar GraphTensor values conforming to gtspec, possibly serialized as a tf.train.Example proto.
train_padding GraphTensor padding for training. Required if training on TPU.
valid_padding GraphTensor padding for validation. Required if training on TPU.
tf_data_service_config tf.data service speeds-up tf.data input pipeline runtime reducing input bottlenecks for model training. Particularly for training on accelerators consider enabling it. For more info please see: https://www.tensorflow.org/api_docs/python/tf/data/experimental/service.
steps_per_execution The number of batches to run during each training iteration. If not set, for TPU strategy default to 100 and to None otherwise.
run_eagerly Whether to compile the model in eager mode, primarily for debugging purposes. Note that the symbolic model will still be run twice, so if you use a breakpoint() you will have to Continue twice before you are in a real eager execution.

Returns

A RunResult object containing models and information about this run.