Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) compute the backward pass only once for all three slots in negative sampling #141

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,13 @@ negative_sampling:
# 'triple' (e.g., for TransE or RotatE in the current implementation).
implementation: triple

# Whether to compute the backward pass after scoring against each slot (S, P, O)
# or once after scoring against all three slots.
# Does not work with reciprocal relations.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to "false".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also say that if false, positive triples are scored only once (else once per slot).

# - True: slower, but needs less memory
# - False: faster, but needs more memory
backward_pass_per_slot: True

# Perform training in chunks of the specified size. When set, process each
# batch in chunks of at most this size. This reduces memory consumption but
# may increase runtime. Useful when there are many negative samples and/or
Expand Down
32 changes: 25 additions & 7 deletions kge/job/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ def __init__(self, config, dataset, parent_job=None, model=None):
"'{}' scoring function ...".format(self._implementation)
)
self.type_str = "negative_sampling"
self.backward_pass_per_slot = self.config.get("negative_sampling.backward_pass_per_slot")

if self.__class__ == TrainingJobNegativeSampling:
for f in Job.job_created_hooks:
Expand Down Expand Up @@ -851,6 +852,13 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult:
triples = batch_triples[chunk_indexes]

# process the chunk
if not self.backward_pass_per_slot:
positive_scores = self.model.score_spo(
triples[:, S],
triples[:, P],
triples[:, O],
)
loss_values_torch = []
for slot in [S, P, O]:
num_samples = self._sampler.num_samples[slot]
if num_samples <= 0:
Expand All @@ -872,12 +880,15 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult:
# compute the scores
forward_time -= time.time()
scores = torch.empty((chunk_size, num_samples + 1), device=self.device)
scores[:, 0] = self.model.score_spo(
triples[:, S],
triples[:, P],
triples[:, O],
direction=SLOT_STR[slot],
)
if not self.backward_pass_per_slot:
scores[:, 0] = positive_scores
else:
scores[:, 0] = self.model.score_spo(
triples[:, S],
triples[:, P],
triples[:, O],
direction=SLOT_STR[slot],
)
forward_time += time.time()
scores[:, 1:] = batch_negative_samples[slot].score(
self.model, indexes=chunk_indexes
Expand All @@ -891,12 +902,19 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult:
self.loss(scores, labels[slot], num_negatives=num_samples)
/ batch_size
)
if not self.backward_pass_per_slot:
loss_values_torch.append(loss_value_torch)
loss_value += loss_value_torch.item()
forward_time += time.time()

# backward pass for this chunk
if self.backward_pass_per_slot:
backward_time -= time.time()
loss_value_torch.backward()
backward_time += time.time()
if not self.backward_pass_per_slot:
backward_time -= time.time()
loss_value_torch.backward()
torch.autograd.backward(loss_values_torch)
backward_time += time.time()

# all done
Expand Down