371 lines
14 KiB
Kotlin
371 lines
14 KiB
Kotlin
package com.acitelight.aether.service
|
|
|
|
import com.acitelight.aether.service.AuthManager.db64
|
|
import com.acitelight.aether.service.AuthManager.signChallenge
|
|
import com.acitelight.aether.service.AuthManager.signChallengeByte
|
|
import kotlinx.coroutines.*
|
|
import java.io.InputStream
|
|
import java.io.OutputStream
|
|
|
|
import java.net.Socket
|
|
import java.nio.ByteBuffer
|
|
import java.security.SecureRandom
|
|
import java.util.ArrayDeque
|
|
|
|
import org.bouncycastle.math.ec.rfc7748.X25519
|
|
import org.bouncycastle.crypto.generators.HKDFBytesGenerator
|
|
import org.bouncycastle.crypto.digests.SHA256Digest
|
|
import org.bouncycastle.crypto.params.HKDFParameters
|
|
import org.bouncycastle.crypto.params.KeyParameter
|
|
import org.bouncycastle.crypto.params.AEADParameters
|
|
import org.bouncycastle.crypto.modes.ChaCha20Poly1305
|
|
import java.io.EOFException
|
|
import java.util.concurrent.atomic.AtomicLong
|
|
|
|
class AbyssStream private constructor(
|
|
private val socket: Socket,
|
|
private val input: InputStream,
|
|
private val output: OutputStream,
|
|
private val aeadKey: ByteArray,
|
|
private val sendSalt: ByteArray,
|
|
private val recvSalt: ByteArray
|
|
) {
|
|
companion object {
|
|
private const val PUBLIC_KEY_LEN = 32
|
|
private const val AEAD_KEY_LEN = 32
|
|
private const val NONCE_SALT_LEN = 4
|
|
private const val AEAD_TAG_LEN = 16
|
|
private const val NONCE_LEN = 12
|
|
private const val MAX_PLAINTEXT_FRAME = 64 * 1024
|
|
|
|
private val secureRandom = SecureRandom()
|
|
|
|
/**
|
|
* Create and perform handshake on an already-connected socket.
|
|
* If privateKeyRaw is provided, it must be 32 bytes.
|
|
*/
|
|
suspend fun create(socket: Socket, privateKeyRaw: ByteArray? = null): AbyssStream = withContext(Dispatchers.IO) {
|
|
if (!socket.isConnected) throw IllegalArgumentException("socket is not connected")
|
|
val inStream = socket.getInputStream()
|
|
val outStream = socket.getOutputStream()
|
|
|
|
// 1) keypair (raw)
|
|
val localPriv = ByteArray(PUBLIC_KEY_LEN)
|
|
if (privateKeyRaw != null) {
|
|
if (privateKeyRaw.size != PUBLIC_KEY_LEN) {
|
|
throw IllegalArgumentException("privateKeyRaw must be $PUBLIC_KEY_LEN bytes")
|
|
}
|
|
System.arraycopy(privateKeyRaw, 0, localPriv, 0, PUBLIC_KEY_LEN)
|
|
} else {
|
|
X25519.generatePrivateKey(secureRandom, localPriv)
|
|
}
|
|
val localPub = ByteArray(PUBLIC_KEY_LEN)
|
|
X25519.scalarMultBase(localPriv, 0, localPub, 0)
|
|
|
|
// 2) exchange raw public keys (exact 32 bytes each) using blocking IO
|
|
writeExact(outStream, localPub, 0, PUBLIC_KEY_LEN)
|
|
val remotePub = ByteArray(PUBLIC_KEY_LEN)
|
|
readExact(inStream, remotePub, 0, PUBLIC_KEY_LEN)
|
|
|
|
val ch = ByteArray(32)
|
|
readExact(inStream, ch, 0, 32)
|
|
val signed = signChallengeByte(localPriv, ch)
|
|
writeExact(outStream, signed, 0, signed.size)
|
|
readExact(inStream, ch, 0, 16)
|
|
|
|
// 3) compute shared secret: X25519.scalarMult(private, remotePublic)
|
|
val shared = ByteArray(PUBLIC_KEY_LEN)
|
|
X25519.scalarMult(localPriv, 0, remotePub, 0, shared, 0)
|
|
|
|
// 4) HKDF-SHA256 -> AEAD key + saltA + saltB
|
|
val hkdf = HKDFBytesGenerator(SHA256Digest())
|
|
// AEAD key
|
|
hkdf.init(HKDFParameters(shared, null, "Abyss-AEAD-Key".toByteArray(Charsets.US_ASCII)))
|
|
val aeadKey = ByteArray(AEAD_KEY_LEN)
|
|
hkdf.generateBytes(aeadKey, 0, AEAD_KEY_LEN)
|
|
|
|
// salt A
|
|
hkdf.init(HKDFParameters(shared, null, "Abyss-Nonce-Salt-A".toByteArray(Charsets.US_ASCII)))
|
|
val saltA = ByteArray(NONCE_SALT_LEN)
|
|
hkdf.generateBytes(saltA, 0, NONCE_SALT_LEN)
|
|
|
|
// salt B
|
|
hkdf.init(HKDFParameters(shared, null, "Abyss-Nonce-Salt-B".toByteArray(Charsets.US_ASCII)))
|
|
val saltB = ByteArray(NONCE_SALT_LEN)
|
|
hkdf.generateBytes(saltB, 0, NONCE_SALT_LEN)
|
|
|
|
// Deterministic assignment by lexicographic comparison
|
|
val cmp = lexicographicCompare(localPub, remotePub)
|
|
val sendSalt: ByteArray
|
|
val recvSalt: ByteArray
|
|
if (cmp < 0) {
|
|
sendSalt = saltA
|
|
recvSalt = saltB
|
|
} else if (cmp > 0) {
|
|
sendSalt = saltB
|
|
recvSalt = saltA
|
|
} else {
|
|
// extremely unlikely
|
|
sendSalt = saltA
|
|
recvSalt = saltB
|
|
}
|
|
|
|
// zero sensitive buffers
|
|
localPriv.fill(0)
|
|
localPub.fill(0)
|
|
remotePub.fill(0)
|
|
shared.fill(0)
|
|
// keep aeadKey, sendSalt, recvSalt
|
|
|
|
return@withContext AbyssStream(socket, inStream, outStream, aeadKey, sendSalt, recvSalt)
|
|
}
|
|
|
|
private fun lexicographicCompare(a: ByteArray, b: ByteArray): Int {
|
|
val min = kotlin.math.min(a.size, b.size)
|
|
for (i in 0 until min) {
|
|
val av = a[i].toInt() and 0xff
|
|
val bv = b[i].toInt() and 0xff
|
|
if (av < bv) return -1
|
|
if (av > bv) return 1
|
|
}
|
|
if (a.size < b.size) return -1
|
|
if (a.size > b.size) return 1
|
|
return 0
|
|
}
|
|
|
|
private fun readExact(input: InputStream, buffer: ByteArray, offset: Int, count: Int) {
|
|
var read = 0
|
|
while (read < count) {
|
|
val n = input.read(buffer, offset + read, count - read)
|
|
if (n == -1) {
|
|
if (read == 0) throw EOFException("Remote closed connection while reading")
|
|
else throw EOFException("Remote closed connection unexpectedly during read")
|
|
}
|
|
read += n
|
|
}
|
|
}
|
|
|
|
private fun writeExact(output: OutputStream, buffer: ByteArray, offset: Int, count: Int) {
|
|
output.write(buffer, offset, count)
|
|
output.flush()
|
|
}
|
|
}
|
|
|
|
// internal state
|
|
private val sendCounter = AtomicLong(0L)
|
|
private val recvCounter = AtomicLong(0L)
|
|
private val sendLock = Any()
|
|
private val aeadLock = Any()
|
|
|
|
// leftover read queue
|
|
private val leftoverQueue = ArrayDeque<ByteArray>()
|
|
private var currentLeftover: ByteArray? = null
|
|
private var currentLeftoverOffset = 0
|
|
|
|
@Volatile
|
|
private var closed = false
|
|
|
|
// ---- high-level read/write APIs (suspendable) ----
|
|
|
|
suspend fun read(buffer: ByteArray, offset: Int, count: Int): Int = withContext(Dispatchers.IO) {
|
|
if (closed) throw IllegalStateException("AbyssStream closed")
|
|
if (buffer.size < offset + count) throw IndexOutOfBoundsException()
|
|
// serve leftover first
|
|
if (ensureCurrentLeftover()) {
|
|
val seg = currentLeftover!!
|
|
val avail = seg.size - currentLeftoverOffset
|
|
val toCopy = kotlin.math.min(avail, count)
|
|
System.arraycopy(seg, currentLeftoverOffset, buffer, offset, toCopy)
|
|
currentLeftoverOffset += toCopy
|
|
if (currentLeftoverOffset >= seg.size) {
|
|
currentLeftover = null
|
|
currentLeftoverOffset = 0
|
|
}
|
|
return@withContext toCopy
|
|
}
|
|
|
|
// read one frame and decrypt
|
|
val plaintext = readOneFrameAndDecrypt()
|
|
if (plaintext == null || plaintext.isEmpty()) {
|
|
// EOF
|
|
return@withContext 0
|
|
}
|
|
|
|
return@withContext if (plaintext.size <= count) {
|
|
System.arraycopy(plaintext, 0, buffer, offset, plaintext.size)
|
|
plaintext.size
|
|
} else {
|
|
System.arraycopy(plaintext, 0, buffer, offset, count)
|
|
val leftoverLen = plaintext.size - count
|
|
val leftover = ByteArray(leftoverLen)
|
|
System.arraycopy(plaintext, count, leftover, 0, leftoverLen)
|
|
synchronized(leftoverQueue) { leftoverQueue.addLast(leftover) }
|
|
count
|
|
}
|
|
}
|
|
|
|
private fun ensureCurrentLeftover(): Boolean {
|
|
if (currentLeftover != null && currentLeftoverOffset < currentLeftover!!.size) return true
|
|
synchronized(leftoverQueue) {
|
|
val next = leftoverQueue.pollFirst()
|
|
if (next != null) {
|
|
currentLeftover = next
|
|
currentLeftoverOffset = 0
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
private fun readOneFrameAndDecrypt(): ByteArray? {
|
|
// read 4-byte header
|
|
val header = ByteArray(4)
|
|
try {
|
|
readExact(input, header, 0, 4)
|
|
} catch (e: EOFException) {
|
|
return null
|
|
}
|
|
val payloadLen = ByteBuffer.wrap(header).int and 0xffffffff.toInt()
|
|
if (payloadLen > MAX_PLAINTEXT_FRAME + AEAD_TAG_LEN) throw IllegalStateException("payload too big")
|
|
if (payloadLen < AEAD_TAG_LEN) throw IllegalStateException("payload too small")
|
|
|
|
val payload = ByteArray(payloadLen)
|
|
readExact(input, payload, 0, payloadLen)
|
|
|
|
val ciphertextLen = payloadLen - AEAD_TAG_LEN
|
|
val ciphertext = ByteArray(ciphertextLen)
|
|
val tag = ByteArray(AEAD_TAG_LEN)
|
|
if (ciphertextLen > 0) System.arraycopy(payload, 0, ciphertext, 0, ciphertextLen)
|
|
System.arraycopy(payload, ciphertextLen, tag, 0, AEAD_TAG_LEN)
|
|
|
|
val remoteCounterValue = recvCounter.getAndIncrement()
|
|
|
|
val nonce = ByteArray(NONCE_LEN)
|
|
System.arraycopy(recvSalt, 0, nonce, 0, NONCE_SALT_LEN)
|
|
// write 8-byte big-endian counter at nonce[4..11]
|
|
val bb = ByteBuffer.wrap(nonce, NONCE_SALT_LEN, 8)
|
|
bb.putLong(remoteCounterValue)
|
|
|
|
val plaintext = try {
|
|
aeadDecrypt(nonce, ciphertext, tag)
|
|
} catch (ex: Exception) {
|
|
close()
|
|
throw SecurityException("AEAD authentication failed; connection closed.", ex)
|
|
} finally {
|
|
nonce.fill(0)
|
|
payload.fill(0)
|
|
ciphertext.fill(0)
|
|
tag.fill(0)
|
|
}
|
|
|
|
return plaintext
|
|
}
|
|
|
|
suspend fun write(buffer: ByteArray, offset: Int, count: Int) = withContext(Dispatchers.IO) {
|
|
if (closed) throw IllegalStateException("AbyssStream closed")
|
|
if (buffer.size < offset + count) throw IndexOutOfBoundsException()
|
|
var remaining = count
|
|
var idx = offset
|
|
while (remaining > 0) {
|
|
val chunk = kotlin.math.min(remaining, MAX_PLAINTEXT_FRAME)
|
|
val plaintext = buffer.copyOfRange(idx, idx + chunk)
|
|
sendPlaintextChunk(plaintext)
|
|
idx += chunk
|
|
remaining -= chunk
|
|
}
|
|
}
|
|
|
|
private fun sendPlaintextChunk(plaintext: ByteArray) {
|
|
if (closed) throw IllegalStateException("AbyssStream closed")
|
|
|
|
val nonce = ByteArray(NONCE_LEN)
|
|
val ciphertextAndTag: ByteArray
|
|
val counterValue: Long
|
|
synchronized(sendLock) {
|
|
counterValue = sendCounter.getAndIncrement()
|
|
}
|
|
|
|
System.arraycopy(sendSalt, 0, nonce, 0, NONCE_SALT_LEN)
|
|
ByteBuffer.wrap(nonce, NONCE_SALT_LEN, 8).putLong(counterValue)
|
|
|
|
try {
|
|
ciphertextAndTag = aeadEncrypt(nonce, plaintext)
|
|
} finally {
|
|
nonce.fill(0)
|
|
}
|
|
|
|
val payloadLen = ciphertextAndTag.size
|
|
// header + ciphertextAndTag 一次性合并
|
|
val packet = ByteBuffer.allocate(4 + payloadLen)
|
|
.putInt(payloadLen)
|
|
.put(ciphertextAndTag)
|
|
.array()
|
|
|
|
try {
|
|
synchronized(output) {
|
|
output.write(packet)
|
|
output.flush()
|
|
}
|
|
} finally {
|
|
// clear sensitive
|
|
ciphertextAndTag.fill(0)
|
|
plaintext.fill(0)
|
|
packet.fill(0)
|
|
}
|
|
}
|
|
|
|
// ---- AEAD helpers using BouncyCastle lightweight API ----
|
|
// ChaCha20-Poly1305 with 12-byte nonce. BouncyCastle ChaCha20Poly1305 produces ciphertext+tag.
|
|
|
|
private fun aeadEncrypt(nonce: ByteArray, plaintext: ByteArray): ByteArray {
|
|
synchronized(aeadLock) {
|
|
val cipher = ChaCha20Poly1305()
|
|
val params = AEADParameters(KeyParameter(aeadKey), AEAD_TAG_LEN * 8, nonce, null)
|
|
cipher.init(true, params)
|
|
val outBuf = ByteArray(cipher.getOutputSize(plaintext.size))
|
|
var len = cipher.processBytes(plaintext, 0, plaintext.size, outBuf, 0)
|
|
len += cipher.doFinal(outBuf, len)
|
|
if (len != outBuf.size) return outBuf.copyOf(len)
|
|
return outBuf
|
|
}
|
|
}
|
|
|
|
private fun aeadDecrypt(nonce: ByteArray, ciphertext: ByteArray, tag: ByteArray): ByteArray {
|
|
synchronized(aeadLock) {
|
|
val cipher = ChaCha20Poly1305()
|
|
val params = AEADParameters(KeyParameter(aeadKey), AEAD_TAG_LEN * 8, nonce, null)
|
|
cipher.init(false, params)
|
|
// input is ciphertext||tag
|
|
val input = ByteArray(ciphertext.size + tag.size)
|
|
if (ciphertext.isNotEmpty()) System.arraycopy(ciphertext, 0, input, 0, ciphertext.size)
|
|
System.arraycopy(tag, 0, input, ciphertext.size, tag.size)
|
|
val outBuf = ByteArray(cipher.getOutputSize(input.size))
|
|
var len = cipher.processBytes(input, 0, input.size, outBuf, 0)
|
|
try {
|
|
len += cipher.doFinal(outBuf, len)
|
|
} catch (ex: Exception) {
|
|
// authentication failure or other
|
|
throw ex
|
|
}
|
|
return if (len != outBuf.size) outBuf.copyOf(len) else outBuf
|
|
}
|
|
}
|
|
|
|
// ---- utility / lifecycle ----
|
|
|
|
fun close() {
|
|
if (!closed) {
|
|
closed = true
|
|
try { socket.close() } catch (_: Exception) {}
|
|
// clear secrets
|
|
aeadKey.fill(0)
|
|
sendSalt.fill(0)
|
|
recvSalt.fill(0)
|
|
synchronized(leftoverQueue) {
|
|
leftoverQueue.forEach { it.fill(0) }
|
|
leftoverQueue.clear()
|
|
}
|
|
currentLeftover = null
|
|
}
|
|
}
|
|
} |