Skip to content

Commit

Permalink
Fixing for_util
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
  • Loading branch information
harshavamsi committed Feb 8, 2024
1 parent 1c999ec commit d6769df
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 21 deletions.
36 changes: 22 additions & 14 deletions src/core/codec/postings/for_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cell::UnsafeCell;
use std::cmp::max;
use std::sync::{Arc, Once};

Expand Down Expand Up @@ -72,10 +73,7 @@ pub fn max_data_size() -> usize {
let iterations = compute_iterations(&decoder) as usize;
max_data_size = max(max_data_size, iterations * decoder.byte_value_count());
} else {
panic!(
"get_decoder({:?},{:?},{:?}) failed.",
format, version, bpv
);
panic!("get_decoder({:?},{:?},{:?}) failed.", format, version, bpv);
}
}
let format = Format::PackedSingleBlock;
Expand All @@ -84,10 +82,7 @@ pub fn max_data_size() -> usize {
let iterations = compute_iterations(&decoder) as usize;
max_data_size = max(max_data_size, iterations * decoder.byte_value_count());
} else {
panic!(
"get_decoder({:?},{:?},{:?}) failed.",
format, version, bpv
);
panic!("get_decoder({:?},{:?},{:?}) failed.", format, version, bpv);
}
}
}
Expand Down Expand Up @@ -132,8 +127,10 @@ impl ForUtilInstance {
let format = Format::with_id(format_id);
encoded_sizes[bpv] = encoded_size(format, packed_ints_version, bits_per_value);
unsafe {
decoders.assume_init_mut()[bpv] = get_decoder(format, packed_ints_version, bits_per_value)?;
encoders.assume_init_mut()[bpv] = get_encoder(format, packed_ints_version, bits_per_value)?;
decoders.assume_init_mut()[bpv] =
get_decoder(format, packed_ints_version, bits_per_value)?;
encoders.assume_init_mut()[bpv] =
get_encoder(format, packed_ints_version, bits_per_value)?;
iterations[bpv] = compute_iterations(&decoders.assume_init_ref()[bpv]);
}
}
Expand Down Expand Up @@ -168,8 +165,10 @@ impl ForUtilInstance {
debug_assert!(bits_per_value <= 32);
encoded_sizes[bpv - 1] = encoded_size(format, VERSION_CURRENT, bits_per_value);
unsafe {
decoders.assume_init_mut()[bpv - 1] = get_decoder(format, VERSION_CURRENT, bits_per_value)?;
encoders.assume_init_mut()[bpv - 1] = get_encoder(format, VERSION_CURRENT, bits_per_value)?;
decoders.assume_init_mut()[bpv - 1] =
get_decoder(format, VERSION_CURRENT, bits_per_value)?;
encoders.assume_init_mut()[bpv - 1] =
get_encoder(format, VERSION_CURRENT, bits_per_value)?;
iterations[bpv - 1] = compute_iterations(&decoders.assume_init_ref()[bpv - 1]);
}

Expand Down Expand Up @@ -334,6 +333,12 @@ impl ForUtil {
self.instance.read_block_by_simd(input, decoder)
}

unsafe fn get_self(
ptr: &UnsafeCell<EliasFanoEncoder>
) -> &mut EliasFanoEncoder {
unsafe { &mut *ptr.get() }
}

pub fn read_other_encode_block(
doc_in: &mut dyn IndexInput,
ef_decoder: &mut Option<EliasFanoDecoder>,
Expand All @@ -346,9 +351,12 @@ impl ForUtil {
let upper_bound = doc_in.read_vlong()?;
if ef_decoder.is_some() {
let encoder = unsafe {
&mut *(ef_decoder.as_mut().unwrap().get_encoder().as_ref()
let s =
ef_decoder.as_mut().unwrap().get_encoder().as_ref()
as *const EliasFanoEncoder
as *mut EliasFanoEncoder)
as *mut EliasFanoEncoder as *const UnsafeCell<EliasFanoEncoder>;
let t = ForUtil::get_self(s.as_ref().unwrap());
&mut *t
};
encoder.rebuild_not_with_check(BLOCK_SIZE as i64, upper_bound)?;
encoder.deserialize2(doc_in)?;
Expand Down
12 changes: 10 additions & 2 deletions src/core/index/writer/bufferd_updates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use core::store::directory::Directory;
use core::store::IOContext;
use core::util::DocId;

use std::cell::UnsafeCell;
use std::cmp::{min, Ordering as CmpOrdering};
use std::collections::{BinaryHeap, HashMap};
use std::fmt;
Expand Down Expand Up @@ -362,6 +363,12 @@ impl<C: Codec> BufferedUpdatesStream<C> {
self.num_terms.load(Ordering::Acquire)
}

unsafe fn get_self(
ptr: &UnsafeCell<BufferedUpdatesStream<C>>,
) -> &mut BufferedUpdatesStream<C> {
unsafe { &mut *ptr.get() }
}

pub fn apply_deletes_and_updates<D, MS, MP>(
&self,
pool: &ReaderPool<D, C, MS, MP>,
Expand All @@ -374,8 +381,9 @@ impl<C: Codec> BufferedUpdatesStream<C> {
{
let _l = self.lock.lock().unwrap();
let updates_stream = unsafe {
let stream = self as *const BufferedUpdatesStream<C> as *mut BufferedUpdatesStream<C>;
&mut *stream
let stream = self as *const BufferedUpdatesStream<C> as *mut BufferedUpdatesStream<C> as *const UnsafeCell<BufferedUpdatesStream<C>>;
let s = BufferedUpdatesStream::get_self(stream.as_ref().unwrap());
&mut *s
};
let mut seg_states = Vec::with_capacity(infos.len());
let gen = self.next_gen.load(Ordering::Acquire);
Expand Down
14 changes: 12 additions & 2 deletions src/core/index/writer/doc_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use core::index::writer::{
use core::search::query::Query;
use core::store::directory::{Directory, LockValidatingDirectoryWrapper};
use core::util::external::Volatile;
use std::cell::UnsafeCell;
use error::{ErrorKind::AlreadyClosed, ErrorKind::IllegalState, Result};

use crossbeam::queue::SegQueue;
Expand Down Expand Up @@ -181,10 +182,19 @@ where
self.index_writer.upgrade().unwrap()
}

unsafe fn get_self(
ptr: &UnsafeCell<DocumentsWriter<D, C, MS, MP>>,
) -> &mut DocumentsWriter<D, C, MS, MP> {
unsafe { &mut *ptr.get() }
}

#[allow(clippy::mut_from_ref)]
unsafe fn doc_writer_mut(&self, _l: &MutexGuard<()>) -> &mut DocumentsWriter<D, C, MS, MP> {
let w = self as *const DocumentsWriter<D, C, MS, MP> as *mut DocumentsWriter<D, C, MS, MP>;
&mut *w
let w = self as *const DocumentsWriter<D, C, MS, MP> as *mut DocumentsWriter<D, C, MS, MP> as *const UnsafeCell<DocumentsWriter<D, C, MS, MP>>;
unsafe {
let s = DocumentsWriter::get_self(w.as_ref().unwrap());
&mut *s
}
}

pub fn set_delete_queue(&self, delete_queue: Arc<DocumentsWriterDeleteQueue<C>>) {
Expand Down
15 changes: 12 additions & 3 deletions src/core/index/writer/doc_writer_per_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use core::{
},
};

use std::collections::{HashMap, HashSet};
use std::{cell::UnsafeCell, collections::{HashMap, HashSet}};
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockWriteGuard, Weak};
use std::time::SystemTime;
Expand Down Expand Up @@ -814,13 +814,22 @@ where
}
}

unsafe fn get_self(
ptr: &UnsafeCell<ThreadState<D, C, MS, MP>>,
) -> &mut ThreadState<D, C, MS, MP> {
unsafe { &mut *ptr.get() }
}

#[allow(clippy::mut_from_ref)]
pub fn thread_state_mut(
&self,
_lock: &MutexGuard<ThreadStateLock>,
) -> &mut ThreadState<D, C, MS, MP> {
let state = self as *const ThreadState<D, C, MS, MP> as *mut ThreadState<D, C, MS, MP>;
unsafe { &mut *state }
let state = self as *const ThreadState<D, C, MS, MP> as *mut ThreadState<D, C, MS, MP> as *const UnsafeCell<ThreadState<D, C, MS, MP>>;
unsafe {
let s = ThreadState::get_self(state.as_ref().unwrap());
&mut *s
}
}

pub fn dwpt(&self) -> &DocumentsWriterPerThread<D, C, MS, MP> {
Expand Down

0 comments on commit d6769df

Please sign in to comment.