17.3 (异质)线性+非线性的Cox比例风险模型

该节是对Subgroup detection in the heterogeneous partially linear additive Cox model[33]论文的复现。

\[ \lambda(t|X_i,Z_i)=\lambda_0(t)\exp\{X_i^T\beta_i+\sum_{j=1}^q f_j(Z_{ij})\} \tag{17.2} \]

该论文在Cox比例风险模型的基础上,不单单引入回归系数的异质性,还引入了\(f(\cdot)\)来捕捉非线性效应。其中\(f(\cdot)\)利用B-spline去近似,同时也引入融合惩罚项\(u_{ik}=\beta_i-\beta_k\)\(Y'\),通过majorized ADMM算法进行求解。

17.3.1 自定义算法

算法逻辑:

  1. 传入参数

    • T:矩阵,必须包含列名“time”和“status”,分别表示观测时间和最终状态
    • X:具有线性效应、异质性的协变量矩阵
    • Z:具有非线性效应的协变量矩阵
    • penalty:惩罚函数类型,SCAD或MCP
    • K:K-means的聚类个数
    • \(\lambda\):惩罚函数中的惩罚系数
    • a: 融合惩罚中的正则化因子,默认SCAD是3.7,MCP是2.5
    • \(\theta\):majorized ADMM算法的惩罚系数
    • df:splines::bs()的参数,控制基函数个数,默认为6,详见splines::bs(),下同
    • degree:splines::bs()的参数,设置基函数的次数,默认为3
    • tol:收敛精度,默认为0.001
    • max_iter:最大迭代次数,默认为10000
  2. 其余符号说明

    除了传入参数外,在运算过程中还有其它符号,其含义如下所示。

    • \(\beta\)\((\beta_1^T,\cdots,\beta_n^T)^T\),长度为\(np\)的向量

    • \(\gamma\)\((\gamma_1^T,\cdots,\gamma_q^T)^T\),长度为\(dq\)的向量,其中\(\gamma_j=(\gamma_{j1},\cdots,\gamma_{jd})^T\)

    • \(u\)\((u_{ik}^T,i<k)^T\),长度为\(\frac{n(n-1)}{2}p\)的向量,其中\(u_{ik}=\beta_i-\beta_k\)

    • \(Y\)\((Y_1,\cdots,Y_n)^T\),长度为\(n\)的向量,\(Y_i=X_i^T\beta_i+B_i(Z_i)^T\gamma\)

    \(Y'\)的含义类似,只是与\(Y\)的更新规则不同

    • \(w\)\((w_1,\cdots,w_n)^T\),长度为\(n\)的拉格朗日乘子向量

    • \(\nu\)\((\nu_{ik}^T, i<k)^T\),长度为\(\frac{n(n-1)}{2}p\)的拉格朗日乘子向量

    • \(B\)\((B_1(Z_1),\cdots,B_n(Z_n))^T\)\(n\times (dq)\)维矩阵,其中每个\(B_i(Z_i)\)\(q\)个分量,每个分量又能由\(d\)个基函数表示

    就是按列拼接而成的基函数矩阵,并且经过列方向上的中心化处理

    • \(X\)\(\textrm{diag}\{X_1^T,\cdots,X_n^T\}\),应该是\(n \times (pn)\)维矩阵

    这里的\(\textrm{diag}\)是针对\(X_i^T\)而言的,若把\(X_i^T\)展开就不是真正意义上的对角阵,而是类似阶梯状的矩阵

    • \(D\)\(\{(e_i-e_j), i<j\}^T\)\(\frac{n(n-1)}{2} \times n\)矩阵,\(e_i\)是第\(i\)个分量为1,其余分量为0并且长度为\(n\)的向量

    • \(A\)\(D \otimes I_p\)\(\frac{n(n-1)}{2}p \times np\)维矩阵

    • \(Q\)\(I_n-B(B^TB)^{-1}B^T\)

    • \(\tilde g_j\)\(\sum_{i=1}^n \delta_iI_{j \in R_i}\),即第\(j\)个对象出现在所有风险集中的次数

    • \(\nabla_ig(Y'^{(m+1)})\)\(-\delta_i+\sum_{k=1}^n \delta_k[\exp (Y_i'^{(m+1)})\cdot I_{i \in R_k}]/[\sum_{l \in R_k} \exp (Y_l'^{(m+1)})]\)

    • \(c_{ik}\)\(\beta_i-\beta_k+\nu_{ik}/\theta\)

  3. 初始值

    对矩阵\(X\)进行K-means聚类,一般聚成2-5类即可。对每一类分别拟合基础的Cox比例风险模型,将回归系数作为对应观测的初始值。其余参数的初始值分别设置为\(u^{(0)}=A\beta^{(0)}, w^{(0)}=0,\nu^{(0)}=0,Y^{(0)}=Y'^{(0)}=X\beta^{(0)}+B\gamma^{(0)}\)

原文仅提到根据\(X\)先进行聚类,再拟合Cox模型,但由于假设\(Z\)是同质的,如果也先聚类再拟合,得到的各组回归系数并不相同甚至维度也不满足\(\gamma\)的长度,因此将协变量\(Z\)转化成\(B\)后单独拟合Cox模型,将该次回归系数作为\(\gamma\)的初始值

  1. 迭代

    迭代顺序为\(\gamma,\beta;Y',Y,u;w,\nu\)

    \[ \begin{aligned} \gamma^{(m+1)} &= (B^TB)^{-1}B^T(Y^{(m)}-X\beta^{(m)}+\frac{w^{(m)}}{\theta}) \\ \beta^{(m+1)} &= (X^TQX+A^TA)^{-1}[X^TQ(\frac{w^{(m)}}{\theta}+Y^{(m)})+A^T(u^{(m)}-\frac{\nu ^{(m)}}{\theta})] \\ Y_i'^{(m+1)} &= X_i^T\beta_i^{(m+1)}+B_i(Z_i)^T\gamma^{(m+1)} \\ Y_i^{(m+1)} &= (\tilde g_i + \theta)^{-1} [-\nabla_ig(Y'^{(m+1)})+\tilde g_i Y_i'^{(m+1)} - w_i^{(m)}+\theta (X_i^T\beta_i^{(m+1)}+B_i(Z_i)^T\gamma^{(m+1)})] \\ u^{(m+1)} &= \textrm{Penalty} \\ w_i^{(m+1)} &= w_i^{(m)}+\theta (Y_i^{(m+1)}-X_i^T\beta^{(m+1)}-B_i(Z_i)^T\gamma^{(m+1)}) \\ \nu_{ik}^{(m+1)} &= \nu_{ik}^{(m)}+\theta (\beta_i^{(m+1)}-\beta_k^{(m+1)}-u_{ik}^{(m+1)}) \end{aligned} \]

  2. 停止

    设置停止条件为达到最大迭代次数或者残差\(r\)满足一定的精度要求即可。

    \[ r^{(m+1)} = ||A\beta^{(m+1)}-u^{(m+1)}||+||Y^{(m+1)}-X\beta^{(m+1)}-B\gamma^{(m+1)}|| \]

  3. 输出

    • beta:列表,X的异质性回归系数
    • gamma:向量,基函数的回归系数
    • label:向量,样本亚组标签
    • alpha:数据框,beta亚组
library(tidyverse)
library(survival)
library(Matrix)
library(splines)
library(R6)

SubgroupBeta <- R6Class(
  classname <-  'SubgroupBeta',
  
  public <-  list(
    # 传入参数
    T = NULL,           # 观测时间与最终状态
    X = NULL,           # 具有线性效应、异质性的协变量矩阵
    Z = NULL,           # 具有非线性效应的协变量矩阵
    n = NULL,           # 样本容量
    p = NULL,           # X维度p
    q = NULL,           # Z维度q
    penalty = NULL,     # 惩罚函数,SCAD或MCP,默认为MCP
    K = NULL,           # K-means聚类的个数,默认为2
    lambda = NULL,      # 惩罚函数的惩罚系数
    a = NULL,           # 融合惩罚中的正则化因子,默认SCAD是3.7,MCP是2.5
    theta = NULL,       # majorized ADMM算法的惩罚系数
    df = NULL,          # 基函数个数
    degree = NULL,      # 基函数的次数
    tol = NULL,         # 收敛精度,默认为0.001
    max_iter = NULL,    # 最大迭代次数,默认为10000
    
    # 初始化
    initialize = function(T, X, Z, penalty = 'MCP', K = 2, lambda, a = NULL, theta, df = 6, degree = 3, tol = 0.001, max_iter = 10000){
      self$T <- T
      self$X <- X
      self$Z <- Z
      self$n <- dim(T)[1]   # 观测数
      self$p <- dim(X)[2]   # 维度p
      self$q <- dim(Z)[2]   # 维度q
      self$penalty <- penalty
      self$K <- K
      self$lambda <- lambda
      self$a <- a
      self$theta <- theta
      self$df <- df
      self$degree <- degree
      self$tol <- tol
      self$max_iter <- max_iter
    },
    
    # 主函数——运行
    run = function(trace = TRUE){
      start_time <- proc.time()
      
      # 检验输入是否合理
      private$validate()
      
      # 获取初始值
      initial_value <- private$initial_value()
      private$beta <- initial_value[[1]]
      B <- initial_value[[2]]
      private$gamma <- initial_value[[3]]
      X_ls <- initial_value[[4]]
      private$Y <- initial_value[[5]]
      private$Y2 <- initial_value[[5]]
      A <- private$gen_A()
      private$u <- A %*% unlist(private$beta)
      private$w <- rep(0, self$n)
      private$nu <- rep(0, self$n * (self$n-1) * self$p / 2)
      
      # 计算矩阵Q
      Q <- diag(1, ncol = self$n, nrow = self$n)- B %*% solve(t(B) %*% B) %*% t(B)
      
      # 计算向量g
      g <- private$gen_g()
      
      # 计算(B'B)^{-1}B',用于gamma更新
      B_ols <- solve(t(B) %*% B) %*% t(B)
      
      # 计算(X'QX+A'A)^{-1}与X'Q,用于beta更新
      X_diag <- map(X_ls, ~matrix(., nrow = 1, byrow = T))
      X_diag <- do.call(bdiag, X_diag) %>% as.matrix()         # 转化为对角形式的X
      XQX_AA <- solve(t(X_diag) %*% Q %*% X_diag + t(A) %*% A)
      XQ <- t(X_diag) %*% Q
      
      # 参数迭代
      for (i in 1:self$max_iter) {
        if(trace == TRUE) print(paste0('正在进行第[', i, ']次迭代'))
        private$gamma <- private$iter_gamma(X_ls, B_ols, private$Y, private$beta, private$w)
        private$beta <- private$iter_beta(XQX_AA, XQ, A, private$w, private$Y, private$u, private$nu)
        private$Y2 <- private$iter_Y2(X_ls, B, g, private$beta, private$gamma)
        private$Y <- private$iter_Y(g, private$Y2, private$w)
        private$u <- private$iter_u(self$penalty, private$beta, private$nu, A)
        private$w <- private$iter_w(private$w, private$Y, private$Y2)
        private$nu <- private$iter_nu(private$nu, private$beta, private$u, A)
        
        # 终止条件
        term_1 <- A %*% unlist(private$beta) - private$u
        term_2 <- private$Y - X_diag %*% unlist(private$beta) - B %*% private$gamma
        if((norm(term_1, type = '2') + norm(term_2, type = '2')) <= self$tol){
          if(trace == TRUE) print('达到精度要求')
          break
        }
      }
      
      # beta亚组
      private$beta <- lapply(private$beta, round, digits = 2)
      str_beta <- sapply(private$beta, paste, collapse = ',')
      label <- as.numeric(factor(str_beta, levels = unique(str_beta)))
      subgroup_beta <- tibble(label = label, beta = private$beta) %>% 
        group_by(label) %>% 
        summarise(size = n(), beta = unique(beta)) %>% 
        arrange(-size)
      K_hat <- unique(label) %>% length()
      result <- list(beta = private$beta, gamma = private$gamma, label = label, K_hat = K_hat, alpha = subgroup_beta)
      
      end_time <- proc.time()
      cost_time <- end_time - start_time
      print(cost_time)
      return(result)
    },
    
    # 主函数——调优
    tune_lambda = function(seq_lambda, trace = TRUE){
      start_time <- proc.time()
      
      # 检验输入是否合理
      private$validate()
      
      # 获取初始值
      initial_value <- private$initial_value()
      private$beta <- initial_value[[1]]
      B <- initial_value[[2]]
      private$gamma <- initial_value[[3]]
      X_ls <- initial_value[[4]]
      private$Y <- initial_value[[5]]
      private$Y2 <- initial_value[[5]]
      A <- private$gen_A()
      private$u <- A %*% unlist(private$beta)
      private$w <- rep(0, self$n)
      private$nu <- rep(0, self$n * (self$n-1) * self$p / 2)
      
      # 计算矩阵Q
      Q <- diag(1, ncol = self$n, nrow = self$n)- B %*% solve(t(B) %*% B) %*% t(B)
      
      # 计算向量g
      g <- private$gen_g()
      
      # 计算(B'B)^{-1}B',用于gamma更新
      B_ols <- solve(t(B) %*% B) %*% t(B)
      
      # 计算(X'QX+A'A)^{-1}与X'Q,用于beta更新
      X_diag <- map(X_ls, ~matrix(., nrow = 1, byrow = T))
      X_diag <- do.call(bdiag, X_diag) %>% as.matrix()         # 转化为对角形式的X
      XQX_AA <- solve(t(X_diag) %*% Q %*% X_diag + t(A) %*% A)
      XQ <- t(X_diag) %*% Q
      
      bic_vec <- rep(0, length(seq_lambda))
      result_ls <- vector('list', length = length(seq_lambda))
      for (i in 1:length(seq_lambda)) {
        self$lambda <- seq_lambda[i]
        
        # Warm Start,仅保留上轮的beta和gamma,其余恢复为初始值设定
        private$Y <- private$Y2    # Y2 = X_beta + B_gamma,因此保留,故Y的初始值与Y2相同
        private$u <- A %*% unlist(private$beta)
        private$w <- rep(0, self$n)
        private$nu <- rep(0, self$n * (self$n-1) * self$p / 2)
        
        for (j in 1:self$max_iter) {
          if(trace == TRUE) print(paste0('lambda = ', self$lambda, '; 第[', j, ']次迭代'))
          private$gamma <- private$iter_gamma(X_ls, B_ols, private$Y, private$beta, private$w)
          private$beta <- private$iter_beta(XQX_AA, XQ, A, private$w, private$Y, private$u, private$nu)
          private$Y2 <- private$iter_Y2(X_ls, B, g, private$beta, private$gamma)
          private$Y <- private$iter_Y(g, private$Y2, private$w)
          private$u <- private$iter_u(self$penalty, private$beta, private$nu, A)
          private$w <- private$iter_w(private$w, private$Y, private$Y2)
          private$nu <- private$iter_nu(private$nu, private$beta, private$u, A)
          
          # 终止条件
          term_1 <- A %*% unlist(private$beta) - private$u
          term_2 <- private$Y - X_diag %*% unlist(private$beta) - B %*% private$gamma
          if((norm(term_1, type = '2') + norm(term_2, type = '2')) <= self$tol){
            if(trace == TRUE) cat('==========\n')
            if(trace == TRUE) print(paste0('lambda = ', self$lambda, ' 达到精度要求'))
            break
          }
        }
        
        # beta亚组
        private$beta <- lapply(private$beta, round, digits = 2)
        str_beta <- sapply(private$beta, paste, collapse = ',')
        label <- as.numeric(factor(str_beta, levels = unique(str_beta)))
        subgroup_beta <- tibble(label = label, beta = private$beta) %>% 
          group_by(label) %>% 
          summarise(size = n(), beta = unique(beta)) %>% 
          arrange(-size)
        K_hat <- unique(label) %>% length()
        
        bic_vec[i] <- private$Bic(private$Y2, K_hat)
        if(trace == TRUE) print(paste0('BIC: ', round(bic_vec[i],3)))
        if(trace == TRUE) cat('==========\n')
        result <- list(beta = private$beta, gamma = private$gamma, label = label, K_hat = K_hat, alpha = subgroup_beta, 
                       best_lambda = self$lambda, BIC = bic_vec[i])
        result_ls[[i]] <- result
      }
      
      best_index <- which.min(bic_vec)
      best_result <- result_ls[[best_index]]
      
      end_time <- proc.time()
      print(end_time - start_time)
      
      return(list(best_result = best_result, BIC = bic_vec))
    }
  ),
  
  private <- list(
    # 迭代参数
    gamma = NULL,        # B样条的回归系数
    beta = NULL,         # X的回归系数
    Y2 = NULL,           # Y'
    Y = NULL,            # Y
    u = NULL,            # 融合惩罚项
    w = NULL,            # 拉格朗日乘子
    nu = NULL,           # 拉格朗日乘子
    c = NULL,            # beta_i-beta_k+nu_ik/theta
    bic = NULL,          # BIC
    
    # 验证输入是否正确
    validate = function(){
      if(!(dim(self$T)[2] == 2 & all(colnames(self$T) %in% c('time', 'status')))) stop('T要求为包含time和status的两列矩阵')
      if(!is.matrix(self$X)) stop('X要求为矩阵')
      if(!is.matrix(self$Z)) stop('Z要求为矩阵')
      if(!self$penalty %in% c('MCP','SCAD')) stop('请选择合适的惩罚函数,SCAD或MCP')
      if(!(self$K > 0 & self$K == as.integer(self$K))) stop('确保K是正整数')
      if(self$lambda <= 0) stop('请选择合适的lambda值')
      if(self$a <= 0) stop('请选择合适的theta值')
      if(self$theta <= 0) stop('请选择合适的theta值')
      if(!(self$df > 0 & self$df == as.integer(self$df))) stop('请选择合适的df值')
      if(!(self$degree > 0 & self$degree == as.integer(self$degree))) stop('请选择合适的degree值')
      if(self$tol <= 0) stop('请选择合适的精度要求')
      if(self$max_iter <= 0) stop('请选择合适的最大迭代次数')
    },
    
    # 获取初始值
    initial_value = function(){
      # beta初始值
      result <- kmeans(self$X, centers = self$K)
      df <- cbind(self$T, self$X) %>% as.data.frame()
      df$label <- result$cluster        # 添加类别标签
      df <- df %>% 
        group_nest(label) %>%           # 分组回归,批量建模
        mutate(model = map(data, ~coxph(Surv(time, status)~., data=.x)),
               coef = map(model, ~.x[['coefficients']]))
      coef <- df$coef %>% as.list()     # 提取每个类别的回归系数向量
      beta_0 <- coef[result$cluster]    # 为了后续处理方便,beta暂时以列表形式存储
      
      # gamma初始值与B
      B <- asplit(self$Z, 2)   # 按列分割Z矩阵,分别由bs拟合
      B <- B %>% map(~bs(., df = self$df, degree = self$degree))
      B <- do.call(cbind, B) %>% scale(center = TRUE, scale = FALSE)   # 列方向的中心化处理
      colnames(B) <- paste0('b_', 1:ncol(B))
      df <- cbind(self$T, B) %>% as.data.frame()
      model <- coxph(Surv(time, status)~., data = df)
      gamma_0 <- model[['coefficients']]
      
      # Y与Y'的初始值
      X_ls <- asplit(self$X, 1)
      Y_0 <- map2_vec(X_ls, beta_0, ~.x %*% .y) + B %*% gamma_0
      
      return(list(beta_0 = beta_0, B = B, gamma_0 = gamma_0, X_ls = X_ls, Y_0 = Y_0))
    },
    
    # 计算矩阵A
    gen_A = function(){
      gen_mat <- function(n){
        mat_1 <- if(n == self$n-1){
          NULL
        }else{
          matrix(0, nrow = self$n-1-n, ncol = n)
        }
        mat_2 <- matrix(1, nrow = 1, ncol = n, byrow = T)
        mat_3 <- diag(-1, nrow=n, ncol = n)
        mat <- rbind(mat_1, mat_2, mat_3)
        return(mat)
      }
      D <- as.list(c((self$n-1):1)) %>% map(~gen_mat(.))
      D <- do.call(cbind, D) %>% t()
      
      A <- D %x% diag(1, ncol = self$p, nrow = self$p)
      
      return(A)
    },
    
    # 计算g
    gen_g = function(){
      status <- self$T[,'status'] %>% as.vector()
      time <- self$T[,'time'] %>% as.vector()
      g <- as.list(time) %>% 
        map(~ifelse(. >= time, 1, 0)) %>% 
        map_vec(~status %*% .)
      
      return(g)
    },
    
    # 计算nabla_g
    gen_nabla_g = function(Y2){
      status <- self$T[,'status'] %>% as.vector()
      time <- self$T[,'time'] %>% as.vector()
      
      exp_Y2 <- as.list(exp(Y2))
      
      # 向量sum_{l \in R_k} exp(Y2)
      sum_l_Rk <- as.list(time) %>% 
        map(~which(. <= time)) %>% 
        map_vec(~sum(exp(Y2[.])))
      
      # 计算nabla_g
      nabla_g <- as.list(time) %>% 
        map(~ifelse(. >= time, 1, 0)) %>%    # 计算I_{i \in R_k}
        map(~. %*% diag(status) %*% sum_l_Rk^(-1)) %>%     # 计算delta与I与{l \in R_k}的复合项
        map2_vec(exp_Y2, ~.x * .y) - status
      
      return(nabla_g)
    },
    
    # 计算c,输出列表
    gen_c = function(beta, nu, A){
      delta_beta <- A %*% unlist(beta)
      c <- delta_beta + nu / self$theta
      c <- matrix(c, ncol = self$p, byrow = T) %>% asplit(1)
      
      return(c)
    },
    
    # gamma迭代式
    iter_gamma = function(X_ls, B_ols, Y_current, beta_current, w_current){
      # 这里的beta是列表形式
      X_beta <- map2_vec(X_ls, beta_current, ~.x %*% .y)
      gamma_next <- B_ols %*% (Y_current - X_beta + w_current / self$theta)
      
      return(gamma_next)
    },
    
    # beta迭代式,输出列表
    iter_beta = function(XQX_AA, XQ, A, w_current, Y_current, u_current, nu_current){
      beta_next <- XQX_AA %*% (XQ %*% (w_current / self$theta + Y_current) + t(A) %*% (u_current - nu_current / self$theta))
      beta_next <- matrix(beta_next, ncol = self$p, byrow = T) %>% asplit(1)   # 输出列表形式的beta
      
      return(beta_next)
    },
    
    # Y2迭代式
    iter_Y2 = function(X_ls, B, g, beta_next, gamma_next){
      term_1 <- map2_vec(X_ls, beta_next, ~.x %*% .y)   # X与beta
      term_2 <- asplit(B,1) %>% map_vec(~. %*% gamma_next)  # B与gamma
      Y2_next <- term_1 + term_2
      
      return(Y2_next)
    },
    
    # Y迭代式
    iter_Y = function(g, Y2_next, w_current){
      nabla_g_next <- private$gen_nabla_g(Y2_next)
      Y_next <- (g + self$theta)^(-1) * (-nabla_g_next+ g * Y2_next - w_current + self$theta * Y2_next)
      
      return(Y_next)
    },
    
    # u迭代式
    iter_u = function(penalty, beta_next, nu_current, A){
      S <- function(c, lambda){
       result <- max((1 - lambda / norm(c, type = '2')), 0) * c
       
        return(result)
      }
      
      c <- private$gen_c(beta_next, nu_current, A)
      
      switch(penalty,
             'SCAD' = {
               if(is.null(self$a)) self$a <- 3.7 
               u_next <- map(c, function(c_ik){
                 norm_c_ik <- norm(c_ik, type = '2')
                 if(norm_c_ik <= self$lambda + self$lambda / self$theta){
                   u_ik <- S(c_ik, self$lambda / self$theta)
                 }else if(self$lambda + self$lambda / self$theta < norm_c_ik & norm_c_ik <= self$a * self$lambda){
                   u_ik <- S(c_ik, self$a * self$lambda / ((self$a - 1) * self$theta)) / (1-1/((self$a - 1) * self$theta))
                 }else {
                   u_ik <- c_ik
                 }
                 
                 return(u_ik)
               }) %>% unlist()
             },
             'MCP' = {
               if(is.null(self$a)) self$a <- 2.5
               u_next <- map(c, function(c_ik){
                 norm_c_ik <- norm(c_ik, type = '2')
                 if(norm_c_ik <= self$a * self$lambda){
                   u_ik <- S(c_ik, self$lambda / self$theta) / (1-1/(self$a * self$theta))
                 }else {
                   u_ik <- c_ik
                 }
                 
                 return(u_ik)
               }) %>% unlist()
             }
        
      )
      
      return(u_next)
    },
    
    # w迭代式
    iter_w = function(w_current, Y_next, Y2_next){
      w_next <- w_current + self$theta * (Y_next - Y2_next)
      
      return(w_next)
    },
    
    # nu迭代式
    iter_nu = function(nu_current, beta_next, u_next, A){
      delta_beta <- A %*% unlist(beta_next)
      nu_next <- nu_current + self$theta * (delta_beta - u_next)
      
      return(nu_next)
    },
    
    # BIC准则
    Bic = function(Y2, K){
      # X_beta+B_gamma就是Y2
      status <- self$T[,'status'] %>% as.vector()
      time <- self$T[,'time'] %>% as.vector()
      
      log_sum_l_Ri <- as.list(time) %>% 
        map(~which(. <= time)) %>% 
        map_vec(~log(sum(exp(Y2[.]))))
      term_1 <- -sum(Y2[status] - log_sum_l_Ri[status])
      term_2 <- log(self$n * K + self$q) * log(self$n) * (K * self$p + self$q) / self$n
      bic <- term_1 + term_2
      
      return(bic)
    }
  )
)

下面给出python版本的代码。

import numpy as np
import pandas as pd
import time
from sklearn.cluster import KMeans
from sklearn.preprocessing import scale
from scipy import sparse
from scipy.linalg import block_diag
from lifelines import CoxPHFitter
from patsy import bs


class SubgroupBeta:
    
    def __init__(self, time_status, X, Z,  lambda_value, a = None, penalty = 'MCP', K = 2, theta = 1, df = 6, degree = 3, tol = 1e-4, max_iter = 10000):
        self.time_status = time_status
        self.X = X
        self.Z = Z
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.q = Z.shape[1]
        self.lambda_value = lambda_value
        self.a = a
        self.penalty = penalty
        self.K = K
        self.theta = theta
        self.df = df
        self.degree = degree
        self.tol = tol
        self.max_iter = max_iter
        
        self.gamma = None
        self.beta = None
        self.Y2 = None
        self.Y = None
        self.u = None
        self.w = None
        self.nu = None
        self.c = None
        self.bic = None
    
    # 拟合cox模型
    def fit_cox_model(self, data):
        cph = CoxPHFitter()
        cph.fit(data, duration_col = 'time', event_col = 'status')
        coef = np.array(cph.params_).reshape(-1,1)
        return coef
    
    # 生成B样条的基函数矩阵
    def gen_B(self, z, df = 6, degree = 3, intercept = False):
        B = bs(z, df = df, degree = degree, include_intercept = intercept)
        
        return B
    
    def initial_value(self):
        # kmeans
        kmeans = KMeans(self.K, random_state = 1)
        kmeans.fit(self.X)
        labels = kmeans.labels_
        
        # beta初始值
        col_names = ['time', 'status'] + [f'X_{i+1}' for i in range(self.p)]
        df = pd.DataFrame(np.hstack((self.time_status, self.X)), columns = col_names)
        df['label'] = labels
        beta_cox = df.groupby('label', group_keys=False).apply(self.fit_cox_model, include_groups=False)
        beta_0 = np.hstack(beta_cox.tolist())
        # 每列都是beta的系数
        beta_0 = beta_0[:, labels]
        
        # gamma初始值
        B = []
        for col_Z in range(self.q):
            z = self.Z[:, col_Z]
            B_col_Z = self.gen_B(z)
            B.append(B_col_Z)
        B = np.hstack(B)
        B = scale(B, axis = 0, with_mean = True, with_std = False)
        col_names = ['time', 'status'] + [f'b_{i+1}' for i in range(B.shape[1])]
        df = pd.DataFrame(np.hstack((self.time_status, B)), columns = col_names)
        gamma_0 = self.fit_cox_model(df)
        
        # Y与Y2的初始值
        Y_0 = np.einsum('ij,ji->i', self.X, beta_0).reshape(-1,1) + B @ gamma_0

        return beta_0, B, gamma_0, Y_0
        
    # 生成矩阵A
    def gen_A(self):
        # 生成稀疏Delta矩阵
        rows = []
        for i in range(self.n-1):
            row = sparse.lil_matrix((self.n-1-i, self.n))
            for j in range(i+1, self.n):
                row[j-i-1, i] = 1
                row[j-i-1, j] = -1
            rows.append(row)
        Delta = sparse.vstack(rows)
        A = sparse.kron(Delta, np.eye(self.p))
        return A
    
    # 计算g
    def gen_g(self):
        time_obs = self.time_status[:, 0]
        # R是风险集,每行都是第i个元素的风险集
        R = (time_obs[:, np.newaxis] >= time_obs).astype(int)
        g = R @ self.time_status[:, 1].reshape(-1,1)
        return g
    
    # 计算nabla_g
    def gen_nabla_g(self, Y2):
        time_obs = self.time_status[:, 0].flatten()
        status = self.time_status[:, 1].reshape(-1,1)
        
        
        exp_Y2 = np.exp(Y2)
        R = (time_obs[:, np.newaxis] <= time_obs).astype(int)
        sum_l_Rk = 1 / (R @ exp_Y2)
        R_rev = (time_obs[:, np.newaxis] >= time_obs).astype(int)
        nabla_g = -status + exp_Y2 * (R_rev @ (status * sum_l_Rk))
        return nabla_g
    
    # 计算c
    def gen_c(self, beta, nu, A):
        delta_beta = A @ beta.flatten(order = 'F').reshape(-1,1)
        c = delta_beta + nu / self.theta
        # 每行都是beta_i - beta_k
        c = c.reshape(-1, self.p)
        return c, delta_beta
    
    # gamma迭代式
    def iter_gamma(self, B_ols, Y_current, beta_current, w_current):
        X_beta = np.einsum('ij,ji->i', self.X, beta_current).reshape(-1, 1)
        gamma_next = B_ols @ (Y_current - X_beta + w_current / self.theta)
        return gamma_next
    
    # beta迭代式
    def iter_beta(self, XQX_AA, XQ, A, w_current, Y_current, u_current, nu_current):
        beta_next = XQX_AA @ (XQ @ (w_current / self.theta + Y_current) + A.T @ (u_current - nu_current / self.theta))
        beta_next = beta_next.reshape(self.p, -1, order = 'F')
        return beta_next
    
    # Y2迭代式
    def iter_Y2(self, B, g, beta_next, gamma_next):
        term_1 = np.einsum('ij,ji->i', self.X, beta_next).reshape(-1, 1)
        term_2 = B @ gamma_next
        Y2_next = term_1 + term_2
        return Y2_next
    
    # Y迭代式
    def iter_Y(self, g, Y2_next, w_current):
        nabla_g_next = self.gen_nabla_g(Y2_next)
        Y_next = 1/(g + self.theta) * (-nabla_g_next + g * Y2_next - w_current + self.theta * Y2_next)
        return Y_next
    
    # u迭代式
    def penalty_fun(self):
        def S(c_ik, lambda_val):
            result = np.max([(1 - lambda_val / np.linalg.norm(c_ik, ord = 2)), 0]) * c_ik
            return result
        
        def SCAD(c):
            if self.a is None:
                self.a = 3.7
            lamba_lambda_theta = self.lambda_value + self.lambda_value / self.theta
            a_lambda = self.a * self.lambda_value
            
            norm_c_vec = np.linalg.norm(c, ord = 2, axis = 1)
            
            cond_1 = norm_c_vec <= lamba_lambda_theta
            cond_2 = (lamba_lambda_theta < norm_c_vec) & (norm_c_vec <= a_lambda)
            cond_3 = norm_c_vec > a_lambda
            
            u_next = np.copy(c)
            if np.any(cond_1):
                u_next[cond_1, :] = np.apply_along_axis(S, axis = 1, arr = c[cond_1, :], lambda_val = self.lambda_value/self.theta)
            if np.any(cond_2):
                u_next[cond_2, :] = np.apply_along_axis(S, axis = 1, arr = c[cond_2, :], lambda_val = self.a * self.lambda_value/((self.a - 1) * self.theta)) / (1 - 1 / ((self.a-1) * self.theta))
            if np.any(cond_3):
                u_next[cond_3, :] = c[cond_3, :]
            return u_next
        
        def MCP(c):
            if self.a is None:
                self.a = 2.5
            a_lambda = self.a * self.lambda_value
            
            norm_c_vec = np.linalg.norm(c, ord = 2, axis = 1)
            
            cond_1 = norm_c_vec <= a_lambda
            cond_2 = norm_c_vec > a_lambda
            
            u_next = np.copy(c)
            if np.any(cond_1):
                u_next[cond_1, :] = np.apply_along_axis(S, axis = 1, arr = c[cond_1, :], lambda_val = self.lambda_value / self.theta) / (1 - 1 / (self.a * self.theta))
            if np.any(cond_2):
                u_next[cond_2, :] = c[cond_2, :]
            return u_next
        
        if self.penalty == 'SCAD':
            return SCAD
        elif self.penalty == 'MCP':
            return MCP
        else:
            raise ValueError("Invalid penalty function type")
    
    def iter_u(self, beta_next, nu_current, A, penalty_fun):
        c, delta_beta = self.gen_c(beta_next, nu_current, A)
        u_next = penalty_fun(c).flatten().reshape(-1,1)
        return u_next, delta_beta
    
    # w迭代式
    def iter_w(self, w_current, Y_next, Y2_next):
        w_next = w_current + self.theta * (Y_next - Y2_next)
        return w_next
    
    # nu迭代式
    def iter_nu(self, nu_current, delta_beta, u_next):
        u_next = u_next.flatten().reshape(-1,1)
        nu_next = nu_current + self.theta * (delta_beta - u_next)
        return nu_next
        
    # 主函数——运行
    def run(self, trace = True):
        start_time = time.time()
        
        # 获取初始值
        self.beta, B, self.gamma, self.Y = self.initial_value()
        self.Y2 = self.Y
        A = self.gen_A()
        self.u = A @ self.beta.flatten(order = 'F').reshape(-1,1)
        self.w = np.zeros((self.n, 1))
        self.nu = np.zeros((int(self.n * (self.n - 1) * self.p / 2), 1))
        
        # 计算Q矩阵
        Q = np.eye(self.n) - B @ np.linalg.inv(B.T @ B) @ B.T
        
        # 计算g
        g = self.gen_g()
        
        # 计算(B'B)^(-1)B'
        B_ols = np.linalg.inv(B.T @ B) @ B.T
        
        # 计算(X'QX+A'A)^{-1}与X'Q
        X_diag = np.split(self.X, self.n, axis = 0)
        X_diag = block_diag(*X_diag)
        XQX_AA = np.linalg.inv(X_diag.T @ Q @ X_diag + A.T @ A)
        XQ = X_diag.T @ Q
        
        # 生成惩罚函数
        penalty_fun = self.penalty_fun()
        
        for i in range(self.max_iter):
            if trace:
                print(f'第[{i+1}]次迭代')
            self.gamma = self.iter_gamma(B_ols, self.Y, self.beta, self.w)
            self.beta = self.iter_beta(XQX_AA, XQ, A, self.w, self.Y, self.u, self.nu)
            self.Y2 = self.iter_Y2(B, g, self.beta, self.gamma)
            self.Y = self.iter_Y(g, self.Y2, self.w)
            self.u, delta_beta = self.iter_u(self.beta, self.nu, A, penalty_fun)
            self.w = self.iter_w(self.w, self.Y, self.Y2)
            self.nu = self.iter_nu(self.nu, delta_beta, self.u)
            
            # 终止条件
            term_1 = delta_beta - self.u
            term_2 = self.Y - self.Y2
            norm_r = np.linalg.norm(term_1, ord=2) + np.linalg.norm(term_2, ord = 2)
            if norm_r <= self.tol:
                if trace == True:
                    print('达到精度要求')
                break
        
        # beta亚组
        self.beta = np.round(self.beta, 3)
        subgroup_beta = np.unique(self.beta, axis = 1, return_counts = True)
        size = subgroup_beta[1]
        subgroup_beta = subgroup_beta[0].T
        self.K = subgroup_beta.shape[0]
        df = pd.DataFrame([{'beta': tuple(row.tolist())} for row in subgroup_beta])
        df['size'] = size
        df = df.sort_values(by = 'size', ascending = False).reset_index(drop = True)
        df['label'] = range(self.K)
        
        df_label = pd.DataFrame([{'beta': tuple(row.tolist())} for row in self.beta.T])
        df_label = df_label.merge(df, on = 'beta', how = 'left')
        label = df_label['label'].tolist()
        
        result = {
            'beta' : self.beta.T,
            'gamma' : self.gamma,
            'alpha' : df,
            'K' : self.K,
            'label' : label
        }
        
        print(f'耗时:{time.time() - start_time:.2f}s')
        return result
    
    # 主函数——调优
    def tune_lambda(self, seq_lambda, seed = None, trace = True):
        start_time = time.time()
        
        # 获取初始值
        beta_0, B, gamma_0, Y_0 = self.initial_value()
        A = self.gen_A()
        u_0 = A @ beta_0.flatten(order = 'F').reshape(-1,1)
        w_0 = np.zeros((self.n, 1))
        nu_0 = np.zeros((int(self.n * (self.n - 1) * self.p / 2), 1))
        
        # 计算Q矩阵
        Q = np.eye(self.n) - B @ np.linalg.inv(B.T @ B) @ B.T
        
        # 计算g
        g = self.gen_g()
        
        # 计算(B'B)^(-1)B'
        B_ols = np.linalg.inv(B.T @ B) @ B.T
        
        # 计算(X'QX+A'A)^{-1}与X'Q
        X_diag = np.split(self.X, self.n, axis = 0)
        X_diag = block_diag(*X_diag)
        XQX_AA = np.linalg.inv(X_diag.T @ Q @ X_diag + A.T @ A)
        XQ = X_diag.T @ Q
        
        # 生成惩罚函数
        penalty_fun = self.penalty_fun()
        
        # 存储结果
        bic_ls = []
        bic_log_ls = []
        bic_c_ls = []
        bic_logk_ls = []
        result_ls = []
        K_ls = []
        
        for i in range(len(seq_lambda)):
            self.lambda_value = seq_lambda[i]
            
            # 初始化
            self.beta = beta_0
            self.gamma = gamma_0
            self.Y = Y_0
            self.Y2 = Y_0
            self.u = u_0
            self.w = w_0
            self.nu = nu_0
            
            for j in range(self.max_iter):
                if trace:
                    print(f'seed_[{seed+1}]: lambda={seq_lambda[i]}--[{j+1}]')
                self.gamma = self.iter_gamma(B_ols, self.Y, self.beta, self.w)
                self.beta = self.iter_beta(XQX_AA, XQ, A, self.w, self.Y, self.u, self.nu)
                self.Y2 = self.iter_Y2(B, g, self.beta, self.gamma)
                self.Y = self.iter_Y(g, self.Y2, self.w)
                self.u, delta_beta = self.iter_u(self.beta, self.nu, A, penalty_fun)
                self.w = self.iter_w(self.w, self.Y, self.Y2)
                self.nu = self.iter_nu(self.nu, delta_beta, self.u)
                
                # 终止条件
                term_1 = delta_beta - self.u
                term_2 = self.Y - self.Y2
                norm_r = np.linalg.norm(term_1, ord=2) + np.linalg.norm(term_2, ord = 2)
                if norm_r <= self.tol:
                    if trace == True:
                        print('达到精度要求')
                    break
            
            # beta亚组
            self.beta = np.round(self.beta, 3)
            subgroup_beta = np.unique(self.beta, axis = 1, return_counts = True)
            size = subgroup_beta[1]
            subgroup_beta = subgroup_beta[0].T
            self.K = subgroup_beta.shape[0]
            df = pd.DataFrame([{'beta': tuple(row.tolist())} for row in subgroup_beta])
            df['size'] = size
            df = df.sort_values(by = 'size', ascending = False).reset_index(drop = True)
            df['label'] = range(self.K)
            
            df_label = pd.DataFrame([{'beta': tuple(row.tolist())} for row in self.beta.T])
            df_label = df_label.merge(df, on = 'beta', how = 'left')
            label = df_label['label'].tolist()
            
            # 计算bic
            time_obs = self.time_status[:, 0]
            status = self.time_status[:, 1].reshape(1,-1)
            R = (time_obs[:, np.newaxis] <= time_obs).astype(int)
            term_1 = - status @ (self.Y2 - np.log(R @ np.exp(self.Y2)))
            term_2 = np.log(self.n * self.K + self.q) * np.log(self.n) * (self.K * self.p + self.q) / self.n
            bic_value = term_1 + term_2
            bic_value = round(bic_value.item(), 3)
            bic_ls.append(bic_value)
            
            
            term_2 = np.log(np.log(self.n * self.K + self.q)) * np.log(self.n) * (self.K * self.p + self.q) / self.n
            bic_log_value = term_1 + term_2
            bic_log_value = round(bic_log_value.item(), 3)
            bic_log_ls.append(bic_log_value)
            
            
            term_2 = 0.5 * np.log(self.n * self.K + self.q) * np.log(self.n) * (self.K * self.p + self.q) / self.n
            bic_c_value = term_1 + term_2
            bic_c_value = round(bic_c_value.item(), 3)
            bic_c_ls.append(bic_c_value)
            
            term_2 = np.log(self.K) * np.log(self.n * self.K + self.q) * np.log(self.n) * (self.K * self.p + self.q) / self.n
            bic_logk_value = term_1 + term_2
            bic_logk_value = round(bic_logk_value.item(), 3)
            bic_logk_ls.append(bic_logk_value)
            
            K_ls.append(self.K)
            
            result = {
                'beta' : self.beta.T,
                'gamma' : self.gamma,
                'alpha' : df,
                'K' : self.K,
                'label' : label,
                'bic' : bic_value
            }
            result_ls.append(result)
        
        best_index = bic_ls.index(min(bic_ls))
        best_result = result_ls[best_index]
        print(f'总耗时{time.time() - start_time:.2f}s')
        tune_result = {'bic' : bic_ls, 'result' : result_ls, 'best_result' : best_result, 'bic_log' : bic_log_ls, 'bic_c' : bic_c_ls, 'K_ls' : K_ls, 'bic_logk' : bic_logk_ls}
        
        return tune_result
        
if __name__ == "__main__":
    
    total_start_time = time.time()
    
    def gen_data(seed, n=100):
        np.random.seed(seed)
        X = np.random.normal(loc = 0, scale = 1, size = (n, 2))
        Z_1 = np.random.uniform(low = 0, high = 1, size = (n, 1))
        f_Z1 = np.sin(np.pi * (Z_1-0.5))
        Z_2 = np.random.uniform(low = 0, high = 1, size = (n, 1))
        f_Z2 = np.cos(np.pi * (Z_2 - 0.5)) - 2/np.pi
        Z = np.hstack((Z_1, Z_2))
        
        kmeans = KMeans(2, random_state = 1)
        kmeans.fit(X)
        labels = kmeans.labels_
        X_1 = X[labels == 0, :]
        X_2 = X[labels == 1, :]
        X = np.vstack((X_1, X_2))
        
        beta = 3 * np.ones((int(n/2),2))
        beta = np.vstack((-beta, beta))
        
        log_U = np.log(np.random.uniform(0, 1, size = (n,1)))
        X_beta = np.einsum('ij,ji->i', X, beta.T).reshape(-1,1)
        time_obs = -np.exp(-X_beta - f_Z1 - f_Z2) * log_U
        status = np.random.choice([0,1], size = (n,1), p = [0.2, 0.8])
        time_status = np.hstack((time_obs, status))
       
        return time_status, X, Z

    np.random.seed(564)
    seed_ls = np.random.randint(low = 1, high = 1000, size= 10).tolist()
    seq_lambda = np.arange(0.04, 0.075, 0.005)
    bic_mat = np.zeros((len(seed_ls), len(seq_lambda)))
    K_mat = np.zeros((len(seed_ls), len(seq_lambda)))
    result_ls = []
    
    for i in range(len(seed_ls)):
        seed = seed_ls[i]
        time_status, X, Z = gen_data(n=100,seed = seed)
        model = SubgroupBeta(time_status, X, Z, lambda_value = 0.06, tol = 1e-3)
        result_tune = model.tune_lambda(seq_lambda, seed = i)
        
        bic_mat[i] = result_tune['bic']
        K_mat[i] = result_tune['K_ls']
        result_ls.append(result_tune['result'])
    
    seed_ls = list(map(str, seed_ls))
    seq_lambda = list(map(str, seq_lambda.tolist()))
    bic_mat = pd.DataFrame(bic_mat, index = seed_ls, columns = seq_lambda)
    K_mat = pd.DataFrame(K_mat, index = seed_ls, columns = seq_lambda)
    
    total_end_time = time.time()
    
    print(f'所有种子总计耗时:{total_end_time - total_start_time}')

17.3.2 数据模拟

# case_3
# 有顺序调整
set.seed(123)
x_1 <- rnorm(100)
x_2 <- rnorm(100)
X <- cbind(x_1, x_2)
X_cluster <- kmeans(X, centers = 2)
X_1 <- X[which(X_cluster$cluster == 1),]
X_2 <- X[which(X_cluster$cluster == 2),]
X <- rbind(X_1,X_2)
X_diag <- map(asplit(X,1), ~matrix(., nrow = 1, byrow = T))
X_diag <- do.call(bdiag, X_diag) %>% as.matrix()
z_1 <- runif(100)
z_2 <- runif(100)
f_z1 <- sin((z_1 - 0.5) * pi)
f_z2 <- cos((z_2 - 0.5) * pi) - 2 / pi
beta <- c(rep(c(3,3), times = 50), rep(c(-3,-3), times = 50))
time <- as.vector(-exp(-X_diag %*% beta - f_z1 - f_z2)) * log(runif(100))
status <- sample(c(1,0), size = 100, replace = TRUE, prob = c(0.8, 0.2))

T <- cbind(time, status)
Z <- cbind(z_1, z_2)

case3 <- SubgroupBeta$new(T = T, X = X, Z = Z, penalty = 'MCP', K = 2, lambda = 0.1, a = 2.5, theta = 1, df = 6, degree = 3)
result3 <- case3$run(trace = FALSE)
##   用户   系统   流逝 
## 119.50   5.81 127.66
result3$alpha
##   label size         beta
## 1     1   56   2.26, 2.19
## 2     2   44 -2.56, -2.54
result3$alpha$beta
## [[1]]
## [1] 2.26 2.19
## 
## [[2]]
## [1] -2.56 -2.54

References

[33]
CAI T. HU T. Subgroup detection in the heterogeneous partially linear additive Cox model[J/OL]. Journal of Nonparametric Statistics, 2024, 0(0): 1-26. DOI:10.1080/10485252.2024.2303103.