Skip to content

Commit

Permalink
Minor Code Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wi-re committed Aug 8, 2024
1 parent 4143102 commit 53c13bf
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/BasisConvolution/util/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
parser.add_argument('-w','--windowFunction', type=str, default=argparse.SUPPRESS, help='Window function [default = poly6]')
parser.add_argument('-c','--cutoff', type=int, default=argparse.SUPPRESS, help='Cutoff distance [default = 1800]')
parser.add_argument('-b','--batch_size', type=int, default=argparse.SUPPRESS, help='Batch size [default = 1]')
parser.add_argument('-o','--output', type = str, default = '../../trainingData_TGV/randomFlows/', help='Output directory [default = ""]')
parser.add_argument('-o','--output', type = str, default = argparse.SUPPRESS, help='Output directory [default = ""]')
parser.add_argument('--cutlassBatchSize', type=int, default=argparse.SUPPRESS, help='Cutlass batch size [default = 512]')
parser.add_argument('--lr', type=float, default=argparse.SUPPRESS, help='Learning rate [default = 0.01]')
parser.add_argument('--finalLR', type=float, default=argparse.SUPPRESS, help='Final learning rate [default = 0.0001]')
Expand Down
2 changes: 1 addition & 1 deletion src/BasisConvolution/util/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def loadAugmentedFrame(index, dataset, hyperParameterDict, unrollLength = 8):
# print('gt')
cState = currentState
for state in trajectoryStates:
state['fluid']['target'] = getFeatures(hyperParameterDict['groundTruth'].split(' '), state, cState, 'fluid', config, state['time'] - cState['time'], verbose = False, includeOther = 'boundary' in currentState and currentState['boundary'] is not None,)
state['fluid']['target'] = getFeatures(hyperParameterDict['groundTruth'].split(' '), state, cState, 'fluid', config, state['time'] - cState['time'] if state['time'] != cState['time'] else 1, verbose = False, includeOther = 'boundary' in currentState and currentState['boundary'] is not None,)
cState = state

return config, attributes, augmentedStates[0], augmentedStates[1] if priorState is not None else None, augmentedStates[2:] if priorState is not None else augmentedStates[1:]
Expand Down
2 changes: 1 addition & 1 deletion src/BasisConvolution/util/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def getStyle(inFile):
def parseFile(inFile, hyperParameterDict):
frameDistance = hyperParameterDict['frameDistance'] if 'frameDistance' in hyperParameterDict else 1
frameSpacing = hyperParameterDict['dataDistance'] if 'dataDistance' in hyperParameterDict else 1
maxRollout = hyperParameterDict['maxRollOut'] if 'maxRollOut' in hyperParameterDict else 0
maxRollout = hyperParameterDict['maxUnroll'] if 'maxUnroll' in hyperParameterDict else 0

temporalData = isTemporalData(inFile)
skip = (1 if hyperParameterDict['zeroOffset'] and temporalData else 0) if 'zeroOffset' in hyperParameterDict else 0
Expand Down
11 changes: 4 additions & 7 deletions src/BasisConvolution/util/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def defaultHyperParameters():
'initialLR': 0.01,
'finalLR': 0.0001,
'lrStep': 10,
'maxRollOut': 10,
'epochs': 25,
'frameDistance': 4,
'iterations': 1000,
Expand All @@ -36,7 +35,7 @@ def defaultHyperParameters():
'weight_decay': 0,
'input': '',
'input': './',
'output': '../../trainingData_TGV/randomFlows/',
'output': 'training',
'outputBias': False,
'loss': 'mse',
'batchSize': 1,
Expand Down Expand Up @@ -82,7 +81,6 @@ def parseArguments(args, hyperParameterDict):
hyperParameterDict['lrStep'] = args.lrStep if hasattr(args, 'lrStep') else hyperParameterDict['lrStep']


hyperParameterDict['maxRollOut'] = args.maxUnroll if hasattr(args, 'maxUnroll') else hyperParameterDict['maxUnroll']
hyperParameterDict['epochs'] = args.epochs if hasattr(args, 'epochs') else hyperParameterDict['epochs']
hyperParameterDict['frameDistance'] = args.frameDistance if hasattr(args, 'frameDistance') else hyperParameterDict['frameDistance']
hyperParameterDict['iterations'] = args.iterations if hasattr(args, 'iterations') else hyperParameterDict['iterations']
Expand Down Expand Up @@ -376,7 +374,7 @@ def toPandaDict(hyperParameterDict):
'layers': hyperParameterDict['layers'],
'seed': hyperParameterDict['seed'],

'windowFunction': hyperParameterDict['windowFunction'],
'windowFunction': hyperParameterDict['windowFunction'] if hyperParameterDict['windowFunction'] is not None else 'None',
'coordinateMapping' : hyperParameterDict['coordinateMapping'],

'trainingFiles': hyperParameterDict['trainingFiles'],
Expand Down Expand Up @@ -415,7 +413,6 @@ def toPandaDict(hyperParameterDict):

'minUnroll': hyperParameterDict['minUnroll'],
'maxUnroll': hyperParameterDict['maxUnroll'],
'maxRollOut': hyperParameterDict['maxRollOut'],

'cutlassBatchSize': hyperParameterDict['cutlassBatchSize'],
'li' : hyperParameterDict['liLoss'] if 'liLoss' in hyperParameterDict else None,
Expand Down Expand Up @@ -588,9 +585,9 @@ def finalizeHyperParameters(hyperParameterDict, dataset):
hyperParameterDict['layers'] = [int(s) for s in hyperParameterDict['widths']]


hyperParameterDict['shortLabel'] = f'{hyperParameterDict["networkType"]:8s} [{hyperParameterDict["arch"]:14s}] -> [{hyperParameterDict["basisFunctions"]:8s}] x [{hyperParameterDict["basisTerms"]:2d}] @ {hyperParameterDict["coordinateMapping"]:4s}/{hyperParameterDict["windowFunction"]:4s}, {hyperParameterDict["fluidFeatures"]} -> {hyperParameterDict["groundTruth"]}'
hyperParameterDict['shortLabel'] = f'{hyperParameterDict["networkType"]:8s} [{hyperParameterDict["arch"]:14s}] -> [{hyperParameterDict["basisFunctions"]:8s}] x [{hyperParameterDict["basisTerms"]:2d}] @ {hyperParameterDict["coordinateMapping"]:4s}/{hyperParameterDict["windowFunction"] if hyperParameterDict["windowFunction"] is not None else "None":4s}, {hyperParameterDict["fluidFeatures"]} -> {hyperParameterDict["groundTruth"]}'

hyperParameterDict['progressLabel'] = f'{hyperParameterDict["networkType"]:8s} [{hyperParameterDict["arch"]:4s}] -> [{hyperParameterDict["basisFunctions"]:8s}] x [{hyperParameterDict["basisTerms"]:2d}] @ {hyperParameterDict["coordinateMapping"]:4s}/{hyperParameterDict["windowFunction"]:4s}'
hyperParameterDict['progressLabel'] = f'{hyperParameterDict["networkType"]:8s} [{hyperParameterDict["arch"]:4s}] -> [{hyperParameterDict["basisFunctions"]:8s}] x [{hyperParameterDict["basisTerms"]:2d}] @ {hyperParameterDict["coordinateMapping"]:4s}/{hyperParameterDict["windowFunction"] if hyperParameterDict["windowFunction"] is not None else "None":4s}'

hyperParameterDict['exportLabel'] = f'{hyperParameterDict["timestamp"]} - {hyperParameterDict["networkSeed"]} - {hyperParameterDict["shortLabel"]}'.replace(":", ".").replace("/", "_")

Expand Down

0 comments on commit 53c13bf

Please sign in to comment.