Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Dataset with schema #124

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions sgkit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import xarray as xr

from .utils import check_array_like
from .typing import SgkitSchema

DIM_VARIANT = "variants"
DIM_SAMPLE = "samples"
Expand Down Expand Up @@ -53,11 +53,6 @@ def create_genotype_call_dataset(
The dataset of genotype calls.

"""
check_array_like(variant_contig, kind="i", ndim=1)
check_array_like(variant_position, kind="i", ndim=1)
check_array_like(variant_alleles, kind={"S", "O"}, ndim=2)
check_array_like(sample_id, kind={"U", "O"}, ndim=1)
check_array_like(call_genotype, kind="i", ndim=3)
data_vars: Dict[Hashable, Any] = {
"variant_contig": ([DIM_VARIANT], variant_contig),
"variant_position": ([DIM_VARIANT], variant_position),
Expand All @@ -69,17 +64,25 @@ def create_genotype_call_dataset(
call_genotype < 0,
),
}
schema = {
SgkitSchema.variant_contig,
SgkitSchema.variant_position,
SgkitSchema.variant_allele,
SgkitSchema.sample_id,
SgkitSchema.call_genotype,
SgkitSchema.call_genotype_mask,
}
if call_genotype_phased is not None:
check_array_like(call_genotype_phased, kind="b", ndim=2)
data_vars["call_genotype_phased"] = (
[DIM_VARIANT, DIM_SAMPLE],
call_genotype_phased,
)
schema.add(SgkitSchema.call_genotype_phased)
if variant_id is not None:
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
schema.add(SgkitSchema.variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return xr.Dataset(data_vars=data_vars, attrs=attrs)
return SgkitSchema.spec(xr.Dataset(data_vars=data_vars, attrs=attrs), *schema)


def create_genotype_dosage_dataset(
Expand Down Expand Up @@ -115,14 +118,9 @@ def create_genotype_dosage_dataset(
Returns
-------
xr.Dataset
The dataset of genotype calls.
The dataset of genotype dosage.

"""
check_array_like(variant_contig, kind="i", ndim=1)
check_array_like(variant_position, kind="i", ndim=1)
check_array_like(variant_alleles, kind={"S", "O"}, ndim=2)
check_array_like(sample_id, kind={"U", "O"}, ndim=1)
check_array_like(call_dosage, kind="f", ndim=2)
data_vars: Dict[Hashable, Any] = {
"variant_contig": ([DIM_VARIANT], variant_contig),
"variant_position": ([DIM_VARIANT], variant_position),
Expand All @@ -131,8 +129,16 @@ def create_genotype_dosage_dataset(
"call_dosage": ([DIM_VARIANT, DIM_SAMPLE], call_dosage),
"call_dosage_mask": ([DIM_VARIANT, DIM_SAMPLE], np.isnan(call_dosage),),
}
schema = {
SgkitSchema.variant_contig,
SgkitSchema.variant_position,
SgkitSchema.variant_allele,
SgkitSchema.sample_id,
SgkitSchema.call_dosage,
SgkitSchema.call_dosage_mask,
}
if variant_id is not None:
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
schema.add(SgkitSchema.variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return xr.Dataset(data_vars=data_vars, attrs=attrs)
return SgkitSchema.spec(xr.Dataset(data_vars=data_vars, attrs=attrs), *schema)
15 changes: 11 additions & 4 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import xarray as xr
from xarray import DataArray, Dataset

from sgkit.typing import SgkitSchema


def count_alleles(ds: Dataset) -> DataArray:
"""Compute allele count from genotype calls.

Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
Genotype call dataset such as from `sgkit.create_genotype_call_dataset`.
Must hold:
* `sgkit.typing.SgkitSchema.call_genotype`
* `sgkit.typing.SgkitSchema.call_genotype_mask`

Returns
-------
Expand Down Expand Up @@ -40,10 +44,13 @@ def count_alleles(ds: Dataset) -> DataArray:
[2, 2],
[4, 0]])
"""
schm = SgkitSchema.schema_has(
ds, SgkitSchema.call_genotype, SgkitSchema.call_genotype_mask
)
# Count each allele index individually as a 1D vector and
# restack into new alleles dimension with same order
G = ds["call_genotype"].stack(calls=("samples", "ploidy"))
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
G = ds[schm["call_genotype"][0]].stack(calls=("samples", "ploidy"))
M = ds[schm["call_genotype_mask"][0]].stack(calls=("samples", "ploidy"))
n_variant, n_allele = G.shape[0], ds.dims["alleles"]
max_allele = n_allele + 1

Expand Down
62 changes: 25 additions & 37 deletions sgkit/stats/association.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from dataclasses import dataclass
from typing import Optional, Sequence, Union
from typing import Optional

import dask.array as da
import numpy as np
import xarray as xr
from dask.array import Array, stats
from xarray import Dataset

from ..typing import ArrayLike
from ..typing import ArrayLike, SgkitSchema
from .utils import concat_2d


Expand Down Expand Up @@ -109,14 +109,7 @@ def _get_loop_covariates(ds: Dataset, dosage: Optional[str] = None) -> Array:
return da.asarray(G.data)


def gwas_linear_regression(
ds: Dataset,
*,
dosage: str,
covariates: Union[str, Sequence[str]],
traits: Union[str, Sequence[str]],
add_intercept: bool = True,
) -> Dataset:
def gwas_linear_regression(ds: Dataset, *, add_intercept: bool = True) -> Dataset:
"""Run linear regression to identify continuous trait associations with genetic variants.

This method solves OLS regressions for each variant simultaneously and reports
Expand All @@ -130,22 +123,10 @@ def gwas_linear_regression(
----------
ds : Dataset
Dataset containing necessary dependent and independent variables.
dosage : str
Dosage variable name where "dosage" array can contain represent
one of several possible quantities, e.g.:
- Alternate allele counts
- Recessive or dominant allele encodings
- True dosages as computed from imputed or probabilistic variant calls
- Any other custom encoding in a user-defined variable
covariates : Union[str, Sequence[str]]
Covariate variable names, must correspond to 1 or 2D dataset
variables of shape (samples[, covariates]). All covariate arrays
will be concatenated along the second axis (columns).
traits : Union[str, Sequence[str]]
Trait (e.g. phenotype) variable names, must all be continuous and
correspond to 1 or 2D dataset variables of shape (samples[, traits]).
2D trait arrays will be assumed to contain separate traits within columns
and concatenated to any 1D traits along the second axis (columns).
Must hold:
* `sgkit.typing.SgkitSchema.dosage`
* `sgkit.typing.SgkitSchema.covariates`
* `sgkit.typing.SgkitSchema.traits`
add_intercept : bool, optional
Add intercept term to covariate set, by default True.

Expand All @@ -164,11 +145,11 @@ def gwas_linear_regression(
-------
:class:`xarray.Dataset`
Dataset containing (N = num variants, O = num traits):
variant_beta : (N, O) array-like
SgkitSchema.variant_beta: (N, O)
Beta values associated with each variant and trait
variant_t_value : (N, O) array-like
SgkitSchema.variant_t_value: (N, O)
T statistics for each beta
variant_p_value : (N, O) array-like
SgkitSchema.variant_p_value: (N, O)
P values as float in [0, 1]

References
Expand All @@ -182,14 +163,15 @@ def gwas_linear_regression(
Nature Genetics 47 (3): 284–90.

"""
if isinstance(covariates, str):
covariates = [covariates]
if isinstance(traits, str):
traits = [traits]
schm = SgkitSchema.schema_has(
ds, SgkitSchema.dosage, SgkitSchema.covariates, SgkitSchema.traits,
)

G = _get_loop_covariates(ds, dosage=dosage)
G = _get_loop_covariates(ds, dosage=schm[SgkitSchema.dosage][0])

X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
X = da.asarray(
concat_2d(ds[schm[SgkitSchema.covariates]], dims=("samples", "covariates"))
)
if add_intercept:
X = da.concatenate([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1)
# Note: dask qr decomp (used by lstsq) requires no chunking in one
Expand All @@ -198,15 +180,21 @@ def gwas_linear_regression(
# should be removed from dim 1
X = X.rechunk((None, -1))

Y = da.asarray(concat_2d(ds[list(traits)], dims=("samples", "traits")))
Y = da.asarray(concat_2d(ds[schm[SgkitSchema.traits]], dims=("samples", "traits")))
# Like covariates, traits must also be tall-skinny arrays
Y = Y.rechunk((None, -1))

res = linear_regression(G.T, X, Y)
return xr.Dataset(
ds = xr.Dataset(
{
"variant_beta": (("variants", "traits"), res.beta),
"variant_t_value": (("variants", "traits"), res.t_value),
"variant_p_value": (("variants", "traits"), res.p_value),
}
)
return SgkitSchema.spec(
ds,
SgkitSchema.variant_beta,
SgkitSchema.variant_t_value,
SgkitSchema.variant_p_value,
)
40 changes: 21 additions & 19 deletions sgkit/stats/hwe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Hashable, Optional

import dask.array as da
import numpy as np
import xarray as xr
from numba import njit
from numpy import ndarray
from xarray import Dataset

from sgkit.typing import SgkitSchema


def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float:
"""Exact test for HWE as described in Wigginton et al. 2005 [1].
Expand Down Expand Up @@ -119,22 +119,17 @@ def hardy_weinberg_p_value_vec(
hardy_weinberg_p_value_vec_jit = njit(hardy_weinberg_p_value_vec, fastmath=True)


def hardy_weinberg_test(
ds: Dataset, genotype_counts: Optional[Hashable] = None
) -> Dataset:
def hardy_weinberg_test(ds: Dataset) -> Dataset:
"""Exact test for HWE as described in Wigginton et al. 2005 [1].

Parameters
----------
ds : Dataset
Dataset containing genotype calls or precomputed genotype counts.
genotype_counts : Optional[Hashable], optional
Name of variable containing precomputed genotype counts, by default
None. If not provided, these counts will be computed automatically
from genotype calls. If present, must correspond to an (`N`, 3) array
where `N` is equal to the number of variants and the 3 columns contain
heterozygous, homozygous reference, and homozygous alternate counts
(in that order) across all samples for a variant.
May contain `sgkit.typing.SgkitSchema.genotype_counts` otherwise,
`sgkit.typing.SgkitSchema.call_genotype` and `sgkit.typing.SgkitSchema.call_genotype_mask`
must be present to calculate genotype counts.


Warnings
--------
Expand All @@ -144,8 +139,8 @@ def hardy_weinberg_test(
-------
Dataset
Dataset containing (N = num variants):
variant_hwe_p_value : (N,) ArrayLike
P values from HWE test for each variant as float in [0, 1].
* `sgkit.typing.SgkitSchema.variant_hwe_p_value`: (N,)
P values from HWE test for each variant as float in [0, 1].

References
----------
Expand All @@ -164,15 +159,22 @@ def hardy_weinberg_test(
if ds.dims["alleles"] != 2:
raise NotImplementedError("HWE test only implemented for biallelic genotypes")
# Use precomputed genotype counts if provided
if genotype_counts is not None:
obs = list(da.asarray(ds[genotype_counts]).T)
schm = SgkitSchema.get_schema(ds)
if SgkitSchema.genotype_counts in schm:
obs = list(da.asarray(ds[schm[SgkitSchema.genotype_counts][0]]).T)
# Otherwise compute genotype counts from calls
else:
SgkitSchema.schema_has(
ds, SgkitSchema.call_genotype, SgkitSchema.call_genotype_mask
)
# TODO: Use API genotype counting function instead, e.g.
# https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069
M = ds["call_genotype_mask"].any(dim="ploidy")
AC = xr.where(M, -1, ds["call_genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call]
M = ds[schm["call_genotype_mask"][0]].any(dim="ploidy")
AC = xr.where(M, -1, ds[schm["call_genotype"][0]].sum(dim="ploidy")) # type: ignore[no-untyped-call]
cts = [1, 0, 2] # arg order: hets, hom1, hom2
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
return xr.Dataset({"variant_hwe_p_value": ("variants", p)})
return SgkitSchema.spec(
xr.Dataset({"variant_hwe_p_value": ("variants", p)}),
SgkitSchema.variant_hwe_p_value,
)
Loading