优化plus

This commit is contained in:
pengxiaolong
2026-01-15 21:32:32 +08:00
parent a1fbc6417f
commit 673b4491d7
17 changed files with 649 additions and 346 deletions

View File

@@ -2,7 +2,15 @@ package com.example.myapplication.data
import android.content.Context
import java.io.BufferedReader
import java.io.FileInputStream
import java.io.FileNotFoundException
import java.io.InputStream
import java.io.InputStreamReader
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.Channels
import java.nio.channels.FileChannel
import kotlin.math.max
data class BigramModel(
val vocab: List<String>, // 保留全部词(含 <unk>, <s>, </s>),与二元矩阵索引对齐
@@ -30,39 +38,104 @@ object LanguageModelLoader {
}
private fun readInt32(context: Context, name: String): IntArray {
context.assets.open(name).use { input ->
val bytes = input.readBytes()
val n = bytes.size / 4
val out = IntArray(n)
var i = 0; var j = 0
while (i < n) {
// 小端序
val v = (bytes[j].toInt() and 0xFF) or
((bytes[j+1].toInt() and 0xFF) shl 8) or
((bytes[j+2].toInt() and 0xFF) shl 16) or
((bytes[j+3].toInt() and 0xFF) shl 24)
out[i++] = v
j += 4
try {
context.assets.openFd(name).use { afd ->
FileInputStream(afd.fileDescriptor).channel.use { channel ->
return readInt32Channel(channel, afd.startOffset, afd.length)
}
}
return out
} catch (e: FileNotFoundException) {
// Compressed assets do not support openFd; fall back to streaming.
}
context.assets.open(name).use { input ->
return readInt32Stream(input)
}
}
private fun readFloat32(context: Context, name: String): FloatArray {
context.assets.open(name).use { input ->
val bytes = input.readBytes()
val n = bytes.size / 4
val out = FloatArray(n)
var i = 0; var j = 0
while (i < n) {
val bits = (bytes[j].toInt() and 0xFF) or
((bytes[j+1].toInt() and 0xFF) shl 8) or
((bytes[j+2].toInt() and 0xFF) shl 16) or
((bytes[j+3].toInt() and 0xFF) shl 24)
out[i++] = Float.fromBits(bits)
j += 4
try {
context.assets.openFd(name).use { afd ->
FileInputStream(afd.fileDescriptor).channel.use { channel ->
return readFloat32Channel(channel, afd.startOffset, afd.length)
}
}
return out
} catch (e: FileNotFoundException) {
// Compressed assets do not support openFd; fall back to streaming.
}
context.assets.open(name).use { input ->
return readFloat32Stream(input)
}
}
private fun readInt32Channel(channel: FileChannel, offset: Long, length: Long): IntArray {
require(length % 4L == 0L) { "int32 length invalid: $length" }
require(length <= Int.MAX_VALUE.toLong()) { "int32 asset too large: $length" }
val count = (length / 4L).toInt()
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
mapped.order(ByteOrder.LITTLE_ENDIAN)
val out = IntArray(count)
mapped.asIntBuffer().get(out)
return out
}
private fun readFloat32Channel(channel: FileChannel, offset: Long, length: Long): FloatArray {
require(length % 4L == 0L) { "float32 length invalid: $length" }
require(length <= Int.MAX_VALUE.toLong()) { "float32 asset too large: $length" }
val count = (length / 4L).toInt()
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
mapped.order(ByteOrder.LITTLE_ENDIAN)
val out = FloatArray(count)
mapped.asFloatBuffer().get(out)
return out
}
private fun readInt32Stream(input: InputStream): IntArray {
val initialSize = max(1024, input.available() / 4)
var out = IntArray(initialSize)
var count = 0
val buffer = ByteBuffer.allocateDirect(64 * 1024)
buffer.order(ByteOrder.LITTLE_ENDIAN)
Channels.newChannel(input).use { channel ->
while (true) {
val read = channel.read(buffer)
if (read == -1) break
if (read == 0) continue
buffer.flip()
while (buffer.remaining() >= 4) {
if (count == out.size) out = out.copyOf(out.size * 2)
out[count++] = buffer.getInt()
}
buffer.compact()
}
}
buffer.flip()
check(buffer.remaining() == 0) { "truncated int32 stream" }
return out.copyOf(count)
}
private fun readFloat32Stream(input: InputStream): FloatArray {
val initialSize = max(1024, input.available() / 4)
var out = FloatArray(initialSize)
var count = 0
val buffer = ByteBuffer.allocateDirect(64 * 1024)
buffer.order(ByteOrder.LITTLE_ENDIAN)
Channels.newChannel(input).use { channel ->
while (true) {
val read = channel.read(buffer)
if (read == -1) break
if (read == 0) continue
buffer.flip()
while (buffer.remaining() >= 4) {
if (count == out.size) out = out.copyOf(out.size * 2)
out[count++] = buffer.getFloat()
}
buffer.compact()
}
}
buffer.flip()
check(buffer.remaining() == 0) { "truncated float32 stream" }
return out.copyOf(count)
}
}