package com.example.myapplication.data import android.content.Context import java.io.BufferedReader import java.io.FileInputStream import java.io.FileNotFoundException import java.io.InputStream import java.io.InputStreamReader import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.channels.Channels import java.nio.channels.FileChannel import kotlin.math.max data class BigramModel( val vocab: List, // 保留全部词(含 , , ),与二元矩阵索引对齐 val uniLogp: FloatArray, // 长度 = vocab.size val biRowptr: IntArray, // 长度 = vocab.size + 1 (CSR) val biCols: IntArray, // 长度 = nnz val biLogp: FloatArray // 长度 = nnz ) object LanguageModelLoader { fun load(context: Context): BigramModel { val vocab = context.assets.open("vocab.txt").bufferedReader() .use(BufferedReader::readLines) val uniLogp = readFloat32(context, "uni_logp.bin") val biRowptr = readInt32(context, "bi_rowptr.bin") val biCols = readInt32(context, "bi_cols.bin") val biLogp = readFloat32(context, "bi_logp.bin") require(uniLogp.size == vocab.size) { "uni_logp length != vocab size" } require(biRowptr.size == vocab.size + 1) { "bi_rowptr length invalid" } require(biCols.size == biLogp.size) { "bi cols/logp nnz mismatch" } return BigramModel(vocab, uniLogp, biRowptr, biCols, biLogp) } private fun readInt32(context: Context, name: String): IntArray { try { context.assets.openFd(name).use { afd -> FileInputStream(afd.fileDescriptor).channel.use { channel -> return readInt32Channel(channel, afd.startOffset, afd.length) } } } catch (e: FileNotFoundException) { // Compressed assets do not support openFd; fall back to streaming. } context.assets.open(name).use { input -> return readInt32Stream(input) } } private fun readFloat32(context: Context, name: String): FloatArray { try { context.assets.openFd(name).use { afd -> FileInputStream(afd.fileDescriptor).channel.use { channel -> return readFloat32Channel(channel, afd.startOffset, afd.length) } } } catch (e: FileNotFoundException) { // Compressed assets do not support openFd; fall back to streaming. } context.assets.open(name).use { input -> return readFloat32Stream(input) } } private fun readInt32Channel(channel: FileChannel, offset: Long, length: Long): IntArray { require(length % 4L == 0L) { "int32 length invalid: $length" } require(length <= Int.MAX_VALUE.toLong()) { "int32 asset too large: $length" } val count = (length / 4L).toInt() val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length) mapped.order(ByteOrder.LITTLE_ENDIAN) val out = IntArray(count) mapped.asIntBuffer().get(out) return out } private fun readFloat32Channel(channel: FileChannel, offset: Long, length: Long): FloatArray { require(length % 4L == 0L) { "float32 length invalid: $length" } require(length <= Int.MAX_VALUE.toLong()) { "float32 asset too large: $length" } val count = (length / 4L).toInt() val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length) mapped.order(ByteOrder.LITTLE_ENDIAN) val out = FloatArray(count) mapped.asFloatBuffer().get(out) return out } private fun readInt32Stream(input: InputStream): IntArray { val initialSize = max(1024, input.available() / 4) var out = IntArray(initialSize) var count = 0 val buffer = ByteBuffer.allocateDirect(64 * 1024) buffer.order(ByteOrder.LITTLE_ENDIAN) Channels.newChannel(input).use { channel -> while (true) { val read = channel.read(buffer) if (read == -1) break if (read == 0) continue buffer.flip() while (buffer.remaining() >= 4) { if (count == out.size) out = out.copyOf(out.size * 2) out[count++] = buffer.getInt() } buffer.compact() } } buffer.flip() check(buffer.remaining() == 0) { "truncated int32 stream" } return out.copyOf(count) } private fun readFloat32Stream(input: InputStream): FloatArray { val initialSize = max(1024, input.available() / 4) var out = FloatArray(initialSize) var count = 0 val buffer = ByteBuffer.allocateDirect(64 * 1024) buffer.order(ByteOrder.LITTLE_ENDIAN) Channels.newChannel(input).use { channel -> while (true) { val read = channel.read(buffer) if (read == -1) break if (read == 0) continue buffer.flip() while (buffer.remaining() >= 4) { if (count == out.size) out = out.copyOf(out.size * 2) out[count++] = buffer.getFloat() } buffer.compact() } } buffer.flip() check(buffer.remaining() == 0) { "truncated float32 stream" } return out.copyOf(count) } }