diff --git a/cogdl/models/agc/daegc.py b/cogdl/models/agc/daegc.py index 1a6c4c6d..42e6f3ef 100644 --- a/cogdl/models/agc/daegc.py +++ b/cogdl/models/agc/daegc.py @@ -49,12 +49,8 @@ def __init__(self, num_features, hidden_size, embedding_size, num_heads, dropout self.embedding_size = embedding_size self.dropout = dropout self.num_clusters = num_clusters - self.att1 = GATLayer( - num_features, hidden_size, dropout=dropout, alpha=0.2, nhead=num_heads, concat=True, fast_mode=False - ) - self.att2 = GATLayer( - hidden_size * num_heads, embedding_size, dropout=dropout, alpha=0.2, nhead=1, concat=False, fast_mode=False - ) + self.att1 = GATLayer(num_features, hidden_size, dropout=dropout, alpha=0.2, nhead=num_heads, concat=True) + self.att2 = GATLayer(hidden_size * num_heads, embedding_size, dropout=dropout, alpha=0.2, nhead=1, concat=False) self.cluster_center = None def get_trainer(self, task, args): diff --git a/cogdl/models/nn/gat.py b/cogdl/models/nn/gat.py index 7f7a0a6a..b457b563 100644 --- a/cogdl/models/nn/gat.py +++ b/cogdl/models/nn/gat.py @@ -1,228 +1,202 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math - -from .. import BaseModel, register_model -from cogdl.utils import add_remaining_self_loops, mul_edge_softmax, spmm - - -class GATLayer(nn.Module): - """ - Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 - """ - - def __init__( - self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, residual=False, fast_mode=False - ): - super(GATLayer, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.alpha = alpha - self.concat = concat - self.nhead = nhead - self.fast_mode = fast_mode - - self.W = nn.Parameter(torch.FloatTensor(in_features, out_features * nhead)) - - self.a_l = nn.Parameter(torch.zeros(size=(1, nhead, out_features))) - self.a_r = nn.Parameter(torch.zeros(size=(1, nhead, out_features))) - - self.dropout = nn.Dropout(dropout) - self.leakyrelu = nn.LeakyReLU(self.alpha) - - if residual: - out_features = out_features * nhead if concat else out_features - self.residual = nn.Linear(in_features, out_features) - else: - self.register_buffer("residual", None) - self.reset_parameters() - - def reset_parameters(self): - def reset(tensor): - stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) - tensor.data.uniform_(-stdv, stdv) - - reset(self.a_l) - reset(self.a_r) - reset(self.W) - - # nn.init.xavier_uniform_(self.W.data, gain=1.414) - # nn.init.xavier_uniform_(self.a_r.data, gain=1.414) - # nn.init.xavier_uniform_(self.a_l.data, gain=1.414) - - def forward(self, graph, x): - h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features) - # h: N * H * d - h[torch.isnan(h)] = 0.0 - - edge_index = graph.edge_index - # Self-attention on the nodes - Shared attention mechanism - h_l = (self.a_l * h).sum(dim=-1)[edge_index[0, :]] - h_r = (self.a_r * h).sum(dim=-1)[edge_index[1, :]] - edge_attention = self.leakyrelu(h_l + h_r) - # edge_e: E * H - edge_attention = mul_edge_softmax(graph, edge_attention) - num_edges = graph.num_edges - num_nodes = graph.num_nodes - - with graph.local_graph(): - if self.fast_mode: - edge_attention = edge_attention.view(-1) - edge_attention = self.dropout(edge_attention) - - edge_index = edge_index.view(-1) - edge_index = edge_index.unsqueeze(0).repeat(self.nhead, 1) - add_num = torch.arange(0, self.nhead * num_nodes, num_nodes).view(-1, 1).to(edge_index.device) - edge_index = edge_index + add_num - edge_index = edge_index.split((num_edges, num_edges), dim=1) - - row, col = edge_index - row = row.reshape(-1) - col = col.reshape(-1) - edge_index = torch.stack([row, col]) - - graph.edge_index = edge_index - graph.edge_weight = edge_attention - h_prime = spmm(graph, h.permute(1, 0, 2).reshape(num_nodes * self.nhead, -1)) - assert not torch.isnan(h_prime).any() - h_prime = h_prime.split([num_nodes] * self.nhead) - else: - edge_attention = self.dropout(edge_attention) - h_prime = [] - h = h.permute(1, 0, 2).contiguous() - for i in range(self.nhead): - edge_weight = edge_attention[i] - graph.edge_weight = edge_weight - hidden = h[i] - assert not torch.isnan(hidden).any() - h_prime.append(spmm(graph, hidden)) - if self.residual: - res = self.residual(x) - else: - res = 0 - - if self.concat: - out = torch.cat(h_prime, dim=1) + res - else: - out = sum(h_prime) / self.nhead + res - return out - - def __repr__(self): - return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")" - - -@register_model("gat") -class GAT(BaseModel): - r"""The GAT model from the `"Graph Attention Networks" - `_ paper - - Args: - num_features (int) : Number of input features. - num_classes (int) : Number of classes. - hidden_size (int) : The dimension of node representation. - dropout (float) : Dropout rate for model training. - alpha (float) : Coefficient of leaky_relu. - nheads (int) : Number of attention heads. - """ - - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # fmt: off - parser.add_argument("--num-features", type=int) - parser.add_argument("--num-layers", type=int, default=2) - parser.add_argument("--residual", action="store_true") - parser.add_argument("--num-classes", type=int) - parser.add_argument("--hidden-size", type=int, default=8) - parser.add_argument("--dropout", type=float, default=0.6) - parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--nhead", type=int, default=8) - parser.add_argument("--last-nhead", type=int, default=1) - parser.add_argument("--fast-mode", action="store_true", default=False) - # fmt: on - - @classmethod - def build_model_from_args(cls, args): - return cls( - args.num_features, - args.hidden_size, - args.num_classes, - args.num_layers, - args.dropout, - args.alpha, - args.nhead, - args.residual, - args.last_nhead, - args.fast_mode, - ) - - def __init__( - self, - in_feats, - hidden_size, - out_features, - num_layers, - dropout, - alpha, - nhead, - residual, - last_nhead, - fast_mode=False, - ): - """Sparse version of GAT.""" - super(GAT, self).__init__() - self.dropout = dropout - self.attentions = nn.ModuleList() - self.attentions.append( - GATLayer( - in_feats, - hidden_size, - nhead=nhead, - dropout=dropout, - alpha=alpha, - concat=True, - residual=residual, - fast_mode=fast_mode, - ) - ) - for i in range(num_layers - 2): - self.attentions.append( - GATLayer( - hidden_size * nhead, - hidden_size, - nhead=nhead, - dropout=dropout, - alpha=alpha, - concat=True, - residual=residual, - fast_mode=fast_mode, - ) - ) - self.attentions.append( - GATLayer( - hidden_size * nhead, - out_features, - dropout=dropout, - alpha=alpha, - concat=False, - nhead=last_nhead, - residual=False, - fast_mode=fast_mode, - ) - ) - self.num_layers = num_layers - self.last_nhead = last_nhead - self.residual = residual - - def forward(self, graph): - x = graph.x - for i, layer in enumerate(self.attentions): - x = F.dropout(x, p=self.dropout, training=self.training) - x = layer(graph, x) - if i != self.num_layers - 1: - x = F.elu(x) - return x - - def predict(self, graph): - return self.forward(graph) +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from .. import BaseModel, register_model +from cogdl.utils import mul_edge_softmax, spmm, mh_spmm +from cogdl.operators.mhspmm import csrmhspmm + + +class GATLayer(nn.Module): + """ + Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 + """ + + def __init__(self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, residual=False): + super(GATLayer, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.alpha = alpha + self.concat = concat + self.nhead = nhead + + self.W = nn.Parameter(torch.FloatTensor(in_features, out_features * nhead)) + + self.a_l = nn.Parameter(torch.zeros(size=(1, nhead, out_features))) + self.a_r = nn.Parameter(torch.zeros(size=(1, nhead, out_features))) + + self.dropout = nn.Dropout(dropout) + self.leakyrelu = nn.LeakyReLU(self.alpha) + + if residual: + out_features = out_features * nhead if concat else out_features + self.residual = nn.Linear(in_features, out_features) + else: + self.register_buffer("residual", None) + self.reset_parameters() + + def reset_parameters(self): + def reset(tensor): + stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) + tensor.data.uniform_(-stdv, stdv) + + reset(self.a_l) + reset(self.a_r) + reset(self.W) + + # nn.init.xavier_uniform_(self.W.data, gain=1.414) + # nn.init.xavier_uniform_(self.a_r.data, gain=1.414) + # nn.init.xavier_uniform_(self.a_l.data, gain=1.414) + + def forward(self, graph, x): + h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features) + h[torch.isnan(h)] = 0.0 + + edge_index = graph.edge_index + # Self-attention on the nodes - Shared attention mechanism + h_l = (self.a_l * h).sum(dim=-1)[edge_index[0, :]] + h_r = (self.a_r * h).sum(dim=-1)[edge_index[1, :]] + edge_attention = self.leakyrelu(h_l + h_r) + # edge_e: E * H + edge_attention = mul_edge_softmax(graph, edge_attention) + edge_attention = self.dropout(edge_attention) + + if csrmhspmm is not None: + if self.nhead > 1: + h_prime = mh_spmm(graph, edge_attention, h) + out = h_prime.view(h_prime.shape[0], -1) + else: + edge_weight = edge_attention[0] + with graph.local_graph(): + graph.edge_weight = edge_weight + out = spmm(graph, h.squeeze(1)) + else: + with graph.local_graph(): + h_prime = [] + h = h.permute(1, 0, 2).contiguous() + for i in range(self.nhead): + edge_weight = edge_attention[i] + graph.edge_weight = edge_weight + hidden = h[i] + assert not torch.isnan(hidden).any() + h_prime.append(spmm(graph, hidden)) + out = torch.cat(h_prime, dim=1) + + if self.residual: + res = self.residual(x) + out += res + return out + + def __repr__(self): + return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")" + + +@register_model("gat") +class GAT(BaseModel): + r"""The GAT model from the `"Graph Attention Networks" + `_ paper + + Args: + num_features (int) : Number of input features. + num_classes (int) : Number of classes. + hidden_size (int) : The dimension of node representation. + dropout (float) : Dropout rate for model training. + alpha (float) : Coefficient of leaky_relu. + nheads (int) : Number of attention heads. + """ + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--num-features", type=int) + parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--residual", action="store_true") + parser.add_argument("--num-classes", type=int) + parser.add_argument("--hidden-size", type=int, default=8) + parser.add_argument("--dropout", type=float, default=0.6) + parser.add_argument("--alpha", type=float, default=0.2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--last-nhead", type=int, default=1) + # fmt: on + + @classmethod + def build_model_from_args(cls, args): + return cls( + args.num_features, + args.hidden_size, + args.num_classes, + args.num_layers, + args.dropout, + args.alpha, + args.nhead, + args.residual, + args.last_nhead, + ) + + def __init__( + self, + in_feats, + hidden_size, + out_features, + num_layers, + dropout, + alpha, + nhead, + residual, + last_nhead, + ): + """Sparse version of GAT.""" + super(GAT, self).__init__() + self.dropout = dropout + self.attentions = nn.ModuleList() + self.attentions.append( + GATLayer( + in_feats, + hidden_size, + nhead=nhead, + dropout=dropout, + alpha=alpha, + concat=True, + residual=residual, + ) + ) + for i in range(num_layers - 2): + self.attentions.append( + GATLayer( + hidden_size * nhead, + hidden_size, + nhead=nhead, + dropout=dropout, + alpha=alpha, + concat=True, + residual=residual, + ) + ) + self.attentions.append( + GATLayer( + hidden_size * nhead, + out_features, + dropout=dropout, + alpha=alpha, + concat=False, + nhead=last_nhead, + residual=False, + ) + ) + self.num_layers = num_layers + self.last_nhead = last_nhead + self.residual = residual + + def forward(self, graph): + x = graph.x + for i, layer in enumerate(self.attentions): + x = F.dropout(x, p=self.dropout, training=self.training) + x = layer(graph, x) + if i != self.num_layers - 1: + x = F.elu(x) + return x + + def predict(self, graph): + return self.forward(graph) diff --git a/cogdl/oag/README.md b/cogdl/oag/README.md index a6131f5a..bd2332a0 100644 --- a/cogdl/oag/README.md +++ b/cogdl/oag/README.md @@ -45,6 +45,22 @@ sequence_output, pooled_output = model.bert.forward( position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0) ) ``` +If you want to encode various type of entities separately, you can use the following code instead +```python +from cogdl import oagbert + +tokenizer, model = oagbert("oagbert-v2") +title = 'BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding' +abstract = 'We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation...' +authors = ['Jacob Devlin', 'Ming-Wei Chang', 'Kenton Lee', 'Kristina Toutanova'] +venue = 'north american chapter of the association for computational linguistics' +affiliations = ['Google'] +concepts = ['language model', 'natural language inference', 'question answering'] +# encode paper +paper_info = model.encode_paper( + title=title, abstract=abstract, venue=venue, authors=authors, concepts=concepts, affiliations=affiliations, reduction="max" +) +``` You can also use some integrated functions to use OAG-BERT v2 directly, such as using `decode_beamsearch` to generate entities based on existing context. For example, to generate concepts with 2 tokens for the BERT paper, run the following code ```python model.eval() diff --git a/cogdl/oag/oagbert_metainfo.py b/cogdl/oag/oagbert_metainfo.py index 25b4be59..8f8d9d9b 100644 --- a/cogdl/oag/oagbert_metainfo.py +++ b/cogdl/oag/oagbert_metainfo.py @@ -14,8 +14,9 @@ def __init__(self, bert_config, tokenizer): self.tokenizer = tokenizer self.spm = not isinstance(self.tokenizer, BertTokenizer) if self.spm: - self.tokenizer.cls_token_id, self.tokenizer.mask_token_id, self.tokenizer.sep_token_id = self.tokenizer.PieceToId([ - '[CLS]', '[MASK]', '[SEP]']) + self.tokenizer.cls_token_id, self.tokenizer.mask_token_id, self.tokenizer.sep_token_id = self.tokenizer.PieceToId( + [ + '[CLS]', '[MASK]', '[SEP]']) def __recursively_build_spm_token_ids(self, text, splitters=[]): """ @@ -41,7 +42,8 @@ def __recursively_build_spm_token_ids(self, text, splitters=[]): def _convert_text_to_token_ids(self, text): if self.spm: - return self.__recursively_build_spm_token_ids(text, splitters=['[PAD]', '[EOS]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '[BOS]']) + return self.__recursively_build_spm_token_ids(text, splitters=['[PAD]', '[EOS]', '[UNK]', '[CLS]', '[SEP]', + '[MASK]', '[BOS]']) else: return self.tokenizer(text, add_special_tokens=False)["input_ids"] if len(text) > 0 else [] @@ -68,14 +70,14 @@ def _convert_token_ids_to_text(self, token_ids): return self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(token_ids)) def print_oag_instance( - self, - input_ids, - token_type_ids, - input_masks, - masked_lm_labels, - position_ids, - position_ids_second, - predictions=None, + self, + input_ids, + token_type_ids, + input_masks, + masked_lm_labels, + position_ids, + position_ids_second, + predictions=None, ): COLORS = ["white", "green", "blue", "red", "yellow", "magenta"] try: @@ -111,7 +113,7 @@ def print_oag_instance( prediction_topk_strs = [[""] for _ in range(K)] current_length = 0 for pos, (input_token, position_id, position_id_second, token_type_id, mask) in enumerate( - zip(input_tokens, position_ids, position_ids_second, token_type_ids, masks_tokens) + zip(input_tokens, position_ids, position_ids_second, token_type_ids, masks_tokens) ): token_type = OAG_TOKEN_TYPE_NAMES[token_type_id] length = max( @@ -136,7 +138,8 @@ def print_oag_instance( position_ids_second_str[-1] += stringRjustCJK(str(position_id_second), length) token_type_ids_str[-1] += stringRjustCJK(token_type, length) masks_str[-1] += colored(stringRjustCJK(mask, length) if mask - != "[PAD]" else stringRjustCJK("", length), COLORS[token_type_id]) + != "[PAD]" else stringRjustCJK("", length), + COLORS[token_type_id]) for k in range(K): v = prediction_tokens[k][pos] if prediction_tokens[k][pos] != "[PAD]" else "" prediction_topk_strs[k][-1] += colored( @@ -159,17 +162,17 @@ def print_oag_instance( print("-" * termwidth) def build_inputs( - self, - title="", - abstract="", - venue="", - authors=[], - concepts=[], - affiliations=[], - decode_span_type="FOS", - decode_span_length=0, - max_seq_length=512, - mask_propmt_text="", + self, + title="", + abstract="", + venue="", + authors=[], + concepts=[], + affiliations=[], + decode_span_type="FOS", + decode_span_length=0, + max_seq_length=512, + mask_propmt_text="", ): """build inputs from text information for model to use @@ -224,7 +227,7 @@ def add_span(token_type_id, token_ids, is_mask=False): add_span( 0, (self._convert_text_to_token_ids(title) + self._convert_text_to_token_ids(abstract) + prompt_token_ids)[ - : max_seq_length - decode_span_length + : max_seq_length - decode_span_length ], ) add_span(2, self._convert_text_to_token_ids(venue)[: max_seq_length - len(input_ids) - decode_span_length]) @@ -251,21 +254,140 @@ def add_span(token_type_id, token_ids, is_mask=False): num_spans, ) + def encode_paper( + self, + title="", + abstract="", + venue="", + authors=[], + concepts=[], + affiliations=[], + decode_span_type="FOS", + decode_span_length=0, + max_seq_length=512, + mask_propmt_text="", + reduction="first", + ): + """encode paper from text information and run forward to get sequence and pool output for each entity + + Args: + title (str, optional): [paper title]. Defaults to ''. + abstract (str, optional): [paper abstract]. Defaults to ''. + venue (str, optional): [paper venue]. Defaults to ''. + authors (list, optional): [paper author]. Defaults to []. + concepts (list, optional): [paper concepts]. Defaults to []. + affiliations (list, optional): [paper affiliations]. Defaults to []. + decode_span_type (str, optional): [the span type to decode, choose from 'FOS','VENUE','AFF','AUTHOR']. Defaults to 'FOS'. + decode_span_length (int, optional): [the length of span to decode]. Defaults to 0. + max_seq_length (int, optional): [maximum sequence length for the input, the context information will be truncated if the total length exceeds this number]. Defaults to 512. + mask_propmt_text (str, optional): [the prompt text to add after title and abstract]. Defaults to ''. + reduction (str, optional): [the way to get pooled_output, choose from 'cls','max','mean']. Defaults to 'cls'. + + Raises: + Exception: [provided inputs are invalid] + + Returns: + [dictionary of list of dictionary]: { + 'text': text_item, + 'venue': venue_item, + 'authors': [authors_item, authors_item, ...] + 'concepts': [concepts_item, concepts_item, ...] + 'affiliations': [affiliations_item, affiliations_item, ...] + } + """ + input_ids, input_masks, token_type_ids, masked_lm_labels, position_ids, position_ids_second, masked_positions, num_spans = self.build_inputs( + title=title, abstract=abstract, venue=venue, authors=authors, concepts=concepts, affiliations=affiliations, + decode_span_type=decode_span_type, decode_span_length=decode_span_length, max_seq_length=max_seq_length, + mask_propmt_text=mask_propmt_text + ) + + search = { + 'text': [title + abstract], + 'venue': [venue], + 'authors': authors, + 'concepts': concepts, + 'affiliations': affiliations + } + + item = { + 'originalText': "", + 'inputText': "", + 'type': "", + 'tokens': [], + 'token_ids': [], + 'sequence_output': [], + 'pooled_output': [] + } + + output = { + 'text': [], + 'venue': [], + 'authors': [], + 'concepts': [], + 'affiliations': [] + } + + split_index = { + 'text': [], + 'venue': [], + 'authors': [], + 'concepts': [], + 'affiliations': [] + } + + sequence_output, pooled_output = self.bert.forward( + input_ids=torch.LongTensor(input_ids).unsqueeze(0), + token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0), + attention_mask=torch.LongTensor(input_masks).unsqueeze(0), + output_all_encoded_layers=False, + checkpoint_activations=False, + position_ids=torch.LongTensor(position_ids).unsqueeze(0), + position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0), + ) + + entities = {0: 'text', 2: 'venue', 1: 'authors', 4: 'concepts', 3: 'affiliations'} + for num, name in entities.items(): + if num in token_type_ids: + start_index = position_ids[token_type_ids.index(num)] + split_index[name].append(position_ids.index(start_index) - 1) + for i in range(0, len(search[name])): + split_index[name].append( + len(position_ids) - 1 - list(reversed(position_ids)).index(start_index + i)) + item = item.copy() + item['type'] = name.upper() + item['originalText'] = search[name][i] + item['token_ids'] = input_ids[ + split_index[name][i] + 1:split_index[name][i + 1] + 1] + item['tokens'] = self._convert_ids_to_tokens(item['token_ids']) + item['inputText'] = self._convert_token_ids_to_text(item['token_ids']) + item['sequence_output'] = sequence_output[:, + split_index[name][i] + 1:split_index[name][i + 1] + 1, + :].squeeze(0) + if reduction == "mean": + item['pooled_output'] = item['sequence_output'].mean(dim=0, keepdim=False) + elif reduction == "max": + item['pooled_output'], _ = item['sequence_output'].max(dim=0) + else: + item['pooled_output'] = pooled_output + output[name].append(item) + + return output + def calculate_span_prob( - self, - title="", - abstract="", - venue="", - authors=[], - concepts=[], - affiliations=[], - decode_span_type="FOS", - decode_span="", - force_forward=False, - max_seq_length=512, - mask_propmt_text="", - device=None, - debug=False, + self, + title="", + abstract="", + venue="", + authors=[], + concepts=[], + affiliations=[], + decode_span_type="FOS", + decode_span="", + force_forward=False, + max_seq_length=512, + mask_propmt_text="", + device=None, + debug=False, ): """calculate span probability by greedy algorithm @@ -361,21 +483,21 @@ def tensorize(x): return np.exp(logprobs), logproblist def decode_beamsearch( - self, - title="", - abstract="", - venue="", - authors=[], - concepts=[], - affiliations=[], - decode_span_type="", - decode_span_length=0, - beam_width=16, - force_forward=False, - max_seq_length=512, - mask_propmt_text="", - device=None, - debug=False, + self, + title="", + abstract="", + venue="", + authors=[], + concepts=[], + affiliations=[], + decode_span_type="", + decode_span_length=0, + beam_width=16, + force_forward=False, + max_seq_length=512, + mask_propmt_text="", + device=None, + debug=False, ): """decode span by using beamsearch @@ -492,20 +614,20 @@ def tensorize(x): return results def generate_title( # noqa C901 - self, - abstract="", - authors=[], - venue="", - affiliations=[], - concepts=[], - num_beams=1, - no_repeat_ngram_size=3, - num_return_sequences=1, - min_length=10, - max_length=30, - device=None, - early_stopping=False, - debug=False, + self, + abstract="", + authors=[], + venue="", + affiliations=[], + concepts=[], + num_beams=1, + no_repeat_ngram_size=3, + num_return_sequences=1, + min_length=10, + max_length=30, + device=None, + early_stopping=False, + debug=False, ): """generate paper titles given other information @@ -588,12 +710,12 @@ def tensorize(x): batch_attention_mask = torch.ones((current_total_length, current_total_length)) batch_attention_mask[ - decode_pos - current_entity_length + 1: decode_pos + 1, - decode_pos - current_entity_length + 1: decode_pos + 1, + decode_pos - current_entity_length + 1: decode_pos + 1, + decode_pos - current_entity_length + 1: decode_pos + 1, ] = torch.tril( batch_attention_mask[ - decode_pos - current_entity_length + 1: decode_pos + 1, - decode_pos - current_entity_length + 1: decode_pos + 1, + decode_pos - current_entity_length + 1: decode_pos + 1, + decode_pos - current_entity_length + 1: decode_pos + 1, ] ) batch_attention_mask = batch_attention_mask.unsqueeze(0).repeat(len(q), 1, 1).to(device or "cpu") diff --git a/cogdl/operators/edge_softmax.py b/cogdl/operators/edge_softmax.py index c214629a..26825ccf 100644 --- a/cogdl/operators/edge_softmax.py +++ b/cogdl/operators/edge_softmax.py @@ -7,18 +7,24 @@ # SPMM if not torch.cuda.is_available(): edge_softmax = None + csr_edge_softmax = None else: try: edge_softmax = load( name="edge_softmax", - sources=[os.path.join(path, "edge_softmax/edge_softmax.cc"), os.path.join(path, "edge_softmax/edge_softmax.cu")], + sources=[ + os.path.join(path, "edge_softmax/edge_softmax.cc"), + os.path.join(path, "edge_softmax/edge_softmax.cu"), + ], verbose=False, ) + def csr_edge_softmax(rowptr, h): return EdgeSoftmaxFunction.apply(rowptr, h) except Exception: edge_softmax = None + csr_edge_softmax = None class EdgeSoftmaxFunction(torch.autograd.Function): @@ -32,5 +38,5 @@ def forward(ctx, rowptr, h): def backward(ctx, grad_out): rowptr, out = ctx.backward_csc grad_out = grad_out.contiguous() - grad_softmax = edge_softmax.edge_softmax_backward(rowptr, out, grad_out) + grad_softmax = edge_softmax.edge_softmax_backward(rowptr, out, grad_out) return None, grad_softmax diff --git a/cogdl/operators/mhspmm.py b/cogdl/operators/mhspmm.py new file mode 100644 index 00000000..677dbeb0 --- /dev/null +++ b/cogdl/operators/mhspmm.py @@ -0,0 +1,62 @@ +import os +import torch +from torch.utils.cpp_extension import load + +path = os.path.join(os.path.dirname(__file__)) + +# SPMM +if not torch.cuda.is_available(): + spmm = None + mhspmm = None + csrmhspmm = None +else: + try: + mhspmm = load( + name="mhspmm", + sources=[os.path.join(path, "spmm/multiheadSpmm.cpp"), os.path.join(path, "spmm/multiheadSpmm.cu")], + verbose=False, + ) + mhsddmm = load( + name="mhsddmm", + sources=[os.path.join(path, "spmm/multiheadSddmm.cpp"), os.path.join(path, "spmm/multiheadSddmm.cu")], + verbose=False, + ) + mhtranspose = load( + name="mhtranspose", + sources=[os.path.join(path, "spmm/mhTranspose.cpp"), os.path.join(path, "spmm/mhTranspose.cu")], + verbose=False, + ) + + spmm = load( + name="spmm", + sources=[os.path.join(path, "spmm/spmm.cpp"), os.path.join(path, "spmm/spmm_kernel.cu")], + verbose=False, + ) + + def csrmhspmm(rowptr, colind, feat, attention): + return MHSPMMFunction.apply(rowptr, colind, feat, attention) + + except Exception: + mhspmm = None + csrmhspmm = None + + +class MHSPMMFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, rowptr, colind, feat, attention): + out = mhspmm.mhspmm(rowptr, colind, attention, feat) + ctx.backward_csc = (rowptr, colind, feat, attention) + return out + + @staticmethod + def backward(ctx, grad_out): + rowptr, colind, feat, attention = ctx.backward_csc + grad_out = grad_out.contiguous() + numlist = torch.arange(colind.size(0), device=grad_out.device, dtype=torch.int32) + # colptr, rowind, permute = mhtranspose.csr2csc(rowptr, colind, numlist) + colptr, rowind, permute = spmm.csr2csc(rowptr, colind, numlist.float()) + permute = permute.int() + attention_csc = mhtranspose.mhtranspose(permute, attention) + grad_feat = mhspmm.mhspmm(colptr, rowind, attention_csc, grad_out) + grad_edge_weight = mhsddmm.mhsddmm(rowptr, colind, grad_out, feat) + return None, None, grad_feat, grad_edge_weight diff --git a/cogdl/operators/spmm/computeUtil.h b/cogdl/operators/spmm/computeUtil.h index 0d2309e3..634e8c52 100644 --- a/cogdl/operators/spmm/computeUtil.h +++ b/cogdl/operators/spmm/computeUtil.h @@ -1,136 +1,168 @@ -#ifndef computeUtil_H -#define computeUtil_H -#include -#include -#include "device_launch_parameters.h" -#include "device_atomic_functions.h" - -#define CEIL(x, y) (((x) + (y)-1) / (y)) - -#define MIN(a, b) ((a < b) ? a : b) -#define MAX(a, b) ((a < b) ? b : a) - -__device__ __forceinline__ int findRow(const int *S_csrRowPtr, int eid, int start, int end) -{ - int low = start, high = end; - if (low == high) - return low; - while (low < high) - { - int mid = (low + high) >> 1; - if (S_csrRowPtr[mid] <= eid) - low = mid + 1; - else - high = mid; - } - if (S_csrRowPtr[high] == eid) - return high; - else - return high - 1; -} - -template -__device__ __forceinline__ void Load(ldType &tmp, data *array, int offset) -{ - tmp = *(reinterpret_cast(array + offset)); -} - -template -__device__ __forceinline__ void Load(data *lhd, data *rhd, int offset) -{ - *(reinterpret_cast(lhd)) = *(reinterpret_cast(rhd + offset)); -} - -template -__device__ __forceinline__ void Store(data *lhd, data *rhd, int offset) -{ - *(reinterpret_cast(lhd + offset)) = *(reinterpret_cast(rhd)); -} - -template -__device__ __forceinline__ void Load4(ldType *tmp, data *array, int *offset, int offset2 = 0) -{ - Load(tmp[0], array, offset[0] + offset2); - Load(tmp[1], array, offset[1] + offset2); - Load(tmp[2], array, offset[2] + offset2); - Load(tmp[3], array, offset[3] + offset2); -} - -template -__device__ __forceinline__ data vecDot2(vecData &lhd, vecData &rhd) -{ - return lhd.x * rhd.x + lhd.y * rhd.y; -} - -template -__device__ __forceinline__ data vecDot4(vecData &lhd, vecData &rhd) -{ - return lhd.x * rhd.x + lhd.y * rhd.y + lhd.z * rhd.z + lhd.w * rhd.w; -} - -template -__device__ __forceinline__ void vec4Dot4(data *cal, vecData *lhd, vecData *rhd) -{ - cal[0] += vecDot4(lhd[0], rhd[0]); - cal[1] += vecDot4(lhd[1], rhd[1]); - cal[2] += vecDot4(lhd[2], rhd[2]); - cal[3] += vecDot4(lhd[3], rhd[3]); -} - -template -__device__ __forceinline__ void vec2Dot4(data *cal, vecData *lhd, vecData *rhd) -{ - cal[0] += vecDot2(lhd[0], rhd[0]); - cal[1] += vecDot2(lhd[1], rhd[1]); - cal[2] += vecDot2(lhd[2], rhd[2]); - cal[3] += vecDot2(lhd[3], rhd[3]); -} - -template -__device__ __forceinline__ void Dot4(data *cal, data *lhd, data *rhd) -{ - cal[0] += lhd[0] * rhd[0]; - cal[1] += lhd[1] * rhd[1]; - cal[2] += lhd[2] * rhd[2]; - cal[3] += lhd[3] * rhd[3]; -} - -template -__device__ __forceinline__ void selfMul4(data *lhd, data *rhd) -{ - lhd[0] *= rhd[0]; - lhd[1] *= rhd[1]; - lhd[2] *= rhd[2]; - lhd[3] *= rhd[3]; -} - -template -__device__ __forceinline__ void selfMulConst4(data *lhd, data Const) -{ - lhd[0] *= Const; - lhd[1] *= Const; - lhd[2] *= Const; - lhd[3] *= Const; -} - -template -__device__ __forceinline__ void AllReduce4(data *multi, int stride, int warpSize) -{ - for (; stride > 0; stride >>= 1) - { - multi[0] += __shfl_xor_sync(0xffffffff, multi[0], stride, warpSize); - multi[1] += __shfl_xor_sync(0xffffffff, multi[1], stride, warpSize); - multi[2] += __shfl_xor_sync(0xffffffff, multi[2], stride, warpSize); - multi[3] += __shfl_xor_sync(0xffffffff, multi[3], stride, warpSize); - } -} - -template -__device__ __forceinline__ void AllReduce(data multi, int stride, int warpSize) -{ - for(; stride > 0; stride >>= 1) - { - multi += shlf_xor_sync(0xffffffff, multi, stride, warpSize); - } -} -#endif computeUtil_H \ No newline at end of file +#ifndef computeUtil_H +#define computeUtil_H +#include +#include +#include "device_launch_parameters.h" +#include "device_atomic_functions.h" + +#define CEIL(x, y) (((x) + (y)-1) / (y)) + +#define MIN(a, b) ((a < b) ? a : b) +#define MAX(a, b) ((a < b) ? b : a) + +#define checkCudaError( a ) do { \ + if (cudaSuccess != (a)) { \ + fprintf(stderr, "Cuda runTime error in line %d of file %s \ + : %s \n", __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError()) ); \ + exit(EXIT_FAILURE); \ + } \ +} while(0) + +#define checkCuSparseError( a ) do { \ + if (CUSPARSE_STATUS_SUCCESS != (a)) { \ + fprintf(stderr, "CuSparse runTime error in line %d of file %s \ + : %s \n", __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError()) ); \ + exit(EXIT_FAILURE); \ + } \ +} while (0) +__device__ __forceinline__ float sum_reduce(float acc, float x) { + return acc + x; +} + +__device__ __forceinline__ float sum_init() { + return 0; +} + +__device__ __forceinline__ int findRow(const int *S_csrRowPtr, int eid, int start, int end) +{ + int low = start, high = end; + if (low == high) + return low; + while (low < high) + { + int mid = (low + high) >> 1; + if (S_csrRowPtr[mid] <= eid) + low = mid + 1; + else + high = mid; + } + if (S_csrRowPtr[high] == eid) + return high; + else + return high - 1; +} + +template +__device__ __forceinline__ void Load(ldType &tmp, data *array, int offset) +{ + tmp = *(reinterpret_cast(array + offset)); +} + +template +__device__ __forceinline__ void Load(data *lhd, data *rhd, int offset) +{ + *(reinterpret_cast(lhd)) = *(reinterpret_cast(rhd + offset)); +} + +template +__device__ __forceinline__ void Store(data *lhd, data *rhd, int offset) +{ + *(reinterpret_cast(lhd + offset)) = *(reinterpret_cast(rhd)); +} + +template +__device__ __forceinline__ void Load4(ldType *tmp, data *array, int *offset, int offset2 = 0) +{ + Load(tmp[0], array, offset[0] + offset2); + Load(tmp[1], array, offset[1] + offset2); + Load(tmp[2], array, offset[2] + offset2); + Load(tmp[3], array, offset[3] + offset2); +} + +template +__device__ __forceinline__ data vecDot2(vecData &lhd, vecData &rhd) +{ + return lhd.x * rhd.x + lhd.y * rhd.y; +} + +template +__device__ __forceinline__ data vecDot4(vecData &lhd, vecData &rhd) +{ + return lhd.x * rhd.x + lhd.y * rhd.y + lhd.z * rhd.z + lhd.w * rhd.w; +} + +template +__device__ __forceinline__ void vec4Dot4(data *cal, vecData *lhd, vecData *rhd) +{ + cal[0] += vecDot4(lhd[0], rhd[0]); + cal[1] += vecDot4(lhd[1], rhd[1]); + cal[2] += vecDot4(lhd[2], rhd[2]); + cal[3] += vecDot4(lhd[3], rhd[3]); +} + +template +__device__ __forceinline__ void vec2Dot4(data *cal, vecData *lhd, vecData *rhd) +{ + cal[0] += vecDot2(lhd[0], rhd[0]); + cal[1] += vecDot2(lhd[1], rhd[1]); + cal[2] += vecDot2(lhd[2], rhd[2]); + cal[3] += vecDot2(lhd[3], rhd[3]); +} + +template +__device__ __forceinline__ void Dot4(data *cal, data *lhd, data *rhd) +{ + cal[0] += lhd[0] * rhd[0]; + cal[1] += lhd[1] * rhd[1]; + cal[2] += lhd[2] * rhd[2]; + cal[3] += lhd[3] * rhd[3]; +} + +template +__device__ __forceinline__ void selfMul4(data *lhd, data *rhd) +{ + lhd[0] *= rhd[0]; + lhd[1] *= rhd[1]; + lhd[2] *= rhd[2]; + lhd[3] *= rhd[3]; +} + +template +__device__ __forceinline__ void selfMulConst4(data *lhd, data Const) +{ + lhd[0] *= Const; + lhd[1] *= Const; + lhd[2] *= Const; + lhd[3] *= Const; +} + +template +__device__ __forceinline__ void selfAddConst4(data *lhd, data Const) +{ + lhd[0] += Const; + lhd[1] += Const; + lhd[2] += Const; + lhd[3] += Const; +} + +template +__device__ __forceinline__ void AllReduce4(data *multi, int stride, int warpSize) +{ + for (; stride > 0; stride >>= 1) + { + multi[0] += __shfl_xor_sync(0xffffffff, multi[0], stride, warpSize); + multi[1] += __shfl_xor_sync(0xffffffff, multi[1], stride, warpSize); + multi[2] += __shfl_xor_sync(0xffffffff, multi[2], stride, warpSize); + multi[3] += __shfl_xor_sync(0xffffffff, multi[3], stride, warpSize); + } +} + +template +__device__ __forceinline__ void AllReduce(data multi, int stride, int warpSize) +{ + for(; stride > 0; stride >>= 1) + { + multi += shlf_xor_sync(0xffffffff, multi, stride, warpSize); + } +} +#endif //computeUtil_H \ No newline at end of file diff --git a/cogdl/operators/spmm/mhTranspose.cpp b/cogdl/operators/spmm/mhTranspose.cpp new file mode 100644 index 00000000..4c301479 --- /dev/null +++ b/cogdl/operators/spmm/mhTranspose.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include + +torch::Tensor mhtranspose_cuda( + torch::Tensor permute, + torch::Tensor attention // E * H +); + +torch::Tensor mhtranspose( + torch::Tensor permute, + torch::Tensor attention) +{ + assert(permute.device().type() == torch::kCUDA); + assert(attention.device().type() == torch::kCUDA); + assert(permute.is_contiguous()); + assert(attention.is_contiguous()); + assert(permute.dtype() == torch::kInt32); + assert(attention.dtype() == torch::kFloat32); + return mhtranspose_cuda(permute, attention); +} + +std::vector csr2csc_cuda( + torch::Tensor csrRowPtr, + torch::Tensor csrColInd, + torch::Tensor csrVal); + +std::vector csr2csc( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor csr_data) +{ + assert(rowptr.device().type() == torch::kCUDA); + assert(colind.device().type() == torch::kCUDA); + assert(csr_data.device().type() == torch::kCUDA); + assert(rowptr.is_contiguous()); + assert(colind.is_contiguous()); + assert(csr_data.is_contiguous()); + assert(rowptr.dtype() == torch::kInt32); + assert(colind.dtype() == torch::kInt32); + assert(csr_data.dtype() == torch::kInt32); + return csr2csc_cuda(rowptr, colind, csr_data); +} + +PYBIND11_MODULE(mhtranspose, m) +{ + m.doc() = "mhtranspose in CSR format. "; + m.def("mhtranspose", &mhtranspose, "CSR mhsddmm"); + m.def("csr2csc", &csr2csc, "csr2csc"); +} \ No newline at end of file diff --git a/cogdl/operators/spmm/mhTranspose.cu b/cogdl/operators/spmm/mhTranspose.cu new file mode 100644 index 00000000..adfc30b1 --- /dev/null +++ b/cogdl/operators/spmm/mhTranspose.cu @@ -0,0 +1,110 @@ +#include +#include +#include +#include "computeUtil.h" + +__global__ void mhtranspose(const int nnz, const int h, const int * permute, float * attention, float * out) +{ + int hid = blockIdx.y; + int nid = blockIdx.x * 32 + threadIdx.x; + if(nid < nnz) + { + int idx = permute[nid]; + out[nid * h + hid] = attention[idx * h + hid]; + } +} + +__global__ void mhtranspose4(const int nnz, const int h, int * permute, float * attention, float * out) +{ + int hid = threadIdx.y << 2; + int nid = blockIdx.x * 32 + threadIdx.x; + if(nid < nnz) + { + int idx = permute[nid]; + float att[4]; + Load(att, attention, idx * h + hid); + Store(out, att, nid * h + hid); + } +} + +torch::Tensor mhtranspose_cuda( + torch::Tensor permute, + torch::Tensor attention // E * H +) +{ + const auto nnz = permute.size(0); + const auto h = attention.size(1); + auto devid = permute.device().index(); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); + auto out = torch::empty({nnz, h}, options); + if((h & 3) == 0) + { + mhtranspose4<<>>(nnz, h, permute.data_ptr(), attention.data_ptr(), out.data_ptr()); + } + else + { + mhtranspose<<>>(nnz, h, permute.data_ptr(), attention.data_ptr(), out.data_ptr()); + } + return out; +} + +void csr2cscKernel(int m, int n, int nnz, + int *csrRowPtr, int *csrColInd, int *csrVal, + int *cscColPtr, int *cscRowInd, int *cscVal +) +{ + cusparseHandle_t handle; + size_t bufferSize = 0; + void* buffer = NULL; + checkCuSparseError(cusparseCsr2cscEx2_bufferSize(handle, + m, + n, + nnz, + csrVal, + csrRowPtr, + csrColInd, + cscVal, + cscColPtr, + cscRowInd, + CUDA_R_32I, + CUSPARSE_ACTION_SYMBOLIC, + CUSPARSE_INDEX_BASE_ZERO, + CUSPARSE_CSR2CSC_ALG1, + &bufferSize + )); + checkCudaError(cudaMalloc((void**)&buffer, bufferSize * sizeof(float))); + checkCuSparseError(cusparseCsr2cscEx2(handle, + m, + n, + nnz, + csrVal, + csrRowPtr, + csrColInd, + cscVal, + cscColPtr, + cscRowInd, + CUDA_R_32I, + CUSPARSE_ACTION_NUMERIC, + CUSPARSE_INDEX_BASE_ZERO, + CUSPARSE_CSR2CSC_ALG1, + buffer + )); + checkCudaError(cudaFree(buffer)); +} + +std::vector csr2csc_cuda( + torch::Tensor csrRowPtr, + torch::Tensor csrColInd, + torch::Tensor csrVal) +{ + const auto n = csrRowPtr.size(0) - 1; + const auto nnz = csrColInd.size(0); + auto devid = csrRowPtr.device().index(); + auto optionsI = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, devid); + auto cscColPtr = torch::empty({n + 1}, optionsI); + auto cscRowInd = torch::empty({nnz}, optionsI); + auto cscVal = torch::empty({nnz}, optionsI); + csr2cscKernel(n, n, nnz, csrRowPtr.data_ptr(), csrColInd.data_ptr(), csrVal.data_ptr(), + cscColPtr.data_ptr(), cscRowInd.data_ptr(), cscVal.data_ptr()); + return {cscColPtr, cscRowInd, cscVal}; +} \ No newline at end of file diff --git a/cogdl/operators/spmm/multiheadSddmm.cpp b/cogdl/operators/spmm/multiheadSddmm.cpp new file mode 100644 index 00000000..f0716fc1 --- /dev/null +++ b/cogdl/operators/spmm/multiheadSddmm.cpp @@ -0,0 +1,39 @@ +#include +#include +#include +#include + +torch::Tensor mhsddmm_cuda( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor grad, // V * H * F + torch::Tensor feature // V * H * F +); + +torch::Tensor mhsddmm( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor grad, // V * H * F + torch::Tensor feature // V * H * F +) +{ + assert(rowptr.device().type() == torch::kCUDA); + assert(colind.device().type() == torch::kCUDA); + assert(grad.device().type() == torch::kCUDA); + assert(feature.device().type() == torch::kCUDA); + assert(rowptr.is_contiguous()); + assert(colind.is_contiguous()); + assert(grad.is_contiguous()); + assert(feature.is_contiguous()); + assert(rowptr.dtype() == torch::kInt32); + assert(colind.dtype() == torch::kInt32); + assert(grad.dtype() == torch::kFloat32); + assert(feature.dtype() == torch::kFloat32); + return mhsddmm_cuda(rowptr, colind, grad, feature); +} + +PYBIND11_MODULE(mhsddmm, m) +{ + m.doc() = "mhsddmm in CSR format. "; + m.def("mhsddmm", &mhsddmm, "CSR mhsddmm"); +} \ No newline at end of file diff --git a/cogdl/operators/spmm/multiheadSddmm.cu b/cogdl/operators/spmm/multiheadSddmm.cu new file mode 100644 index 00000000..89e38744 --- /dev/null +++ b/cogdl/operators/spmm/multiheadSddmm.cu @@ -0,0 +1,113 @@ +#include +#include +#include "computeUtil.h" + + +__global__ void mhsddmm(const int v, const int f, const int h, const int nnz, + int *rowptr, int *colind, float *grad, + float *feature, float *out) // V * H * F +{ + int eid = (blockIdx.x << 4) + (threadIdx.y << 2); + int cid = threadIdx.x; + int hid = blockIdx.y; + + if (blockIdx.x < nnz / 16) + { + float multi[4] = {0, 0, 0, 0}; + int offset1[4], offset2[4]; + float D1tmp[4], D2tmp[4]; + + Load(offset2, colind, eid); + + offset1[0] = findRow(rowptr, eid, 0, v); + offset1[3] = findRow(rowptr, eid + 3, offset1[0], v); + offset1[1] = findRow(rowptr, eid + 1, offset1[0], offset1[3]); + offset1[2] = findRow(rowptr, eid + 2, offset1[1], offset1[3]); + + selfMulConst4(offset1, f * h); + selfAddConst4(offset1, hid * f); + selfMulConst4(offset2, f * h); + selfAddConst4(offset2, hid * f); + for (int i = 0; i < (f >> 5); i++) + { + Load4(D1tmp, grad, offset1, cid); + Load4(D2tmp, feature, offset2, cid); + Dot4(multi, D1tmp, D2tmp); + cid += 32; + } + int res = f & 31; + if(res) + { + float D1[4] = {0, 0, 0, 0}, D2[4] = {0, 0, 0, 0}; + if(threadIdx.x < res) + { + Load4(D1, grad, offset1, cid); + Load4(D2, feature, offset2, cid); + Dot4(multi, D1, D2); + } + } + AllReduce4(multi, 16, 32); + if (threadIdx.x == 0) + { + out[eid * h + hid] = multi[0]; + out[(eid + 1) * h + hid] = multi[1]; + out[(eid + 2) * h + hid] = multi[2]; + out[(eid + 3) * h + hid] = multi[3]; + } + } + else // Dynamic parrallel? + { + eid = nnz - (nnz & 15) + (blockIdx.x - (nnz / 16)); + int offset1 = findRow(rowptr, eid, 0, v) * f * h + hid * f; + int offset2 = colind[eid] * f * h + hid * f; + float multi = 0; + int off1 = cid = threadIdx.x; + float D1tmp0, D2tmp0; + for (int cc = 0; cc < (f >> 5); cc++) + { + D1tmp0 = grad[offset1 + cid]; + D2tmp0 = feature[offset2 + cid]; + multi += D1tmp0 * D2tmp0; + cid += 32; + } + int res = f & 31; + D1tmp0 = D2tmp0 = 0; + if(res) + { + if(off1 < res) + { + D1tmp0 = grad[offset1 + cid]; + D2tmp0 = feature[offset2 + cid]; + } + multi += D1tmp0 * D2tmp0; + } + for (int stride = 16; stride > 0; stride >>= 1) + { + multi += __shfl_xor_sync(0xffffffff, multi, stride, 32); + } + if (threadIdx.x == 0 && threadIdx.y == 0) + { + out[eid * h + hid] = multi; + } + } +} + + +torch::Tensor mhsddmm_cuda( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor grad, // V * H * F + torch::Tensor feature // V * H * F +) +{ + const auto v = rowptr.size(0) - 1; // V + const auto nnz = colind.size(0); // E + const auto h = feature.size(1); // H + const auto f = feature.size(2); // F + auto devid = feature.device().index(); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); + auto out = torch::empty({nnz, h}, options); + mhsddmm<<>>(v, f, h, nnz, + rowptr.data_ptr(), colind.data_ptr(), grad.data_ptr(), feature.data_ptr(), out.data_ptr()); + return out; +} diff --git a/cogdl/operators/spmm/multiheadSpmm.cpp b/cogdl/operators/spmm/multiheadSpmm.cpp new file mode 100644 index 00000000..211cfd60 --- /dev/null +++ b/cogdl/operators/spmm/multiheadSpmm.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include + +torch::Tensor mhspmm_cuda( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor attention, + torch::Tensor infeat); + +torch::Tensor mhspmm( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor attention, + torch::Tensor infeat) +{ + assert(rowptr.device().type() == torch::kCUDA); + assert(colind.device().type() == torch::kCUDA); + assert(attention.device().type() == torch::kCUDA); + assert(infeat.device().type() == torch::kCUDA); + assert(rowptr.is_contiguous()); + assert(colind.is_contiguous()); + assert(attention.is_contiguous()); + assert(infeat.is_contiguous()); + assert(rowptr.dtype() == torch::kInt32); + assert(colind.dtype() == torch::kInt32); + assert(attention.dtype() == torch::kFloat32); + assert(infeat.dtype() == torch::kFloat32); + return mhspmm_cuda(rowptr, colind, attention, infeat); +} + +PYBIND11_MODULE(mhspmm, m) +{ + m.doc() = "mhtranspose in CSR format. "; + m.def("mhspmm", &mhspmm, "CSR mhsddmm"); +} \ No newline at end of file diff --git a/cogdl/operators/spmm/multiheadSpmm.cu b/cogdl/operators/spmm/multiheadSpmm.cu new file mode 100644 index 00000000..8d239934 --- /dev/null +++ b/cogdl/operators/spmm/multiheadSpmm.cu @@ -0,0 +1,45 @@ +#include +#include +#include "computeUtil.h" + + +__global__ void mhspmmSimple( + int v, int nnz, int h, int f, + int *rowptr, int *colind, float *attention /* E*H */, + float *infeat /* V*H*F */, + float *outfeat /* V*H*F */ +) +{ + int rid = blockIdx.x; + int hid = blockIdx.y; + int lb = rowptr[rid]; + int hb = rowptr[(rid + 1)]; + float acc = 0; + int offset1, offset2; + for (int ptr = lb; ptr < hb; ptr++) + { + offset1 = colind[ptr] * f * h + hid * f + threadIdx.x; + float att = attention[ptr * h + hid]; + acc = sum_reduce(acc, infeat[offset1] * att); + } + offset2 = rid * f * h + hid * f + threadIdx.x; + outfeat[offset2] = acc; +} + +torch::Tensor mhspmm_cuda( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor attention, + torch::Tensor infeat) +{ + const auto v = rowptr.size(0) - 1; + const auto nnz = colind.size(0); + const auto h = attention.size(1); + const auto f = infeat.size(2); + auto devid = infeat.device().index(); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); + auto outfeat = torch::empty({v, h, f}, options); + mhspmmSimple<<>>(v, nnz, h, f, rowptr.data_ptr(), colind.data_ptr(), + attention.data_ptr(), infeat.data_ptr(), outfeat.data_ptr()); + return outfeat; +} diff --git a/cogdl/operators/spmm/spmm_kernel.cu b/cogdl/operators/spmm/spmm_kernel.cu index b3d161f6..5053e1bc 100644 --- a/cogdl/operators/spmm/spmm_kernel.cu +++ b/cogdl/operators/spmm/spmm_kernel.cu @@ -1,476 +1,452 @@ -#include -#include -#include - -#define checkCudaError( a ) do { \ - if (cudaSuccess != (a)) { \ - fprintf(stderr, "Cuda runTime error in line %d of file %s \ - : %s \n", __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError()) ); \ - exit(EXIT_FAILURE); \ - } \ -} while(0) - -#define checkCuSparseError( a ) do { \ - if (CUSPARSE_STATUS_SUCCESS != (a)) { \ - fprintf(stderr, "CuSparse runTime error in line %d of file %s \ - : %s \n", __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError()) ); \ - exit(EXIT_FAILURE); \ - } \ -} while (0) - - - -__device__ __forceinline__ float sum_reduce(float acc, float x) { - return acc + x; -} - -__device__ __forceinline__ float sum_init() { - return 0; -} - -__global__ void topoCacheCoarsenSPMMKernel( - int m, int k, const int* A_indptr, const int* A_indices, const float* B, float* C -) { - extern __shared__ int sh[]; - int sm_offset = (threadIdx.y<<5); - int thread_idx = sm_offset+threadIdx.x; - - int rid = blockDim.y*blockIdx.x+threadIdx.y; - if (rid0) { - acc1 = sum_reduce(acc1, B[offset]);} - // acc1 = sum_reduce(acc1, __ldg(B+offset)); } - if (nout>1) { - acc2 = sum_reduce(acc2, B[(offset+32)]);} - // acc2 = sum_reduce(acc2, __ldg(B+offset+32));} - } - __syncwarp(); - } - offset = rid*k+cid; - if (nout>0) { - C[offset] = acc1;} - if (nout>1) { - C[offset+32] = acc2;} - } - } -} - -__global__ void topoCacheSPMMKernel( - int m, int k, const int* A_indptr, const int* A_indices, const float* B, float* C -) { - extern __shared__ int sh[]; - int sm_offset = (threadIdx.y<<5); - int thread_idx = sm_offset + threadIdx.x; - - int cid = (blockIdx.y<<5)+threadIdx.x; - int rid = blockDim.y*blockIdx.x+threadIdx.y; - - if (rid0) { - acc1 = sum_reduce(acc1, B[offset]);} - // acc1 = sum_reduce(acc1, __ldg(B+offset)); } - } - __syncwarp(); - } - offset = rid*k+cid; - if (nout>0) { - C[offset] = acc1;} - } - } -} - -__global__ void topoSimpleSPMMKernel( - int m, int k, const int* A_indptr, const int* A_indices, const float* B, float* C -) { - int rid = blockDim.y*blockIdx.x+threadIdx.y; - if (rid>>( - m, k, rowptr.data_ptr(), colind.data_ptr(), dense.data_ptr(), out.data_ptr()); - return out; - } - if (k<64) { - const int tile_k = (k+31)/32; - const int n_block = (m+3)/4; - topoCacheSPMMKernel<<< dim3(n_block,tile_k,1), dim3(32,4,1), 128*sizeof(int)>>>( - m, k, rowptr.data_ptr(), colind.data_ptr(), dense.data_ptr(), out.data_ptr()); - return out; - } - else { - const int tile_k = (k+63)/64; - const int n_block = (m+8-1)/8; - topoCacheCoarsenSPMMKernel<<< dim3(n_block,tile_k,1), dim3(32,8,1), 8*32*sizeof(int)>>>( - m, k, rowptr.data_ptr(), colind.data_ptr(), dense.data_ptr(), out.data_ptr()); - return out; - } -} - - -__global__ void spmm_test0( - int A_nrows, int B_ncols, - int* A_csrRowPtr, int* A_csrColInd, float* A_csrVal, - float* B_dnVal, float* C_dnVal -) -{ - int rid = blockDim.y*blockIdx.x+threadIdx.y; - if (rid0) { - acc1 += val*B_dnVal[offset]; - } - if (nout>1) { - acc2 += val*B_dnVal[offset+32]; - } - } - __syncwarp(); - } - offset = rid*B_ncols+cid; - if (nout>0) { - C_dnVal[offset] = acc1; - } - if (nout>1) { - C_dnVal[(offset+32)] = acc2; - } - } - } -} - -void csr2cscKernel(int m, int n, int nnz, - int *csrRowPtr, int *csrColInd, float *csrVal, - int *cscColPtr, int *cscRowInd, float *cscVal -) -{ - cusparseHandle_t handle; - size_t bufferSize = 0; - void* buffer = NULL; - checkCuSparseError(cusparseCsr2cscEx2_bufferSize(handle, - m, - n, - nnz, - csrVal, - csrRowPtr, - csrColInd, - cscVal, - cscColPtr, - cscRowInd, - CUDA_R_32F, - CUSPARSE_ACTION_SYMBOLIC, - CUSPARSE_INDEX_BASE_ZERO, - CUSPARSE_CSR2CSC_ALG1, - &bufferSize - )); - checkCudaError(cudaMalloc((void**)&buffer, bufferSize * sizeof(float))); - checkCuSparseError(cusparseCsr2cscEx2(handle, - m, - n, - nnz, - csrVal, - csrRowPtr, - csrColInd, - cscVal, - cscColPtr, - cscRowInd, - CUDA_R_32F, - CUSPARSE_ACTION_NUMERIC, - CUSPARSE_INDEX_BASE_ZERO, - CUSPARSE_CSR2CSC_ALG1, - buffer - )); - checkCudaError(cudaFree(buffer)); -} - -torch::Tensor spmm_cuda( - torch::Tensor rowptr, - torch::Tensor colind, - torch::Tensor values, - torch::Tensor dense -) { - const auto m = rowptr.size(0)-1; - const auto k = dense.size(1); - auto devid = dense.device().index(); - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); - auto out = torch::empty({m,k}, options); - - if (k<32) { - const int row_per_block = 128/k; - const int n_block = (m+row_per_block-1)/row_per_block; - spmm_test0<<>>( - m, k, rowptr.data_ptr(), colind.data_ptr(), values.data_ptr(), dense.data_ptr(), out.data_ptr()); - return out; - } - if (k<64) { - const int tile_k = (k+31)/32; - const int n_block = (m+4-1)/4; - spmm_test1<<>> ( - m, k, rowptr.data_ptr(), colind.data_ptr(), values.data_ptr(), dense.data_ptr(), out.data_ptr()); - return out; - } - else { - const int tile_k = (k+63)/64; - const int n_block = (m+8-1)/8; - spmm_test2<<>> ( - m, k, rowptr.data_ptr(), colind.data_ptr(), values.data_ptr(), dense.data_ptr(), out.data_ptr()); - return out; - } -} - -std::vector csr2csc_cuda( - torch::Tensor csrRowPtr, - torch::Tensor csrColInd, - torch::Tensor csrVal) -{ - const auto n = csrRowPtr.size(0) - 1; - const auto nnz = csrColInd.size(0); - auto devid = csrRowPtr.device().index(); - auto optionsF = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); - auto optionsI = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, devid); - auto cscColPtr = torch::empty({n + 1}, optionsI); - auto cscRowInd = torch::empty({nnz}, optionsI); - auto cscVal = torch::empty({nnz}, optionsF); - csr2cscKernel(n, n, nnz, csrRowPtr.data_ptr(), csrColInd.data_ptr(), csrVal.data_ptr(), - cscColPtr.data_ptr(), cscRowInd.data_ptr(), cscVal.data_ptr()); - return {cscColPtr, cscRowInd, cscVal}; +#include +#include +#include +#include "computeUtil.h" + + +__global__ void topoCacheCoarsenSPMMKernel( + int m, int k, const int* A_indptr, const int* A_indices, const float* B, float* C +) { + extern __shared__ int sh[]; + int sm_offset = (threadIdx.y<<5); + int thread_idx = sm_offset+threadIdx.x; + + int rid = blockDim.y*blockIdx.x+threadIdx.y; + if (rid0) { + acc1 = sum_reduce(acc1, B[offset]);} + // acc1 = sum_reduce(acc1, __ldg(B+offset)); } + if (nout>1) { + acc2 = sum_reduce(acc2, B[(offset+32)]);} + // acc2 = sum_reduce(acc2, __ldg(B+offset+32));} + } + __syncwarp(); + } + offset = rid*k+cid; + if (nout>0) { + C[offset] = acc1;} + if (nout>1) { + C[offset+32] = acc2;} + } + } +} + +__global__ void topoCacheSPMMKernel( + int m, int k, const int* A_indptr, const int* A_indices, const float* B, float* C +) { + extern __shared__ int sh[]; + int sm_offset = (threadIdx.y<<5); + int thread_idx = sm_offset + threadIdx.x; + + int cid = (blockIdx.y<<5)+threadIdx.x; + int rid = blockDim.y*blockIdx.x+threadIdx.y; + + if (rid0) { + acc1 = sum_reduce(acc1, B[offset]);} + // acc1 = sum_reduce(acc1, __ldg(B+offset)); } + } + __syncwarp(); + } + offset = rid*k+cid; + if (nout>0) { + C[offset] = acc1;} + } + } +} + +__global__ void topoSimpleSPMMKernel( + int m, int k, const int* A_indptr, const int* A_indices, const float* B, float* C +) { + int rid = blockDim.y*blockIdx.x+threadIdx.y; + if (rid>>( + m, k, rowptr.data_ptr(), colind.data_ptr(), dense.data_ptr(), out.data_ptr()); + return out; + } + if (k<64) { + const int tile_k = (k+31)/32; + const int n_block = (m+3)/4; + topoCacheSPMMKernel<<< dim3(n_block,tile_k,1), dim3(32,4,1), 128*sizeof(int)>>>( + m, k, rowptr.data_ptr(), colind.data_ptr(), dense.data_ptr(), out.data_ptr()); + return out; + } + else { + const int tile_k = (k+63)/64; + const int n_block = (m+8-1)/8; + topoCacheCoarsenSPMMKernel<<< dim3(n_block,tile_k,1), dim3(32,8,1), 8*32*sizeof(int)>>>( + m, k, rowptr.data_ptr(), colind.data_ptr(), dense.data_ptr(), out.data_ptr()); + return out; + } +} + + +__global__ void spmm_test0( + int A_nrows, int B_ncols, + int* A_csrRowPtr, int* A_csrColInd, float* A_csrVal, + float* B_dnVal, float* C_dnVal +) +{ + int rid = blockDim.y*blockIdx.x+threadIdx.y; + if (rid0) { + acc1 += val*B_dnVal[offset]; + } + if (nout>1) { + acc2 += val*B_dnVal[offset+32]; + } + } + __syncwarp(); + } + offset = rid*B_ncols+cid; + if (nout>0) { + C_dnVal[offset] = acc1; + } + if (nout>1) { + C_dnVal[(offset+32)] = acc2; + } + } + } +} + +void csr2cscKernel(int m, int n, int nnz, + int *csrRowPtr, int *csrColInd, float *csrVal, + int *cscColPtr, int *cscRowInd, float *cscVal +) +{ + cusparseHandle_t handle; + size_t bufferSize = 0; + void* buffer = NULL; + checkCuSparseError(cusparseCsr2cscEx2_bufferSize(handle, + m, + n, + nnz, + csrVal, + csrRowPtr, + csrColInd, + cscVal, + cscColPtr, + cscRowInd, + CUDA_R_32F, + CUSPARSE_ACTION_SYMBOLIC, + CUSPARSE_INDEX_BASE_ZERO, + CUSPARSE_CSR2CSC_ALG1, + &bufferSize + )); + checkCudaError(cudaMalloc((void**)&buffer, bufferSize * sizeof(float))); + checkCuSparseError(cusparseCsr2cscEx2(handle, + m, + n, + nnz, + csrVal, + csrRowPtr, + csrColInd, + cscVal, + cscColPtr, + cscRowInd, + CUDA_R_32F, + CUSPARSE_ACTION_NUMERIC, + CUSPARSE_INDEX_BASE_ZERO, + CUSPARSE_CSR2CSC_ALG1, + buffer + )); + checkCudaError(cudaFree(buffer)); +} + +torch::Tensor spmm_cuda( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor values, + torch::Tensor dense +) { + const auto m = rowptr.size(0)-1; + const auto k = dense.size(1); + auto devid = dense.device().index(); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); + auto out = torch::empty({m,k}, options); + + if (k<32) { + const int row_per_block = 128/k; + const int n_block = (m+row_per_block-1)/row_per_block; + spmm_test0<<>>( + m, k, rowptr.data_ptr(), colind.data_ptr(), values.data_ptr(), dense.data_ptr(), out.data_ptr()); + return out; + } + if (k<64) { + const int tile_k = (k+31)/32; + const int n_block = (m+4-1)/4; + spmm_test1<<>> ( + m, k, rowptr.data_ptr(), colind.data_ptr(), values.data_ptr(), dense.data_ptr(), out.data_ptr()); + return out; + } + else { + const int tile_k = (k+63)/64; + const int n_block = (m+8-1)/8; + spmm_test2<<>> ( + m, k, rowptr.data_ptr(), colind.data_ptr(), values.data_ptr(), dense.data_ptr(), out.data_ptr()); + return out; + } +} + +std::vector csr2csc_cuda( + torch::Tensor csrRowPtr, + torch::Tensor csrColInd, + torch::Tensor csrVal) +{ + const auto n = csrRowPtr.size(0) - 1; + const auto nnz = csrColInd.size(0); + auto devid = csrRowPtr.device().index(); + auto optionsF = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); + auto optionsI = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, devid); + auto cscColPtr = torch::empty({n + 1}, optionsI); + auto cscRowInd = torch::empty({nnz}, optionsI); + auto cscVal = torch::empty({nnz}, optionsF); + csr2cscKernel(n, n, nnz, csrRowPtr.data_ptr(), csrColInd.data_ptr(), csrVal.data_ptr(), + cscColPtr.data_ptr(), cscRowInd.data_ptr(), cscVal.data_ptr()); + return {cscColPtr, cscRowInd, cscVal}; } \ No newline at end of file diff --git a/cogdl/tasks/README.md b/cogdl/tasks/README.md index 433c9cea..0daebd67 100644 --- a/cogdl/tasks/README.md +++ b/cogdl/tasks/README.md @@ -67,12 +67,13 @@ For multiplex node classification, we use macro F1 to evaluate models. We evalua | Rank | Method | DBLP | ACM | IMDB | | ---- | ------------------------------------------------------------------------------------------------------------------ | :-------: | :-------: | :-------: | -| 1 | GTN [(Yun et al, NeurIPS'19)](https://arxiv.org/abs/1911.06455) | **92.03** | **90.85** | **59.24** | -| 2 | HAN [(Xiao et al, WWW'19)](https://arxiv.org/abs/1903.07293) | 91.21 | 87.25 | 53.94 | -| 3 | GCC [(Qiu et al, KDD'20)](http://keg.cs.tsinghua.edu.cn/jietang/publications/KDD20-Qiu-et-al-GCC-GNN-pretrain.pdf) | 79.42 | 86.82 | 55.86 | -| 4 | PTE [(Tang et al, KDD'15)](https://arxiv.org/abs/1508.00200) | 78.65 | 87.44 | 48.91 | -| 5 | Metapath2vec [(Dong et al, KDD'17)](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) | 75.18 | 88.79 | 43.10 | -| 6 | Hin2vec [(Fu et al, CIKM'17)](https://dl.acm.org/doi/10.1145/3132847.3132953) | 74.31 | 84.66 | 44.04 | +| 1 | Simple-HGN [(Lv and Ding et al, KDD'21)](https://github.com/THUDM/HGB) | **95.09** | **92.57** | **58.61** | +| 2 | GTN [(Yun et al, NeurIPS'19)](https://arxiv.org/abs/1911.06455) | 92.03 | 90.85 | 57.53 | +| 3 | HAN [(Xiao et al, WWW'19)](https://arxiv.org/abs/1903.07293) | 91.21 | 87.25 | 53.94 | +| 4 | GCC [(Qiu et al, KDD'20)](http://keg.cs.tsinghua.edu.cn/jietang/publications/KDD20-Qiu-et-al-GCC-GNN-pretrain.pdf) | 79.42 | 86.82 | 55.86 | +| 5 | PTE [(Tang et al, KDD'15)](https://arxiv.org/abs/1508.00200) | 78.65 | 87.44 | 48.91 | +| 6 | Metapath2vec [(Dong et al, KDD'17)](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) | 75.18 | 88.79 | 43.10 | +| 7 | Hin2vec [(Fu et al, CIKM'17)](https://dl.acm.org/doi/10.1145/3132847.3132953) | 74.31 | 84.66 | 44.04 | ### Link Prediction diff --git a/cogdl/utils/utils.py b/cogdl/utils/utils.py index de20b12f..24277803 100644 --- a/cogdl/utils/utils.py +++ b/cogdl/utils/utils.py @@ -1,572 +1,567 @@ -import errno -import itertools -import os -import os.path as osp -import random -import shutil -from collections import defaultdict -from typing import Optional -from urllib import request - -import numpy as np -import scipy.sparse as sp -import torch -import torch.nn.functional as F -from tabulate import tabulate - -from cogdl.operators.sample import coo2csr_cpu, coo2csr_cpu_index - -try: - from cogdl.operators.edge_softmax import csr_edge_softmax -except Exception: - csr_edge_softmax = None - - -class ArgClass(object): - def __init__(self): - pass - - -def build_args_from_dict(dic): - args = ArgClass() - for key, value in dic.items(): - args.__setattr__(key, value) - return args - - -def untar(path, fname, deleteTar=True): - """ - Unpacks the given archive file to the same directory, then (by default) - deletes the archive file. - """ - print("unpacking " + fname) - fullpath = os.path.join(path, fname) - shutil.unpack_archive(fullpath, path) - if deleteTar: - os.remove(fullpath) - - -def makedirs(path): - try: - os.makedirs(osp.expanduser(osp.normpath(path))) - except OSError as e: - if e.errno != errno.EEXIST and osp.isdir(path): - raise e - - -def download_url(url, folder, name=None, log=True): - r"""Downloads the content of an URL to a specific folder. - - Args: - url (string): The url. - folder (string): The folder. - name (string): saved filename. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - if log: - print("Downloading", url) - - makedirs(folder) - - try: - data = request.urlopen(url) - except Exception as e: - print(e) - print("Failed to download the dataset.") - print(f"Please download the dataset manually and put it under {folder}.") - exit(1) - - if name is None: - filename = url.rpartition("/")[2] - else: - filename = name - path = osp.join(folder, filename) - - with open(path, "wb") as f: - f.write(data.read()) - - return path - - -def alias_setup(probs): - """ - Compute utility lists for non-uniform sampling from discrete distributions. - Refer to https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ - for details - """ - K = len(probs) - q = np.zeros(K) - J = np.zeros(K, dtype=np.int) - - smaller = [] - larger = [] - for kk, prob in enumerate(probs): - q[kk] = K * prob - if q[kk] < 1.0: - smaller.append(kk) - else: - larger.append(kk) - - while len(smaller) > 0 and len(larger) > 0: - small = smaller.pop() - large = larger.pop() - - J[small] = large - q[large] = q[large] + q[small] - 1.0 - if q[large] < 1.0: - smaller.append(large) - else: - larger.append(large) - - return J, q - - -def alias_draw(J, q): - """ - Draw sample from a non-uniform discrete distribution using alias sampling. - """ - K = len(J) - - kk = int(np.floor(np.random.rand() * K)) - if np.random.rand() < q[kk]: - return kk - else: - return J[kk] - - -def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): - device = edge_index.device - if edge_weight is None: - edge_weight = torch.ones(edge_index.shape[1]).to(device) - if num_nodes is None: - num_nodes = torch.max(edge_index) + 1 - if fill_value is None: - fill_value = 1 - - N = num_nodes - self_weight = torch.full((num_nodes,), fill_value, dtype=edge_weight.dtype).to(edge_weight.device) - loop_index = torch.arange(0, N, dtype=edge_index.dtype, device=edge_index.device) - loop_index = loop_index.unsqueeze(0).repeat(2, 1) - edge_index = torch.cat([edge_index, loop_index], dim=1) - edge_weight = torch.cat([edge_weight, self_weight]) - return edge_index, edge_weight - - -def add_remaining_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): - device = edge_index.device - if edge_weight is None: - edge_weight = torch.ones(edge_index.shape[1], device=device) - if num_nodes is None: - num_nodes = torch.max(edge_index) + 1 - if fill_value is None: - fill_value = 1 - - N = num_nodes - row, col = edge_index[0], edge_index[1] - mask = row != col - - loop_index = torch.arange(0, N, dtype=edge_index.dtype, device=edge_index.device) - loop_index = loop_index.unsqueeze(0).repeat(2, 1) - edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1) - - inv_mask = ~mask - - loop_weight = torch.full((N,), fill_value, dtype=edge_weight.dtype, device=edge_weight.device) - remaining_edge_weight = edge_weight[inv_mask] - if remaining_edge_weight.numel() > 0: - loop_weight[row[inv_mask]] = remaining_edge_weight - edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0) - - return edge_index, edge_weight - - -def row_normalization(num_nodes, edge_index, edge_weight=None): - device = edge_index.device - if edge_weight is None: - edge_weight = torch.ones(edge_index.shape[1]).to(device) - row_sum = spmm_scatter(edge_index, edge_weight, torch.ones(num_nodes, 1).to(device)) - row_sum_inv = row_sum.pow(-1).view(-1) - row_sum_inv[torch.isinf(row_sum_inv)] = 0 - return edge_weight * row_sum_inv[edge_index[0]] - - -def symmetric_normalization(num_nodes, edge_index, edge_weight=None): - device = edge_index.device - if edge_weight is None: - edge_weight = torch.ones(edge_index.shape[1]).to(device) - row_sum = spmm_scatter(edge_index, edge_weight, torch.ones(num_nodes, 1).to(device)).view(-1) - row_sum_inv_sqrt = row_sum.pow(-0.5) - row_sum_inv_sqrt[row_sum_inv_sqrt == float("inf")] = 0 - return row_sum_inv_sqrt[edge_index[1]] * edge_weight * row_sum_inv_sqrt[edge_index[0]] - - -def spmm_scatter(indices, values, b): - r""" - Args: - indices : Tensor, shape=(2, E) - values : Tensor, shape=(E,) - b : Tensor, shape=(N, ) - """ - output = b.index_select(0, indices[1]) * values.unsqueeze(-1) - output = torch.zeros_like(b).scatter_add_(0, indices[0].unsqueeze(-1).expand_as(output), output) - return output - - -def spmm_adj(indices, values, x, num_nodes=None): - if num_nodes is None: - num_nodes = x.shape[0] - adj = torch.sparse_coo_tensor(indices=indices, values=values, size=(num_nodes, num_nodes)) - return torch.spmm(adj, x) - - -fast_spmm = None - - -def initialize_spmm(args): - if hasattr(args, "fast_spmm") and args.fast_spmm is True: - try: - from cogdl.operators.spmm import csrspmm - - global fast_spmm - fast_spmm = csrspmm - print("Using fast-spmm to speed up training") - except Exception: - print("Failed to load fast version of SpMM, use torch.spmm instead.") - - -def spmm(graph, x): - if graph.out_norm is not None: - x = graph.out_norm * x - - if fast_spmm is not None and str(x.device) != "cpu": - row_ptr, col_indices = graph.row_indptr, graph.col_indices - csr_data = graph.edge_weight - x = fast_spmm(row_ptr.int(), col_indices.int(), x, csr_data.contiguous(), graph.is_symmetric()) - else: - x = spmm_scatter(graph.edge_index, graph.edge_weight, x) - - if graph.in_norm is not None: - x = graph.in_norm * x - return x - - -def _coo2csr(edge_index, data, num_nodes=None, ordered=False, return_index=False): - if ordered: - return sorted_coo2csr(edge_index[0], edge_index[1], data, return_index=return_index) - if num_nodes is None: - num_nodes = torch.max(edge_index) + 1 - device = edge_index[0].device - sorted_index = torch.argsort(edge_index[0]) - sorted_index = sorted_index.long() - edge_index = edge_index[:, sorted_index] - indices = edge_index[1] - - row = edge_index[0] - indptr = torch.zeros(num_nodes + 1, dtype=torch.int32, device=device) - elements, counts = torch.unique(row, return_counts=True) - elements = elements.long() + 1 - indptr[elements] = counts.to(indptr.dtype) - indptr = indptr.cumsum(dim=0) - - if return_index: - return indptr, sorted_index - if data is not None: - data = data[sorted_index] - return indptr, indices, data - - -def coo2csr(row, col, data, num_nodes=None, ordered=False): - if ordered: - indptr, indices, data = sorted_coo2csr(row, col, data) - return indptr, indices, data - if num_nodes is None: - num_nodes = torch.max(torch.stack(row, col)).item() + 1 - if coo2csr_cpu is None: - return _coo2csr(torch.stack([row, col]), data, num_nodes) - device = row.device - row = row.long().cpu() - col = col.long().cpu() - data = data.float().cpu() - indptr, indices, data = coo2csr_cpu(row, col, data, num_nodes) - return indptr.to(device), indices.to(device), data.to(device) - - -def coo2csr_index(row, col, num_nodes=None): - if num_nodes is None: - num_nodes = torch.max(torch.stack(row, col)).item() + 1 - if coo2csr_cpu_index is None: - return _coo2csr(torch.stack([row, col]), None, num_nodes=num_nodes, return_index=True) - device = row.device - row = row.long().cpu() - col = col.long().cpu() - indptr, reindex = coo2csr_cpu_index(row, col, num_nodes) - return indptr.to(device), reindex.to(device) - - -def sorted_coo2csr(row, col, data, num_nodes=None, return_index=False): - indptr = torch.bincount(row) - indptr = indptr.cumsum(dim=0) - zero = torch.zeros(1, device=indptr.device) - indptr = torch.cat([zero, indptr]) - if return_index: - return indptr, torch.arange(0, row.shape[0]) - return indptr, col, data - - -def coo2csc(row, col, data, num_nodes=None, sorted=False): - return coo2csr(col, row, data, num_nodes, sorted) - - -def csr2csc(indptr, indices, data=None): - device = indices.device - indptr = indptr.cpu().numpy() - indices = indices.cpu().numpy() - num_nodes = indptr.shape[0] - 1 - if data is None: - data = np.ones(indices.shape[0]) - else: - data = data.cpu().numpy() - adj = sp.csr_matrix((data, indices, indptr), shape=(num_nodes, num_nodes)) - adj = adj.tocsc() - data = torch.as_tensor(adj.data, device=device) - col_indptr = torch.as_tensor(adj.indptr, device=device) - row_indices = torch.as_tensor(adj.indices, device=device) - return col_indptr, row_indices, data - - -def csr2coo(indptr, indices, data): - num_nodes = indptr.size(0) - 1 - row = torch.arange(num_nodes, device=indptr.device) - row_count = indptr[1:] - indptr[:-1] - row = row.repeat_interleave(row_count) - return row, indices, data - - -def get_degrees(indices, num_nodes=None): - device = indices.device - values = torch.ones(indices.shape[1]).to(device) - if num_nodes is None: - num_nodes = torch.max(indices).item() + 1 - b = torch.ones((num_nodes, 1)).to(device) - degrees = spmm_scatter(indices, values, b).view(-1) - return degrees - - -def edge_softmax(graph, edge_val): - """ - Args: - indices: Tensor, shape=(2, E) - values: Tensor, shape=(N,) - shape: tuple(int, int) - - Returns: - Softmax values of edge values for nodes - """ - edge_val_max = edge_val.max().item() - while edge_val_max > 10: - edge_val -= edge_val / 2 - edge_val_max = edge_val.max().item() - - with graph.local_graph(): - edge_val = torch.exp(edge_val) - graph.edge_weight = edge_val - x = torch.ones(graph.num_nodes, 1).to(edge_val.device) - node_sum = spmm(graph, x).squeeze() - row = graph.edge_index[0] - softmax_values = edge_val / node_sum[row] - return softmax_values - - -def mul_edge_softmax(graph, edge_val): - """ - Returns: - Softmax values of multi-dimension edge values. shape: [d, E] - """ - if csr_edge_softmax is not None: - val = csr_edge_softmax(graph.row_indptr.int(), edge_val) - return val.t() - else: - val = [] - for i in range(edge_val.shape[1]): - val.append(edge_softmax(graph, edge_val[:, i])) - return torch.stack(val) - - -def remove_self_loops(indices, values=None): - mask = indices[0] != indices[1] - indices = indices[:, mask] - if values is not None: - values = values[mask] - return indices, values - - -def filter_adj(row, col, edge_attr, mask): - return torch.stack([row[mask], col[mask]]), None if edge_attr is None else edge_attr[mask] - - -def dropout_adj( - edge_index: torch.Tensor, edge_weight: Optional[torch.Tensor] = None, drop_rate: float = 0.5, renorm: bool = True -): - if drop_rate < 0.0 or drop_rate > 1.0: - raise ValueError("Dropout probability has to be between 0 and 1, " "but got {}".format(drop_rate)) - - num_nodes = int(torch.max(edge_index)) + 1 - mask = edge_index.new_full((edge_index.size(1),), 1 - drop_rate, dtype=torch.float) - mask = torch.bernoulli(mask).to(torch.bool) - edge_index, edge_weight = filter_adj(edge_index[0], edge_index[1], edge_weight, mask) - if renorm: - edge_weight = symmetric_normalization(num_nodes, edge_index) - return edge_index, edge_weight - - -def coalesce(row, col, value=None): - row = row.numpy() - col = col.numpy() - indices = np.lexsort((col, row)) - row = torch.from_numpy(row[indices]) - col = torch.from_numpy(col[indices]) - - num = col.shape[0] + 1 - idx = torch.full((num,), -1, dtype=torch.float) - idx[1:] = row * num + col - mask = idx[1:] > idx[:-1] - - if mask.all(): - return row, col, value - row = row[mask] - if value is not None: - _value = torch.zeros(row.shape[0], dtype=torch.float).to(row.device) - value = _value.scatter_add_(dim=0, src=value, index=col) - col = col[mask] - return row, col, value - - -def to_undirected(edge_index, num_nodes=None): - r"""Converts the graph given by :attr:`edge_index` to an undirected graph, - so that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in - \mathcal{E}`. - - Args: - edge_index (LongTensor): The edge indices. - num_nodes (int, optional): The number of nodes, *i.e.* - :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) - - :rtype: :class:`LongTensor` - """ - - row, col = edge_index - row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) - row, col, _ = coalesce(row, col, None) - edge_index = torch.stack([row, col]) - return edge_index - - -def get_activation(act: str): - if act == "relu": - return F.relu - elif act == "sigmoid": - return torch.sigmoid - elif act == "tanh": - return torch.tanh - elif act == "gelu": - return F.gelu - elif act == "prelu": - return F.prelu - elif act == "identity": - return lambda x: x - else: - return F.relu - - -def cycle_index(num, shift): - arr = torch.arange(num) + shift - arr[-shift:] = torch.arange(shift) - return arr - - -def batch_sum_pooling(x, batch): - batch_size = int(torch.max(batch.cpu())) + 1 - # batch_size = len(torch.unique(batch)) - res = torch.zeros(batch_size, x.size(1)).to(x.device) - return res.scatter_add_(dim=0, index=batch.unsqueeze(-1).expand_as(x), src=x) - - -def batch_mean_pooling(x, batch): - values, counts = torch.unique(batch, return_counts=True) - res = torch.zeros(len(values), x.size(1)).to(x.device) - res = res.scatter_add_(dim=0, index=batch.unsqueeze(-1).expand_as(x), src=x) - return res / counts.unsqueeze(-1) - - -def negative_edge_sampling( - edge_index: torch.Tensor, - num_nodes: Optional[int] = None, - num_neg_samples: Optional[int] = None, - undirected: bool = False, -): - if num_nodes is None: - num_nodes = len(torch.unique(edge_index)) - if num_neg_samples is None: - num_neg_samples = edge_index.shape[1] - - size = num_nodes * num_nodes - num_neg_samples = min(num_neg_samples, size - edge_index.size(1)) - - row, col = edge_index - unique_pair = row * num_nodes + col - - num_samples = int(num_neg_samples * abs(1 / (1 - 1.1 * edge_index.size(1) / size))) - sample_result = torch.LongTensor(random.sample(range(size), min(num_samples, num_samples))) - mask = torch.from_numpy(np.isin(sample_result, unique_pair.to("cpu"))).to(torch.bool) - selected = sample_result[~mask][:num_neg_samples].to(edge_index.device) - - row = selected // num_nodes - col = selected % num_nodes - return torch.stack([row, col]).long() - - -def tabulate_results(results_dict): - # Average for different seeds - tab_data = [] - for variant in results_dict: - results = np.array([list(res.values()) for res in results_dict[variant]]) - tab_data.append( - [variant] - + list( - itertools.starmap( - lambda x, y: f"{x:.4f}±{y:.4f}", - zip( - np.mean(results, axis=0).tolist(), - np.std(results, axis=0).tolist(), - ), - ) - ) - ) - return tab_data - - -def print_result(results, datasets, model_name): - table_header = ["Variants"] + list(results[0].keys()) - - results_dict = defaultdict(list) - num_datasets = len(datasets) - num_seed = len(results) // num_datasets - for i, res in enumerate(results): - results_dict[(model_name, datasets[i // num_seed])].append(res) - tab_data = tabulate_results(results_dict) - print(tabulate(tab_data, headers=table_header, tablefmt="github")) - - -def set_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.determinstic = True - - -if __name__ == "__main__": - args = build_args_from_dict({"a": 1, "b": 2}) - print(args.a, args.b) +import errno +import itertools +import os +import os.path as osp +import random +import shutil +from collections import defaultdict +from typing import Optional +from urllib import request + +import numpy as np +import scipy.sparse as sp +import torch +import torch.nn.functional as F +from tabulate import tabulate + +from cogdl.operators.sample import coo2csr_cpu, coo2csr_cpu_index +from cogdl.operators.edge_softmax import csr_edge_softmax +from cogdl.operators.mhspmm import csrmhspmm + + +class ArgClass(object): + def __init__(self): + pass + + +def build_args_from_dict(dic): + args = ArgClass() + for key, value in dic.items(): + args.__setattr__(key, value) + return args + + +def untar(path, fname, deleteTar=True): + """ + Unpacks the given archive file to the same directory, then (by default) + deletes the archive file. + """ + print("unpacking " + fname) + fullpath = os.path.join(path, fname) + shutil.unpack_archive(fullpath, path) + if deleteTar: + os.remove(fullpath) + + +def makedirs(path): + try: + os.makedirs(osp.expanduser(osp.normpath(path))) + except OSError as e: + if e.errno != errno.EEXIST and osp.isdir(path): + raise e + + +def download_url(url, folder, name=None, log=True): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (string): The url. + folder (string): The folder. + name (string): saved filename. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + if log: + print("Downloading", url) + + makedirs(folder) + + try: + data = request.urlopen(url) + except Exception as e: + print(e) + print("Failed to download the dataset.") + print(f"Please download the dataset manually and put it under {folder}.") + exit(1) + + if name is None: + filename = url.rpartition("/")[2] + else: + filename = name + path = osp.join(folder, filename) + + with open(path, "wb") as f: + f.write(data.read()) + + return path + + +def alias_setup(probs): + """ + Compute utility lists for non-uniform sampling from discrete distributions. + Refer to https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ + for details + """ + K = len(probs) + q = np.zeros(K) + J = np.zeros(K, dtype=np.int) + + smaller = [] + larger = [] + for kk, prob in enumerate(probs): + q[kk] = K * prob + if q[kk] < 1.0: + smaller.append(kk) + else: + larger.append(kk) + + while len(smaller) > 0 and len(larger) > 0: + small = smaller.pop() + large = larger.pop() + + J[small] = large + q[large] = q[large] + q[small] - 1.0 + if q[large] < 1.0: + smaller.append(large) + else: + larger.append(large) + + return J, q + + +def alias_draw(J, q): + """ + Draw sample from a non-uniform discrete distribution using alias sampling. + """ + K = len(J) + + kk = int(np.floor(np.random.rand() * K)) + if np.random.rand() < q[kk]: + return kk + else: + return J[kk] + + +def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): + device = edge_index.device + if edge_weight is None: + edge_weight = torch.ones(edge_index.shape[1]).to(device) + if num_nodes is None: + num_nodes = torch.max(edge_index) + 1 + if fill_value is None: + fill_value = 1 + + N = num_nodes + self_weight = torch.full((num_nodes,), fill_value, dtype=edge_weight.dtype).to(edge_weight.device) + loop_index = torch.arange(0, N, dtype=edge_index.dtype, device=edge_index.device) + loop_index = loop_index.unsqueeze(0).repeat(2, 1) + edge_index = torch.cat([edge_index, loop_index], dim=1) + edge_weight = torch.cat([edge_weight, self_weight]) + return edge_index, edge_weight + + +def add_remaining_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): + device = edge_index.device + if edge_weight is None: + edge_weight = torch.ones(edge_index.shape[1], device=device) + if num_nodes is None: + num_nodes = torch.max(edge_index) + 1 + if fill_value is None: + fill_value = 1 + + N = num_nodes + row, col = edge_index[0], edge_index[1] + mask = row != col + + loop_index = torch.arange(0, N, dtype=edge_index.dtype, device=edge_index.device) + loop_index = loop_index.unsqueeze(0).repeat(2, 1) + edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1) + + inv_mask = ~mask + + loop_weight = torch.full((N,), fill_value, dtype=edge_weight.dtype, device=edge_weight.device) + remaining_edge_weight = edge_weight[inv_mask] + if remaining_edge_weight.numel() > 0: + loop_weight[row[inv_mask]] = remaining_edge_weight + edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0) + + return edge_index, edge_weight + + +def row_normalization(num_nodes, edge_index, edge_weight=None): + device = edge_index.device + if edge_weight is None: + edge_weight = torch.ones(edge_index.shape[1]).to(device) + row_sum = spmm_scatter(edge_index, edge_weight, torch.ones(num_nodes, 1).to(device)) + row_sum_inv = row_sum.pow(-1).view(-1) + row_sum_inv[torch.isinf(row_sum_inv)] = 0 + return edge_weight * row_sum_inv[edge_index[0]] + + +def symmetric_normalization(num_nodes, edge_index, edge_weight=None): + device = edge_index.device + if edge_weight is None: + edge_weight = torch.ones(edge_index.shape[1]).to(device) + row_sum = spmm_scatter(edge_index, edge_weight, torch.ones(num_nodes, 1).to(device)).view(-1) + row_sum_inv_sqrt = row_sum.pow(-0.5) + row_sum_inv_sqrt[row_sum_inv_sqrt == float("inf")] = 0 + return row_sum_inv_sqrt[edge_index[1]] * edge_weight * row_sum_inv_sqrt[edge_index[0]] + + +def spmm_scatter(indices, values, b): + r""" + Args: + indices : Tensor, shape=(2, E) + values : Tensor, shape=(E,) + b : Tensor, shape=(N, ) + """ + output = b.index_select(0, indices[1]) * values.unsqueeze(-1) + output = torch.zeros_like(b).scatter_add_(0, indices[0].unsqueeze(-1).expand_as(output), output) + return output + + +fast_spmm = None + + +def initialize_spmm(args): + if hasattr(args, "fast_spmm") and args.fast_spmm is True: + try: + from cogdl.operators.spmm import csrspmm + + global fast_spmm + fast_spmm = csrspmm + print("Using fast-spmm to speed up training") + except Exception: + print("Failed to load fast version of SpMM, use torch.spmm instead.") + + +def spmm(graph, x): + if graph.out_norm is not None: + x = graph.out_norm * x + + if fast_spmm is not None and str(x.device) != "cpu": + row_ptr, col_indices = graph.row_indptr, graph.col_indices + csr_data = graph.edge_weight + x = fast_spmm(row_ptr.int(), col_indices.int(), x, csr_data.contiguous(), graph.is_symmetric()) + else: + x = spmm_scatter(graph.edge_index, graph.edge_weight, x) + + if graph.in_norm is not None: + x = graph.in_norm * x + return x + + +def _coo2csr(edge_index, data, num_nodes=None, ordered=False, return_index=False): + if ordered: + return sorted_coo2csr(edge_index[0], edge_index[1], data, return_index=return_index) + if num_nodes is None: + num_nodes = torch.max(edge_index) + 1 + device = edge_index[0].device + sorted_index = torch.argsort(edge_index[0]) + sorted_index = sorted_index.long() + edge_index = edge_index[:, sorted_index] + indices = edge_index[1] + + row = edge_index[0] + indptr = torch.zeros(num_nodes + 1, dtype=torch.int32, device=device) + elements, counts = torch.unique(row, return_counts=True) + elements = elements.long() + 1 + indptr[elements] = counts.to(indptr.dtype) + indptr = indptr.cumsum(dim=0) + + if return_index: + return indptr, sorted_index + if data is not None: + data = data[sorted_index] + return indptr, indices, data + + +def coo2csr(row, col, data, num_nodes=None, ordered=False): + if ordered: + indptr, indices, data = sorted_coo2csr(row, col, data) + return indptr, indices, data + if num_nodes is None: + num_nodes = torch.max(torch.stack(row, col)).item() + 1 + if coo2csr_cpu is None: + return _coo2csr(torch.stack([row, col]), data, num_nodes) + device = row.device + row = row.long().cpu() + col = col.long().cpu() + data = data.float().cpu() + indptr, indices, data = coo2csr_cpu(row, col, data, num_nodes) + return indptr.to(device), indices.to(device), data.to(device) + + +def coo2csr_index(row, col, num_nodes=None): + if num_nodes is None: + num_nodes = torch.max(torch.stack(row, col)).item() + 1 + if coo2csr_cpu_index is None: + return _coo2csr(torch.stack([row, col]), None, num_nodes=num_nodes, return_index=True) + device = row.device + row = row.long().cpu() + col = col.long().cpu() + indptr, reindex = coo2csr_cpu_index(row, col, num_nodes) + return indptr.to(device), reindex.to(device) + + +def sorted_coo2csr(row, col, data, num_nodes=None, return_index=False): + indptr = torch.bincount(row) + indptr = indptr.cumsum(dim=0) + zero = torch.zeros(1, device=indptr.device) + indptr = torch.cat([zero, indptr]) + if return_index: + return indptr, torch.arange(0, row.shape[0]) + return indptr, col, data + + +def coo2csc(row, col, data, num_nodes=None, sorted=False): + return coo2csr(col, row, data, num_nodes, sorted) + + +def csr2csc(indptr, indices, data=None): + device = indices.device + indptr = indptr.cpu().numpy() + indices = indices.cpu().numpy() + num_nodes = indptr.shape[0] - 1 + if data is None: + data = np.ones(indices.shape[0]) + else: + data = data.cpu().numpy() + adj = sp.csr_matrix((data, indices, indptr), shape=(num_nodes, num_nodes)) + adj = adj.tocsc() + data = torch.as_tensor(adj.data, device=device) + col_indptr = torch.as_tensor(adj.indptr, device=device) + row_indices = torch.as_tensor(adj.indices, device=device) + return col_indptr, row_indices, data + + +def csr2coo(indptr, indices, data): + num_nodes = indptr.size(0) - 1 + row = torch.arange(num_nodes, device=indptr.device) + row_count = indptr[1:] - indptr[:-1] + row = row.repeat_interleave(row_count) + return row, indices, data + + +def get_degrees(indices, num_nodes=None): + device = indices.device + values = torch.ones(indices.shape[1]).to(device) + if num_nodes is None: + num_nodes = torch.max(indices).item() + 1 + b = torch.ones((num_nodes, 1)).to(device) + degrees = spmm_scatter(indices, values, b).view(-1) + return degrees + + +def mh_spmm(graph, attention, h): + """ + Multi-head spmm + Args: + graph: Graph + attention: torch.Tensor([E, H]) + h: torch.Tensor([N, d]) + + Returns: + torch.Tensor([N, H, d]) + """ + return csrmhspmm(graph.row_indptr.int(), graph.col_indices.int(), h, attention) + + +def edge_softmax_ori(graph, edge_val): + edge_val_max = edge_val.max().item() + while edge_val_max > 10: + edge_val -= edge_val / 2 + edge_val_max = edge_val.max().item() + + with graph.local_graph(): + edge_val = torch.exp(edge_val) + graph.edge_weight = edge_val + x = torch.ones(graph.num_nodes, 1).to(edge_val.device) + node_sum = spmm(graph, x).squeeze() + row = graph.edge_index[0] + softmax_values = edge_val / node_sum[row] + return softmax_values + + +def mul_edge_softmax(graph, edge_val): + """ + Returns: + Softmax values of multi-dimension edge values. shape: [d, E] + """ + if csr_edge_softmax is not None: + val = csr_edge_softmax(graph.row_indptr.int(), edge_val) + return val.contiguous() + else: + val = [] + for i in range(edge_val.shape[1]): + val.append(edge_softmax_ori(graph, edge_val[:, i])) + return torch.stack(val) + + +def remove_self_loops(indices, values=None): + mask = indices[0] != indices[1] + indices = indices[:, mask] + if values is not None: + values = values[mask] + return indices, values + + +def filter_adj(row, col, edge_attr, mask): + return torch.stack([row[mask], col[mask]]), None if edge_attr is None else edge_attr[mask] + + +def dropout_adj( + edge_index: torch.Tensor, edge_weight: Optional[torch.Tensor] = None, drop_rate: float = 0.5, renorm: bool = True +): + if drop_rate < 0.0 or drop_rate > 1.0: + raise ValueError("Dropout probability has to be between 0 and 1, " "but got {}".format(drop_rate)) + + num_nodes = int(torch.max(edge_index)) + 1 + mask = edge_index.new_full((edge_index.size(1),), 1 - drop_rate, dtype=torch.float) + mask = torch.bernoulli(mask).to(torch.bool) + edge_index, edge_weight = filter_adj(edge_index[0], edge_index[1], edge_weight, mask) + if renorm: + edge_weight = symmetric_normalization(num_nodes, edge_index) + return edge_index, edge_weight + + +def coalesce(row, col, value=None): + row = row.numpy() + col = col.numpy() + indices = np.lexsort((col, row)) + row = torch.from_numpy(row[indices]) + col = torch.from_numpy(col[indices]) + + num = col.shape[0] + 1 + idx = torch.full((num,), -1, dtype=torch.float) + idx[1:] = row * num + col + mask = idx[1:] > idx[:-1] + + if mask.all(): + return row, col, value + row = row[mask] + if value is not None: + _value = torch.zeros(row.shape[0], dtype=torch.float).to(row.device) + value = _value.scatter_add_(dim=0, src=value, index=col) + col = col[mask] + return row, col, value + + +def to_undirected(edge_index, num_nodes=None): + r"""Converts the graph given by :attr:`edge_index` to an undirected graph, + so that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in + \mathcal{E}`. + + Args: + edge_index (LongTensor): The edge indices. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: :class:`LongTensor` + """ + + row, col = edge_index + row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) + row, col, _ = coalesce(row, col, None) + edge_index = torch.stack([row, col]) + return edge_index + + +def get_activation(act: str): + if act == "relu": + return F.relu + elif act == "sigmoid": + return torch.sigmoid + elif act == "tanh": + return torch.tanh + elif act == "gelu": + return F.gelu + elif act == "prelu": + return F.prelu + elif act == "identity": + return lambda x: x + else: + return F.relu + + +def cycle_index(num, shift): + arr = torch.arange(num) + shift + arr[-shift:] = torch.arange(shift) + return arr + + +def batch_sum_pooling(x, batch): + batch_size = int(torch.max(batch.cpu())) + 1 + # batch_size = len(torch.unique(batch)) + res = torch.zeros(batch_size, x.size(1)).to(x.device) + return res.scatter_add_(dim=0, index=batch.unsqueeze(-1).expand_as(x), src=x) + + +def batch_mean_pooling(x, batch): + values, counts = torch.unique(batch, return_counts=True) + res = torch.zeros(len(values), x.size(1)).to(x.device) + res = res.scatter_add_(dim=0, index=batch.unsqueeze(-1).expand_as(x), src=x) + return res / counts.unsqueeze(-1) + + +def negative_edge_sampling( + edge_index: torch.Tensor, + num_nodes: Optional[int] = None, + num_neg_samples: Optional[int] = None, + undirected: bool = False, +): + if num_nodes is None: + num_nodes = len(torch.unique(edge_index)) + if num_neg_samples is None: + num_neg_samples = edge_index.shape[1] + + size = num_nodes * num_nodes + num_neg_samples = min(num_neg_samples, size - edge_index.size(1)) + + row, col = edge_index + unique_pair = row * num_nodes + col + + num_samples = int(num_neg_samples * abs(1 / (1 - 1.1 * edge_index.size(1) / size))) + sample_result = torch.LongTensor(random.sample(range(size), min(num_samples, num_samples))) + mask = torch.from_numpy(np.isin(sample_result, unique_pair.to("cpu"))).to(torch.bool) + selected = sample_result[~mask][:num_neg_samples].to(edge_index.device) + + row = selected // num_nodes + col = selected % num_nodes + return torch.stack([row, col]).long() + + +def tabulate_results(results_dict): + # Average for different seeds + tab_data = [] + for variant in results_dict: + results = np.array([list(res.values()) for res in results_dict[variant]]) + tab_data.append( + [variant] + + list( + itertools.starmap( + lambda x, y: f"{x:.4f}±{y:.4f}", + zip( + np.mean(results, axis=0).tolist(), + np.std(results, axis=0).tolist(), + ), + ) + ) + ) + return tab_data + + +def print_result(results, datasets, model_name): + table_header = ["Variants"] + list(results[0].keys()) + + results_dict = defaultdict(list) + num_datasets = len(datasets) + num_seed = len(results) // num_datasets + for i, res in enumerate(results): + results_dict[(model_name, datasets[i // num_seed])].append(res) + tab_data = tabulate_results(results_dict) + print(tabulate(tab_data, headers=table_header, tablefmt="github")) + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.determinstic = True + + +if __name__ == "__main__": + args = build_args_from_dict({"a": 1, "b": 2}) + print(args.a, args.b) diff --git a/examples/oagbert/calculate_paper_similarity.py b/examples/oagbert/calculate_paper_similarity.py old mode 100755 new mode 100644 diff --git a/examples/oagbert/generate_title.py b/examples/oagbert/generate_title.py old mode 100755 new mode 100644 diff --git a/examples/oagbert/oagbert_encode_paper.py b/examples/oagbert/oagbert_encode_paper.py new file mode 100644 index 00000000..2a43c893 --- /dev/null +++ b/examples/oagbert/oagbert_encode_paper.py @@ -0,0 +1,18 @@ +from cogdl import oagbert + +tokenizer, model = oagbert("oagbert-v2") +title = 'BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding' +abstract = 'We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation...' +authors = ['Jacob Devlin', 'Ming-Wei Chang', 'Kenton Lee', 'Kristina Toutanova'] +venue = 'north american chapter of the association for computational linguistics' +affiliations = ['Google'] +concepts = ['language model', 'natural language inference', 'question answering'] +# encode paper +paper_info = model.encode_paper( + title=title, abstract=abstract, venue=venue, authors=authors, concepts=concepts, affiliations=affiliations, reduction="max" +) + +for name, content in paper_info.items(): + print(name) + print(content) + diff --git a/examples/simple_hgn/README.md b/examples/simple_hgn/README.md new file mode 100644 index 00000000..6b0b6fa6 --- /dev/null +++ b/examples/simple_hgn/README.md @@ -0,0 +1,9 @@ +# Simple-HGN + +Simple-HGN code for heterogeneous node classification in cogdl [leaderboard](../../cogdl/tasks/README.md). + +```bash +CUDA_VISIBLE_DEVICES=0 python run.py --seed 0 1 2 3 4 -t heterogeneous_node_classification -dt gtn-acm -m simple_hgn --lr 0.001 +CUDA_VISIBLE_DEVICES=0 python run.py --seed 0 1 2 3 4 -t heterogeneous_node_classification -dt gtn-dblp -m simple_hgn --lr 0.001 +CUDA_VISIBLE_DEVICES=0 python run.py --seed 0 1 2 3 4 -t heterogeneous_node_classification -dt gtn-imdb -m simple_hgn --lr 0.001 +``` diff --git a/examples/simple_hgn/conv.py b/examples/simple_hgn/conv.py new file mode 100644 index 00000000..00511012 --- /dev/null +++ b/examples/simple_hgn/conv.py @@ -0,0 +1,141 @@ +"""Torch modules for graph attention networks(GAT).""" +# pylint: disable= no-member, arguments-differ, invalid-name +import torch as th +from torch import nn + +from dgl import function as fn +from dgl.nn.pytorch import edge_softmax +from dgl._ffi.base import DGLError +from dgl.nn.pytorch.utils import Identity +from dgl.utils import expand_as_pair + + +# pylint: enable=W0235 +class myGATConv(nn.Module): + """ + Adapted from + https://docs.dgl.ai/_modules/dgl/nn/pytorch/conv/gatconv.html#GATConv + """ + + def __init__( + self, + edge_feats, + num_etypes, + in_feats, + out_feats, + num_heads, + feat_drop=0.0, + attn_drop=0.0, + negative_slope=0.2, + residual=False, + activation=None, + allow_zero_in_degree=False, + bias=False, + alpha=0.0, + ): + super(myGATConv, self).__init__() + self._edge_feats = edge_feats + self._num_heads = num_heads + self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) + self._out_feats = out_feats + self._allow_zero_in_degree = allow_zero_in_degree + self.edge_emb = nn.Embedding(num_etypes, edge_feats) + if isinstance(in_feats, tuple): + self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) + self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False) + else: + self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) + self.fc_e = nn.Linear(edge_feats, edge_feats * num_heads, bias=False) + self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) + self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) + self.attn_e = nn.Parameter(th.FloatTensor(size=(1, num_heads, edge_feats))) + self.feat_drop = nn.Dropout(feat_drop) + self.attn_drop = nn.Dropout(attn_drop) + self.leaky_relu = nn.LeakyReLU(negative_slope) + if residual: + if self._in_dst_feats != out_feats: + self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False) + else: + self.res_fc = Identity() + else: + self.register_buffer("res_fc", None) + self.reset_parameters() + self.activation = activation + self.bias = bias + if bias: + self.bias_param = nn.Parameter(th.zeros((1, num_heads, out_feats))) + self.alpha = alpha + + def reset_parameters(self): + gain = nn.init.calculate_gain("relu") + if hasattr(self, "fc"): + nn.init.xavier_normal_(self.fc.weight, gain=gain) + else: + nn.init.xavier_normal_(self.fc_src.weight, gain=gain) + nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) + nn.init.xavier_normal_(self.attn_l, gain=gain) + nn.init.xavier_normal_(self.attn_r, gain=gain) + nn.init.xavier_normal_(self.attn_e, gain=gain) + if isinstance(self.res_fc, nn.Linear): + nn.init.xavier_normal_(self.res_fc.weight, gain=gain) + nn.init.xavier_normal_(self.fc_e.weight, gain=gain) + + def set_allow_zero_in_degree(self, set_value): + self._allow_zero_in_degree = set_value + + def forward(self, graph, feat, e_feat, res_attn=None): + with graph.local_scope(): + if not self._allow_zero_in_degree: + if (graph.in_degrees() == 0).any(): + raise DGLError( + "There are 0-in-degree nodes in the graph, " + "output for those nodes will be invalid. " + "This is harmful for some applications, " + "causing silent performance regression. " + "Adding self-loop on the input graph by " + "calling `g = dgl.add_self_loop(g)` will resolve " + "the issue. Setting ``allow_zero_in_degree`` " + "to be `True` when constructing this module will " + "suppress the check and let the code run." + ) + + if isinstance(feat, tuple): + h_src = self.feat_drop(feat[0]) + h_dst = self.feat_drop(feat[1]) + if not hasattr(self, "fc_src"): + self.fc_src, self.fc_dst = self.fc, self.fc + feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) + feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) + else: + h_src = h_dst = self.feat_drop(feat) + feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats) + if graph.is_block: + feat_dst = feat_src[: graph.number_of_dst_nodes()] + e_feat = self.edge_emb(e_feat) + e_feat = self.fc_e(e_feat).view(-1, self._num_heads, self._edge_feats) + ee = (e_feat * self.attn_e).sum(dim=-1).unsqueeze(-1) + el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) + er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) + graph.srcdata.update({"ft": feat_src, "el": el}) + graph.dstdata.update({"er": er}) + graph.edata.update({"ee": ee}) + graph.apply_edges(fn.u_add_v("el", "er", "e")) + e = self.leaky_relu(graph.edata.pop("e") + graph.edata.pop("ee")) + # compute softmax + graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) + if res_attn is not None: + graph.edata["a"] = graph.edata["a"] * (1 - self.alpha) + res_attn * self.alpha + # message passing + graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft")) + rst = graph.dstdata["ft"] + # residual + if self.res_fc is not None: + resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) + rst = rst + resval + # bias + if self.bias: + rst = rst + self.bias_param + # activation + if self.activation: + rst = self.activation(rst) + return rst, graph.edata.pop("a").detach() diff --git a/examples/simple_hgn/run.py b/examples/simple_hgn/run.py new file mode 100644 index 00000000..5a1e1bbb --- /dev/null +++ b/examples/simple_hgn/run.py @@ -0,0 +1,212 @@ +import numpy as np +import scipy.sparse as sp +import torch +import torch.nn as nn +import dgl +from dgl.nn.pytorch import GraphConv + +import dgl.function as fn +from dgl.nn.pytorch import edge_softmax, GATConv +from conv import myGATConv + +import torch.nn.functional as F + +from cogdl import experiment, options +from cogdl.models import BaseModel, register_model +from cogdl.models.nn.gcn import GraphConvolution +from cogdl.utils import add_remaining_self_loops, symmetric_normalization, accuracy + + +@register_model("simple_hgn") +class SimpleHGN(BaseModel): + r"""The Simple-HGN model from the `"Are we really making much progress? Revisiting, benchmarking, and refining heterogeneous graph neural networks"`_ paper + + Args: + num_features (int) : Number of input features. + num_classes (int) : Number of classes. + hidden_size (int) : The dimension of node representation. + dropout (float) : Dropout rate for model training. + """ + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--num-features", type=int) + parser.add_argument("--num-classes", type=int) + parser.add_argument("--num-nodes", type=int) + parser.add_argument("--hidden-size", type=int, default=64) + parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--num-edge", type=int, default=2) + parser.add_argument("--num-heads", type=int, default=8) + parser.add_argument('--dropout', type=float, default=0.5) + parser.add_argument('--slope', type=float, default=0.05) + parser.add_argument('--edge-dim', type=int, default=64) + # fmt: on + + @classmethod + def build_model_from_args(cls, args): + heads = [args.num_heads] * args.num_layers + [1] + return cls( + args.edge_dim, + args.num_edge * 2 + 1, + [args.num_features], + args.hidden_size, + args.num_classes, + args.num_layers, + heads, + args.dropout, + args.dropout, + args.slope, + True, + 0.05, + True, + ) + + def __init__( + self, + edge_dim, + num_etypes, + in_dims, + num_hidden, + num_classes, + num_layers, + heads, + feat_drop, + attn_drop, + negative_slope, + residual, + alpha, + use_cuda, + ): + super(SimpleHGN, self).__init__() + self.cross_entropy_loss = nn.CrossEntropyLoss() + self.device = torch.device("cuda:0" if torch.cuda.is_available() and use_cuda else "cpu") + self.g = None + self.num_layers = num_layers + self.gat_layers = nn.ModuleList() + self.activation = F.elu + # self.fc_list = nn.ModuleList([nn.Linear(in_dim, num_hidden, bias=True) for in_dim in in_dims]) + # for fc in self.fc_list: + # nn.init.xavier_normal_(fc.weight, gain=1.414) + # input projection (no residual) + self.gat_layers.append( + myGATConv( + edge_dim, + num_etypes, + in_dims[0], + num_hidden, + heads[0], + feat_drop, + attn_drop, + negative_slope, + False, + self.activation, + alpha=alpha, + ) + ) + # hidden layers + for l in range(1, num_layers): # noqa E741 + # due to multi-head, the in_dim = num_hidden * num_heads + self.gat_layers.append( + myGATConv( + edge_dim, + num_etypes, + num_hidden * heads[l - 1], + num_hidden, + heads[l], + feat_drop, + attn_drop, + negative_slope, + residual, + self.activation, + alpha=alpha, + ) + ) + # output projection + self.gat_layers.append( + myGATConv( + edge_dim, + num_etypes, + num_hidden * heads[-2], + num_classes, + heads[-1], + feat_drop, + attn_drop, + negative_slope, + residual, + None, + alpha=alpha, + ) + ) + self.epsilon = torch.FloatTensor([1e-12]).to(self.device) + + def list_to_sp_mat(self, edges, weights): + data = [x for x in weights] + i = [x for x in edges[0]] + j = [x for x in edges[1]] + total = max(max(i), max(j)) + 1 + return sp.coo_matrix((data, (i, j)), shape=(total, total)).tocsr() + + def build_g_feat(self, A): + edge2type = {} + edges = [] + weights = [] + for k, mat in enumerate(A): + edges.append(mat[0].cpu().numpy()) + weights.append(mat[1].cpu().numpy()) + for u, v in zip(*edges[-1]): + edge2type[(u, v)] = k + edges = np.concatenate(edges, axis=1) + weights = np.concatenate(weights) + adjM = self.list_to_sp_mat(edges, weights) + g = dgl.DGLGraph(adjM) + g = dgl.remove_self_loop(g) + g = dgl.add_self_loop(g) + g = g.to(self.device) + e_feat = [] + for u, v in zip(*g.edges()): + u = u.cpu().item() + v = v.cpu().item() + e_feat.append(edge2type[(u, v)]) + e_feat = torch.tensor(e_feat, dtype=torch.long).to(self.device) + self.g = g + self.e_feat = e_feat + + def forward(self, A, X, target_x, target): # features_list, e_feat): + # h = [] + # for fc, feature in zip(self.fc_list, [X]): + # h.append(fc(feature)) + h = X # torch.cat(h, 0) + if self.g is None: + self.build_g_feat(A) + res_attn = None + for l in range(self.num_layers): # noqa E741 + h, res_attn = self.gat_layers[l](self.g, h, self.e_feat, res_attn=res_attn) + h = h.flatten(1) + # output projection + logits, _ = self.gat_layers[-1](self.g, h, self.e_feat, res_attn=None) + logits = logits.mean(1) + # This is an equivalent replacement for tf.l2_normalize, see https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/l2_normalize for more information. + logits = logits / (torch.max(torch.norm(logits, dim=1, keepdim=True), self.epsilon)) + y = logits[target_x] + loss = self.cross_entropy_loss(y, target) + return loss, y, None + + def loss(self, data): + loss, y, _ = self.forward(data.adj, data.x, data.train_node, data.train_target) + return loss + + def evaluate(self, data, nodes, targets): + loss, y, _ = self.forward(data.adj, data.x, nodes, targets) + f1 = accuracy(y, targets) + return loss.item(), f1 + + +if __name__ == "__main__": + # CUDA_VISIBLE_DEVICES=0 python custom_gcn.py --seed 0 1 2 3 4 -t heterogeneous_node_classification -dt gtn-acm -m simple_hgn --lr 0.001 + parser = options.get_training_parser() + args, _ = parser.parse_known_args() + args = options.parse_args_and_arch(parser, args) + experiment(task="heterogeneous_node_classification", dataset="gtn-acm", model="simple_hgn", args=args) + # experiment(task="node_classification", dataset="cora", model="mygcn") diff --git a/tests/tasks/test_encode_paper.py b/tests/tasks/test_encode_paper.py new file mode 100644 index 00000000..dfdb2131 --- /dev/null +++ b/tests/tasks/test_encode_paper.py @@ -0,0 +1,26 @@ +from cogdl import oagbert + + +def test_encode_paper(): + tokenizer, model = oagbert("oagbert-v2") + title = 'BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding' + abstract = 'We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation...' + authors = ['Jacob Devlin', 'Ming-Wei Chang', 'Kenton Lee', 'Kristina Toutanova'] + venue = 'north american chapter of the association for computational linguistics' + affiliations = ['Google'] + concepts = ['language model', 'natural language inference', 'question answering'] + # encode paper + paper_info = model.encode_paper( + title=title, abstract=abstract, venue=venue, authors=authors, concepts=concepts, affiliations=affiliations, reduction="max" + ) + + assert len(paper_info) == 5 + assert paper_info['text'][0]['type'] == 'TEXT' + assert len(paper_info['authors']) == 4 + assert len(paper_info['venue'][0]['token_ids']) == 9 + assert tuple(paper_info['text'][0]['sequence_output'].shape) == (43, 768) + assert len(paper_info['text'][0]['pooled_output']) == 768 + + +if __name__ == "__main__": + test_encode_paper() diff --git a/tests/tasks/test_node_classification.py b/tests/tasks/test_node_classification.py index 091c5753..faee97ce 100644 --- a/tests/tasks/test_node_classification.py +++ b/tests/tasks/test_node_classification.py @@ -84,12 +84,6 @@ def test_gat_cora(): args.nhead = 8 args.residual = False args.last_nhead = 2 - args.num_layers = 2 - for i in [True, False]: - args.fast_mode = i - task = build_task(args) - ret = task.train() - assert 0 <= ret["Acc"] <= 1 args.num_layers = 3 args.residual = True