Skip to content

Commit

Permalink
refactors for #525 (#534)
Browse files Browse the repository at this point in the history
* refactor(fakes): clean up add_terminal_input

* refactor(fakes): append whole buf to output_buffer in FakeStdoutWriter::write

* refactor(fakes): append whole buf to output_buffer in FakeInputOutput::write_to_tty_stdin

* fix(fakes): allow partial reads in read_from_tty_stdout

This patch fixes two bugs in read_from_tty_stdout:
* if there was a partial read (ie. `bytes.read_position` is not 0 but
less than `bytes.content.len()`), subsequent calls to would fill `buf`
starting at index `bytes.read_position` instead of 0, leaving range
0..`bytes.read_position` untouched.
* if `buf` was smaller than `bytes.content.len()`, a panic would occur.

* refactor(channels): use crossbeam instead of mpsc

This patch replaces mpsc with crossbeam channels because crossbeam
supports selecting on multiple channels which will be necessary in a
subsequent patch.

* refactor(threadbus): allow multiple receivers in Bus

This patch changes Bus to use multiple receivers. Method `recv` returns
data from all of them. This will be used in a subsequent patch for
receiving from bounded and unbounded queues at the same time.

* refactor(channels): remove SenderType enum

This enum has only one variant, so the entire enum can be replaced with
the innards of said variant.

* refactor(channels): remove Send+Sync trait implementations

The implementation of these traits is not necessary, as
SenderWithContext is automatically Send and Sync for every T and
ErrorContext that's Send and Sync.
  • Loading branch information
kxt authored May 27, 2021
1 parent 9bdb40b commit 0c0355d
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 95 deletions.
25 changes: 25 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 20 additions & 33 deletions src/tests/fakes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::{HashMap, VecDeque};
use std::io::Write;
use std::os::unix::io::RawFd;
use std::path::PathBuf;
use std::sync::{mpsc, Arc, Condvar, Mutex};
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};

use zellij_utils::{nix, zellij_tile};
Expand All @@ -14,7 +14,7 @@ use zellij_server::os_input_output::{async_trait, AsyncReader, Pid, ServerOsApi}
use zellij_tile::data::Palette;
use zellij_utils::{
async_std,
channels::{ChannelWithContext, SenderType, SenderWithContext},
channels::{self, ChannelWithContext, SenderWithContext},
errors::ErrorContext,
interprocess::local_socket::LocalSocketStream,
ipc::{ClientToServerMsg, ServerToClientMsg},
Expand Down Expand Up @@ -52,13 +52,9 @@ impl FakeStdoutWriter {

impl Write for FakeStdoutWriter {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
let mut bytes_written = 0;
let mut output_buffer = self.output_buffer.lock().unwrap();
for byte in buf {
bytes_written += 1;
output_buffer.push(*byte);
}
Ok(bytes_written)
output_buffer.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> Result<(), std::io::Error> {
let mut output_buffer = self.output_buffer.lock().unwrap();
Expand All @@ -83,9 +79,11 @@ pub struct FakeInputOutput {
possible_tty_inputs: HashMap<u16, Bytes>,
last_snapshot_time: Arc<Mutex<Instant>>,
send_instructions_to_client: SenderWithContext<ServerToClientMsg>,
receive_instructions_from_server: Arc<Mutex<mpsc::Receiver<(ServerToClientMsg, ErrorContext)>>>,
receive_instructions_from_server:
Arc<Mutex<channels::Receiver<(ServerToClientMsg, ErrorContext)>>>,
send_instructions_to_server: SenderWithContext<ClientToServerMsg>,
receive_instructions_from_client: Arc<Mutex<mpsc::Receiver<(ClientToServerMsg, ErrorContext)>>>,
receive_instructions_from_client:
Arc<Mutex<channels::Receiver<(ClientToServerMsg, ErrorContext)>>>,
should_trigger_sigwinch: Arc<(Mutex<bool>, Condvar)>,
sigwinch_event: Option<PositionAndSize>,
}
Expand All @@ -96,11 +94,11 @@ impl FakeInputOutput {
let last_snapshot_time = Arc::new(Mutex::new(Instant::now()));
let stdout_writer = FakeStdoutWriter::new(last_snapshot_time.clone());
let (client_sender, client_receiver): ChannelWithContext<ServerToClientMsg> =
mpsc::channel();
let send_instructions_to_client = SenderWithContext::new(SenderType::Sender(client_sender));
channels::unbounded();
let send_instructions_to_client = SenderWithContext::new(client_sender);
let (server_sender, server_receiver): ChannelWithContext<ClientToServerMsg> =
mpsc::channel();
let send_instructions_to_server = SenderWithContext::new(SenderType::Sender(server_sender));
channels::unbounded();
let send_instructions_to_server = SenderWithContext::new(server_sender);
win_sizes.insert(0, winsize); // 0 is the current terminal
FakeInputOutput {
read_buffers: Arc::new(Mutex::new(HashMap::new())),
Expand All @@ -125,10 +123,7 @@ impl FakeInputOutput {
self
}
pub fn add_terminal_input(&mut self, input: &[&[u8]]) {
let mut stdin_commands: VecDeque<Vec<u8>> = VecDeque::new();
for command in input.iter() {
stdin_commands.push_back(command.iter().copied().collect())
}
let stdin_commands = input.iter().map(|i| i.to_vec()).collect();
self.stdin_commands = Arc::new(Mutex::new(stdin_commands));
}
pub fn add_terminal(&self, fd: RawFd) {
Expand Down Expand Up @@ -281,26 +276,18 @@ impl ServerOsApi for FakeInputOutput {
fn write_to_tty_stdin(&self, pid: RawFd, buf: &[u8]) -> Result<usize, nix::Error> {
let mut stdin_writes = self.stdin_writes.lock().unwrap();
let write_buffer = stdin_writes.get_mut(&pid).unwrap();
let mut bytes_written = 0;
for byte in buf {
bytes_written += 1;
write_buffer.push(*byte);
}
Ok(bytes_written)
Ok(write_buffer.write(buf).unwrap())
}
fn read_from_tty_stdout(&self, pid: RawFd, buf: &mut [u8]) -> Result<usize, nix::Error> {
fn read_from_tty_stdout(&self, pid: RawFd, mut buf: &mut [u8]) -> Result<usize, nix::Error> {
let mut read_buffers = self.read_buffers.lock().unwrap();
let mut bytes_read = 0;
match read_buffers.get_mut(&pid) {
Some(bytes) => {
for i in bytes.read_position..bytes.content.len() {
bytes_read += 1;
buf[i] = bytes.content[i];
}
if bytes_read > bytes.read_position {
bytes.set_read_position(bytes_read);
let available_range = bytes.read_position..bytes.content.len();
let len = buf.write(&bytes.content[available_range]).unwrap();
if len > bytes.read_position {
bytes.set_read_position(len);
}
return Ok(bytes_read);
return Ok(len);
}
None => Err(nix::Error::Sys(nix::errno::Errno::EAGAIN)),
}
Expand Down
10 changes: 4 additions & 6 deletions zellij-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use std::env::current_exe;
use std::io::{self, Write};
use std::path::Path;
use std::process::Command;
use std::sync::mpsc;
use std::thread;

use crate::{
Expand All @@ -16,7 +15,7 @@ use crate::{
};
use zellij_utils::cli::CliArgs;
use zellij_utils::{
channels::{SenderType, SenderWithContext, SyncChannelWithContext},
channels::{self, ChannelWithContext, SenderWithContext},
consts::{SESSION_NAME, ZELLIJ_IPC_PIPE},
errors::{ClientContext, ContextType, ErrorInstruction},
input::{actions::Action, config::Config, options::Options},
Expand Down Expand Up @@ -149,11 +148,10 @@ pub fn start_client(
.write(bracketed_paste.as_bytes())
.unwrap();

let (send_client_instructions, receive_client_instructions): SyncChannelWithContext<
let (send_client_instructions, receive_client_instructions): ChannelWithContext<
ClientInstruction,
> = mpsc::sync_channel(50);
let send_client_instructions =
SenderWithContext::new(SenderType::SyncSender(send_client_instructions));
> = channels::bounded(50);
let send_client_instructions = SenderWithContext::new(send_client_instructions);

#[cfg(not(any(feature = "test", test)))]
std::panic::set_hook({
Expand Down
29 changes: 14 additions & 15 deletions zellij-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ mod wasm_vm;

use zellij_utils::zellij_tile;

use std::path::PathBuf;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::{path::PathBuf, sync::mpsc};
use wasmer::Store;
use zellij_tile::data::{Event, InputMode, PluginCapabilities};

Expand All @@ -27,7 +27,8 @@ use crate::{
};
use route::route_thread_main;
use zellij_utils::{
channels::{ChannelWithContext, SenderType, SenderWithContext, SyncChannelWithContext},
channels,
channels::{ChannelWithContext, SenderWithContext},
cli::CliArgs,
errors::{ContextType, ErrorInstruction, ServerContext},
input::{get_mode_info, options::Options},
Expand Down Expand Up @@ -117,9 +118,8 @@ pub fn start_server(os_input: Box<dyn ServerOsApi>, socket_path: PathBuf) {

std::env::set_var(&"ZELLIJ", "0");

let (to_server, server_receiver): SyncChannelWithContext<ServerInstruction> =
mpsc::sync_channel(50);
let to_server = SenderWithContext::new(SenderType::SyncSender(to_server));
let (to_server, server_receiver): ChannelWithContext<ServerInstruction> = channels::bounded(50);
let to_server = SenderWithContext::new(to_server);
let session_data: Arc<RwLock<Option<SessionMetaData>>> = Arc::new(RwLock::new(None));
let session_state = Arc::new(RwLock::new(SessionState::Uninitialized));

Expand Down Expand Up @@ -301,13 +301,12 @@ fn init_session(
client_attributes: ClientAttributes,
session_state: Arc<RwLock<SessionState>>,
) -> SessionMetaData {
let (to_screen, screen_receiver): ChannelWithContext<ScreenInstruction> = mpsc::channel();
let to_screen = SenderWithContext::new(SenderType::Sender(to_screen));

let (to_plugin, plugin_receiver): ChannelWithContext<PluginInstruction> = mpsc::channel();
let to_plugin = SenderWithContext::new(SenderType::Sender(to_plugin));
let (to_pty, pty_receiver): ChannelWithContext<PtyInstruction> = mpsc::channel();
let to_pty = SenderWithContext::new(SenderType::Sender(to_pty));
let (to_screen, screen_receiver): ChannelWithContext<ScreenInstruction> = channels::unbounded();
let to_screen = SenderWithContext::new(to_screen);
let (to_plugin, plugin_receiver): ChannelWithContext<PluginInstruction> = channels::unbounded();
let to_plugin = SenderWithContext::new(to_plugin);
let (to_pty, pty_receiver): ChannelWithContext<PtyInstruction> = channels::unbounded();
let to_pty = SenderWithContext::new(to_pty);

// Determine and initialize the data directory
let data_dir = opts.data_dir.unwrap_or_else(get_default_data_dir);
Expand All @@ -334,7 +333,7 @@ fn init_session(
.spawn({
let pty = Pty::new(
Bus::new(
pty_receiver,
vec![pty_receiver],
Some(&to_screen),
None,
Some(&to_plugin),
Expand All @@ -352,7 +351,7 @@ fn init_session(
.name("screen".to_string())
.spawn({
let screen_bus = Bus::new(
screen_receiver,
vec![screen_receiver],
None,
Some(&to_pty),
Some(&to_plugin),
Expand All @@ -377,7 +376,7 @@ fn init_session(
.name("wasm".to_string())
.spawn({
let plugin_bus = Bus::new(
plugin_receiver,
vec![plugin_receiver],
Some(&to_screen),
Some(&to_pty),
None,
Expand Down
27 changes: 16 additions & 11 deletions zellij-server/src/thread_bus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use crate::{
os_input_output::ServerOsApi, pty::PtyInstruction, screen::ScreenInstruction,
wasm_vm::PluginInstruction, ServerInstruction,
};
use std::sync::mpsc;
use zellij_utils::{channels::SenderWithContext, errors::ErrorContext};
use zellij_utils::{channels, channels::SenderWithContext, errors::ErrorContext};

/// A container for senders to the different threads in zellij on the server side
#[derive(Clone)]
Expand All @@ -20,50 +19,50 @@ impl ThreadSenders {
pub fn send_to_screen(
&self,
instruction: ScreenInstruction,
) -> Result<(), mpsc::SendError<(ScreenInstruction, ErrorContext)>> {
) -> Result<(), channels::SendError<(ScreenInstruction, ErrorContext)>> {
self.to_screen.as_ref().unwrap().send(instruction)
}

pub fn send_to_pty(
&self,
instruction: PtyInstruction,
) -> Result<(), mpsc::SendError<(PtyInstruction, ErrorContext)>> {
) -> Result<(), channels::SendError<(PtyInstruction, ErrorContext)>> {
self.to_pty.as_ref().unwrap().send(instruction)
}

pub fn send_to_plugin(
&self,
instruction: PluginInstruction,
) -> Result<(), mpsc::SendError<(PluginInstruction, ErrorContext)>> {
) -> Result<(), channels::SendError<(PluginInstruction, ErrorContext)>> {
self.to_plugin.as_ref().unwrap().send(instruction)
}

pub fn send_to_server(
&self,
instruction: ServerInstruction,
) -> Result<(), mpsc::SendError<(ServerInstruction, ErrorContext)>> {
) -> Result<(), channels::SendError<(ServerInstruction, ErrorContext)>> {
self.to_server.as_ref().unwrap().send(instruction)
}
}

/// A container for a receiver, OS input and the senders to a given thread
pub(crate) struct Bus<T> {
pub receiver: mpsc::Receiver<(T, ErrorContext)>,
receivers: Vec<channels::Receiver<(T, ErrorContext)>>,
pub senders: ThreadSenders,
pub os_input: Option<Box<dyn ServerOsApi>>,
}

impl<T> Bus<T> {
pub fn new(
receiver: mpsc::Receiver<(T, ErrorContext)>,
receivers: Vec<channels::Receiver<(T, ErrorContext)>>,
to_screen: Option<&SenderWithContext<ScreenInstruction>>,
to_pty: Option<&SenderWithContext<PtyInstruction>>,
to_plugin: Option<&SenderWithContext<PluginInstruction>>,
to_server: Option<&SenderWithContext<ServerInstruction>>,
os_input: Option<Box<dyn ServerOsApi>>,
) -> Self {
Bus {
receiver,
receivers,
senders: ThreadSenders {
to_screen: to_screen.cloned(),
to_pty: to_pty.cloned(),
Expand All @@ -74,7 +73,13 @@ impl<T> Bus<T> {
}
}

pub fn recv(&self) -> Result<(T, ErrorContext), mpsc::RecvError> {
self.receiver.recv()
pub fn recv(&self) -> Result<(T, ErrorContext), channels::RecvError> {
let mut selector = channels::Select::new();
self.receivers.iter().for_each(|r| {
selector.recv(r);
});
let oper = selector.select();
let idx = oper.index();
oper.recv(&self.receivers[idx])
}
}
1 change: 1 addition & 0 deletions zellij-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ license = "MIT"
backtrace = "0.3.55"
bincode = "1.3.1"
colors-transform = "0.2.5"
crossbeam = "0.8.0"
directories-next = "2.0"
interprocess = "1.1.1"
lazy_static = "1.4.0"
Expand Down
Loading

0 comments on commit 0c0355d

Please sign in to comment.