Rework CSRF and session tracking
authorLanius Trolling <lanius@laniustrolling.dev>
Tue, 30 Apr 2024 12:13:09 +0000 (08:13 -0400)
committerLanius Trolling <lanius@laniustrolling.dev>
Tue, 30 Apr 2024 12:13:09 +0000 (08:13 -0400)
src/jvmMain/kotlin/info/mechyrdia/Factbooks.kt
src/jvmMain/kotlin/info/mechyrdia/auth/SessionStorage.kt
src/jvmMain/kotlin/info/mechyrdia/auth/Sessions.kt [new file with mode: 0644]
src/jvmMain/kotlin/info/mechyrdia/auth/sessions.kt [deleted file]
src/jvmMain/kotlin/info/mechyrdia/data/Nations.kt
src/jvmMain/kotlin/info/mechyrdia/data/Visits.kt
src/jvmMain/kotlin/info/mechyrdia/lore/ViewNav.kt
src/jvmMain/kotlin/info/mechyrdia/robot/ViewsRobot.kt
src/jvmMain/kotlin/info/mechyrdia/route/ResourceCsrf.kt

index 87f44ce154f03ae46d8c61a261a484128f5583cc..63f7f44ff02d8f489a8b97b09fd5689e93b5e349 100644 (file)
@@ -37,6 +37,7 @@ import org.slf4j.event.Level
 import java.io.IOException
 import java.util.concurrent.atomic.AtomicLong
 import kotlin.random.Random
+import kotlin.time.Duration.Companion.hours
 
 fun main() {
        System.setProperty("logback.statusListenerClass", "ch.qos.logback.core.status.NopStatusListener")
@@ -108,6 +109,7 @@ fun Application.factbooks() {
                        
                        serializer = KotlinxSessionSerializer(UserSession.serializer(), JsonStorageCodec)
                        
+                       cookie.maxAge = 336.hours
                        cookie.secure = true
                        cookie.httpOnly = true
                        cookie.extensions["SameSite"] = "Lax"
index c8645d1c1577fe3d94b739ea4fde6f5acb1e38bd..c0ab266353c4104650358dc1b8d7504a91b19080 100644 (file)
@@ -1,5 +1,6 @@
 package info.mechyrdia.auth
 
+import info.mechyrdia.JsonStorageCodec
 import info.mechyrdia.data.*
 import io.ktor.server.sessions.*
 import kotlinx.serialization.SerialName
@@ -11,11 +12,13 @@ object SessionStorageMongoDB : SessionStorage {
        }
        
        override suspend fun read(id: String): String {
-               return SessionStorageDoc.Table.get(Id(id))?.session ?: throw NoSuchElementException("Session $id not found")
+               val value = SessionStorageDoc.Table.get(Id(id))?.session ?: throw NoSuchElementException("Session $id not found")
+               return JsonStorageCodec.encodeToString(UserSession.serializer(), value)
        }
        
        override suspend fun write(id: String, value: String) {
-               SessionStorageDoc.Table.put(SessionStorageDoc(Id(id), value))
+               val session = JsonStorageCodec.decodeFromString(UserSession.serializer(), value)
+               SessionStorageDoc.Table.put(SessionStorageDoc(Id(id), session))
        }
 }
 
@@ -23,7 +26,7 @@ object SessionStorageMongoDB : SessionStorage {
 data class SessionStorageDoc(
        @SerialName(MONGODB_ID_KEY)
        override val id: Id<SessionStorageDoc>,
-       val session: String
+       val session: UserSession,
 ) : DataDocument<SessionStorageDoc> {
        companion object : TableHolder<SessionStorageDoc> {
                override val Table = DocumentTable<SessionStorageDoc>()
diff --git a/src/jvmMain/kotlin/info/mechyrdia/auth/Sessions.kt b/src/jvmMain/kotlin/info/mechyrdia/auth/Sessions.kt
new file mode 100644 (file)
index 0000000..7beb75c
--- /dev/null
@@ -0,0 +1,59 @@
+package info.mechyrdia.auth
+
+import info.mechyrdia.data.Id
+import info.mechyrdia.data.InstantSerializer
+import info.mechyrdia.data.NationData
+import io.ktor.server.application.*
+import io.ktor.server.plugins.*
+import io.ktor.server.sessions.*
+import kotlinx.serialization.Serializable
+import java.time.Instant
+
+@Serializable
+data class CsrfTokenEntry(
+       val targetRoute: String,
+       val expiresAt: @Serializable(with = InstantSerializer::class) Instant,
+)
+
+@Serializable
+data class UserSession(
+       val nationId: Id<NationData>? = null,
+       val csrfTokens: Map<String, CsrfTokenEntry> = emptyMap(),
+)
+
+var ApplicationCall.currentUserSession: UserSession
+       get() = sessions.get<UserSession>() ?: UserSession().also { sessions.set(it) }
+       set(value) = sessions.set(value)
+
+suspend fun ApplicationCall.updateUserSession(session: UserSession) {
+       sessionId<UserSession>()?.let {
+               SessionStorageDoc.Table.put(SessionStorageDoc(Id(it), session))
+       }
+}
+
+const val DEFAULT_CSRF_TOKEN_EXPIRY_SECONDS = 7200
+
+fun ApplicationCall.createCsrfToken(targetRoute: String = request.origin.uri, expireSeconds: Int = DEFAULT_CSRF_TOKEN_EXPIRY_SECONDS): String {
+       val token = token()
+       val entry = CsrfTokenEntry(
+               targetRoute = targetRoute,
+               expiresAt = Instant.now().plusSeconds(expireSeconds.toLong())
+       )
+       
+       currentUserSession = currentUserSession.let { sess ->
+               sess.copy(csrfTokens = sess.csrfTokens + (token to entry))
+       }
+       
+       return token
+}
+
+suspend fun ApplicationCall.retrieveCsrfToken(token: String): CsrfTokenEntry? {
+       val session = currentUserSession
+       val entry = session.csrfTokens[token] ?: return null
+       
+       updateUserSession(session.let { sess ->
+               sess.copy(csrfTokens = sess.csrfTokens - token)
+       })
+       
+       return entry
+}
diff --git a/src/jvmMain/kotlin/info/mechyrdia/auth/sessions.kt b/src/jvmMain/kotlin/info/mechyrdia/auth/sessions.kt
deleted file mode 100644 (file)
index f47b9c1..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-package info.mechyrdia.auth
-
-import info.mechyrdia.data.Id
-import info.mechyrdia.data.NationData
-import kotlinx.serialization.Serializable
-
-@Serializable
-data class UserSession(
-       val nationId: Id<NationData>,
-)
index 50e8ba9eead94a5a9cf25fea09d79dc3ed26e692..6c0d73a6e0d1722f5877f89a3e239008de80cbde 100644 (file)
@@ -53,9 +53,7 @@ val CallNationCacheAttribute = AttributeKey<MutableMap<Id<NationData>, NationDat
 
 val ApplicationCall.nationCache: MutableMap<Id<NationData>, NationData>
        get() = attributes.computeIfAbsent(CallNationCacheAttribute) {
-               ConcurrentHashMap<Id<NationData>, NationData>().also { cache ->
-                       attributes.put(CallNationCacheAttribute, cache)
-               }
+               ConcurrentHashMap<Id<NationData>, NationData>()
        }
 
 suspend fun MutableMap<Id<NationData>, NationData>.getNation(id: Id<NationData>): NationData {
index 5b185f59e2c6ef1d706383a916b10772209fd5ef..614cd2ebd47b0af4bc7d103e255d67c5750e6dec 100644 (file)
@@ -4,12 +4,11 @@ import com.mongodb.client.model.Accumulators
 import com.mongodb.client.model.Aggregates
 import com.mongodb.client.model.Filters
 import com.mongodb.client.model.Updates
+import info.mechyrdia.auth.UserSession
 import info.mechyrdia.lore.dateTime
 import io.ktor.server.application.*
-import io.ktor.server.plugins.*
 import io.ktor.server.request.*
 import io.ktor.server.sessions.*
-import io.ktor.util.*
 import kotlinx.coroutines.flow.firstOrNull
 import kotlinx.html.FlowContent
 import kotlinx.html.p
@@ -17,7 +16,6 @@ import kotlinx.html.style
 import kotlinx.serialization.SerialName
 import kotlinx.serialization.Serializable
 import org.intellij.lang.annotations.Language
-import java.security.MessageDigest
 import java.time.Instant
 
 @Serializable
@@ -76,24 +74,12 @@ data class PageVisitData(
        }
 }
 
-private val messageDigestProvider = ThreadLocal.withInitial { MessageDigest.getInstance("SHA-256") }
-
-fun ApplicationCall.anonymizedClientId(): String {
-       val messageDigest = messageDigestProvider.get()
-       
-       messageDigest.reset()
-       messageDigest.update(request.origin.remoteAddress.encodeToByteArray())
-       request.userAgent()?.encodeToByteArray()?.let { messageDigest.update(it) }
-       
-       return hex(messageDigest.digest())
-}
-
 suspend fun ApplicationCall.processGuestbook(): PageVisitTotals {
        val path = request.path()
        
        val totals = PageVisitData.totalVisits(path)
        if (!RobotDetector.isRobot(request.userAgent()))
-               PageVisitData.visit(path, anonymizedClientId())
+               sessionId<UserSession>()?.let { PageVisitData.visit(path, it) }
        
        return totals
 }
index 2aefb7c4838eba77d0a63cbb3f4103e5a78bcd6a..ffa4d7b6596b1e25aae74d5760f572ffd02da115 100644 (file)
@@ -2,13 +2,13 @@ package info.mechyrdia.lore
 
 import info.mechyrdia.JsonFileCodec
 import info.mechyrdia.OwnerNationId
+import info.mechyrdia.auth.createCsrfToken
 import info.mechyrdia.data.FileStorage
 import info.mechyrdia.data.StoragePath
 import info.mechyrdia.data.currentNation
 import info.mechyrdia.robot.RobotService
 import info.mechyrdia.robot.RobotServiceStatus
 import info.mechyrdia.route.Root
-import info.mechyrdia.route.createCsrfToken
 import info.mechyrdia.route.href
 import io.ktor.server.application.*
 import kotlinx.html.*
index 4dffd6ae6f6738bb03ee23bb14f01093352d6b8c..fb063b08f8fb22eee57322f4a08885e9efe89222 100644 (file)
@@ -1,12 +1,12 @@
 package info.mechyrdia.robot
 
+import info.mechyrdia.auth.createCsrfToken
 import info.mechyrdia.data.currentNation
 import info.mechyrdia.lore.page
 import info.mechyrdia.lore.redirectHref
 import info.mechyrdia.lore.standardNavBar
 import info.mechyrdia.route.Root
 import info.mechyrdia.route.checkCsrfToken
-import info.mechyrdia.route.createCsrfToken
 import info.mechyrdia.route.href
 import io.ktor.server.application.*
 import io.ktor.server.websocket.*
index 80e9809b62f0c6e5e572a086aa6ab275f5e40b06..52db438a3dae01a70f36c206d12f533775c33d8b 100644 (file)
@@ -1,79 +1,49 @@
 package info.mechyrdia.route
 
-import info.mechyrdia.auth.UserSession
-import info.mechyrdia.auth.token
-import info.mechyrdia.data.Id
-import info.mechyrdia.data.NationData
+import info.mechyrdia.auth.createCsrfToken
+import info.mechyrdia.auth.retrieveCsrfToken
 import io.ktor.server.application.*
-import io.ktor.server.plugins.*
 import io.ktor.server.request.*
-import io.ktor.server.sessions.*
 import kotlinx.html.A
 import kotlinx.html.FORM
 import kotlinx.html.FlowContent
 import kotlinx.html.hiddenInput
 import java.time.Instant
-import java.util.concurrent.ConcurrentHashMap
 import kotlin.collections.set
 
-data class CsrfPayload(
-       val route: String,
-       val remoteAddress: String,
-       val userAgent: String?,
-       val userAccount: Id<NationData>?,
-       val expires: Instant
-)
-
-fun ApplicationCall.csrfPayload(route: String, withExpiration: Instant = Instant.now().plusSeconds(7200)) =
-       CsrfPayload(
-               route = route,
-               remoteAddress = request.origin.remoteAddress,
-               userAgent = request.userAgent(),
-               userAccount = sessions.get<UserSession>()?.nationId,
-               expires = withExpiration
-       )
-
-private val csrfMap = ConcurrentHashMap<String, CsrfPayload>()
-
 data class CsrfFailedException(override val message: String, val payload: CsrfProtectedResourcePayload?) : RuntimeException(message)
 
 interface CsrfProtectedResourcePayload {
        val csrfToken: String?
        
-       fun ApplicationCall.verifyCsrfToken(route: String = request.uri) {
+       suspend fun ApplicationCall.verifyCsrfToken(route: String = request.uri) {
                val token = csrfToken ?: throw CsrfFailedException("The submitted CSRF token is not present", this@CsrfProtectedResourcePayload)
-               val check = csrfMap.remove(token) ?: throw CsrfFailedException("The submitted CSRF token is not valid", this@CsrfProtectedResourcePayload)
-               val payload = csrfPayload(route, check.expires)
-               if (check != payload)
+               val entry = retrieveCsrfToken(token) ?: throw CsrfFailedException("The submitted CSRF token is not valid", this@CsrfProtectedResourcePayload)
+               if (entry.targetRoute != route)
                        throw CsrfFailedException("The submitted CSRF token does not match", this@CsrfProtectedResourcePayload)
-               if (payload.expires < Instant.now())
+               if (entry.expiresAt < Instant.now())
                        throw CsrfFailedException("The submitted CSRF token has expired", this@CsrfProtectedResourcePayload)
        }
        
        fun FlowContent.displayRetryData() {}
 }
 
-fun ApplicationCall.checkCsrfToken(csrfToken: String?, route: String = request.uri): Boolean {
+suspend fun ApplicationCall.checkCsrfToken(csrfToken: String?, route: String = request.uri): Boolean {
        val token = csrfToken ?: return false
-       val check = csrfMap.remove(token) ?: return false
-       val payload = csrfPayload(route, check.expires)
-       return check == payload && payload.expires >= Instant.now()
-}
-
-fun ApplicationCall.createCsrfToken(route: String = request.origin.uri): String {
-       return token().also { csrfMap[it] = csrfPayload(route) }
+       val entry = retrieveCsrfToken(token) ?: return false
+       return entry.targetRoute == route && entry.expiresAt >= Instant.now()
 }
 
 context(ApplicationCall)
 fun A.installCsrfToken(route: String = href) {
        attributes["data-method"] = "post"
-       attributes["data-csrf-token"] = token().also { csrfMap[it] = csrfPayload(route) }
+       attributes["data-csrf-token"] = createCsrfToken(route)
 }
 
 context(ApplicationCall)
 fun FORM.installCsrfToken(route: String = action) {
        hiddenInput {
                name = "csrfToken"
-               value = token().also { csrfMap[it] = csrfPayload(route) }
+               value = createCsrfToken(route)
        }
 }