Skip to content

Commit

Permalink
refactor: gmm laser code restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccalmon committed Jan 17, 2025
1 parent cf1cbe9 commit 602f71b
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 143 deletions.
143 changes: 0 additions & 143 deletions pyfishsensedev/laser/bayesian_laser_detector.py

This file was deleted.

151 changes: 151 additions & 0 deletions pyfishsensedev/laser/gmm_laser_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import numpy as np
import cv2
import sys
import math
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
import glob
import os
from pathlib import Path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from pyfishsensedev.plane_detector.slate_detector import SlateDetector
from pyfishsensedev.image.pdf import Pdf
from pyfishsensedev.laser.laser_detector import LaserDetector

def pdf_multidim(X, mean, cov_inv, cov_det):
X = X.astype(np.float32)
d = X.shape[1]
diff = X - mean
exponent = -0.5 * np.einsum('ij,ij->i', diff @ cov_inv, diff)
denom = np.sqrt((2.0 * np.pi)**d * cov_det)
return (1.0 / denom) * np.exp(exponent)

class GMM():
def __init__(self, k, samples):
self.gmm = BayesianGaussianMixture(n_components=k, max_iter=200, covariance_type='full')
self.gmm.fit(samples)

self.means = self.gmm.means_
self.covs = self.gmm.covariances_
self.weights = self.gmm.weights_

self.k = k
self.cov_invs = []
self.cov_dets = []
for i in range(self.k):
cov_inv = np.linalg.inv(self.covs[i])
cov_det = np.linalg.det(self.covs[i])
self.cov_invs.append(cov_inv)
self.cov_dets.append(cov_det)

def pdf(self, X):
N = X.shape[0]
p = np.zeros(N, dtype=np.float64)
for i in range(self.k):
p_comp = pdf_multidim(X, self.means[i], self.cov_invs[i], self.cov_dets[i])
p += self.weights[i] * p_comp
return p

class PDFMultiDim():
def __init__(self, mean, cov):
self.mean = np.asarray(mean, dtype=np.float64)
self.cov = np.asarray(cov, dtype=np.float64)
self.cov_inv = np.linalg.inv(self.cov)
self.cov_det = np.linalg.det(self.cov)

def pdf(self, X):
return pdf_multidim(X, self.mean, self.cov_inv, self.cov_det)

class GMMLaserDetector(LaserDetector):

def __init__(self, laser_samples_path):
super().__init__()

# Build laser model from GMM
laser_samples = self.__get_samples(laser_samples_path)
self.pdf_laser = GMM(k=10, samples=laser_samples)

def __find_mean_cov(self, img):
img32 = img.astype(np.float32)
pixels = img32.reshape(-1, 3)
mean_hsv = np.mean(pixels, axis=0)
cov_hsv = np.cov(pixels, rowvar=False)
return mean_hsv, cov_hsv

def __get_samples(self, path):
dirs = glob.glob(f"{path}*")
pixels = np.empty((0, 3), dtype=np.float32)

for p in dirs:
img = cv2.imread(p)
if img is None:
print(f"Error reading image: {p}")
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# Convert to float32, then flatten
pixels = np.vstack((pixels, img.reshape(-1, 3).astype(np.float32)))
return pixels

def find_laser(self, img_bgr):
"""
Using Bayesian inference, classify pixels with Gaussian models and return laser point
"""
# Convert to RGB and HSV color spaces
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)

# Flatten (H*W, 3)
H, W, _ = img_hsv.shape
pixels_hsv = img_hsv.reshape(-1, 3).astype(np.float32)

# Priors - right now hard coded, should be related to how many total pixels in image
prior_laser = 0.00001 # 0.00001
prior_bg = 0.99999 # 0.99999

# Build background model
bg_mean, bg_cov = self.__find_mean_cov(img_hsv)
pdf_bg = PDFMultiDim(bg_mean, bg_cov)

# Compute the likelihoods
likelihood_bg = pdf_bg.pdf(pixels_hsv)
likelihood_laser = self.pdf_laser.pdf(pixels_hsv)

# Combine with priors => posterior
numerator_laser = likelihood_laser * prior_laser
numerator_bg = likelihood_bg * prior_bg
norm_const = numerator_laser + numerator_bg

# Posterior for laser
posterior_laser = numerator_laser / norm_const
# Classify as laser if posterior_laser > posterior_bg => posterior_laser > 0.5
laser_mask = (posterior_laser > 0.5).reshape(H, W)

# Color the laser pixels in BGR
img_bgr[laser_mask] = [128, 0, 0]

# Return the centroid
laser_rows, laser_cols = np.where(laser_mask)
centroid_row = None
centroid_col = None
if len(laser_rows) > 0:
centroid_row = np.mean(laser_rows)
centroid_col = np.mean(laser_cols)

# Show results
plt.scatter(
[centroid_col], [centroid_row],
color='red', marker='x', s=50, label='Laser Pixel Mean'
)
plt.legend()
plt.imshow(img_rgb)
plt.title("Laser Pixels (Blue)")
plt.show()

return centroid_col, centroid_row

if __name__ == "__main__":
path = sys.argv[1]
img_bgr = cv2.imread(path)
gmm_laser_detector = GMMLaserDetector(laser_samples_path="laser_data/")
x, y = gmm_laser_detector.find_laser(img_bgr)

0 comments on commit 602f71b

Please sign in to comment.