[update] support asr by aliyun android sdk

This commit is contained in:
2025-08-15 14:48:29 +08:00
parent ae3289dc29
commit 9715a18c65
3 changed files with 316 additions and 51 deletions

View File

@@ -1,31 +1,42 @@
package com.example.ai_chat_assistant;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.util.Log;
import io.flutter.embedding.android.FlutterActivity;
import io.flutter.embedding.engine.FlutterEngine;
import io.flutter.plugin.common.MethodChannel;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.idst.nui.AsrResult;
import com.alibaba.idst.nui.Constants;
import com.alibaba.idst.nui.INativeNuiCallback;
import com.alibaba.idst.nui.INativeStreamInputTtsCallback;
import com.alibaba.idst.nui.KwsResult;
import com.alibaba.idst.nui.NativeNui;
import com.example.ai_chat_assistant.token.AccessToken;
import java.util.Map;
public class MainActivity extends FlutterActivity {
private static final String CHANNEL = "com.example.ai_chat_assistant/tts";
private static final String TAG = "StreamInputTts";
public class MainActivity extends FlutterActivity implements INativeNuiCallback {
private static final String TTS_CHANNEL = "com.example.ai_chat_assistant/tts";
private static final String ASR_CHANNEL = "com.example.ai_chat_assistant/asr";
private static final String TAG = "AliyunSDK";
private static final String APP_KEY = "bXFFc1V65iYbW6EF";
private static final String ACCESS_KEY = "LTAI5t71JHxXRvt2mGuEVz9X";
private static final String ACCESS_KEY_SECRET = "WQOUWvngxmCg4CxIG0qkSlkcH5hrVT";
private static final String URL = "wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1";
private final NativeNui streamInputTtsInstance = new NativeNui(Constants.ModeType.MODE_STREAM_INPUT_TTS);
private final AudioPlayer mAudioTrack = new AudioPlayer(new AudioPlayerCallback() {
private final NativeNui asrInstance = new NativeNui();
private final AudioPlayer ttsAudioTrack = new AudioPlayer(new AudioPlayerCallback() {
@Override
public void playStart() {
Log.i(TAG, "start play");
@@ -41,11 +52,19 @@ public class MainActivity extends FlutterActivity {
}
});
private MethodChannel asrMethodChannel;
private final static int ASR_SAMPLE_RATE = 16000;
private final static int ASR_WAVE_FRAM_SIZE = 20 * 2 * 1 * ASR_SAMPLE_RATE / 1000; //20ms audio for 16k/16bit/mono
private AudioRecord asrAudioRecorder = null;
private boolean asrStopping = false;
private Handler asrHandler;
private String asrText = "";
@Override
public void configureFlutterEngine(FlutterEngine flutterEngine) {
super.configureFlutterEngine(flutterEngine);
new MethodChannel(flutterEngine.getDartExecutor().getBinaryMessenger(), CHANNEL)
System.out.println("??????????????????????????????????");
new MethodChannel(flutterEngine.getDartExecutor().getBinaryMessenger(), TTS_CHANNEL)
.setMethodCallHandler((call, result) -> {
Map<String, Object> args = (Map<String, Object>) call.arguments;
switch (call.method) {
@@ -54,7 +73,7 @@ public class MainActivity extends FlutterActivity {
if (isChinese == null || isChinese.toString().isBlank()) {
return;
}
boolean isSuccess = start(Boolean.parseBoolean(isChinese.toString()));
boolean isSuccess = startTts(Boolean.parseBoolean(isChinese.toString()));
result.success(isSuccess);
break;
case "send":
@@ -62,10 +81,10 @@ public class MainActivity extends FlutterActivity {
if (textArg == null || textArg.toString().isBlank()) {
return;
}
send(textArg.toString());
sendTts(textArg.toString());
break;
case "stop":
stop();
stopTts();
result.success("已停止");
break;
default:
@@ -73,29 +92,61 @@ public class MainActivity extends FlutterActivity {
break;
}
});
asrMethodChannel = new MethodChannel(flutterEngine.getDartExecutor().getBinaryMessenger(), ASR_CHANNEL);
asrMethodChannel.setMethodCallHandler((call, result) -> {
switch (call.method) {
case "start":
startAsr();
break;
case "stop":
stopAsr();
break;
default:
result.notImplemented();
break;
}
});
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
HandlerThread asrHandlerThread = new HandlerThread("process_thread");
asrHandlerThread.start();
asrHandler = new Handler(asrHandlerThread.getLooper());
}
@Override
protected void onStart() {
Log.i(TAG, "onStart");
super.onStart();
asrInstance.initialize(this, genAsrInitParams(),
Constants.LogLevel.LOG_LEVEL_NONE, false);
}
@Override
protected void onStop() {
Log.i(TAG, "onStop");
super.onStop();
asrInstance.release();
}
@Override
protected void onDestroy() {
super.onDestroy();
mAudioTrack.stop();
mAudioTrack.releaseAudioTrack();
ttsAudioTrack.stop();
ttsAudioTrack.releaseAudioTrack();
streamInputTtsInstance.stopStreamInputTts();
}
private boolean start(boolean isChinese) {
private boolean startTts(boolean isChinese) {
int ret = streamInputTtsInstance.startStreamInputTts(new INativeStreamInputTtsCallback() {
@Override
public void onStreamInputTtsEventCallback(INativeStreamInputTtsCallback.StreamInputTtsEvent event, String task_id, String session_id, int ret_code, String error_msg, String timestamp, String all_response) {
Log.i(TAG, "stream input tts event(" + event + ") session id(" + session_id + ") task id(" + task_id + ") retCode(" + ret_code + ") errMsg(" + error_msg + ")");
if (event == StreamInputTtsEvent.STREAM_INPUT_TTS_EVENT_SYNTHESIS_STARTED) {
Log.i(TAG, "STREAM_INPUT_TTS_EVENT_SYNTHESIS_STARTED");
mAudioTrack.play();
ttsAudioTrack.play();
Log.i(TAG, "start play");
} else if (event == StreamInputTtsEvent.STREAM_INPUT_TTS_EVENT_SENTENCE_SYNTHESIS) {
Log.i(TAG, "STREAM_INPUT_TTS_EVENT_SENTENCE_SYNTHESIS:" + timestamp);
@@ -106,7 +157,7 @@ public class MainActivity extends FlutterActivity {
Log.i(TAG, "play end");
// 表示推送完数据, 当播放器播放结束则会有playOver回调
mAudioTrack.isFinishSend(true);
ttsAudioTrack.isFinishSend(true);
if (event == StreamInputTtsEvent.STREAM_INPUT_TTS_EVENT_TASK_FAILED) {
Log.e(TAG, "STREAM_INPUT_TTS_EVENT_TASK_FAILED: " + "error_code(" + ret_code + ") error_message(" + error_msg + ")");
@@ -121,10 +172,10 @@ public class MainActivity extends FlutterActivity {
@Override
public void onStreamInputTtsDataCallback(byte[] data) {
if (data.length > 0) {
mAudioTrack.setAudioData(data);
ttsAudioTrack.setAudioData(data);
}
}
}, genTicket(), genParameters(isChinese), "", Constants.LogLevel.toInt(Constants.LogLevel.LOG_LEVEL_NONE), false);
}, genTtsTicket(), genTtsParameters(isChinese), "", Constants.LogLevel.toInt(Constants.LogLevel.LOG_LEVEL_NONE), false);
if (Constants.NuiResultCode.SUCCESS != ret) {
Log.i(TAG, "start tts failed " + ret);
return false;
@@ -133,16 +184,135 @@ public class MainActivity extends FlutterActivity {
}
}
private void send(String text) {
private void sendTts(String text) {
streamInputTtsInstance.sendStreamInputTts(text);
}
private void stop() {
private void stopTts() {
streamInputTtsInstance.cancelStreamInputTts();
mAudioTrack.stop();
ttsAudioTrack.stop();
}
private String genTicket() {
private void startAsr() {
asrText = "";
if (asrAudioRecorder == null) {
asrAudioRecorder = new AudioRecord(MediaRecorder.AudioSource.DEFAULT,
ASR_SAMPLE_RATE,
AudioFormat.CHANNEL_IN_MONO,
AudioFormat.ENCODING_PCM_16BIT,
ASR_WAVE_FRAM_SIZE * 4);
Log.d(TAG, "AudioRecorder new ...");
} else {
Log.w(TAG, "AudioRecord has been new ...");
}
asrHandler.post(() -> {
String setParamsString = genAsrParams();
Log.i(TAG, "nui set params " + setParamsString);
asrInstance.setParams(setParamsString);
int ret = asrInstance.startDialog(Constants.VadMode.TYPE_P2T, genAsrDialogParams());
Log.i(TAG, "start done with " + ret);
});
}
private void stopAsr() {
asrHandler.post(() -> {
asrStopping = true;
long ret = asrInstance.stopDialog();
runOnUiThread(() -> asrMethodChannel.invokeMethod("onAsrStop", null));
Log.i(TAG, "cancel dialog " + ret + " end");
});
}
@Override
public void onNuiEventCallback(Constants.NuiEvent event, final int resultCode,
final int arg2, KwsResult kwsResult,
AsrResult asrResult) {
Log.i(TAG, "event=" + event + " resultCode=" + resultCode);
if (event == Constants.NuiEvent.EVENT_TRANSCRIBER_STARTED) {
} else if (event == Constants.NuiEvent.EVENT_TRANSCRIBER_COMPLETE) {
asrStopping = false;
} else if (event == Constants.NuiEvent.EVENT_ASR_PARTIAL_RESULT) {
JSONObject jsonObject = JSON.parseObject(asrResult.allResponse);
JSONObject payload = jsonObject.getJSONObject("payload");
String result = payload.getString("result");
if (asrMethodChannel != null && result != null && !result.isBlank()) {
runOnUiThread(() -> asrMethodChannel.invokeMethod("onAsrResult", asrText + result));
}
} else if (event == Constants.NuiEvent.EVENT_SENTENCE_END) {
JSONObject jsonObject = JSON.parseObject(asrResult.allResponse);
JSONObject payload = jsonObject.getJSONObject("payload");
String result = payload.getString("result");
if (asrMethodChannel != null && result != null && !result.isBlank()) {
asrText += result;
runOnUiThread(() -> asrMethodChannel.invokeMethod("onAsrResult", asrText));
}
} else if (event == Constants.NuiEvent.EVENT_VAD_START) {
} else if (event == Constants.NuiEvent.EVENT_VAD_END) {
} else if (event == Constants.NuiEvent.EVENT_ASR_ERROR) {
asrStopping = false;
} else if (event == Constants.NuiEvent.EVENT_MIC_ERROR) {
asrStopping = false;
} else if (event == Constants.NuiEvent.EVENT_DIALOG_EX) { /* unused */
Log.i(TAG, "dialog extra message = " + asrResult.asrResult);
}
}
//当调用NativeNui的start后会一定时间反复回调该接口底层会提供buffer并告知这次需要数据的长度
//返回值告知底层读了多少数据应该尽量保证return的长度等于需要的长度如果返回<=0则表示出错
@Override
public int onNuiNeedAudioData(byte[] buffer, int len) {
if (asrAudioRecorder == null) {
return -1;
}
if (asrAudioRecorder.getState() != AudioRecord.STATE_INITIALIZED) {
Log.e(TAG, "audio recorder not init");
return -1;
}
return asrAudioRecorder.read(buffer, 0, len);
}
//当录音状态发送变化的时候调用
@Override
public void onNuiAudioStateChanged(Constants.AudioState state) {
Log.i(TAG, "onNuiAudioStateChanged");
if (state == Constants.AudioState.STATE_OPEN) {
Log.i(TAG, "audio recorder start");
if (asrAudioRecorder != null) {
asrAudioRecorder.startRecording();
}
Log.i(TAG, "audio recorder start done");
} else if (state == Constants.AudioState.STATE_CLOSE) {
Log.i(TAG, "audio recorder close");
if (asrAudioRecorder != null) {
asrAudioRecorder.release();
}
} else if (state == Constants.AudioState.STATE_PAUSE) {
Log.i(TAG, "audio recorder pause");
if (asrAudioRecorder != null) {
asrAudioRecorder.stop();
}
}
}
@Override
public void onNuiAudioRMSChanged(float val) {
// Log.i(TAG, "onNuiAudioRMSChanged vol " + val);
}
@Override
public void onNuiVprEventCallback(Constants.NuiVprEvent event) {
Log.i(TAG, "onNuiVprEventCallback event " + event);
}
@Override
public void onNuiLogTrackCallback(Constants.LogLevel level, String log) {
Log.i(TAG, "onNuiLogTrackCallback log level:" + level + ", message -> " + log);
}
private String genTtsTicket() {
String str = "";
try {
Auth.GetTicketMethod method = Auth.GetTicketMethod.GET_ACCESS_IN_CLIENT_FOR_ONLINE_FEATURES;
@@ -163,7 +333,7 @@ public class MainActivity extends FlutterActivity {
return str;
}
private String genParameters(boolean isChinese) {
private String genTtsParameters(boolean isChinese) {
String str = "";
try {
JSONObject object = new JSONObject();
@@ -180,4 +350,87 @@ public class MainActivity extends FlutterActivity {
Log.i(TAG, "user parameters:" + str);
return str;
}
private String genAsrInitParams() {
String str = "";
try {
Auth.GetTicketMethod method = Auth.GetTicketMethod.GET_ACCESS_IN_CLIENT_FOR_ONLINE_FEATURES;
Auth.setAppKey(APP_KEY);
Auth.setAccessKey(ACCESS_KEY);
Auth.setAccessKeySecret(ACCESS_KEY_SECRET);
Log.i(TAG, "Use method:" + method);
JSONObject object = Auth.getTicket(method);
if (!object.containsKey("token")) {
Log.e(TAG, "Cannot get token !!!");
}
object.put("device_id", "empty_device_id"); // 必填, 推荐填入具有唯一性的id, 方便定位问题
object.put("url", URL);
object.put("service_mode", Constants.ModeFullCloud); // 必填
str = object.toString();
} catch (JSONException e) {
e.printStackTrace();
}
Log.i(TAG, "InsideUserContext:" + str);
return str;
}
private String genAsrParams() {
String params = "";
try {
JSONObject nls_config = new JSONObject();
//参数可根据实际业务进行配置
//接口说明可见https://help.aliyun.com/document_detail/173528.html
//查看 2.开始识别
// 是否返回中间识别结果默认值False。
nls_config.put("enable_intermediate_result", true);
// 是否在后处理中添加标点默认值False。
nls_config.put("enable_punctuation_prediction", true);
nls_config.put("sample_rate", 16000);
nls_config.put("sr_format", "pcm");
// nls_config.put("enable_inverse_text_normalization", true);
// nls_config.put("max_sentence_silence", 800);
// nls_config.put("enable_words", false);
/*若文档中不包含某些参数,但是此功能支持这个参数,可以用如下万能接口设置参数*/
// JSONObject extend_config = new JSONObject();
// extend_config.put("custom_test", true);
// nls_config.put("extend_config", extend_config);
JSONObject tmp = new JSONObject();
tmp.put("nls_config", nls_config);
tmp.put("service_type", Constants.kServiceTypeSpeechTranscriber); // 必填
// 如果有HttpDns则可进行设置
// tmp.put("direct_ip", Utils.getDirectIp());
params = tmp.toString();
} catch (JSONException e) {
e.printStackTrace();
}
return params;
}
private String genAsrDialogParams() {
String params = "";
try {
JSONObject dialog_param = new JSONObject();
// 运行过程中可以在startDialog时更新临时参数尤其是更新过期token
// 注意: 若下一轮对话不再设置参数,则继续使用初始化时传入的参数
long distance_expire_time_30m = 1800;
dialog_param = Auth.refreshTokenIfNeed(dialog_param, distance_expire_time_30m);
// 注意: 若需要更换appkey和token可以直接传入参数
// dialog_param.put("app_key", "");
// dialog_param.put("token", "");
params = dialog_param.toString();
} catch (JSONException e) {
e.printStackTrace();
}
Log.i(TAG, "dialog params: " + params);
return params;
}
}

View File

@@ -4,12 +4,10 @@ import 'dart:io';
import 'dart:math';
import 'package:ai_chat_assistant/utils/tts_util.dart';
import 'package:flutter/services.dart';
import '../utils/common_util.dart';
class ChatSseService {
static const MethodChannel _channel = MethodChannel('com.example.ai_chat_assistant/tts');
// 缓存用户ID和会话ID
String? _cachedUserId;
String? _cachedConversationId;

View File

@@ -1,9 +1,8 @@
import 'dart:io';
import 'package:ai_chat_assistant/utils/common_util.dart';
import 'package:ai_chat_assistant/utils/tts_util.dart';
import 'package:basic_intl/intl.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'package:permission_handler/permission_handler.dart';
import 'package:uuid/uuid.dart';
import '../enums/vehicle_command_type.dart';
@@ -14,22 +13,35 @@ import '../models/vehicle_cmd.dart';
import '../services/chat_sse_service.dart';
import '../services/classification_service.dart';
import '../services/control_recognition_service.dart';
import '../services/audio_recorder_service.dart';
import '../services/voice_recognition_service.dart';
// import '../services/audio_recorder_service.dart';
// import '../services/voice_recognition_service.dart';
import 'command_service.dart';
import 'package:fluttertoast/fluttertoast.dart';
class MessageService extends ChangeNotifier {
static const MethodChannel _channel = MethodChannel('com.example.ai_chat_assistant/asr');
static final MessageService _instance = MessageService._internal();
factory MessageService() => _instance;
MessageService._internal();
MessageService._internal() {
_channel.setMethodCallHandler((call) async {
switch (call.method) {
case "onAsrResult":
replaceMessage(id: _latestUserMessageId!, text: call.arguments);
break;
case "onAsrStop":
replaceMessage(
id: _latestUserMessageId!, status: MessageStatus.normal);
}
});
}
final ChatSseService _chatSseService = ChatSseService();
// final LocalTtsService _ttsService = LocalTtsService();
final AudioRecorderService _audioService = AudioRecorderService();
final VoiceRecognitionService _recognitionService = VoiceRecognitionService();
// final AudioRecorderService _audioService = AudioRecorderService();
// final VoiceRecognitionService _recognitionService = VoiceRecognitionService();
final TextClassificationService _classificationService =
TextClassificationService();
final VehicleCommandService _vehicleCommandService = VehicleCommandService();
@@ -78,9 +90,9 @@ class MessageService extends ChangeNotifier {
_latestAssistantMessageId = null;
_isReplyAborted = false;
changeState(MessageServiceState.recording);
await _audioService.startRecording();
addMessage("", true, MessageStatus.listening);
_latestUserMessageId = messages.last.id;
_channel.invokeMethod("start");
} catch (e) {
print('录音开始出错: $e');
}
@@ -108,25 +120,27 @@ class MessageService extends ChangeNotifier {
}
try {
changeState(MessageServiceState.recognizing);
final audioData = await _audioService.stopRecording();
replaceMessage(
id: _latestUserMessageId!,
text: "",
status: MessageStatus.recognizing);
if (audioData == null || audioData.isEmpty) {
removeMessageById(_latestUserMessageId!);
return;
}
final recognizedText =
await _recognitionService.recognizeSpeech(audioData);
if (recognizedText == null || recognizedText.isEmpty) {
removeMessageById(_latestUserMessageId!);
return;
}
replaceMessage(
id: _latestUserMessageId!,
text: recognizedText,
status: MessageStatus.normal);
_channel.invokeMethod("stop");
final recognizedText = messages.last.text;
// final audioData = await _audioService.stopRecording();
// replaceMessage(
// id: _latestUserMessageId!,
// text: "",
// status: MessageStatus.recognizing);
// if (audioData == null || audioData.isEmpty) {
// removeMessageById(_latestUserMessageId!);
// return;
// }
// final recognizedText =
// await _recognitionService.recognizeSpeech(audioData);
// if (recognizedText == null || recognizedText.isEmpty) {
// removeMessageById(_latestUserMessageId!);
// return;
// }
// replaceMessage(
// id: _latestUserMessageId!,
// text: recognizedText,
// status: MessageStatus.normal);
changeState(MessageServiceState.replying);
await reply(recognizedText);
} catch (e) {