Skip to content

Commit

Permalink
Adds the restricted damerau-levenshtein distance
Browse files Browse the repository at this point in the history
* Add the restricted demerau-levenshtein distance

* Additional code comments + impl. alignment with wikipedia pseudocode

* Remove gleam/io import

* Address PR review comments

* Consistent spacing between tests
  • Loading branch information
NicklasXYZ authored Oct 2, 2024
1 parent ff4c2ea commit 62aed50
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 0 deletions.
201 changes: 201 additions & 0 deletions src/edit_distance/osa.gleam
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import gleam/dict.{type Dict}
import gleam/int
import gleam/list
import gleam/string

/// Compute the edit distance between two strings using the Restricted Damerau–Levenshtein, also
/// called the [Optimal String Alignment (OSA) distance](https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance#Optimal_string_alignment_distance).
///
/// The OSA distance counts operations: Insertions, deletions, substitutions, and limited
/// transpositions (only adjacent transpositions are allowed).
///
/// ## Examples
///
/// ```gleam
/// > distance("gleam", "beam")
/// 2
/// ```
///
/// ```gleam
/// > distance("cat", "cap")
/// 1
/// ```
///
/// ```gleam
/// > distance("wibble", "wibbel")
/// 2
/// ```
///
pub fn distance(one: String, other: String) -> Int {
case one, other {
// Case: Both strings are empty, so the distance is 0
"", "" -> 0

// Case: One string is empty, so the distance is the length of the other. That is, equal to the
// cost of inserting all characters of the 'other' string and vice versa
"", _ -> string.length(other)
_, "" -> string.length(one)

// Case: Strings are identical, so the distance is 0
one, other if one == other -> 0

// General case: Compute the Restricted Damerau-Levenshtein distance recursively
_, _ -> calculate_distance(one, other)
}
}

fn calculate_distance(one: String, other: String) -> Int {
let m = string.length(one)
let n = string.length(other)

// Convert both strings to dictionaries for key-value (index-character) access
// Indexing starts from 0. The indices are:
// - 0 to m - 1 for 'one'
// - 0 to n - 1 for 'other'
let one_dict = string_to_dict(one)
let other_dict = string_to_dict(other)

// Initialize a dictionary, representing a (distance) matrix, to store distance calculations
// Each key is a tuple (i, j) where i and j represent the lengths of substrings
// Each corresponding value represents the minimum cost to transform one substring into the
// other, using a set of allowed operations
let distance_dict = dict.new()

// Compute the distances between substrings
let distance_dict =
compute_distances(one_dict, other_dict, distance_dict, m, n)

// Retrieve the final distance between the full strings
let assert Ok(result) = dict.get(distance_dict, #(m, n))
result
}

// A function that converts a string into a dictionary of key-value (index-character) pairs
fn string_to_dict(str: String) -> Dict(Int, String) {
// Split the string into its graphemes
let string_graphemes = string.to_graphemes(str)

// Create an index-character dictionary
use acc, character, index <- list.index_fold(string_graphemes, dict.new())
dict.insert(acc, index, character)
}

// A function that gets the minimum of three values (used for the recursive case)
fn min_three(v1: Int, v2: Int, v3: Int) -> Int {
int.min(v1, int.min(v2, v3))
}

// A function that checks if two characters match and returns the appropriate cost
fn match_cost(
one: Dict(Int, String),
i: Int,
other: Dict(Int, String),
j: Int,
) -> Int {
// Return 1 if characters at positions i and j do not match, otherwise returns 0
case dict.get(one, i), dict.get(other, j) {
Ok(c0), Ok(c1) if c0 != c1 -> 1
_, _ -> 0
}
}

// A function that calculates the edit distance taking into account edit operations:
// - insertion
// - deletion
// - substitution
// - limited transposition
fn compute_cost(
one: Dict(Int, String),
other: Dict(Int, String),
d: Dict(#(Int, Int), Int),
i: Int,
j: Int,
) -> Int {
let cost = match_cost(one, i - 1, other, j - 1)

// Calculate insertion cost
let assert Ok(d_ij1) = dict.get(d, #(i, j - 1))

// Calculate deletion cost
let assert Ok(d_i1j) = dict.get(d, #(i - 1, j))

// Calculate substitution cost
let assert Ok(d_i1j1) = dict.get(d, #(i - 1, j - 1))

// Find the minimum cost between insertion, deletion, and substitution
let min_cost = min_three(d_i1j + 1, d_ij1 + 1, d_i1j1 + cost)

// Finally, take into account the cost of a transposition
let updated_cost = case i > 1 && j > 1 {
True -> {
case check_transposition(one, i - 1, other, j - 1) {
True -> {
let assert Ok(d_i2j2) = dict.get(d, #(i - 2, j - 2))
int.min(min_cost, d_i2j2 + 1)
}
False -> min_cost
}
}
False -> min_cost
}
updated_cost
}

// Check if a transposition is valid by comparing adjacent characters in both strings
fn check_transposition(
one: Dict(Int, String),
i: Int,
other: Dict(Int, String),
j: Int,
) -> Bool {
let one_i = dict.get(one, i)
let other_j = dict.get(other, j)
let one_i1 = dict.get(one, i - 1)
let other_j1 = dict.get(other, j - 1)

// Check if transposing adjacent characters results in a match
case one_i, other_j1, one_i1, other_j {
Ok(c0), Ok(c1), Ok(c2), Ok(c3) if c0 == c1 && c2 == c3 -> True
_, _, _, _ -> False
}
}

fn fold_range(start: Int, end: Int, acc: a, fun: fn(a, Int) -> a) -> a {
case start > end {
True -> acc
False -> {
let updated_acc = fun(acc, start)
fold_range(start + 1, end, updated_acc, fun)
}
}
}

// Compute the edit distance by filling the distance matrix with the minimum costs
fn compute_distances(
one: Dict(Int, String),
other: Dict(Int, String),
d: Dict(#(Int, Int), Int),
m: Int,
n: Int,
) -> Dict(#(Int, Int), Int) {
// Initialize base cases, to avoid continuously checking these later on
// Fixing i = 0 or j = 0, set the cost of transforming one string into the other
let d = {
use acc, i <- fold_range(0, m, d)
dict.insert(acc, #(i, 0), i)
}
let d = {
use acc, j <- fold_range(1, n, d)
dict.insert(acc, #(0, j), j)
}

// Iterate over the matrix to calculate the minimum costs for all substrings
use acc, i <- fold_range(1, m, d)
use acc, j <- fold_range(1, n, acc)

// Calculate the cost for this position (i, j)
let cost = compute_cost(one, other, acc, i, j)

// Update the distance dictionary with the calculated cost
dict.insert(acc, #(i, j), cost)
}
83 changes: 83 additions & 0 deletions test/edit_distance_test/osa_test.gleam
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import edit_distance/osa
import gleeunit/should

pub fn osa_test() {
osa.distance("", "hello")
|> should.equal(5)

osa.distance("hello", "")
|> should.equal(5)

// Test symmetry
osa.distance("a", "bbb")
|> should.equal(3)

osa.distance("bbb", "a")
|> should.equal(3)

osa.distance("hello", "hello")
|> should.equal(0)

// Test distinction between uppercase and lowercase letters
osa.distance("hello", "HELLO")
|> should.equal(5)

osa.distance("CA", "ABC")
|> should.equal(3)

// Test both strings are empty
osa.distance("", "")
|> should.equal(0)

// Test strings with only one character
osa.distance("a", "a")
|> should.equal(0)

osa.distance("a", "b")
|> should.equal(1)

// Test one character, one empty string
osa.distance("a", "")
|> should.equal(1)

osa.distance("", "b")
|> should.equal(1)

// Test repeating characters
osa.distance("aaa", "aaaa")
|> should.equal(1)

osa.distance("aaaa", "aaa")
|> should.equal(1)

// Test strings with spaces
osa.distance("hello world", "hello_world")
|> should.equal(1)

osa.distance(" a ", "a")
|> should.equal(2)

// Test strings with special characters
osa.distance("!@#", "!@")
|> should.equal(1)

osa.distance("!@#", "#@!")
|> should.equal(2)

// Test very long strings with minimal differences
osa.distance("aaaaaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaa")
|> should.equal(0)

osa.distance("aaaaaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaab")
|> should.equal(1)

// Test transpositions
osa.distance("abc", "acb")
|> should.equal(1)

osa.distance("abcdef", "badcfe")
|> should.equal(3)

osa.distance("abcdef", "abcfed")
|> should.equal(2)
}

0 comments on commit 62aed50

Please sign in to comment.