From 2f95e8aab0a6f18c60f0965a9e140e8d0c0950e4 Mon Sep 17 00:00:00 2001 From: Lanius Trolling Date: Tue, 30 Apr 2024 08:13:09 -0400 Subject: [PATCH] Rework CSRF and session tracking --- .../kotlin/info/mechyrdia/Factbooks.kt | 2 + .../info/mechyrdia/auth/SessionStorage.kt | 9 ++- .../kotlin/info/mechyrdia/auth/Sessions.kt | 59 +++++++++++++++++++ .../kotlin/info/mechyrdia/auth/sessions.kt | 10 ---- .../kotlin/info/mechyrdia/data/Nations.kt | 4 +- .../kotlin/info/mechyrdia/data/Visits.kt | 18 +----- .../kotlin/info/mechyrdia/lore/ViewNav.kt | 2 +- .../kotlin/info/mechyrdia/robot/ViewsRobot.kt | 2 +- .../info/mechyrdia/route/ResourceCsrf.kt | 52 ++++------------ 9 files changed, 83 insertions(+), 75 deletions(-) create mode 100644 src/jvmMain/kotlin/info/mechyrdia/auth/Sessions.kt delete mode 100644 src/jvmMain/kotlin/info/mechyrdia/auth/sessions.kt diff --git a/src/jvmMain/kotlin/info/mechyrdia/Factbooks.kt b/src/jvmMain/kotlin/info/mechyrdia/Factbooks.kt index 87f44ce..63f7f44 100644 --- a/src/jvmMain/kotlin/info/mechyrdia/Factbooks.kt +++ b/src/jvmMain/kotlin/info/mechyrdia/Factbooks.kt @@ -37,6 +37,7 @@ import org.slf4j.event.Level import java.io.IOException import java.util.concurrent.atomic.AtomicLong import kotlin.random.Random +import kotlin.time.Duration.Companion.hours fun main() { System.setProperty("logback.statusListenerClass", "ch.qos.logback.core.status.NopStatusListener") @@ -108,6 +109,7 @@ fun Application.factbooks() { serializer = KotlinxSessionSerializer(UserSession.serializer(), JsonStorageCodec) + cookie.maxAge = 336.hours cookie.secure = true cookie.httpOnly = true cookie.extensions["SameSite"] = "Lax" diff --git a/src/jvmMain/kotlin/info/mechyrdia/auth/SessionStorage.kt b/src/jvmMain/kotlin/info/mechyrdia/auth/SessionStorage.kt index c8645d1..c0ab266 100644 --- a/src/jvmMain/kotlin/info/mechyrdia/auth/SessionStorage.kt +++ b/src/jvmMain/kotlin/info/mechyrdia/auth/SessionStorage.kt @@ -1,5 +1,6 @@ package info.mechyrdia.auth +import info.mechyrdia.JsonStorageCodec import info.mechyrdia.data.* import io.ktor.server.sessions.* import kotlinx.serialization.SerialName @@ -11,11 +12,13 @@ object SessionStorageMongoDB : SessionStorage { } override suspend fun read(id: String): String { - return SessionStorageDoc.Table.get(Id(id))?.session ?: throw NoSuchElementException("Session $id not found") + val value = SessionStorageDoc.Table.get(Id(id))?.session ?: throw NoSuchElementException("Session $id not found") + return JsonStorageCodec.encodeToString(UserSession.serializer(), value) } override suspend fun write(id: String, value: String) { - SessionStorageDoc.Table.put(SessionStorageDoc(Id(id), value)) + val session = JsonStorageCodec.decodeFromString(UserSession.serializer(), value) + SessionStorageDoc.Table.put(SessionStorageDoc(Id(id), session)) } } @@ -23,7 +26,7 @@ object SessionStorageMongoDB : SessionStorage { data class SessionStorageDoc( @SerialName(MONGODB_ID_KEY) override val id: Id, - val session: String + val session: UserSession, ) : DataDocument { companion object : TableHolder { override val Table = DocumentTable() diff --git a/src/jvmMain/kotlin/info/mechyrdia/auth/Sessions.kt b/src/jvmMain/kotlin/info/mechyrdia/auth/Sessions.kt new file mode 100644 index 0000000..7beb75c --- /dev/null +++ b/src/jvmMain/kotlin/info/mechyrdia/auth/Sessions.kt @@ -0,0 +1,59 @@ +package info.mechyrdia.auth + +import info.mechyrdia.data.Id +import info.mechyrdia.data.InstantSerializer +import info.mechyrdia.data.NationData +import io.ktor.server.application.* +import io.ktor.server.plugins.* +import io.ktor.server.sessions.* +import kotlinx.serialization.Serializable +import java.time.Instant + +@Serializable +data class CsrfTokenEntry( + val targetRoute: String, + val expiresAt: @Serializable(with = InstantSerializer::class) Instant, +) + +@Serializable +data class UserSession( + val nationId: Id? = null, + val csrfTokens: Map = emptyMap(), +) + +var ApplicationCall.currentUserSession: UserSession + get() = sessions.get() ?: UserSession().also { sessions.set(it) } + set(value) = sessions.set(value) + +suspend fun ApplicationCall.updateUserSession(session: UserSession) { + sessionId()?.let { + SessionStorageDoc.Table.put(SessionStorageDoc(Id(it), session)) + } +} + +const val DEFAULT_CSRF_TOKEN_EXPIRY_SECONDS = 7200 + +fun ApplicationCall.createCsrfToken(targetRoute: String = request.origin.uri, expireSeconds: Int = DEFAULT_CSRF_TOKEN_EXPIRY_SECONDS): String { + val token = token() + val entry = CsrfTokenEntry( + targetRoute = targetRoute, + expiresAt = Instant.now().plusSeconds(expireSeconds.toLong()) + ) + + currentUserSession = currentUserSession.let { sess -> + sess.copy(csrfTokens = sess.csrfTokens + (token to entry)) + } + + return token +} + +suspend fun ApplicationCall.retrieveCsrfToken(token: String): CsrfTokenEntry? { + val session = currentUserSession + val entry = session.csrfTokens[token] ?: return null + + updateUserSession(session.let { sess -> + sess.copy(csrfTokens = sess.csrfTokens - token) + }) + + return entry +} diff --git a/src/jvmMain/kotlin/info/mechyrdia/auth/sessions.kt b/src/jvmMain/kotlin/info/mechyrdia/auth/sessions.kt deleted file mode 100644 index f47b9c1..0000000 --- a/src/jvmMain/kotlin/info/mechyrdia/auth/sessions.kt +++ /dev/null @@ -1,10 +0,0 @@ -package info.mechyrdia.auth - -import info.mechyrdia.data.Id -import info.mechyrdia.data.NationData -import kotlinx.serialization.Serializable - -@Serializable -data class UserSession( - val nationId: Id, -) diff --git a/src/jvmMain/kotlin/info/mechyrdia/data/Nations.kt b/src/jvmMain/kotlin/info/mechyrdia/data/Nations.kt index 50e8ba9..6c0d73a 100644 --- a/src/jvmMain/kotlin/info/mechyrdia/data/Nations.kt +++ b/src/jvmMain/kotlin/info/mechyrdia/data/Nations.kt @@ -53,9 +53,7 @@ val CallNationCacheAttribute = AttributeKey, NationDat val ApplicationCall.nationCache: MutableMap, NationData> get() = attributes.computeIfAbsent(CallNationCacheAttribute) { - ConcurrentHashMap, NationData>().also { cache -> - attributes.put(CallNationCacheAttribute, cache) - } + ConcurrentHashMap, NationData>() } suspend fun MutableMap, NationData>.getNation(id: Id): NationData { diff --git a/src/jvmMain/kotlin/info/mechyrdia/data/Visits.kt b/src/jvmMain/kotlin/info/mechyrdia/data/Visits.kt index 5b185f5..614cd2e 100644 --- a/src/jvmMain/kotlin/info/mechyrdia/data/Visits.kt +++ b/src/jvmMain/kotlin/info/mechyrdia/data/Visits.kt @@ -4,12 +4,11 @@ import com.mongodb.client.model.Accumulators import com.mongodb.client.model.Aggregates import com.mongodb.client.model.Filters import com.mongodb.client.model.Updates +import info.mechyrdia.auth.UserSession import info.mechyrdia.lore.dateTime import io.ktor.server.application.* -import io.ktor.server.plugins.* import io.ktor.server.request.* import io.ktor.server.sessions.* -import io.ktor.util.* import kotlinx.coroutines.flow.firstOrNull import kotlinx.html.FlowContent import kotlinx.html.p @@ -17,7 +16,6 @@ import kotlinx.html.style import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import org.intellij.lang.annotations.Language -import java.security.MessageDigest import java.time.Instant @Serializable @@ -76,24 +74,12 @@ data class PageVisitData( } } -private val messageDigestProvider = ThreadLocal.withInitial { MessageDigest.getInstance("SHA-256") } - -fun ApplicationCall.anonymizedClientId(): String { - val messageDigest = messageDigestProvider.get() - - messageDigest.reset() - messageDigest.update(request.origin.remoteAddress.encodeToByteArray()) - request.userAgent()?.encodeToByteArray()?.let { messageDigest.update(it) } - - return hex(messageDigest.digest()) -} - suspend fun ApplicationCall.processGuestbook(): PageVisitTotals { val path = request.path() val totals = PageVisitData.totalVisits(path) if (!RobotDetector.isRobot(request.userAgent())) - PageVisitData.visit(path, anonymizedClientId()) + sessionId()?.let { PageVisitData.visit(path, it) } return totals } diff --git a/src/jvmMain/kotlin/info/mechyrdia/lore/ViewNav.kt b/src/jvmMain/kotlin/info/mechyrdia/lore/ViewNav.kt index 2aefb7c..ffa4d7b 100644 --- a/src/jvmMain/kotlin/info/mechyrdia/lore/ViewNav.kt +++ b/src/jvmMain/kotlin/info/mechyrdia/lore/ViewNav.kt @@ -2,13 +2,13 @@ package info.mechyrdia.lore import info.mechyrdia.JsonFileCodec import info.mechyrdia.OwnerNationId +import info.mechyrdia.auth.createCsrfToken import info.mechyrdia.data.FileStorage import info.mechyrdia.data.StoragePath import info.mechyrdia.data.currentNation import info.mechyrdia.robot.RobotService import info.mechyrdia.robot.RobotServiceStatus import info.mechyrdia.route.Root -import info.mechyrdia.route.createCsrfToken import info.mechyrdia.route.href import io.ktor.server.application.* import kotlinx.html.* diff --git a/src/jvmMain/kotlin/info/mechyrdia/robot/ViewsRobot.kt b/src/jvmMain/kotlin/info/mechyrdia/robot/ViewsRobot.kt index 4dffd6a..fb063b0 100644 --- a/src/jvmMain/kotlin/info/mechyrdia/robot/ViewsRobot.kt +++ b/src/jvmMain/kotlin/info/mechyrdia/robot/ViewsRobot.kt @@ -1,12 +1,12 @@ package info.mechyrdia.robot +import info.mechyrdia.auth.createCsrfToken import info.mechyrdia.data.currentNation import info.mechyrdia.lore.page import info.mechyrdia.lore.redirectHref import info.mechyrdia.lore.standardNavBar import info.mechyrdia.route.Root import info.mechyrdia.route.checkCsrfToken -import info.mechyrdia.route.createCsrfToken import info.mechyrdia.route.href import io.ktor.server.application.* import io.ktor.server.websocket.* diff --git a/src/jvmMain/kotlin/info/mechyrdia/route/ResourceCsrf.kt b/src/jvmMain/kotlin/info/mechyrdia/route/ResourceCsrf.kt index 80e9809..52db438 100644 --- a/src/jvmMain/kotlin/info/mechyrdia/route/ResourceCsrf.kt +++ b/src/jvmMain/kotlin/info/mechyrdia/route/ResourceCsrf.kt @@ -1,79 +1,49 @@ package info.mechyrdia.route -import info.mechyrdia.auth.UserSession -import info.mechyrdia.auth.token -import info.mechyrdia.data.Id -import info.mechyrdia.data.NationData +import info.mechyrdia.auth.createCsrfToken +import info.mechyrdia.auth.retrieveCsrfToken import io.ktor.server.application.* -import io.ktor.server.plugins.* import io.ktor.server.request.* -import io.ktor.server.sessions.* import kotlinx.html.A import kotlinx.html.FORM import kotlinx.html.FlowContent import kotlinx.html.hiddenInput import java.time.Instant -import java.util.concurrent.ConcurrentHashMap import kotlin.collections.set -data class CsrfPayload( - val route: String, - val remoteAddress: String, - val userAgent: String?, - val userAccount: Id?, - val expires: Instant -) - -fun ApplicationCall.csrfPayload(route: String, withExpiration: Instant = Instant.now().plusSeconds(7200)) = - CsrfPayload( - route = route, - remoteAddress = request.origin.remoteAddress, - userAgent = request.userAgent(), - userAccount = sessions.get()?.nationId, - expires = withExpiration - ) - -private val csrfMap = ConcurrentHashMap() - data class CsrfFailedException(override val message: String, val payload: CsrfProtectedResourcePayload?) : RuntimeException(message) interface CsrfProtectedResourcePayload { val csrfToken: String? - fun ApplicationCall.verifyCsrfToken(route: String = request.uri) { + suspend fun ApplicationCall.verifyCsrfToken(route: String = request.uri) { val token = csrfToken ?: throw CsrfFailedException("The submitted CSRF token is not present", this@CsrfProtectedResourcePayload) - val check = csrfMap.remove(token) ?: throw CsrfFailedException("The submitted CSRF token is not valid", this@CsrfProtectedResourcePayload) - val payload = csrfPayload(route, check.expires) - if (check != payload) + val entry = retrieveCsrfToken(token) ?: throw CsrfFailedException("The submitted CSRF token is not valid", this@CsrfProtectedResourcePayload) + if (entry.targetRoute != route) throw CsrfFailedException("The submitted CSRF token does not match", this@CsrfProtectedResourcePayload) - if (payload.expires < Instant.now()) + if (entry.expiresAt < Instant.now()) throw CsrfFailedException("The submitted CSRF token has expired", this@CsrfProtectedResourcePayload) } fun FlowContent.displayRetryData() {} } -fun ApplicationCall.checkCsrfToken(csrfToken: String?, route: String = request.uri): Boolean { +suspend fun ApplicationCall.checkCsrfToken(csrfToken: String?, route: String = request.uri): Boolean { val token = csrfToken ?: return false - val check = csrfMap.remove(token) ?: return false - val payload = csrfPayload(route, check.expires) - return check == payload && payload.expires >= Instant.now() -} - -fun ApplicationCall.createCsrfToken(route: String = request.origin.uri): String { - return token().also { csrfMap[it] = csrfPayload(route) } + val entry = retrieveCsrfToken(token) ?: return false + return entry.targetRoute == route && entry.expiresAt >= Instant.now() } context(ApplicationCall) fun A.installCsrfToken(route: String = href) { attributes["data-method"] = "post" - attributes["data-csrf-token"] = token().also { csrfMap[it] = csrfPayload(route) } + attributes["data-csrf-token"] = createCsrfToken(route) } context(ApplicationCall) fun FORM.installCsrfToken(route: String = action) { hiddenInput { name = "csrfToken" - value = token().also { csrfMap[it] = csrfPayload(route) } + value = createCsrfToken(route) } } -- 2.25.1