From 45800292de64f5e6a009399604e2761167bb3365 Mon Sep 17 00:00:00 2001 From: Emile Anclin Date: Fri, 27 May 2016 14:56:38 +0200 Subject: [PATCH] strongly improved network stripping --- ocrolib/lstm.py | 23 +++++++++++++++++++++-- ocropus-rtrain | 3 ++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/ocrolib/lstm.py b/ocrolib/lstm.py index 8380b994..cfc0d874 100644 --- a/ocrolib/lstm.py +++ b/ocrolib/lstm.py @@ -292,6 +292,14 @@ def __init__(self,Nh,No,initial_range=initial_range,rand=rand): self.No = No self.W2 = randu(No,Nh+1)*initial_range self.DW2 = zeros((No,Nh+1)) + def postLoad(self): + self.DW2 = zeros(self.W2.shape) + + def preSave(self): + for var in ('state', 'DW2'): + if hasattr(self, var): + delattr(self, var) + def ninputs(self): return self.Nh def noutputs(self): @@ -596,6 +604,12 @@ def walk(self): yield self for sub in self.nets: for x in sub.walk(): yield x + def preSave(self): + self.dstats = defaultdict(list) # reset + for delta in ('deltas', 'ldeltas'): + if hasattr(self, delta): + delattr(self, delta) + def ninputs(self): return self.nets[0].ninputs() def noutputs(self): @@ -859,11 +873,16 @@ def __init__(self,ninput,nstates,noutput=-1,codec=None,normalize=normalize_nfkc) self.clear_log() def walk(self): for x in self.lstm.walk(): yield x - def clear_log(self): + def clear_log(self, deallocate_tempvars=False): self.command_log = [] self.error_log = [] self.cerror_log = [] self.key_log = [] + if deallocate_tempvars: + for attrname in ('outputs', 'targets', 'aligned'): + if hasattr(self, attrname): + delattr(self, attrname) + def __setstate__(self,state): self.__dict__.update(state) self.upgrade() @@ -874,7 +893,7 @@ def upgrade(self): if "cerror_log" not in dir(self): self.cerror_log = [] if "key_log" not in dir(self): self.key_log = [] def info(self): - self.net.info() + self.lstm.info() def setLearningRate(self,r,momentum=0.9): self.lstm.setLearningRate(r,momentum) def predictSequence(self,xs): diff --git a/ocropus-rtrain b/ocropus-rtrain index f84bc41e..e3ae2491 100755 --- a/ocropus-rtrain +++ b/ocropus-rtrain @@ -156,7 +156,8 @@ def save_lstm(fname,network): network.lstm.save(fname) else: if args.strip: - network.clear_log() + print yellow('saving stripped network (without temporary variables)...') + network.clear_log(deallocate_tempvars=True) for x in network.walk(): x.preSave() ocrolib.save_object(fname,network) if args.strip: