debiasm.OnlineDebiasMClassifier


class debiasm.OnlineDebiasMClassifier(batch_str = ‘infer’,
                         learning_rate=0.005,
                         min_epochs=25,
                         l2_strength=0,
                         w_l2=0,
                         random_state=None,
                        prediction_loss=torch.nn.functional.binary_cross_entropy
                         )


The Online DEBIAS-M Classifier.

This class is developed to allow a trained DEBIAS-M model to make predictions on samples from batches that are unobserved during training. This is done by running an online step to infer the biases on previously unobserved data when running the transform and predict_proba methods.

Similarly to other classes, this class implements multiplicative bias-correction via DEBIAS-M for classification. It received as input an X matrix of n_samples n_taxa read count or relative abundancees from multiple microbiome samples, along with a binary y label.

The ‘batch_str’ parameter weights the strength of the enforced cross-batch similarity, ‘l2_strength’ for an l2 regularization of the predictive parameters, and ‘w_l2’ for an l2 regularization of the bias-correction parameters. ‘x_val’ corresponds to microbiome inputs for a held-out set, for which the y labels are unavailable.


Parameters

  • batch_str: {‘infer’ or float}, default=‘infer’
    • The weight of the enforced cross-batch similarity. Selecting ‘infer’ automatically selects the weight inversely proportional to the number of pairs of batches, and the number of taxa in the input matrix. Larger values specify stronger regularization.

  • learning_rate: float, default=0.005
    • The learning rate used during the DEBIAS-M model convergence.

  • min_epochs: int, default=25
    • The minimum number of epochs completed during training.

  • l2_strength: float, default=0
    • The l2 regularization of the linear predictive layer’s parameters. Larger values specify stronger regularization.

  • w_l2: float, default=0
    • The l2 regularization of the multiplicative bias correction parameters (applied to the logarithm of the multiplicative parameters). Larger values specify stronger regularization.

  • random_state: int, default=None
    • Used to specify the seed during training, if specified.

  • prediction_loss: loss function, default=torch.nn.functional.binary_cross_entropy
    • Used to specify the prediction loss function to be used during training.



Example

## import packages
import numpy as np
from sklearn.metrics import roc_auc_score
from debiasm import OnlineDebiasMClassifier

## generate data for the example
np.random.seed(123)
n_samples = 96*5
n_batches = 5
n_features = 100

## the read count matrix
X = ( np.random.rand(n_samples, n_features) * 1000 ).astype(int)

## the labels
y = np.random.rand(n_samples)>0.5

## the batches
batches = ( np.random.rand(n_samples) * n_batches ).astype(int)

## we assume the batches are numbered ints starting at '0',
## and they are in the first column of the input X matrices
X_with_batch = np.hstack((batches[:, np.newaxis], X))
## set the valdiation batch to '4'
val_inds = batches==4
X_train, X_val = X_with_batch[~val_inds], X_with_batch[val_inds]
y_train, y_val = y[~val_inds], y[val_inds]

### Run DEBIAS-M, using standard sklearn object methods
odmc = OnlineDebiasMClassifier() ## give it the held-out inputs to account for
                                    ## those domains shifts while training
odmc.fit(X_train, y_train)

## Assess results
### should be ~~0.5 in this example , since the data is all random
roc_auc_score(y_val, odmc.predict_proba(X_val)[:, 1]) 

## extract the 'DEBIAS-ed' data for other downstream analyses, if applicable 
X_debiassed = odmc.transform(X_with_batch)


Methods

  • fit(X, y)

    • Fit the model according to the given training data.

      • Parameters:
        • X : {array-like, sparse matrix} of shape (n_samples, 1 + n_taxa)
          • Training samples, where n_samples is the number of samples and n_taxa is the number of taxa. The first column of X denotes the batch of each sample, as non-negative integers, while the remaining n_taxa describe the read counts of each taxon. DEBIAS-M also supports relative abundance inputs.
        • y : array-like of shape (n_samples,)
          • Target vector relative to X.

      • Returns:
        • self
          • Fitted DEBIAS-M preprocessor and estimator
  • transform(X)

    • Apply DEBIAS-M processing to X.

      • Parameters:
        • X : {array-like, sparse matrix} of shape (n_samples, 1 + n_taxa)
          • Samples to be transformed; n_samples is the number of samples and n_taxa is the number of taxa. The first column of X denotes the batch of each sample, as non-negative integers, while the remaining n_taxa describe the read counts of each taxon. DEBIAS-M also supports relative abundance inputs.

      • Returns:
        • X_debias
          • matrix of shape (n_samples, n_taxa), of the relative abundance matrix of X following bias-correction
  • predict_proba(X)

    • Calculate DEBIAS-M classification probability estimates; the returned estimates for all classes are ordered by the label of classes.

      • Parameters:
        • X : {array-like, sparse matrix} of shape (n_samples, 1 + n_taxa)
          • Samples to obtain predictions for; n_samples is the number of samples and n_taxa is the number of taxa. The first column of X denotes the batch of each sample, as non-negative integers, while the remaining n_taxa describe the read counts of each taxon. DEBIAS-M also supports relative abundance inputs.

      • Returns:
        • T : array-like of shape (n_samples, n_classes)
          • The probability of the sample for each class in the model

See also:
DebiasMClassifierLogAdd
       Implementation of a DEBIAS-M classifier in log-space

For more background on DEBIAS-M, refer to our manuscript.