1
0
Форкнуть 0
leprosy-skin-lesion-ai-anal.../run_lr.R

84 строки
3.4 KiB
R

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
library(glmnet)
library(pROC)
library(caret)
#' Logistic regression
#'
#' @param train_df A data frame for model training.
#' @param test_df An optional data frame for model testing.
#' @param target The column name of the outcome.
#' @param pos_level Will be treated as positive events for classification.
#' @param neg_level Will be treated as negative events for classification.
#' @param nfolds Num of folds for cross validation. Default is 10.
#' @param foldid An optional vector of values between 1 and nfold identifying what fold each observation is in. If supplied, nfold can be missing.
#' @param alphas Optional user-supplied sequence. Default is [0,.25,.5,.75,1]. Alpha is the elasticnet mixing parameter. When alpha=1, the penalty is the lasso penalty, and alpha=0 the ridge penalty.
#' @param lambdas Optional user-supplied lambda sequence; default is NULL, and glmnet chooses its own sequence.
#' @param show_var Whether to show selected variables.
#' @param repeated_cv The times to repeat cross-validation to select the optimal alpha.
#' @return if test_df exisits: list(acc=accuracy,auc=auc,sp=specificity,se=sensitivity) on the testing dataset.
#' @return if test_df=NULL: the trained model (an object of class "cv.glmnet")
run_logit<-function(train_df,
test_df=NULL,
target='Diagnostic',
pos_level='Leprosy',
neg_level="OD",
nfolds=10,
foldid=NULL,
alphas=seq(0,1,length.out=5),
lambdas=NULL,
show_var=FALSE,
repeated_cv=1){
if (any(levels(train_df[[target]])!=c(neg_level,pos_level)))
stop("The levels of the target col must match with c(neg_level,pos_level)")
y<-train_df[,target]
x <- model.matrix(~., data = train_df[,-which(names(train_df)==target)])
scores<-c()
for (ii in 1:repeated_cv){
# Statified sampling
if(is.null(foldid)){
dt<-createFolds(y,k=nfolds)
foldid<-rep(0,length(y))
for (ii in 1:length(dt)){foldid[unlist(dt[ii])]<-ii}
}
ss<-c()
for (a in alphas){
cv=cv.glmnet(x,y,foldid=foldid,alpha=a, family = "binomial",
lambda=lambdas,
maxit=10000)
ss<-c(ss,min(cv$cvm))
}
scores<-rbind(scores,ss)
}
alpha_best<-alphas[which.min(apply(scores,2,mean))]
logit_model<-cv.glmnet(x,y,foldid=foldid,alpha=alpha_best,
family = "binomial",
lambda=lambdas,
maxit=10000)
coef_lasso<-coef(logit_model)
if(show_var){
print(coef_lasso[rowSums(coef_lasso != 0) != 0,])
}
if(!is.null(test_df)){
test_x=model.matrix(~.,data = test_df[,-which(names(test_df)==target)])
test_y=test_df[,target]
pred<-predict(logit_model, test_x,type = "response")
acc<-mean((pred-.5)*(as.numeric(test_y==pos_level)-0.5)>0)
pred_output<-as.factor(sapply(pred,function(x){
if(x>=.5) pos_level
else neg_level
}))
sp<-specificity(pred_output, test_y,negative=neg_level)
se<-sensitivity(pred_output, test_y,positive = pos_level)
test_roc = roc(test_y ~ c(pred), plot = FALSE,quiet = TRUE)
auc=0+test_roc$auc
list(acc=acc,auc=auc,sp=sp,se=se)
}else{
list(model=logit_model)
}
}