优化plus
This commit is contained in:
@@ -2,7 +2,15 @@ package com.example.myapplication.data
|
||||
|
||||
import android.content.Context
|
||||
import java.io.BufferedReader
|
||||
import java.io.FileInputStream
|
||||
import java.io.FileNotFoundException
|
||||
import java.io.InputStream
|
||||
import java.io.InputStreamReader
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.nio.channels.Channels
|
||||
import java.nio.channels.FileChannel
|
||||
import kotlin.math.max
|
||||
|
||||
data class BigramModel(
|
||||
val vocab: List<String>, // 保留全部词(含 <unk>, <s>, </s>),与二元矩阵索引对齐
|
||||
@@ -30,39 +38,104 @@ object LanguageModelLoader {
|
||||
}
|
||||
|
||||
private fun readInt32(context: Context, name: String): IntArray {
|
||||
context.assets.open(name).use { input ->
|
||||
val bytes = input.readBytes()
|
||||
val n = bytes.size / 4
|
||||
val out = IntArray(n)
|
||||
var i = 0; var j = 0
|
||||
while (i < n) {
|
||||
// 小端序
|
||||
val v = (bytes[j].toInt() and 0xFF) or
|
||||
((bytes[j+1].toInt() and 0xFF) shl 8) or
|
||||
((bytes[j+2].toInt() and 0xFF) shl 16) or
|
||||
((bytes[j+3].toInt() and 0xFF) shl 24)
|
||||
out[i++] = v
|
||||
j += 4
|
||||
try {
|
||||
context.assets.openFd(name).use { afd ->
|
||||
FileInputStream(afd.fileDescriptor).channel.use { channel ->
|
||||
return readInt32Channel(channel, afd.startOffset, afd.length)
|
||||
}
|
||||
}
|
||||
return out
|
||||
} catch (e: FileNotFoundException) {
|
||||
// Compressed assets do not support openFd; fall back to streaming.
|
||||
}
|
||||
|
||||
context.assets.open(name).use { input ->
|
||||
return readInt32Stream(input)
|
||||
}
|
||||
}
|
||||
|
||||
private fun readFloat32(context: Context, name: String): FloatArray {
|
||||
context.assets.open(name).use { input ->
|
||||
val bytes = input.readBytes()
|
||||
val n = bytes.size / 4
|
||||
val out = FloatArray(n)
|
||||
var i = 0; var j = 0
|
||||
while (i < n) {
|
||||
val bits = (bytes[j].toInt() and 0xFF) or
|
||||
((bytes[j+1].toInt() and 0xFF) shl 8) or
|
||||
((bytes[j+2].toInt() and 0xFF) shl 16) or
|
||||
((bytes[j+3].toInt() and 0xFF) shl 24)
|
||||
out[i++] = Float.fromBits(bits)
|
||||
j += 4
|
||||
try {
|
||||
context.assets.openFd(name).use { afd ->
|
||||
FileInputStream(afd.fileDescriptor).channel.use { channel ->
|
||||
return readFloat32Channel(channel, afd.startOffset, afd.length)
|
||||
}
|
||||
}
|
||||
return out
|
||||
} catch (e: FileNotFoundException) {
|
||||
// Compressed assets do not support openFd; fall back to streaming.
|
||||
}
|
||||
|
||||
context.assets.open(name).use { input ->
|
||||
return readFloat32Stream(input)
|
||||
}
|
||||
}
|
||||
|
||||
private fun readInt32Channel(channel: FileChannel, offset: Long, length: Long): IntArray {
|
||||
require(length % 4L == 0L) { "int32 length invalid: $length" }
|
||||
require(length <= Int.MAX_VALUE.toLong()) { "int32 asset too large: $length" }
|
||||
val count = (length / 4L).toInt()
|
||||
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
|
||||
mapped.order(ByteOrder.LITTLE_ENDIAN)
|
||||
val out = IntArray(count)
|
||||
mapped.asIntBuffer().get(out)
|
||||
return out
|
||||
}
|
||||
|
||||
private fun readFloat32Channel(channel: FileChannel, offset: Long, length: Long): FloatArray {
|
||||
require(length % 4L == 0L) { "float32 length invalid: $length" }
|
||||
require(length <= Int.MAX_VALUE.toLong()) { "float32 asset too large: $length" }
|
||||
val count = (length / 4L).toInt()
|
||||
val mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length)
|
||||
mapped.order(ByteOrder.LITTLE_ENDIAN)
|
||||
val out = FloatArray(count)
|
||||
mapped.asFloatBuffer().get(out)
|
||||
return out
|
||||
}
|
||||
|
||||
private fun readInt32Stream(input: InputStream): IntArray {
|
||||
val initialSize = max(1024, input.available() / 4)
|
||||
var out = IntArray(initialSize)
|
||||
var count = 0
|
||||
val buffer = ByteBuffer.allocateDirect(64 * 1024)
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
||||
Channels.newChannel(input).use { channel ->
|
||||
while (true) {
|
||||
val read = channel.read(buffer)
|
||||
if (read == -1) break
|
||||
if (read == 0) continue
|
||||
buffer.flip()
|
||||
while (buffer.remaining() >= 4) {
|
||||
if (count == out.size) out = out.copyOf(out.size * 2)
|
||||
out[count++] = buffer.getInt()
|
||||
}
|
||||
buffer.compact()
|
||||
}
|
||||
}
|
||||
buffer.flip()
|
||||
check(buffer.remaining() == 0) { "truncated int32 stream" }
|
||||
return out.copyOf(count)
|
||||
}
|
||||
|
||||
private fun readFloat32Stream(input: InputStream): FloatArray {
|
||||
val initialSize = max(1024, input.available() / 4)
|
||||
var out = FloatArray(initialSize)
|
||||
var count = 0
|
||||
val buffer = ByteBuffer.allocateDirect(64 * 1024)
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
||||
Channels.newChannel(input).use { channel ->
|
||||
while (true) {
|
||||
val read = channel.read(buffer)
|
||||
if (read == -1) break
|
||||
if (read == 0) continue
|
||||
buffer.flip()
|
||||
while (buffer.remaining() >= 4) {
|
||||
if (count == out.size) out = out.copyOf(out.size * 2)
|
||||
out[count++] = buffer.getFloat()
|
||||
}
|
||||
buffer.compact()
|
||||
}
|
||||
}
|
||||
buffer.flip()
|
||||
check(buffer.remaining() == 0) { "truncated float32 stream" }
|
||||
return out.copyOf(count)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user