Skip to content

Commit

Permalink
started making training pool #20
Browse files Browse the repository at this point in the history
made ShakespeareGenerator
  • Loading branch information
Francesco215 committed May 15, 2023
1 parent 6422ce2 commit 9540e95
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/training_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from numpy.random import randint
import random
import torch
from torch.utils.data import Dataset

import numpy as np
from .encoder import NoiseEncoder

from typing import Any, Iterable, List, Tuple, Callable


class ShakespeareGenerator:

def __init__(self, input_file_path, lenght, encoder):
self.input_file_path=input_file_path
self.lenght=int(lenght)

assert isinstance(encoder,NoiseEncoder), f"The encoder must be of type NoiseEncoder, got {type(encoder)} instead"
self.encoder=encoder

with open(input_file_path, 'r') as f:
data = f.read()

#this splits the dataset into train and validation
n = len(data)
self.train_data = data[:int(n*0.9)]
self.val_data = data[int(n*0.9):]

def __call__(self):
target=self.sample_text()

noise = torch.rand(())

return self.encoder(target, noise)

def sample_text(self,train=True):
data = self.train_data if train else self.val_data

starting_index=randint(0,len(data)-self.lenght)

if starting_index+self.lenght >= len(data):
return data[starting_index:]

return data[starting_index:starting_index+self.lenght]

0 comments on commit 9540e95

Please sign in to comment.