Skip to content

Commit

Permalink
Apply patches and build new release (#53)
Browse files Browse the repository at this point in the history
* Resolves #51: Check empty sample frame and raise SampleEmptyError

Signed-off-by: Aivin V. Solatorio <[email protected]>

* Resolves #52: Remove FrozenSeq2SeqTrainer since Seq2SeqTrainer is already fixed

Signed-off-by: Aivin V. Solatorio <[email protected]>

* Bump version for release v0.1.5

Signed-off-by: Aivin V. Solatorio <[email protected]>

---------

Signed-off-by: Aivin V. Solatorio <[email protected]>
  • Loading branch information
avsolatorio authored Nov 20, 2023
1 parent 4f3fa54 commit c8c56d8
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 96 deletions.
2 changes: 1 addition & 1 deletion src/realtabformer/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.4
0.1.5
7 changes: 4 additions & 3 deletions src/realtabformer/realtabformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from sklearn.metrics.pairwise import manhattan_distances

# from sklearn.metrics import accuracy_score
from transformers import ( # Seq2SeqTrainer,
from transformers import (
EarlyStoppingCallback,
EncoderDecoderConfig,
EncoderDecoderModel,
PreTrainedModel,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
Trainer,
TrainingArguments,
Expand All @@ -47,7 +48,7 @@
from .rtf_datacollator import RelationalDataCollator
from .rtf_exceptions import SampleEmptyLimitError
from .rtf_sampler import RelationalSampler, TabularSampler
from .rtf_trainer import FrozenSeq2SeqTrainer, ResumableTrainer
from .rtf_trainer import ResumableTrainer
from .rtf_validators import ObservationValidator


Expand Down Expand Up @@ -1024,7 +1025,7 @@ def _fit_relational(
]

# instantiate trainer
trainer = FrozenSeq2SeqTrainer(
trainer = Seq2SeqTrainer(
model=self.model,
args=Seq2SeqTrainingArguments(**training_args_kwargs),
callbacks=callbacks,
Expand Down
5 changes: 5 additions & 0 deletions src/realtabformer/rtf_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,11 @@ def _decode_tokens(s):
samples, columns=self.processed_columns, index=group_ids
)

# Initial check for an empty sample frame.
if synth_sample.empty:
# Handle this exception in the sampling function.
raise SampleEmptyError(in_size=len(sample_outputs))

# # Is this useful when we actually filter the vocabulary
# # during generation???
# # Let's try removing this for now and see what happens... XD
Expand Down
92 changes: 0 additions & 92 deletions src/realtabformer/rtf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,14 @@
EvalPrediction,
PreTrainedModel,
PreTrainedTokenizerBase,
Seq2SeqTrainer,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
logging,
)
from transformers.integrations import is_fairscale_available
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.trainer_utils import ShardedDDPOption
from transformers.utils import is_sagemaker_mp_enabled

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -133,89 +127,3 @@ def create_scheduler(
)

return self.lr_scheduler


class FrozenSeq2SeqTrainer(Seq2SeqTrainer):
"""This trainer excludes all parameters that have
`.requires_grad=False` set.
"""

def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
# Add here the `p.requires_grad` condition
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
# Add here the `p.requires_grad` condition
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args
)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
# Do the import here...
if is_fairscale_available():
from fairscale.optim import OSS

self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
print(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
logger.debug(
f"bitsandbytes: will optimize {module} in fp32"
)
print(f"skipped: {skipped/2**20}M params")

if is_sagemaker_mp_enabled():
# Do the import here...
import smdistributed.modelparallel.torch as smp

self.optimizer = smp.DistributedOptimizer(self.optimizer)

return self.optimizer

0 comments on commit c8c56d8

Please sign in to comment.