Skip to content

Commit

Permalink
Merge pull request #65 from luozhiya/commit_opts
Browse files Browse the repository at this point in the history
Refactor `SuggestionsPreprocessing`
luozhiya authored May 20, 2024
2 parents bfdddff + fe8796c commit 31da331
Showing 5 changed files with 286 additions and 256 deletions.
44 changes: 23 additions & 21 deletions lua/fittencode/engines/actions.lua
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ local Promise = require('fittencode.concurrency.promise')
local PromptProviders = require('fittencode.prompt_providers')
local Sessions = require('fittencode.sessions')
local Status = require('fittencode.status')
local SuggestionsPreprocessing = require('fittencode.suggestions_preprocessing')
local TaskScheduler = require('fittencode.tasks')

local schedule = Base.schedule
@@ -93,22 +94,22 @@ end
---@param task_id integer
---@param suggestions Suggestions
---@return Suggestions?, integer?
local function filter_suggestions(task_id, suggestions)
if not suggestions then
return
end
local function filter_suggestions(window, buffer, task_id, suggestions)
local matched, ms = tasks:match_clean(task_id, 0, 0)
if not matched then
Log.debug('Action request is outdated, discarding task: {}', task_id)
return
return nil, ms
end
if not suggestions then
return nil, ms
end
return vim.tbl_filter(function(s) return #s > 0 end, suggestions), ms
return SuggestionsPreprocessing.run(window, buffer, suggestions), ms
end

---@param action integer
---@param solved_prefix string
---@param on_error function
local function chain_actions(action, solved_prefix, on_error)
local function chain_actions(window, buffer, action, solved_prefix, on_error)
Log.debug('Chain Action({})...', get_action_name(action))
if depth >= MAX_DEPTH then
Log.debug('Max depth reached, stopping evaluation')
@@ -127,7 +128,7 @@ local function chain_actions(action, solved_prefix, on_error)
solved_prefix = solved_prefix,
}, function(_, prompt, suggestions)
-- Log.debug('Suggestions for Actions: {}', suggestions)
local lines, ms = filter_suggestions(task_id, suggestions)
local lines, ms = filter_suggestions(window, buffer, task_id, suggestions)
if not lines or #lines == 0 then
schedule(on_error)
else
@@ -137,9 +138,9 @@ local function chain_actions(action, solved_prefix, on_error)
else
elapsed_time = elapsed_time + ms
depth = depth + 1
chat:commit(lines, true)
local new_solved_prefix = prompt.prefix .. table.concat(lines, '\n') .. '\n'
chain_actions(action, new_solved_prefix, on_error)
chat:commit(lines)
local new_solved_prefix = prompt.prefix .. table.concat(lines, '\n')
chain_actions(window, buffer, action, new_solved_prefix, on_error)
end
end
end, function(err)
@@ -164,7 +165,8 @@ local function on_error(err)
end
Log.debug('Action elapsed time: {}', elapsed_time)
Log.debug('Action depth: {}', depth)
chat:commit('> Q.E.D.' .. '(' .. elapsed_time .. ' ms)' .. '\n', true, true)
local qed = '\n\n' .. '> Q.E.D.' .. '(' .. elapsed_time .. ' ms)' .. '\n\n'
chat:commit(qed)
current_eval = current_eval + 1
end

@@ -260,26 +262,26 @@ local function make_filetype(buffer, range)
return filetype
end

local function _start_action(action, prompt_opts)
local function _start_action(window, buffer, action, prompt_opts)
Promise:new(function(resolve, reject)
local task_id = tasks:create(0, 0)
Sessions.request_generate_one_stage(task_id, prompt_opts, function(_, prompt, suggestions)
-- Log.debug('Suggestions for Actions: {}', suggestions)
local lines, ms = filter_suggestions(task_id, suggestions)
local lines, ms = filter_suggestions(window, buffer, task_id, suggestions)
elapsed_time = elapsed_time + ms
if not lines or #lines == 0 then
reject()
else
depth = depth + 1
chat:commit(lines, true)
local solved_prefix = prompt.prefix .. table.concat(lines, '\n') .. '\n'
chat:commit(lines)
local solved_prefix = prompt.prefix .. table.concat(lines, '\n')
resolve(solved_prefix)
end
end, function(err)
reject(err)
end)
end):forward(function(solved_prefix)
chain_actions(action, solved_prefix, on_error)
chain_actions(window, buffer, action, solved_prefix, on_error)
end, function(err)
schedule(on_error, err)
end
@@ -292,10 +294,10 @@ local function chat_commit_inout(action_name, prompt_opts, range)
prompt_preview.filename = 'unnamed'
end
local source_info = ' (' .. prompt_preview.filename .. ' ' .. range.start[1] .. ':' .. range['end'][1] .. ')'
local c_in = '# In`[' .. current_eval .. ']`:= ' .. action_name .. source_info
local c_in = '# In`[' .. current_eval .. ']`:= ' .. action_name .. source_info .. '\n'
chat:commit(c_in)
chat:commit(prompt_preview.content)
local c_out = '# Out`[' .. current_eval .. ']`='
chat:commit(prompt_preview.content .. '\n')
local c_out = '# Out`[' .. current_eval .. ']`=' .. '\n'
chat:commit(c_out)
end

@@ -348,7 +350,7 @@ function ActionsEngine.start_action(action, opts)
}

chat_commit_inout(action_name, prompt_opts, range)
_start_action(action, prompt_opts)
_start_action(chat.window, chat.buffer, action, prompt_opts)
end

---@param opts? ActionOptions
155 changes: 23 additions & 132 deletions lua/fittencode/engines/inline.lua
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ local NetworkError = require('fittencode.client.network_error')
local Sessions = require('fittencode.sessions')
local Status = require('fittencode.status')
local SuggestionsCache = require('fittencode.suggestions_cache')
local SuggestionsPreprocessing = require('fittencode.suggestions_preprocessing')
local TaskScheduler = require('fittencode.tasks')
local PromptProviders = require('fittencode.prompt_providers')
local Unicode = require('fittencode.unicode')
@@ -28,119 +29,30 @@ local tasks = nil
---@type Status
local status = nil

local function _set_text(lines)
local window = api.nvim_get_current_win()
local buffer = api.nvim_win_get_buf(window)
Lines.set_text({
window = window,
buffer = buffer,
lines = lines,
})
end

function M.setup()
cache = SuggestionsCache:new()
tasks = TaskScheduler:new()
tasks:setup()
status = Status:new({ tag = 'InlineEngine' })
end

---@param suggestions string[]
local function condense_nl(suggestions)
if not suggestions or #suggestions == 0 then
return
end

local is_all_empty = true
for _, suggestion in ipairs(suggestions) do
if #suggestion ~= 0 then
is_all_empty = false
break
end
end

if is_all_empty then
return {}
end

local row, col = Base.get_cursor()
local prev_line = nil
local cur_line = api.nvim_buf_get_lines(0, row, row + 1, false)[1]
if row > 1 then
prev_line = api.nvim_buf_get_lines(0, row - 1, row, false)[1]
end

local nls = {}
local remove_all = false
local keep_first = false

if vim.bo.filetype == 'TelescopePrompt' then
remove_all = true
end

if #cur_line == 0 then
if not prev_line or #prev_line == 0 then
remove_all = true
end
else
if col == #cur_line then
keep_first = true
end
end

if not remove_all and not keep_first then
return suggestions
end

Log.debug('remove_all: {}, keep_first: {}', remove_all, keep_first)

local is_processed = false
for i, suggestion in ipairs(suggestions) do
if #suggestion == 0 and not is_processed then
if remove_all then
-- ignore
elseif keep_first and i ~= 1 then
-- ignore
else
table.insert(nls, suggestion)
end
else
is_processed = true
table.insert(nls, suggestion)
end
end

if vim.bo.filetype == 'TelescopePrompt' then
nls = { nls[1] }
end

return nls
end

---@param suggestions string[]
local function normalize_indent(suggestions)
if not suggestions or #suggestions == 0 then
return
end
if not vim.bo.expandtab then
return
end
local nor = {}
for i, suggestion in ipairs(suggestions) do
-- replace `\t` with space
suggestion = suggestion:gsub('\t', string.rep(' ', vim.bo.tabstop))
nor[i] = suggestion
end
return nor
end

local function replace_slash(suggestions)
if not suggestions or #suggestions == 0 then
return
end
local slash = {}
for i, suggestion in ipairs(suggestions) do
suggestion = suggestion:gsub('\\"', '"')
slash[i] = suggestion
end
return slash
end

---@param task_id integer
---@param suggestions? Suggestions
---@return Suggestions?
local function process_suggestions(task_id, suggestions)
local row, col = Base.get_cursor()
local window = api.nvim_get_current_win()
local buffer = api.nvim_win_get_buf(window)
local row, col = Base.get_cursor(window)
if not tasks:match_clean(task_id, row, col) then
Log.debug('Completion request is outdated, discarding; task_id: {}, row: {}, col: {}', task_id, row, col)
return
@@ -153,28 +65,7 @@ local function process_suggestions(task_id, suggestions)

Log.debug('Suggestions received; task_id: {}, suggestions: {}', task_id, suggestions)

local nls = condense_nl(suggestions)
if nls then
suggestions = nls
end

local nor = normalize_indent(suggestions)
if nor then
suggestions = nor
end

local slash = replace_slash(suggestions)
if slash then
suggestions = slash
end

if #suggestions == 0 then
return
end

Log.debug('Processed suggestions: {}', suggestions)

return suggestions
return SuggestionsPreprocessing.run(window, buffer, suggestions)
end

local function apply_suggestion(task_id, row, col, suggestion)
@@ -310,7 +201,7 @@ function M.accept_all_suggestions()
Log.debug('Pretreatment cached lines: {}', cache:get_lines())

Lines.clear_virt_text()
Lines.set_text(cache:get_lines())
_set_text(cache:get_lines())

M.reset()

@@ -356,16 +247,16 @@ function M.accept_line()
local stage = cache:get_count() - 1

if cur == stage then
Lines.set_text({ line })
_set_text({ line })
Log.debug('Set line: {}', line)
Lines.set_text({ '', '' })
_set_text({ '', '' })
Log.debug('Set empty new line')
else
if cur == 0 then
Lines.set_text({ line })
_set_text({ line })
Log.debug('Set line: {}', line)
else
Lines.set_text({ line, '' })
_set_text({ line, '' })
Log.debug('Set line and empty new line; line: {}', line)
end
end
@@ -454,15 +345,15 @@ function M.accept_word()
if string.len(line) == 0 then
cache:remove_line(1)
if M.has_suggestions() then
Lines.set_text({ word, '' })
_set_text({ word, '' })
Log.debug('Set word and empty new line, word: {}', word)
else
Lines.set_text({ word })
_set_text({ word })
Log.debug('Set word: {}', word)
end
else
cache:update_line(1, line)
Lines.set_text({ word })
_set_text({ word })
Log.debug('Set word: {}', word)
end

131 changes: 131 additions & 0 deletions lua/fittencode/suggestions_preprocessing.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
local api = vim.api

local Base = require('fittencode.base')
local Log = require('fittencode.log')

local M = {}

---@param suggestions string[]
local function condense_nl(window, buffer, suggestions)
if not suggestions or #suggestions == 0 then
return
end

local is_all_empty = true
for _, suggestion in ipairs(suggestions) do
if #suggestion ~= 0 then
is_all_empty = false
break
end
end

if is_all_empty then
return {}
end

local row, col = Base.get_cursor(window)
local prev_line = nil
local cur_line = api.nvim_buf_get_lines(buffer, row, row + 1, false)[1]
if row > 1 then
prev_line = api.nvim_buf_get_lines(buffer, row - 1, row, false)[1]
end

Log.debug('prev_line: {}, cur_line: {}, col: {}', prev_line, cur_line, col)

local nls = {}
local remove_all = false
local keep_first = true

local filetype = api.nvim_get_option_value('filetype', { buf = buffer })
if filetype == 'TelescopePrompt' then
remove_all = true
end

if #cur_line == 0 then
if not prev_line or #prev_line == 0 then
remove_all = true
end
end

Log.debug('remove_all: {}, keep_first: {}', remove_all, keep_first)

local count = 0
for _, suggestion in ipairs(suggestions) do
if #suggestion == 0 then
if remove_all then
-- ignore
elseif keep_first and count ~= 0 then
-- ignore
else
table.insert(nls, suggestion)
end
count = count + 1
else
count = 0
table.insert(nls, suggestion)
end
end

if filetype == 'TelescopePrompt' then
nls = { nls[1] }
end

return nls
end

---@param suggestions string[]
local function normalize_indent(buffer, suggestions)
if not suggestions or #suggestions == 0 then
return
end
local expandtab = api.nvim_get_option_value('expandtab', { buf = buffer })
local tabstop = api.nvim_get_option_value('tabstop', { buf = buffer })
if not expandtab then
return
end
local nor = {}
for i, suggestion in ipairs(suggestions) do
-- replace `\t` with space
suggestion = suggestion:gsub('\t', string.rep(' ', tabstop))
nor[i] = suggestion
end
return nor
end

local function replace_slash(suggestions)
if not suggestions or #suggestions == 0 then
return
end
local slash = {}
for i, suggestion in ipairs(suggestions) do
suggestion = suggestion:gsub('\\"', '"')
slash[i] = suggestion
end
return slash
end

function M.run(window, buffer, suggestions)
local nls = condense_nl(window, buffer, suggestions)
if nls then
suggestions = nls
end

local nor = normalize_indent(buffer, suggestions)
if nor then
suggestions = nor
end

local slash = replace_slash(suggestions)
if slash then
suggestions = slash
end

if #suggestions == 0 then
return
end

Log.debug('Processed suggestions: {}', suggestions)
return suggestions
end

return M
161 changes: 74 additions & 87 deletions lua/fittencode/views/chat.lua
Original file line number Diff line number Diff line change
@@ -1,122 +1,109 @@
local api = vim.api

local Base = require('fittencode.base')
local Lines = require('fittencode.views.lines')
local Log = require('fittencode.log')

---@class Chat
---@field win? integer
---@field window? integer
---@field buffer? integer
---@field text? string
---@field content string[]
---@field show function
---@field commit function
---@field is_repeated function

local M = {}

function M:new()
local o = {
text = {}
content = {}
}
self.__index = self
return setmetatable(o, self)
end

function M:show()
if self.win == nil then
if not self.buffer then
self.buffer = api.nvim_create_buf(false, true)
api.nvim_buf_set_name(self.buffer, 'FittenCodeChat')
end
local function _commit(window, buffer, lines)
if api.nvim_buf_is_valid(buffer) and api.nvim_win_is_valid(window) then
api.nvim_set_option_value('modifiable', true, { buf = buffer })
api.nvim_set_option_value('readonly', false, { buf = buffer })
Lines.set_text({
window = window,
buffer = buffer,
lines = lines,
is_undo_disabled = true,
is_last = true
})
api.nvim_set_option_value('modifiable', false, { buf = buffer })
api.nvim_set_option_value('readonly', true, { buf = buffer })
end
end

vim.cmd('topleft vsplit')
vim.cmd('vertical resize ' .. 40)
self.win = api.nvim_get_current_win()
api.nvim_win_set_buf(self.win, self.buffer)

api.nvim_set_option_value('filetype', 'markdown', { buf = self.buffer })
api.nvim_set_option_value('modifiable', false, { buf = self.buffer })
api.nvim_set_option_value('wrap', true, { win = self.win })
api.nvim_set_option_value('linebreak', true, { win = self.win })
api.nvim_set_option_value('cursorline', true, { win = self.win })
api.nvim_set_option_value('spell', false, { win = self.win })
api.nvim_set_option_value('number', false, { win = self.win })
api.nvim_set_option_value('relativenumber', false, { win = self.win })
api.nvim_set_option_value('conceallevel', 3, { win = self.win })

Base.map('n', 'q', function()
self:close()
end, { buffer = self.buffer })

if #self.text > 0 then
-- api.nvim_set_option_value('modifiable', true, { buf = self.buffer })
-- api.nvim_buf_set_lines(self.buffer, 0, -1, false, self.text)
api.nvim_win_set_cursor(self.win, { #self.text, 0 })
-- api.nvim_set_option_value('modifiable', false, { buf = self.buffer })
local function set_content(window, buffer, text)
if #text > 0 then
for _, lines in ipairs(text) do
_commit(window, buffer, lines)
end
end
end

function M:close()
if self.win == nil then
local function scroll_to_last(window, buffer)
local row = math.max(api.nvim_buf_line_count(buffer), 1)
local col = api.nvim_buf_get_lines(buffer, row - 1, row, false)[1]:len()
api.nvim_win_set_cursor(window, { row, col })
end

local function set_option_value(window, buffer)
api.nvim_set_option_value('filetype', 'markdown', { buf = buffer })
api.nvim_set_option_value('readonly', true, { buf = buffer })
api.nvim_set_option_value('modifiable', false, { buf = buffer })
api.nvim_set_option_value('wrap', true, { win = window })
api.nvim_set_option_value('linebreak', true, { win = window })
api.nvim_set_option_value('cursorline', true, { win = window })
api.nvim_set_option_value('spell', false, { win = window })
api.nvim_set_option_value('number', false, { win = window })
api.nvim_set_option_value('relativenumber', false, { win = window })
api.nvim_set_option_value('conceallevel', 3, { win = window })
end

function M:show()
if self.window then
return
end
if api.nvim_win_is_valid(self.win) then
api.nvim_win_close(self.win, true)

if not self.buffer then
self.buffer = api.nvim_create_buf(false, true)
api.nvim_buf_set_name(self.buffer, 'FittenCodeChat')
end
self.win = nil
-- api.nvim_buf_delete(self.buffer, { force = true })
-- self.buffer = nil
end

local stack = {}
vim.cmd('topleft vsplit')
vim.cmd('vertical resize ' .. 42)
self.window = api.nvim_get_current_win()
api.nvim_win_set_buf(self.window, self.buffer)

local function push_stack(x)
if #stack == 0 then
table.insert(stack, x)
else
table.remove(stack)
end
Base.map('n', 'q', function() self:close() end, { buffer = self.buffer })

set_option_value(self.window, self.buffer)
scroll_to_last(self.window, self.buffer)
end

---@param text? string|string[]
---@param linebreak? boolean
---@param force? boolean
function M:commit(text, linebreak, force)
local lines = nil
if type(text) == 'string' then
lines = vim.split(text, '\n')
elseif type(text) == 'table' then
lines = text
else
function M:close()
if self.window == nil then
return
end
vim.tbl_map(function(x)
if x:match('^```') then
push_stack(x)
end
end, lines)
if #stack > 0 and not force then
linebreak = false
if api.nvim_win_is_valid(self.window) then
api.nvim_win_close(self.window, true)
end
if linebreak and #self.text > 0 and #lines > 0 then
if lines[1] ~= '' and not string.match(lines[1], '^```') and self.text[#self.text] ~= '' and not string.match(self.text[#self.text], '^```') then
table.insert(lines, 1, '')
end
end
if self.buffer then
api.nvim_set_option_value('modifiable', true, { buf = self.buffer })
if #self.text == 0 then
api.nvim_buf_set_lines(self.buffer, 0, -1, false, lines)
else
api.nvim_buf_set_lines(self.buffer, -1, -1, false, lines)
end
api.nvim_set_option_value('modifiable', false, { buf = self.buffer })
end
table.move(lines, 1, #lines, #self.text + 1, self.text)
self.window = nil
-- api.nvim_buf_delete(self.buffer, { force = true })
-- self.buffer = nil
end

if api.nvim_win_is_valid(self.win) then
api.nvim_win_set_cursor(self.win, { #self.text, 0 })
function M:commit(lines)
if type(lines) == 'string' then
lines = vim.split(lines, '\n')
end
table.insert(self.content, lines)
_commit(self.window, self.buffer, lines)
Log.debug('Chat text: {}', self.content)
end

local function _sub_match(s, pattern)
@@ -145,13 +132,13 @@ function M:is_repeated(lines)
end

---@return string[]
function M:get_text()
return self.text
function M:get_content()
return self.content
end

---@return boolean
function M:has_text()
return #self.text > 0
function M:has_content()
return #self.content > 0
end

return M
51 changes: 35 additions & 16 deletions lua/fittencode/views/lines.lua
Original file line number Diff line number Diff line change
@@ -118,25 +118,25 @@ end
---@param row integer
---@param col integer
---@param lines string[]
local function append_text_at_pos(row, col, lines)
local function append_text_at_pos(buffer, row, col, lines)
local count = vim.tbl_count(lines)
for i = 1, count, 1 do
local line = lines[i]
local len = string.len(line)
if i == 1 then
if len ~= 0 then
api.nvim_buf_set_text(0, row, col, row, col, { line })
api.nvim_buf_set_text(buffer, row, col, row, col, { line })
end
else
local max = api.nvim_buf_line_count(0)
local max = api.nvim_buf_line_count(buffer)
local try_row = row + i - 1
if try_row >= max then
api.nvim_buf_set_lines(0, max, max, false, { line })
api.nvim_buf_set_lines(buffer, max, max, false, { line })
else
if string.len(api.nvim_buf_get_lines(0, try_row, try_row + 1, false)[1]) ~= 0 then
api.nvim_buf_set_lines(0, try_row, try_row, false, { line })
if string.len(api.nvim_buf_get_lines(buffer, try_row, try_row + 1, false)[1]) ~= 0 then
api.nvim_buf_set_lines(buffer, try_row, try_row, false, { line })
else
api.nvim_buf_set_text(0, try_row, 0, try_row, 0, { line })
api.nvim_buf_set_text(buffer, try_row, 0, try_row, 0, { line })
end
end
end
@@ -146,16 +146,16 @@ end
---@param row integer
---@param col integer
---@param lines string[]
local function move_cursor_to_text_end(row, col, lines)
local function move_cursor_to_text_end(window, row, col, lines)
local count = vim.tbl_count(lines)
if count == 1 then
local first_len = string.len(lines[1])
if first_len ~= 0 then
api.nvim_win_set_cursor(0, { row + 1, col + first_len })
api.nvim_win_set_cursor(window, { row + 1, col + first_len })
end
else
local last_len = string.len(lines[count])
api.nvim_win_set_cursor(0, { row + count, last_len })
api.nvim_win_set_cursor(window, { row + count, last_len })
end
end

@@ -184,14 +184,33 @@ local function format_wrap(fx)
return ret
end

---@param lines string[]
function M.set_text(lines)
---@class LinesSetTextOptions
---@field window integer
---@field buffer integer
---@field lines string[]
---@field is_undo_disabled? boolean
---@field is_last? boolean

---@param opts LinesSetTextOptions
function M.set_text(opts)
local window = opts.window
local buffer = opts.buffer
local lines = opts.lines or {}
local is_undo_disabled = opts.is_undo_disabled or false
local is_last = opts.is_last or false

format_wrap(function()
local row, col = Base.get_cursor()
undojoin()
local row, col = Base.get_cursor(window)
if is_last then
row = math.max(api.nvim_buf_line_count(buffer) - 1, 0)
col = api.nvim_buf_get_lines(buffer, row, row + 1, false)[1]:len()
end
if not is_undo_disabled then
undojoin()
end
-- Emit events `CursorMovedI` `CursorHoldI`
append_text_at_pos(row, col, lines)
move_cursor_to_text_end(row, col, lines)
append_text_at_pos(buffer, row, col, lines)
move_cursor_to_text_end(window, row, col, lines)
end)
end

0 comments on commit 31da331

Please sign in to comment.