Skip to content

Commit

Permalink
update dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
2toinf committed Mar 19, 2024
1 parent 599860a commit 4050a12
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions DecisionNCE/datasets/EpicKitchen_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from tqdm import tqdm
import random
import torch.nn.functional as F
import mmengine.fileio as fileio

class EpicKitchen(Dataset):
def __init__(self, root,
meta_file = "assets/EpicKitchen-100-train.csv",
img_size = 224,
num_frames=2,
file_client_args=dict(backend='petrel')):
num_frames=2):
"""Define EpicKiten Dataset
Args:
root (str): path of images
Expand All @@ -36,7 +36,6 @@ def __init__(self, root,
self.root = root
self.img_size = img_size
self.num_frames = num_frames
self.client = mmengine.fileio.FileClient(**file_client_args)
self._create_transform()
self._check()

Expand Down Expand Up @@ -87,7 +86,7 @@ def _get_single_img(self, img_dict, cur_idx):
frame_name = f"frame_{_tmp}{cur_idx}.jpg"
img_path = osp.join(img_dict, frame_name)
try:
value = self.client.get(img_path)
value = fileio.get(img_path)
img_bytes = np.frombuffer(value, np.uint8)
buff = io.BytesIO(img_bytes)
with Image.open(buff) as img:
Expand Down Expand Up @@ -155,7 +154,7 @@ def EpicKitchenDataLoader(root,
import clip
from typing import Iterable
def train_one_epoch(model: torch.nn.Module,
loss_model: torch.nn.Module,
loss_model: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
Expand Down

0 comments on commit 4050a12

Please sign in to comment.