Skip to content

Commit

Permalink
Normalize range
Browse files Browse the repository at this point in the history
  • Loading branch information
luozhiya committed May 22, 2024
1 parent b6d178b commit fc2687a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
38 changes: 38 additions & 0 deletions lua/fittencode/engines/actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ local Sessions = require('fittencode.sessions')
local Status = require('fittencode.status')
local SuggestionsPreprocessing = require('fittencode.suggestions_preprocessing')
local TaskScheduler = require('fittencode.tasks')
local Unicode = require('fittencode.unicode')

local schedule = Base.schedule

Expand Down Expand Up @@ -255,6 +256,40 @@ end

local VMODE = { ['v'] = true, ['V'] = true, [api.nvim_replace_termcodes('<C-V>', true, true, true)] = true }

---@param buffer number
---@param range ActionRange
local function normalize_range(buffer, range)
local start = range.start
local end_ = range['end']

if end_[1] < start[1] then
start[1], end_[1] = end_[1], start[1]
start[2], end_[2] = end_[2], start[2]
end
if end_[2] < start[2] and end_[1] == start[1] then
start[2], end_[2] = end_[2], start[2]
end

local utf_end_byte = function(row, col)
local line = api.nvim_buf_get_lines(buffer, row - 1, row, false)[1]
local byte_start = math.min(col + 1, #line)
local utf_index = Unicode.calculate_utf8_index(line)
local flag = utf_index[byte_start]
assert(flag == 0)
local byte_end = #line
local next = Unicode.find_zero(utf_index, byte_start + 1)
if next then
byte_end = next - 1
end
return byte_end
end

end_[2] = utf_end_byte(end_[1], end_[2])

range.start = start
range['end'] = end_
end

local function make_range(buffer)
local in_v = false
local region = nil
Expand All @@ -273,12 +308,15 @@ local function make_range(buffer)
local start = api.nvim_buf_get_mark(buffer, '<')
local end_ = api.nvim_buf_get_mark(buffer, '>')

---@type ActionRange
local range = {
start = start,
['end'] = end_,
vmode = in_v,
region = region,
}
normalize_range(buffer, range)

return range
end

Expand Down
12 changes: 1 addition & 11 deletions lua/fittencode/prompt_providers/actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ function M:get_priority()
return self.priority
end

local function max_len(buffer, row, len)
local max = string.len(api.nvim_buf_get_lines(buffer, row - 1, row, false)[1])
if len > max then
return max
end
return len
end

---@param buffer integer
---@param range ActionRange
---@return string
Expand All @@ -44,14 +36,12 @@ local function make_range_content(buffer, range)
if range.vmode and range.region then
lines = range.region or {}
else
-- lines = api.nvim_buf_get_text(buffer, range.start[1] - 1, 0, range.start[1] - 1, -1, {})
local end_col = max_len(buffer, range['end'][1], range['end'][2])
lines = api.nvim_buf_get_text(
buffer,
range.start[1] - 1,
range.start[2],
range['end'][1] - 1,
end_col + 1, {})
range['end'][2], {})
end
return table.concat(lines, '\n')
end
Expand Down
6 changes: 3 additions & 3 deletions lua/fittencode/unicode.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function M.calculate_utf8_index_tbl(lines)
return index
end

local function find_zero(tbl, start_index)
function M.find_zero(tbl, start_index)
for i = start_index, #tbl do
if tbl[i] == 0 then
return i
Expand All @@ -30,14 +30,14 @@ function M.find_first_character(s, tbl, start_index)
return nil
end

local v1 = find_zero(tbl, start_index)
local v1 = M.find_zero(tbl, start_index)
assert(v1 == start_index)
if v1 == nil then
-- Invalid UTF-8 sequence
return nil
end

local v2 = find_zero(tbl, v1 + 1)
local v2 = M.find_zero(tbl, v1 + 1)
if v2 == nil then
v2 = #tbl
else
Expand Down

0 comments on commit fc2687a

Please sign in to comment.