Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 298817372
  • Loading branch information
tensorflower-gardener committed Mar 4, 2020
1 parent 75d1304 commit 238321c
Showing 1 changed file with 61 additions and 9 deletions.
70 changes: 61 additions & 9 deletions official/nlp/bert/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,74 @@ def run_keras_compile_fit(model_dir,
return bert_model


def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
eval_steps):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation steps.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""

@tf.function
def test_step(iterator):
"""Computes predictions on distributed devices."""

def _test_step_fn(inputs):
"""Replicated predictions."""
inputs, labels = inputs
model_outputs = trained_model(inputs, training=False)
return model_outputs, labels

outputs, labels = strategy.experimental_run_v2(
_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
return outputs, labels

def _run_evaluation(test_iterator):
"""Runs evaluation steps."""
preds, golds = list(), list()
for _ in range(eval_steps):
logits, labels = test_step(test_iterator)
for cur_logits, cur_labels in zip(logits, labels):
preds.extend(tf.math.argmax(cur_logits, axis=1).numpy())
golds.extend(cur_labels.numpy().tolist())
return preds, golds

test_iter = iter(
strategy.experimental_distribute_datasets_from_function(eval_input_fn))
predictions, labels = _run_evaluation(test_iter)

return predictions, labels


def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights,
bert_config, model_dir):
restore_model_using_load_weights, bert_config, model_dir):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
There are 2 different ways to save checkpoints. One is using
tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used together, model loading logic
must be take into account how model checkpoint was saved.
for custom checkpoint or to use model.load_weights() API. There are 2
different ways to save checkpoints. One is using tf.train.Checkpoint and
another is using Keras model.save_weights(). Custom training loop
implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
callback internally uses model.save_weights() API. Since these two API's
cannot be used together, model loading logic must be take into account how
model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
Expand Down

0 comments on commit 238321c

Please sign in to comment.