Skip to content

Commit

Permalink
Latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
luiscosio committed Dec 9, 2023
1 parent 70a66f2 commit a4b95ec
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 240 deletions.
12 changes: 0 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,6 @@ To run the detector on a folder and a specific models:

`python main.py --image_folder image.jpg --model dMDetectorResults`

For just running the diffusion detector:

`python dmdetector.py --image_path image.jpg`

For just running the GAN detector:

`python gandetector.py --image_path image.jpg`

For just running the EXIF detector:

`python gandetector.py --image_path image.jpg`

If you want to use the API, first run the server:

`python api.py`
Expand Down
2 changes: 1 addition & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from waitress import serve
from flask import Flask, request, jsonify
import torch
from utils import (
from utils.general import (
setup_logger,
memory_usage,
validate_image_file,
Expand Down
2 changes: 1 addition & 1 deletion explainability.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import requests
import boto3
from botocore.exceptions import ClientError
from utils import encode_image, setup_logger
from utils.general import encode_image, setup_logger

logger = setup_logger(__name__)

Expand Down
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import glob
from time import time
import torch
from utils import (
from utils.general import (
setup_logger,
validate_image_file,
write_to_csv,
Expand Down Expand Up @@ -62,14 +62,14 @@ def preload_models():
dm_loaded_models[model_name] = load_dm_model(model_name, device)

logger.info("Loaded DM model: %s", model_name)
logger.info(memory_usage())
logger.info("Memory usage: %s", memory_usage())

# Preload GAN models
for model_name in gan_models_config:
gan_loaded_models[model_name] = load_gan_model(model_name, device)

logger.info("Loaded GAN model: %s", model_name)
logger.info(memory_usage())
logger.info("Memory usage: %s", memory_usage())

logger.info("Model preloading complete!")

Expand Down Expand Up @@ -126,7 +126,7 @@ def process_image(image_path, models):
return logger.error("Image %s is not valid: %s", image_path, e)

image_results = {}
# image_results["path"] = processed_image_path

if "dMDetectorResults" in models:
logger.info("Starting DM detection on %s", processed_image_path)
image_results["dMDetectorResults"] = dm_process_image(
Expand Down
45 changes: 3 additions & 42 deletions models/dmdetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
"""
import traceback
import argparse
import time
import json
import logging
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import models.networks.resnet_mod as resnet_mod
from utils import (
from utils.general import (
setup_logger,
memory_usage,
calculate_sigmoid_probabilities,
Expand Down Expand Up @@ -61,7 +58,7 @@ def load_model(model_name, device):
model = model.to(device).eval()

logger.info("Model %s loaded", model_name)
logger.info(memory_usage())
logger.info("Memory usage: %s", memory_usage())

return model

Expand Down Expand Up @@ -141,7 +138,7 @@ def process_image(image_path, preloaded_models=None):
del model
torch.cuda.empty_cache()
logger.info("Model %s unloaded", model_name)
logger.info(memory_usage())
logger.info("Memory usage: %s", memory_usage())

execution_time = time.time() - start_time

Expand Down Expand Up @@ -202,39 +199,3 @@ def process_image(image_path, preloaded_models=None):
}

return detection_output


def main():
"""
Command-line interface for the Diffusor detector.
"""
parser = argparse.ArgumentParser(
description="Diffusion detector inference on a single image"
)
parser.add_argument(
"--image_path", type=str, required=True, help="Path to the image file"
)
parser.add_argument(
"--log_level",
type=str,
default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)

args = parser.parse_args()
# Configure logger
log_levels = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
setup_logger(__name__, log_levels.get(args.log_level.upper(), logging.INFO))

return process_image(args.image_path)


if __name__ == "__main__":
output = main()
print(json.dumps(output, indent=4))
50 changes: 8 additions & 42 deletions models/exifdetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
Exif Detector: Inference on a single image using meta information
"""

import os
import re
import argparse
import time
import json
import logging
from PIL import Image
from PIL.ExifTags import TAGS
from PIL.PngImagePlugin import PngImageFile
from utils import (
from utils.general import (
setup_logger,
validate_image_file,
)
Expand Down Expand Up @@ -105,10 +103,13 @@ def process_image(image_path):
execution_time = time.time() - start_time

# # Regex for DALL-E filename pattern
dalle_regex = r'^DALL.*\.png$'
# @TODO: Make it more specific
dalle_regex = r"DALL·E \d{4}-\d{2}-\d{2} \d{2}\.\d{2}\.\d{2}"

print(image.filename)
if re.match(dalle_regex, image.filename):
filename = os.path.basename(image_path)
logger.info("Applying regex to %s", filename)

if re.match(dalle_regex, filename):
is_synthetic_image = True
infered_model = "DALLE-3"
elif (
Expand Down Expand Up @@ -138,38 +139,3 @@ def process_image(image_path):
}

return detection_output


def main():
"""
Command-line interface for the GAN detector.
"""
parser = argparse.ArgumentParser(
description="EXIF detector inference on a single image"
)
parser.add_argument(
"--image_path", type=str, required=True, help="Path to the image file"
)
parser.add_argument(
"--log_level",
type=str,
default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)

args = parser.parse_args()
# Configure logger
log_levels = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
setup_logger(__name__, log_levels.get(args.log_level.upper(), logging.INFO))
return process_image(args.image_path)


if __name__ == "__main__":
output = main()
print(json.dumps(output, indent=4))
46 changes: 4 additions & 42 deletions models/gandetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
Based on: https://github.com/grip-unina/GANimageDetection
"""
import argparse
import time
import json
import logging
import torch
import numpy as np
from PIL import Image
from models.networks.resnet50nodown import resnet50nodown
from utils import (
from utils.general import (
setup_logger,
memory_usage,
calculate_sigmoid_probabilities,
Expand Down Expand Up @@ -49,7 +46,7 @@ def load_model(model_name, device):
model = resnet50nodown(device, model_config["model_path"])

logger.info("Model %s loaded", model_name)
logger.info(memory_usage())
logger.info("Memory usage: %s", memory_usage())

return model

Expand Down Expand Up @@ -94,7 +91,7 @@ def process_image(image_path, preloaded_models=None):
del model
torch.cuda.empty_cache()
logger.info("Model %s unloaded", model_name)
logger.info(memory_usage())
logger.info("Memory usage: %s", memory_usage())

execution_time = time.time() - start_time

Expand Down Expand Up @@ -130,39 +127,4 @@ def process_image(image_path, preloaded_models=None):
},
}

return detection_output


def main():
"""
Command-line interface for the GAN detector.
"""
parser = argparse.ArgumentParser(
description="GAN detector inference on a single image"
)
parser.add_argument(
"--image_path", type=str, required=True, help="Path to the image file"
)
parser.add_argument(
"--log_level",
type=str,
default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)

args = parser.parse_args()
# Configure logger
log_levels = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
setup_logger(__name__, log_levels.get(args.log_level.upper(), logging.INFO))
return process_image(args.image_path)


if __name__ == "__main__":
output = main()
print(json.dumps(output, indent=4))
return detection_output
Loading

0 comments on commit a4b95ec

Please sign in to comment.