Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo files #2

Merged
merged 2 commits into from
Jan 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ def get_data_config(self, data_name):
elif data_name == 'DSIFN':
self.label_transform = "norm"
self.root_dir = '/media/lidan/ssd2/CDData/DSIFN_256/'
elif data_name == 'quick_start':
self.root_dir = './samples/'
elif data_name == 'quick_start_LEVIR':
self.root_dir = './samples_LEVIR/'
elif data_name == 'quick_start_DSIFN':
self.root_dir = './samples_DSIFN/'
else:
raise TypeError('%s has not defined' % data_name)
return self
Expand Down
80 changes: 80 additions & 0 deletions demo_DSIFN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from argparse import ArgumentParser

import utils
import torch
from models.basic_model import CDEvaluator

import os

"""
quick start

sample files in ./samples

save prediction files in the ./samples/predict

"""


def get_args():
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--project_name', default='ChangeFormer_DSIFN', type=str)
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoint_root', default='checkpoints', type=str)
parser.add_argument('--output_folder', default='samples_DSIFN/predict', type=str)

# data
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--dataset', default='CDDataset', type=str)
parser.add_argument('--data_name', default='quick_start_DSIFN', type=str)

parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--split', default="demo", type=str)
parser.add_argument('--img_size', default=256, type=int)

# model
parser.add_argument('--n_class', default=2, type=int)
parser.add_argument('--embed_dim', default=256, type=int)
parser.add_argument('--net_G', default='ChangeFormerV6', type=str,
help='ChangeFormerV6 | base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|')
parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str)

args = parser.parse_args()
return args


if __name__ == '__main__':

args = get_args()
utils.get_device(args)
device = torch.device("cuda:%s" % args.gpu_ids[0]
if torch.cuda.is_available() and len(args.gpu_ids)>0
else "cpu")
args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name)
os.makedirs(args.output_folder, exist_ok=True)

log_path = os.path.join(args.output_folder, 'log_vis.txt')

data_loader = utils.get_loader(args.data_name, img_size=args.img_size,
batch_size=args.batch_size,
split=args.split, is_train=False)

model = CDEvaluator(args)
model.load_checkpoint(args.checkpoint_name)
model.eval()

for i, batch in enumerate(data_loader):
name = batch['name']
print('process: %s' % name)
score_map = model._forward_pass(batch)
model._save_predictions()







8 changes: 4 additions & 4 deletions demo.py → demo_LEVIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def get_args():
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--project_name', default='BIT_LEVIR', type=str)
parser.add_argument('--project_name', default='ChangeFormer_LEVIR', type=str)
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoint_root', default='checkpoints', type=str)
parser.add_argument('--output_folder', default='samples/predict', type=str)
parser.add_argument('--output_folder', default='samples_LEVIR/predict', type=str)

# data
parser.add_argument('--num_workers', default=0, type=int)
Expand All @@ -38,8 +38,8 @@ def get_args():
# model
parser.add_argument('--n_class', default=2, type=int)
parser.add_argument('--embed_dim', default=256, type=int)
parser.add_argument('--net_G', default='ChangeFormerV5', type=str,
help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|')
parser.add_argument('--net_G', default='ChangeFormerV6', type=str,
help='ChangeFormerV6 | base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|')
parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str)

args = parser.parse_args()
Expand Down
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.