Skip to content

Commit

Permalink
fix model name
Browse files Browse the repository at this point in the history
  • Loading branch information
tonghe90 committed Jul 16, 2021
1 parent d638b2a commit f93b31a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
12 changes: 6 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ def non_max_suppression(ious, scores, threshold):
logger.info('=> creating model ...')
logger.info('Classes: {}'.format(cfg.classes))

if model_name == 'pointgroup':
from model.pointgroup.pointgroup import PointGroup as Network
from model.pointgroup.pointgroup import model_fn_decorator
else:
print("Error: no model version " + model_name)
exit(0)
#if model_name == 'pointgroup':
from model.pointgroup.pointgroup import PointGroup as Network
from model.pointgroup.pointgroup import model_fn_decorator
#else:
# print("Error: no model version " + model_name)
# exit(0)
model = Network(cfg)

use_cuda = torch.cuda.is_available()
Expand Down
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ def eval_epoch(val_loader, model, model_fn, epoch):
##### model
logger.info('=> creating model ...')

if model_name == 'pointgroup':
from model.pointgroup.pointgroup import PointGroup as Network
from model.pointgroup.pointgroup import model_fn_decorator
else:
print("Error: no model - " + model_name)
exit(0)
#if model_name == 'pointgroup':
from model.pointgroup.pointgroup import PointGroup as Network
from model.pointgroup.pointgroup import model_fn_decorator
#else:
# print("Error: no model - " + model_name)
# exit(0)

model = Network(cfg)

Expand Down

0 comments on commit f93b31a

Please sign in to comment.