Fix file uploads
authorLanius Trolling <lanius@laniustrolling.dev>
Wed, 18 Dec 2024 03:08:16 +0000 (22:08 -0500)
committerLanius Trolling <lanius@laniustrolling.dev>
Wed, 18 Dec 2024 03:08:16 +0000 (22:08 -0500)
src/main/kotlin/info/mechyrdia/data/ViewsFiles.kt
src/main/kotlin/info/mechyrdia/route/ResourceMultipart.kt
src/main/kotlin/info/mechyrdia/route/ResourceTypes.kt

index c10039c8ed242290d0fd2047d661e55b202e2567..dc1eb4959aa33f56cfd755e798487f40c8c2c17a 100644 (file)
@@ -5,6 +5,7 @@ import info.mechyrdia.lore.adminPage
 import info.mechyrdia.lore.dateTime
 import info.mechyrdia.lore.mapSuspend
 import info.mechyrdia.lore.redirectHref
+import info.mechyrdia.route.MultiPartPayloadPart
 import info.mechyrdia.route.Root
 import info.mechyrdia.route.href
 import info.mechyrdia.route.installCsrfToken
@@ -282,19 +283,18 @@ suspend fun ApplicationCall.adminDoCopyFile(from: StoragePath, into: StoragePath
                respond(HttpStatusCode.Conflict)
 }
 
-suspend fun ApplicationCall.adminUploadFile(path: StoragePath, part: PartData.FileItem) {
+suspend fun ApplicationCall.adminUploadFile(path: StoragePath, part: MultiPartPayloadPart.FileData) {
        val name = part.originalFileName ?: throw MissingRequestParameterException("originalFileName")
        val filePath = path / name
        
-       val content = part.provider().toByteArray()
-       if (FileStorage.instance.writeFile(filePath, content))
+       if (FileStorage.instance.writeFile(filePath, part.contents))
                redirectHref(Root.Admin.Vfs.View(filePath.elements), HttpStatusCode.SeeOther)
        else
                respond(HttpStatusCode.Conflict)
 }
 
-suspend fun ApplicationCall.adminOverwriteFile(path: StoragePath, part: PartData.FileItem) {
-       if (FileStorage.instance.writeFile(path, part.provider().toByteArray()))
+suspend fun ApplicationCall.adminOverwriteFile(path: StoragePath, part: MultiPartPayloadPart.FileData) {
+       if (FileStorage.instance.writeFile(path, part.contents))
                redirectHref(Root.Admin.Vfs.View(path.elements), HttpStatusCode.SeeOther)
        else
                respond(HttpStatusCode.Conflict)
index 667555000883671267ee0642192f0cd40fab0005..2257e5d0fecb35bb697bcddc51ac242318ca3485 100644 (file)
@@ -1,12 +1,17 @@
 package info.mechyrdia.route
 
+import io.ktor.http.ContentDisposition
+import io.ktor.http.ContentType
+import io.ktor.http.Headers
+import io.ktor.http.HttpHeaders
 import io.ktor.http.content.MultiPartData
 import io.ktor.http.content.PartData
 import io.ktor.http.content.forEachPart
+import io.ktor.utils.io.toByteArray
 import kotlin.reflect.full.companionObjectInstance
 
 interface MultiPartPayload : AutoCloseable {
-       val payload: List<PartData>
+       val payload: List<MultiPartPayloadPart>
        
        override fun close() {
                for (data in payload)
@@ -23,19 +28,51 @@ inline fun <reified P : MultiPartPayload> payloadProcessor(): MultiPartPayloadPr
        return P::class.companionObjectInstance as MultiPartPayloadProcessor<P>
 }
 
+sealed class MultiPartPayloadPart {
+       abstract val headers: Headers
+       abstract fun dispose()
+       
+       val contentDisposition: ContentDisposition?
+               get() = headers[HttpHeaders.ContentDisposition]?.let { ContentDisposition.parse(it) }
+       
+       val contentType: ContentType?
+               get() = headers[HttpHeaders.ContentType]?.let { ContentType.parse(it) }
+       
+       val name: String?
+               get() = contentDisposition?.name
+       
+       class FormData(val value: String, override val headers: Headers, private val disposer: () -> Unit) : MultiPartPayloadPart() {
+               override fun dispose() {
+                       disposer()
+               }
+       }
+       
+       class FileData(val contents: ByteArray, override val headers: Headers, private val disposer: () -> Unit) : MultiPartPayloadPart() {
+               override fun dispose() {
+                       disposer()
+               }
+               
+               val originalFileName: String? = contentDisposition?.parameter(ContentDisposition.Parameters.FileName)
+       }
+}
+
 data class CsrfProtectedMultiPartPayload(
        override val csrfToken: String? = null,
-       override val payload: List<PartData>
+       override val payload: List<MultiPartPayloadPart>
 ) : CsrfProtectedResourcePayload, MultiPartPayload {
        companion object : MultiPartPayloadProcessor<CsrfProtectedMultiPartPayload> {
                override suspend fun process(data: MultiPartData): CsrfProtectedMultiPartPayload {
                        var csrfToken: String? = null
-                       val payload = mutableListOf<PartData>()
+                       val payload = mutableListOf<MultiPartPayloadPart>()
                        
                        data.forEachPart { part ->
                                if (part is PartData.FormItem && part.name == "csrfToken")
                                        csrfToken = part.value
-                               else payload.add(part)
+                               else when (part) {
+                                       is PartData.FormItem -> MultiPartPayloadPart.FormData(part.value, part.headers, part.dispose)
+                                       is PartData.FileItem -> MultiPartPayloadPart.FileData(part.provider().toByteArray(), part.headers, part.dispose)
+                                       else -> null
+                               }?.let(payload::add)
                        }
                        
                        return CsrfProtectedMultiPartPayload(csrfToken, payload)
@@ -44,12 +81,20 @@ data class CsrfProtectedMultiPartPayload(
 }
 
 data class PlainMultiPartPayload(
-       override val payload: List<PartData>
+       override val payload: List<MultiPartPayloadPart>
 ) : MultiPartPayload {
        companion object : MultiPartPayloadProcessor<PlainMultiPartPayload> {
                override suspend fun process(data: MultiPartData): PlainMultiPartPayload {
-                       val payload = mutableListOf<PartData>()
-                       data.forEachPart { part -> payload.add(part) }
+                       val payload = mutableListOf<MultiPartPayloadPart>()
+                       
+                       data.forEachPart { part ->
+                               when (part) {
+                                       is PartData.FormItem -> MultiPartPayloadPart.FormData(part.value, part.headers, part.dispose)
+                                       is PartData.FileItem -> MultiPartPayloadPart.FileData(part.provider().toByteArray(), part.headers, part.dispose)
+                                       else -> null
+                               }?.let(payload::add)
+                       }
+                       
                        return PlainMultiPartPayload(payload)
                }
        }
index 58ad73b3fda2afaa8fe829d98e0272e3b9fa2ec6..3bc18db9b9bbd2dedbcdd72a282524bfed116e7d 100644 (file)
@@ -482,7 +482,7 @@ class Root : ResourceHandler, ResourceFilter {
                                        with(vfs) { call.filterCall() }
                                        with(payload) { call.verifyCsrfToken() }
                                        
-                                       val fileItem = payload.payload.filterIsInstance<PartData.FileItem>().singleOrNull()
+                                       val fileItem = payload.payload.filterIsInstance<MultiPartPayloadPart.FileData>().singleOrNull()
                                                ?: throw MissingRequestParameterException("file")
                                        
                                        call.adminUploadFile(StoragePath(path), fileItem)
@@ -495,7 +495,7 @@ class Root : ResourceHandler, ResourceFilter {
                                        with(vfs) { call.filterCall() }
                                        with(payload) { call.verifyCsrfToken() }
                                        
-                                       val fileItem = payload.payload.filterIsInstance<PartData.FileItem>().singleOrNull()
+                                       val fileItem = payload.payload.filterIsInstance<MultiPartPayloadPart.FileData>().singleOrNull()
                                                ?: throw MissingRequestParameterException("file")
                                        
                                        call.adminOverwriteFile(StoragePath(path), fileItem)