via PharmCat on Github (this link has similar code with comments in Russian and English).
begin
using LsqFit
using MortalityTables
using Plots
using Distributions
using Optim
using DataFrames
using PlutoUI; TableOfContents()
end Sample data:
data = let
survival = [0.99,0.98,0.95,0.9,0.8,0.65,0.5,0.38,0.25,0.2,0.1,0.05,0.02,0.01]
times = 1:length(survival)
DataFrame(;times,survival)
end | times | survival | |
|---|---|---|
| 1 | 1 | 0.99 |
| 2 | 2 | 0.98 |
| 3 | 3 | 0.95 |
| 4 | 4 | 0.9 |
| 5 | 5 | 0.8 |
| 6 | 6 | 0.65 |
| 7 | 7 | 0.5 |
| 8 | 8 | 0.38 |
| 9 | 9 | 0.25 |
| 10 | 10 | 0.2 |
| 11 | 11 | 0.1 |
| 12 | 12 | 0.05 |
| 13 | 13 | 0.02 |
| 14 | 14 | 0.01 |
plt = plot(data.times,data.survival,label="observed survival proportion",xlabel="time")
Define the two-parameter Weibull model:
x: array of independent variables
p: array of model parameters
model(x, p) will accept the full data set as the first argument x. This means that we need to write our model function so it applies the model to the full dataset. We use @. to apply ("broadcast") the calculations across all rows.
@. model1(x, p) = survival(MortalityTables.Weibull(;m = p[1],σ = p[2]), x)
model1 (generic function with 1 method)
And fit the model with LsqFit.jl:
fit1 = curve_fit(model1, data.times, data.survival, [1.0, 1.0])
LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Float64}}([8.147278225551966, 2.850707997584664], [0.007512289530944849, 0.002103664528583238, -0.005912835826727303, -0.022715351170030318, -0.019432525084235097, 0.00892696470217913, 0.02305783722851895, 0.0070507784438295085, 0.014721160196872363, -0.03405121675005396, -0.0054321263025712, -0.0014077035269912178, 0.0023337166511927764, -0.0008907842988607542], [0.0026998722498318465 -0.005225236787378377; 0.0149594476524495 -0.024973364570331415; … ; 0.01586690256401684 0.039774503565516456; 0.006885632016938745 0.023229578399732494], true, Float64[])
plot!(plt,data.times, model1(data.times, fit1.param),label="fitted model")
Generate 100 sample datapoints:
t = rand(Weibull(fit1.param[2], fit1.param[1]), 100)
100-element Vector{Float64}:
2.556259607514606
2.5309311482725474
2.4455370093876034
2.4970570725324306
5.618234559090627
4.21231296819365
7.783889837397662
⋮
8.500761796101846
5.088117085444477
7.93132708622475
12.589368515560968
4.6399729274509856
7.80225583512748
#No censored data
fit_mle(Weibull, t)
Distributions.Weibull{Float64}(α=2.679171649539969, θ=7.737055449262706)
c = collect(trues(100))
100-element Vector{Bool}:
1
1
1
1
1
1
1
⋮
1
1
1
1
1
1
c[[1,3,7,9]] .= false
4-element view(::Vector{Bool}, [1, 3, 7, 9]) with eltype Bool:
0
0
0
0
#ML function
survmle(x) = begin
ml = 0.0
for i = 1:length(t)
if c[i]
ml += logpdf(Weibull(x[2], x[1]), t[i]) #if not censored log(f(x))
else
ml += logccdf(Weibull(x[2], x[1]), t[i]) #if censored log(1-F)
end
end
-ml
end
survmle (generic function with 1 method)
opt = Optim.optimize(
survmle, # function to optimize
[1.0,1.0], # lower bound
[15.,15.], # upper bound
[3.,3.] # initial guess
)
* Status: success
* Candidate solution
Final objective value: 2.372043e+02
* Found with
Algorithm: Fminbox with L-BFGS
* Convergence measures
|x - x'| = 8.40e-08 ≰ 0.0e+00
|x - x'|/|x'| = 1.01e-08 ≰ 0.0e+00
|f(x) - f(x')| = 0.00e+00 ≤ 0.0e+00
|f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
|g(x)| = 7.73e-09 ≤ 1.0e-08
* Work counters
Seconds run: 1 (vs limit Inf)
Iterations: 4
f(x) calls: 45
∇f(x) calls: 45
The solution converges to similar values as the function generating the synthetic data:
Optim.minimizer(opt)
2-element Vector{Float64}:
7.875657121975123
2.732699769431714
KaplanMeier comes from Survival.jl
#4
#Подгонка модели по эмпирической функции KM
#Fitting for survival function from Kaplan Meier
using Survival
#t- time vector;c - censored events vector
km = fit(Survival.KaplanMeier, t, c)
KaplanMeier{Float64}([0.7286945555964006, 0.8839999817215857, 1.218290583719745, 1.3724260990544466, 2.4455370093876034, 2.4645189345841256, 2.4970570725324306, 2.5309311482725474, 2.556259607514606, 2.839493850504482 … 10.747258781393816, 10.98718457817398, 11.052114559824426, 11.222495431503242, 11.297456019023278, 11.486065171744436, 11.615288776198556, 11.657078352388332, 12.589368515560968, 13.028749195785995], [1, 1, 1, 1, 0, 1, 1, 1, 0, 1 … 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 1, 0, 0, 0, 1, 0 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [100, 99, 98, 97, 96, 95, 94, 93, 92, 91 … 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], [0.99, 0.98, 0.97, 0.96, 0.96, 0.9498947368421052, 0.9397894736842105, 0.9296842105263158, 0.9296842105263158, 0.9194679005205322 … 0.10162539953121666, 0.09033368847219259, 0.07904197741316851, 0.06775026635414444, 0.056458555295120366, 0.04516684423609629, 0.03387513317707222, 0.02258342211804815, 0.011291711059024075, 0.011291711059024075], [0.01005037815259212, 0.014285714285714285, 0.017586311452816476, 0.020412414523193152, 0.020412414523193152, 0.022992362852334424, 0.02535821463029275, 0.027566575677517274, 0.027566575677517274, 0.02969875783066371 … 0.307754361284321, 0.3295476229293867, 0.3556104309993621, 0.3876445568366412, 0.42848761449825973, 0.4833235311656254, 0.5629697763750222, 0.6954147221467267, 0.9917669261365843, 0.9917669261365843])
plt2 = plot(km.times, km.survival; labels="Empirical")
@. model(x, p) = survival(MortalityTables.Weibull(;m = p[1],σ = p[2]), x)
model (generic function with 1 method)
mfit = LsqFit.curve_fit(model, km.times, km.survival, [2.0, 2.0])
LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Float64}}([7.972766840162683, 2.8666885686740295], [0.008712043806245906, 0.017796761664607974, 0.024632324737230538, 0.03253172375255775, 0.0033129861993818066, 0.012635925169057982, 0.0213765047064024, 0.030029134385761935, 0.028921564359510987, 0.025480349412919656 … -0.0008130493213439338, -0.003155737675752862, 0.004691616638689486, 0.0074263016157703615, 0.015168851746469036, 0.018092584179631235, 0.024099241601380016, 0.03375713709653691, 0.0170723102097812, 0.00856349134655932], [0.0015232271475632741 -0.0029876315312149706; 0.002456184220284718 -0.004695921569508153; … ; 0.01914702755899436 0.044785061481114; 0.013813721434141207 0.0370785977974916], true, Float64[])
plot!(plt2,km.times, model(km.times, mfit.param), labels="Theoretical")
Built with Julia 1.8.0 and
DataFrames 1.3.2To run this page locally, download this file and open it with Pluto.jl.