Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare few-shot GPT4 features to embedding features for EFO term precision classification #8

Closed
eric-czech opened this issue Aug 8, 2023 · 17 comments

Comments

@eric-czech
Copy link

@yonromai suggested trying to embed the text descriptions, labels, aliases, etc. associated with EFO terms and using those for embeddings as a part of #2.

It would be very interesting to see how a model like the one in #7 improves with embedding features by comparison to a model with only the few-shot labels in #6.

The LLM-derived features will definitely be harder to maintain/generate, but on the other hand I know the labels we provided in #5 are not perfect and I expect that the few-shot features will be more helpful in figuring out which ones are most likely to be mislabeled and why (since they can be directly compared). Nevertheless, contrasting the predictive value of the two could potentially be an important determining factor for how this project, or at least #2 , evolves.

@eric-czech
Copy link
Author

FYI @yonromai, this was the embedding model I had in mind when we last spoke: https://huggingface.co/michiyasunaga/BioLinkBERT-base.

That's from a top-tier group in the NLP space and it's the model submitted by the first author (michiyasunaga) on LinkBERT: Pretraining Language Models with Document Links (Mar. 2022). The reported improvements on a recent SOTA model (PubMedBERT) are substantial, so it might be worth kicking the tires on it.

And to be clear, I have no allegiances to this over a LLaMA-derived model, OpenAI or some KG-based approach. Any performance baseline using embeddings would be helpful.

@yonromai
Copy link
Contributor

yonromai commented Aug 15, 2023

cc: @eric-czech @dhimmel

TL;DR:

  • I have some code which builds additional features by embedding nodes' description using BioLinkBERT-base
  • The model performance improvement is currently insignificant (& the slowdown in terms of model training time is significant)
  • Little to no investigation has gone into (hyper)parameters of the model/features or performance analysis
  • I'd like to work a little on metrics (Good metrics for model evaluation? #9) before investigating/troubleshooting these new features
  • Experiment code is in this branch (not ready to merge)

Details:

Idea behind the new features:

  • At training time:
    • Each node's description is embedded using the BioLinkBERT-base
    • Each vector is used to build a Nearest Neighbors Tree (using Meta's Faiss)
    • The labels of each indexed node are saved for later retrieval
  • At inference time, for a given input node:
    • Embed the input node
    • Use the input node vector to fetch k-NN (k=50, not tuned)
    • Group the neighbors by label (from the training data)
    • Calculate, for each label, a set of descriptive statistics which summarize where the target vec sits w.r.t each cluster (the stats are support (i.e. cluster size), min, max, 1st quartile, median and 3rd quartile)

Outcome:

  • This approach is slow (~25min to run training + test set eval)
  • The gain in terms of performance is pretty much inexistent (see results below)

Comments:

  • Why using NN Trees?
    • I assumed that the embedding size would be too large to be exploited "as is" by the model (given that we have ~20k training points)
    • This tree approach is amongst the simplest I could come up with (but it's still more convoluted than I'd like it to be)
  • More analysis need to take place to understand if there is value in these new features (and whether the value added is worth the extra complexity)
  • More work on metrics (Good metrics for model evaluation? #9) and potentially experiment tracking (there are many Weight & Biases type of solutions in the OSS world) would be useful for ^

Experiment results:

With embedding features

Indexing Vectors: 100%|██████████| 17331/17331 [15:50<00:00, 18.40it/s]
Learning rate set to 0.091517
0:	learn: 1.2545993	total: 87.9ms	remaining: 1m 27s
100:	learn: 0.3817487	total: 1.74s	remaining: 15.4s
200:	learn: 0.3448022	total: 3.38s	remaining: 13.5s
300:	learn: 0.3226222	total: 5.06s	remaining: 11.7s
400:	learn: 0.3037497	total: 6.74s	remaining: 10.1s
500:	learn: 0.2870626	total: 8.47s	remaining: 8.43s
600:	learn: 0.2722362	total: 10.2s	remaining: 6.75s
700:	learn: 0.2571492	total: 11.9s	remaining: 5.07s
800:	learn: 0.2440707	total: 13.6s	remaining: 3.37s
900:	learn: 0.2329543	total: 15.3s	remaining: 1.68s
999:	learn: 0.2212473	total: 17s	remaining: 0us
> Feature importance:
                     Feature Id  Importances
0                        prefix    13.154424
1                    L4-support     8.085884
2                       n_roots     5.246222
3                    L2-support     4.897259
4          intrinsic_ic_sanchez     4.789902
5                    L1-support     4.645040
6                         depth     4.092159
7                   n_ancestors     3.673886
8   intrinsic_ic_sanchez_scaled     3.632503
9                        L2-min     3.356392
10                       L1-min     3.200558
11           xref__mondo__count     2.781731
12                   L3-support     2.633530
13                n_descendants     2.073415
14            xref__omim__count     1.899459
15                 intrinsic_ic     1.737778
16                       L3-min     1.673483
17                       L4-min     1.573646
18        xref__orphanet__count     1.524709
19            xref__ncit__count     1.478501
20          intrinsic_ic_scaled     1.456620
21                        L2-q1     1.455818
22                    n_parents     1.408397
23            xref__gard__count     1.368191
24            xref__doid__count     1.347312
25            xref__mesh__count     1.225176
26            xref__umls__count     1.107692
27          xref__omimps__count     1.101323
28                        L1-q1     0.880555
29            xref__icd9__count     0.803441
30                       L3-max     0.754765
31           xref__icd10__count     0.745556
32                       L1-med     0.744894
33                       L4-max     0.730284
34                   n_children     0.698697
35          xref__meddra__count     0.666259
36                       L2-max     0.643364
37        xref__snomedct__count     0.634071
38                        L3-q1     0.606200
39                       L3-med     0.597636
40                       L2-med     0.588662
41                        L4-q3     0.583918
42                is_gwas_trait     0.558833
43                        L2-q3     0.534196
44                        L3-q3     0.524524
45                        L4-q1     0.516574
46                       L1-max     0.450720
47                        L1-q3     0.402829
48                     n_leaves     0.377648
49                       L4-med     0.335366

> Classification report:
                    precision    recall  f1-score   support

01-disease-subtype       0.79      0.84      0.81      1085
   02-disease-root       0.70      0.66      0.68       797
   03-disease-area       0.79      0.72      0.76       257
    04-non-disease       0.98      0.98      0.98       920

          accuracy                           0.82      3059
         macro avg       0.82      0.80      0.81      3059
      weighted avg       0.82      0.82      0.82      3059

Without embedding features

Learning rate set to 0.091517
0:	learn: 1.2510844	total: 71.5ms	remaining: 1m 11s
100:	learn: 0.4182990	total: 1.06s	remaining: 9.43s
200:	learn: 0.3891503	total: 2.07s	remaining: 8.24s
300:	learn: 0.3713912	total: 3.05s	remaining: 7.09s
400:	learn: 0.3581384	total: 4.03s	remaining: 6.03s
500:	learn: 0.3451430	total: 5.01s	remaining: 4.99s
600:	learn: 0.3334726	total: 5.98s	remaining: 3.97s
700:	learn: 0.3224103	total: 7s	remaining: 2.98s
800:	learn: 0.3125697	total: 8.03s	remaining: 2s
900:	learn: 0.3042341	total: 9.04s	remaining: 993ms
999:	learn: 0.2964110	total: 10s	remaining: 0us
> Feature importance:
                     Feature Id  Importances
0                        prefix    19.235261
1                       n_roots     7.848109
2                   n_ancestors     7.547822
3                         depth     7.365341
4          intrinsic_ic_sanchez     5.770118
5            xref__mondo__count     4.362500
6             xref__omim__count     3.867349
7             xref__doid__count     3.414090
8         xref__orphanet__count     3.360891
9   intrinsic_ic_sanchez_scaled     3.086075
10            xref__mesh__count     3.016446
11                    n_parents     2.935589
12                is_gwas_trait     2.854420
13            xref__umls__count     2.633490
14            xref__ncit__count     2.608488
15                n_descendants     2.435195
16            xref__gard__count     2.336128
17                   n_children     2.183118
18                 intrinsic_ic     2.152569
19          xref__meddra__count     2.075885
20        xref__snomedct__count     1.854948
21           xref__icd10__count     1.849050
22          intrinsic_ic_scaled     1.791724
23            xref__icd9__count     1.153552
24                     n_leaves     1.142829
25          xref__omimps__count     1.119010

> Classification report:
                    precision    recall  f1-score   support

01-disease-subtype       0.80      0.85      0.82      1195
   02-disease-root       0.69      0.61      0.65       719
   03-disease-area       0.77      0.73      0.75       244
    04-non-disease       0.96      0.99      0.97       901

          accuracy                           0.82      3059
         macro avg       0.81      0.79      0.80      3059
      weighted avg       0.82      0.82      0.82      3059

@eric-czech
Copy link
Author

Nice @yonromai!

More analysis need to take place to understand if there is value in these new features (and whether the value added is worth the extra complexity)

I think this is clear above, but that does not include GPT-4 assignments of the labels as features (i.e. from #6) correct?

I assumed that the embedding size would be too large to be exploited "as is" by the model

I think it would be ok to include the embeddings or a reduction on them (e.g. PCA) as features directly. I like the tree/clustering approach, but my hunch is that it will be hard to show an improvement over that simpler method.

Experiment results

Do you have a sense of how much macro F1 averages vary across resamplings (e.g. with cross_val_score(..., cv=5, scoring='f1_macro'))? It would be helpful to know what kind of performance loss in an ablation experiment like that should rate as substantial.

@yonromai
Copy link
Contributor

I think this is clear above, but that does not include GPT-4 assignments of the labels as features (i.e. from #6) correct?

That's right, I think we should do that next.

I think it would be ok to include the embeddings or a reduction on them (e.g. PCA) as features directly. I like the tree/clustering approach, but my hunch is that it will be hard to show an improvement over that simpler method.

Sure, I'll give it a try!

Do you have a sense of how much macro F1 averages vary across resamplings (e.g. with cross_val_score(..., cv=5, scoring='f1_macro'))? It would be helpful to know what kind of performance loss in an ablation experiment like that should rate as substantial.

That's a great question! I just started using ROC AUC & MAE's from #9 to look into the performance of the model. I'll spend a little bit of time in notebook land looking at how features & model parameters influences metrics.

@eric-czech
Copy link
Author

looking at how features & model parameters influences metrics

Awesome, sounds good! So we're clear though, I'm proposing that we compare distributions of F1, ROC, MAE, etc. scores between models where the distributions come from multiple evaluations of those metrics for different folds. Given that this dataset is small, I think we'll need that help understand what changes are significant. Would you agree?

@yonromai
Copy link
Contributor

Yes totally agree, the idea is to "repeat the experiment" of training the same model on different (stratified) folds of the training set to get an idea of the spread of metrics. Then we can use this spread to have an idea of the significance of the metrics calculated once we change the model/features. Is that what you meant?

@eric-czech
Copy link
Author

Is that what you meant?

Indeed 👍

@yonromai
Copy link
Contributor

Okay so after some time in notebook land here is gist of what I found out:

TL;DR

@eric-czech both of your hunches were 💯 :

  • There is quite a bit of metrics variation when training the same model => training on multiple folds helps
  • The best way to leverage the node text embeddings is to do PCA on the embeddings (as opposed to doing KNN)

@dhimmel Implementing the MAE (with the class biases suggested by @eric-czech) has proven very useful!

Some results:

image
For comparison:

  • A random classifier with real class weights prior has a biased MAE of 0.284
  • A random classifier with uniform class weights prior has a biased MAE of 0.348

More details about the best performing model (/ Food for thoughts)

The model seems to max out on BiasedMAE on the training data (not on the eval/CV one):
image

The model seems to slightly overfit the objective function:
image

More details..

For more details about the experiments & findings, take a look at the notebook. All the (non-production ready) code is in my branch.

@yonromai
Copy link
Contributor

@eric-czech I think now I have enough understanding about the performance of the pre-GPT model that it'd be worth running the training data through the GPT4 prompt you provided and see if that does better than the model out of the box!

I can run some estimation of the cost of the procedure if useful.

@dhimmel
Copy link
Member

dhimmel commented Aug 22, 2023

Nice finding that PCA is working better on the node text embeddings than KNN and that 64 dimensions captures much of the performance benefit.

it'd be worth running the training data through the GPT4 prompt

I'm excited to see how the GPT4 features perform!

@eric-czech
Copy link
Author

here is gist of what I found out

Very nice @yonromai! Great experimental setup and it's excellent to see some clear separation between those models.

Some results:

For posterity, I think it would be helpful to say more about what the lda_d7 and knn_d7 configurations were at this level. Presumably knn_d7 was the method in #8 (comment). What was lda_d7 though?

Noting the current details in the notebook:

Screen Shot 2023-08-22 at 8 29 27 AM

running the training data through the GPT4 prompt you provided and see if that does better than the model out of the box!

Awesome -- I'd love to see how it performs on its own and when included as a feature with the other baseline features in a catboost gbm.

More details about the best performing model

OOC what is that UI you're looking at there? I don't see any obvious hints in https://github.com/related-sciences/nxontology-ml/tree/romain/embeddings/experimentation.

@yonromai
Copy link
Contributor

For posterity, I think it would be helpful to say more about what the lda_d7 and knn_d7 configurations were at this level. Presumably knn_d7 was the method in #8 (comment). What was lda_d7 though?

@eric-czech Noted, I'll add more details in the notebook. (The LDA code directly applies Sklearn's LDA, similar to the PCA - see this code)

I'd like to cleanup the code I have in my branch and merge it into the main branch. I'm probably going to end up deleting a lot of the code (e.g. the KNN part) in the near future but that way it'll be saved in git history (along with the experimental setup).

@dhimmel: Would that fine with you? (It's gonna be quite a big PR :( )

OOC what is that UI you're looking at there? I don't see any obvious hints in https://github.com/related-sciences/nxontology-ml/tree/romain/embeddings/experimentation.

The code which displays the model metrics is in the "CatBoost's MetricVisualizer" section of the notebook but it's JavaScript so it doesn't render in GH.

@dhimmel
Copy link
Member

dhimmel commented Aug 24, 2023

@dhimmel: Would that fine with you? (It's gonna be quite a big PR :( )

Yes sounds good.

@dhimmel
Copy link
Member

dhimmel commented Aug 24, 2023

The code which displays the model metrics is in the "CatBoost's MetricVisualizer" section of the notebook but it's JavaScript so it doesn't render in GH.

Sometimes this will render in nbviewer, but not in this case.

@yonromai
Copy link
Contributor

Are there short terms plans to work on this or is it appropriate to close this issue?

@dhimmel
Copy link
Member

dhimmel commented Sep 26, 2023

Given that GPT assignments were inferior to text embedding features and didn't add much when combined, I don't think we need to use GPT features at all. Saves on cost and complexity.

@eric-czech
Copy link
Author

I don't think we need to use GPT features at all

I definitely agree. Noting #34 (comment) as the most recent experiment at TOW that still had these features.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants