covid19model/plot-forecast.r

149 строки
6.6 KiB
R
Исходник Обычный вид История

library(ggplot2)
library(tidyr)
library(dplyr)
library(rstan)
library(data.table)
library(lubridate)
library(gdata)
library(EnvStats)
library(matrixStats)
library(scales)
library(gridExtra)
library(ggpubr)
library(bayesplot)
library(cowplot)
2020-03-29 17:05:35 +03:00
source("utils/geom-stepribbon.r")
#---------------------------------------------------------------------------
make_forecast_plot <- function(){
args <- commandArgs(trailingOnly = TRUE)
filename <- args[1]
load(paste0("results/", filename))
2020-04-09 18:18:13 +03:00
for(i in 1:14){
N <- length(dates[[i]])
2020-03-29 17:05:35 +03:00
N2 <- N + 7
country <- countries[[i]]
predicted_cases <- colMeans(prediction[,1:N,i])
predicted_cases_li <- colQuantiles(prediction[,1:N,i], probs=.025)
predicted_cases_ui <- colQuantiles(prediction[,1:N,i], probs=.975)
estimated_deaths <- colMeans(estimated.deaths[,1:N,i])
estimated_deaths_li <- colQuantiles(estimated.deaths[,1:N,i], probs=.025)
estimated_deaths_ui <- colQuantiles(estimated.deaths[,1:N,i], probs=.975)
estimated_deaths_forecast <- colMeans(estimated.deaths[,1:N2,i])[N:N2]
estimated_deaths_li_forecast <- colQuantiles(estimated.deaths[,1:N2,i], probs=.025)[N:N2]
estimated_deaths_ui_forecast <- colQuantiles(estimated.deaths[,1:N2,i], probs=.975)[N:N2]
2020-04-09 18:00:16 +03:00
rt <- colMeans(out$Rt_adj[,1:N,i])
rt_li <- colQuantiles(out$Rt_adj[,1:N,i],probs=.025)
rt_ui <- colQuantiles(out$Rt_adj[,1:N,i],probs=.975)
data_country <- data.frame("time" = as_date(as.character(dates[[i]])),
"country" = rep(country, length(dates[[i]])),
#"country_population" = rep(country_population, length(dates[[i]])),
"reported_cases" = reported_cases[[i]],
"reported_cases_c" = cumsum(reported_cases[[i]]),
"predicted_cases_c" = cumsum(predicted_cases),
"predicted_min_c" = cumsum(predicted_cases_li),
"predicted_max_c" = cumsum(predicted_cases_ui),
"predicted_cases" = predicted_cases,
"predicted_min" = predicted_cases_li,
"predicted_max" = predicted_cases_ui,
"deaths" = deaths_by_country[[i]],
"deaths_c" = cumsum(deaths_by_country[[i]]),
"estimated_deaths_c" = cumsum(estimated_deaths),
"death_min_c" = cumsum(estimated_deaths_li),
"death_max_c"= cumsum(estimated_deaths_ui),
"estimated_deaths" = estimated_deaths,
"death_min" = estimated_deaths_li,
"death_max"= estimated_deaths_ui,
"rt" = rt,
"rt_min" = rt_li,
"rt_max" = rt_ui)
times <- as_date(as.character(dates[[i]]))
2020-03-29 17:05:35 +03:00
times_forecast <- times[length(times)] + 0:7
data_country_forecast <- data.frame("time" = times_forecast,
2020-03-29 17:05:35 +03:00
"country" = rep(country, 8),
"estimated_deaths_forecast" = estimated_deaths_forecast,
"death_min_forecast" = estimated_deaths_li_forecast,
"death_max_forecast"= estimated_deaths_ui_forecast)
make_single_plot(data_country = data_country,
data_country_forecast = data_country_forecast,
filename = filename,
country = country)
}
}
make_single_plot <- function(data_country, data_country_forecast, filename, country){
data_deaths <- data_country %>%
select(time, deaths, estimated_deaths) %>%
gather("key" = key, "value" = value, -time)
data_deaths_forecast <- data_country_forecast %>%
select(time, estimated_deaths_forecast) %>%
gather("key" = key, "value" = value, -time)
# Force less than 1 case to zero
data_deaths$value[data_deaths$value < 1] <- NA
data_deaths_forecast$value[data_deaths_forecast$value < 1] <- NA
data_deaths_all <- rbind(data_deaths, data_deaths_forecast)
p <- ggplot(data_country) +
geom_bar(data = data_country, aes(x = time, y = deaths),
fill = "coral4", stat='identity', alpha=0.5) +
geom_line(data = data_country, aes(x = time, y = estimated_deaths),
col = "deepskyblue4") +
geom_line(data = data_country_forecast,
aes(x = time, y = estimated_deaths_forecast),
col = "black", alpha = 0.5) +
geom_ribbon(data = data_country, aes(x = time,
ymin = death_min,
ymax = death_max),
fill="deepskyblue4", alpha=0.3) +
geom_ribbon(data = data_country_forecast,
aes(x = time,
ymin = death_min_forecast,
ymax = death_max_forecast),
2020-03-29 17:05:35 +03:00
fill = "black", alpha=0.35) +
geom_vline(xintercept = data_deaths$time[length(data_deaths$time)],
col = "black", linetype = "dashed", alpha = 0.5) +
#scale_fill_manual(name = "",
2020-03-29 17:05:35 +03:00
# labels = c("Confirmed deaths", "Predicted deaths"),
# values = c("coral4", "deepskyblue4")) +
xlab("Date") +
ylab("Daily number of deaths\n") +
scale_x_date(date_breaks = "weeks", labels = date_format("%e %b")) +
scale_y_continuous(trans='log10', labels=comma) +
coord_cartesian(ylim = c(1, 100000), expand = FALSE) +
theme_pubr(base_family="sans") +
theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
guides(fill=guide_legend(ncol=1, reverse = TRUE)) +
annotate(geom="text", x=data_country$time[length(data_country$time)]+8,
y=10000, label="",
color="black")
print(p)
ggsave(file= paste0("figures/", country, "_forecast_", filename, ".png"),
2020-03-29 17:05:35 +03:00
p, width = 10)
# Produce plots for Website
dir.create("web/figures/desktop/", showWarnings = FALSE, recursive = TRUE)
save_plot(filename = paste0("web/figures/desktop/", country, "_forecast", ".svg"),
p, base_height = 4, base_asp = 1.618 * 2 * 8/12)
dir.create("web/figures/mobile/", showWarnings = FALSE, recursive = TRUE)
save_plot(filename = paste0("web/figures/mobile/", country, "_forecast", ".svg"),
p, base_height = 4, base_asp = 1.1)
}
#-----------------------------------------------------------------------------------------------
make_forecast_plot()