-
Notifications
You must be signed in to change notification settings - Fork 313
/
Copy pathfinetune_youtube_last.py
36 lines (29 loc) · 1.06 KB
/
finetune_youtube_last.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Finetuning example.
Trains the DeepMoji model on the SS-Youtube dataset, using the 'last'
finetuning method and the accuracy metric.
The 'last' method does the following:
0) Load all weights except for the softmax layer. Do not add tokens to the
vocabulary and do not extend the embedding layer.
1) Freeze all layers except for the softmax layer.
2) Train.
"""
from __future__ import print_function
import example_helper
import json
from deepmoji.model_def import deepmoji_transfer
from deepmoji.global_variables import PRETRAINED_PATH
from deepmoji.finetuning import (
load_benchmark,
finetune)
DATASET_PATH = '../data/SS-Youtube/raw.pickle'
nb_classes = 2
with open('../model/vocabulary.json', 'r') as f:
vocab = json.load(f)
# Load dataset.
data = load_benchmark(DATASET_PATH, vocab)
# Set up model and finetune
model = deepmoji_transfer(nb_classes, data['maxlen'], PRETRAINED_PATH)
model.summary()
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
data['batch_size'], method='last')
print('Acc: {}'.format(acc))