diff --git a/android/app/src/main/java/com/example/ai_chat_assistant/MainActivity.java b/android/app/src/main/java/com/example/ai_chat_assistant/MainActivity.java index f9976c7..446a12d 100644 --- a/android/app/src/main/java/com/example/ai_chat_assistant/MainActivity.java +++ b/android/app/src/main/java/com/example/ai_chat_assistant/MainActivity.java @@ -1,31 +1,42 @@ package com.example.ai_chat_assistant; +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.media.MediaRecorder; import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; import android.util.Log; import io.flutter.embedding.android.FlutterActivity; import io.flutter.embedding.engine.FlutterEngine; import io.flutter.plugin.common.MethodChannel; +import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONException; import com.alibaba.fastjson.JSONObject; +import com.alibaba.idst.nui.AsrResult; import com.alibaba.idst.nui.Constants; +import com.alibaba.idst.nui.INativeNuiCallback; import com.alibaba.idst.nui.INativeStreamInputTtsCallback; +import com.alibaba.idst.nui.KwsResult; import com.alibaba.idst.nui.NativeNui; -import com.example.ai_chat_assistant.token.AccessToken; import java.util.Map; -public class MainActivity extends FlutterActivity { - private static final String CHANNEL = "com.example.ai_chat_assistant/tts"; - private static final String TAG = "StreamInputTts"; +public class MainActivity extends FlutterActivity implements INativeNuiCallback { + private static final String TTS_CHANNEL = "com.example.ai_chat_assistant/tts"; + private static final String ASR_CHANNEL = "com.example.ai_chat_assistant/asr"; + private static final String TAG = "AliyunSDK"; private static final String APP_KEY = "bXFFc1V65iYbW6EF"; private static final String ACCESS_KEY = "LTAI5t71JHxXRvt2mGuEVz9X"; private static final String ACCESS_KEY_SECRET = "WQOUWvngxmCg4CxIG0qkSlkcH5hrVT"; private static final String URL = "wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"; private final NativeNui streamInputTtsInstance = new NativeNui(Constants.ModeType.MODE_STREAM_INPUT_TTS); - private final AudioPlayer mAudioTrack = new AudioPlayer(new AudioPlayerCallback() { + private final NativeNui asrInstance = new NativeNui(); + + private final AudioPlayer ttsAudioTrack = new AudioPlayer(new AudioPlayerCallback() { @Override public void playStart() { Log.i(TAG, "start play"); @@ -41,11 +52,19 @@ public class MainActivity extends FlutterActivity { } }); + private MethodChannel asrMethodChannel; + private final static int ASR_SAMPLE_RATE = 16000; + private final static int ASR_WAVE_FRAM_SIZE = 20 * 2 * 1 * ASR_SAMPLE_RATE / 1000; //20ms audio for 16k/16bit/mono + private AudioRecord asrAudioRecorder = null; + private boolean asrStopping = false; + private Handler asrHandler; + private String asrText = ""; + @Override public void configureFlutterEngine(FlutterEngine flutterEngine) { super.configureFlutterEngine(flutterEngine); - - new MethodChannel(flutterEngine.getDartExecutor().getBinaryMessenger(), CHANNEL) + System.out.println("??????????????????????????????????"); + new MethodChannel(flutterEngine.getDartExecutor().getBinaryMessenger(), TTS_CHANNEL) .setMethodCallHandler((call, result) -> { Map args = (Map) call.arguments; switch (call.method) { @@ -54,7 +73,7 @@ public class MainActivity extends FlutterActivity { if (isChinese == null || isChinese.toString().isBlank()) { return; } - boolean isSuccess = start(Boolean.parseBoolean(isChinese.toString())); + boolean isSuccess = startTts(Boolean.parseBoolean(isChinese.toString())); result.success(isSuccess); break; case "send": @@ -62,10 +81,10 @@ public class MainActivity extends FlutterActivity { if (textArg == null || textArg.toString().isBlank()) { return; } - send(textArg.toString()); + sendTts(textArg.toString()); break; case "stop": - stop(); + stopTts(); result.success("已停止"); break; default: @@ -73,29 +92,61 @@ public class MainActivity extends FlutterActivity { break; } }); + asrMethodChannel = new MethodChannel(flutterEngine.getDartExecutor().getBinaryMessenger(), ASR_CHANNEL); + asrMethodChannel.setMethodCallHandler((call, result) -> { + switch (call.method) { + case "start": + startAsr(); + break; + case "stop": + stopAsr(); + break; + default: + result.notImplemented(); + break; + } + }); } @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); + HandlerThread asrHandlerThread = new HandlerThread("process_thread"); + asrHandlerThread.start(); + asrHandler = new Handler(asrHandlerThread.getLooper()); + } + + @Override + protected void onStart() { + Log.i(TAG, "onStart"); + super.onStart(); + asrInstance.initialize(this, genAsrInitParams(), + Constants.LogLevel.LOG_LEVEL_NONE, false); + } + + @Override + protected void onStop() { + Log.i(TAG, "onStop"); + super.onStop(); + asrInstance.release(); } @Override protected void onDestroy() { super.onDestroy(); - mAudioTrack.stop(); - mAudioTrack.releaseAudioTrack(); + ttsAudioTrack.stop(); + ttsAudioTrack.releaseAudioTrack(); streamInputTtsInstance.stopStreamInputTts(); } - private boolean start(boolean isChinese) { + private boolean startTts(boolean isChinese) { int ret = streamInputTtsInstance.startStreamInputTts(new INativeStreamInputTtsCallback() { @Override public void onStreamInputTtsEventCallback(INativeStreamInputTtsCallback.StreamInputTtsEvent event, String task_id, String session_id, int ret_code, String error_msg, String timestamp, String all_response) { Log.i(TAG, "stream input tts event(" + event + ") session id(" + session_id + ") task id(" + task_id + ") retCode(" + ret_code + ") errMsg(" + error_msg + ")"); if (event == StreamInputTtsEvent.STREAM_INPUT_TTS_EVENT_SYNTHESIS_STARTED) { Log.i(TAG, "STREAM_INPUT_TTS_EVENT_SYNTHESIS_STARTED"); - mAudioTrack.play(); + ttsAudioTrack.play(); Log.i(TAG, "start play"); } else if (event == StreamInputTtsEvent.STREAM_INPUT_TTS_EVENT_SENTENCE_SYNTHESIS) { Log.i(TAG, "STREAM_INPUT_TTS_EVENT_SENTENCE_SYNTHESIS:" + timestamp); @@ -106,7 +157,7 @@ public class MainActivity extends FlutterActivity { Log.i(TAG, "play end"); // 表示推送完数据, 当播放器播放结束则会有playOver回调 - mAudioTrack.isFinishSend(true); + ttsAudioTrack.isFinishSend(true); if (event == StreamInputTtsEvent.STREAM_INPUT_TTS_EVENT_TASK_FAILED) { Log.e(TAG, "STREAM_INPUT_TTS_EVENT_TASK_FAILED: " + "error_code(" + ret_code + ") error_message(" + error_msg + ")"); @@ -121,10 +172,10 @@ public class MainActivity extends FlutterActivity { @Override public void onStreamInputTtsDataCallback(byte[] data) { if (data.length > 0) { - mAudioTrack.setAudioData(data); + ttsAudioTrack.setAudioData(data); } } - }, genTicket(), genParameters(isChinese), "", Constants.LogLevel.toInt(Constants.LogLevel.LOG_LEVEL_NONE), false); + }, genTtsTicket(), genTtsParameters(isChinese), "", Constants.LogLevel.toInt(Constants.LogLevel.LOG_LEVEL_NONE), false); if (Constants.NuiResultCode.SUCCESS != ret) { Log.i(TAG, "start tts failed " + ret); return false; @@ -133,16 +184,135 @@ public class MainActivity extends FlutterActivity { } } - private void send(String text) { + private void sendTts(String text) { streamInputTtsInstance.sendStreamInputTts(text); } - private void stop() { + private void stopTts() { streamInputTtsInstance.cancelStreamInputTts(); - mAudioTrack.stop(); + ttsAudioTrack.stop(); } - private String genTicket() { + private void startAsr() { + asrText = ""; + if (asrAudioRecorder == null) { + asrAudioRecorder = new AudioRecord(MediaRecorder.AudioSource.DEFAULT, + ASR_SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + ASR_WAVE_FRAM_SIZE * 4); + Log.d(TAG, "AudioRecorder new ..."); + } else { + Log.w(TAG, "AudioRecord has been new ..."); + } + asrHandler.post(() -> { + String setParamsString = genAsrParams(); + Log.i(TAG, "nui set params " + setParamsString); + asrInstance.setParams(setParamsString); + int ret = asrInstance.startDialog(Constants.VadMode.TYPE_P2T, genAsrDialogParams()); + Log.i(TAG, "start done with " + ret); + }); + } + + private void stopAsr() { + asrHandler.post(() -> { + asrStopping = true; + long ret = asrInstance.stopDialog(); + runOnUiThread(() -> asrMethodChannel.invokeMethod("onAsrStop", null)); + Log.i(TAG, "cancel dialog " + ret + " end"); + }); + } + + @Override + public void onNuiEventCallback(Constants.NuiEvent event, final int resultCode, + final int arg2, KwsResult kwsResult, + AsrResult asrResult) { + Log.i(TAG, "event=" + event + " resultCode=" + resultCode); + if (event == Constants.NuiEvent.EVENT_TRANSCRIBER_STARTED) { + + } else if (event == Constants.NuiEvent.EVENT_TRANSCRIBER_COMPLETE) { + asrStopping = false; + } else if (event == Constants.NuiEvent.EVENT_ASR_PARTIAL_RESULT) { + JSONObject jsonObject = JSON.parseObject(asrResult.allResponse); + JSONObject payload = jsonObject.getJSONObject("payload"); + String result = payload.getString("result"); + if (asrMethodChannel != null && result != null && !result.isBlank()) { + runOnUiThread(() -> asrMethodChannel.invokeMethod("onAsrResult", asrText + result)); + } + } else if (event == Constants.NuiEvent.EVENT_SENTENCE_END) { + JSONObject jsonObject = JSON.parseObject(asrResult.allResponse); + JSONObject payload = jsonObject.getJSONObject("payload"); + String result = payload.getString("result"); + if (asrMethodChannel != null && result != null && !result.isBlank()) { + asrText += result; + runOnUiThread(() -> asrMethodChannel.invokeMethod("onAsrResult", asrText)); + } + } else if (event == Constants.NuiEvent.EVENT_VAD_START) { + + } else if (event == Constants.NuiEvent.EVENT_VAD_END) { + + } else if (event == Constants.NuiEvent.EVENT_ASR_ERROR) { + asrStopping = false; + } else if (event == Constants.NuiEvent.EVENT_MIC_ERROR) { + asrStopping = false; + } else if (event == Constants.NuiEvent.EVENT_DIALOG_EX) { /* unused */ + Log.i(TAG, "dialog extra message = " + asrResult.asrResult); + } + } + + //当调用NativeNui的start后,会一定时间反复回调该接口,底层会提供buffer并告知这次需要数据的长度 + //返回值告知底层读了多少数据,应该尽量保证return的长度等于需要的长度,如果返回<=0,则表示出错 + @Override + public int onNuiNeedAudioData(byte[] buffer, int len) { + if (asrAudioRecorder == null) { + return -1; + } + if (asrAudioRecorder.getState() != AudioRecord.STATE_INITIALIZED) { + Log.e(TAG, "audio recorder not init"); + return -1; + } + return asrAudioRecorder.read(buffer, 0, len); + } + + //当录音状态发送变化的时候调用 + @Override + public void onNuiAudioStateChanged(Constants.AudioState state) { + Log.i(TAG, "onNuiAudioStateChanged"); + if (state == Constants.AudioState.STATE_OPEN) { + Log.i(TAG, "audio recorder start"); + if (asrAudioRecorder != null) { + asrAudioRecorder.startRecording(); + } + Log.i(TAG, "audio recorder start done"); + } else if (state == Constants.AudioState.STATE_CLOSE) { + Log.i(TAG, "audio recorder close"); + if (asrAudioRecorder != null) { + asrAudioRecorder.release(); + } + } else if (state == Constants.AudioState.STATE_PAUSE) { + Log.i(TAG, "audio recorder pause"); + if (asrAudioRecorder != null) { + asrAudioRecorder.stop(); + } + } + } + + @Override + public void onNuiAudioRMSChanged(float val) { +// Log.i(TAG, "onNuiAudioRMSChanged vol " + val); + } + + @Override + public void onNuiVprEventCallback(Constants.NuiVprEvent event) { + Log.i(TAG, "onNuiVprEventCallback event " + event); + } + + @Override + public void onNuiLogTrackCallback(Constants.LogLevel level, String log) { + Log.i(TAG, "onNuiLogTrackCallback log level:" + level + ", message -> " + log); + } + + private String genTtsTicket() { String str = ""; try { Auth.GetTicketMethod method = Auth.GetTicketMethod.GET_ACCESS_IN_CLIENT_FOR_ONLINE_FEATURES; @@ -163,7 +333,7 @@ public class MainActivity extends FlutterActivity { return str; } - private String genParameters(boolean isChinese) { + private String genTtsParameters(boolean isChinese) { String str = ""; try { JSONObject object = new JSONObject(); @@ -180,4 +350,87 @@ public class MainActivity extends FlutterActivity { Log.i(TAG, "user parameters:" + str); return str; } + + private String genAsrInitParams() { + String str = ""; + try { + Auth.GetTicketMethod method = Auth.GetTicketMethod.GET_ACCESS_IN_CLIENT_FOR_ONLINE_FEATURES; + Auth.setAppKey(APP_KEY); + Auth.setAccessKey(ACCESS_KEY); + Auth.setAccessKeySecret(ACCESS_KEY_SECRET); + Log.i(TAG, "Use method:" + method); + JSONObject object = Auth.getTicket(method); + if (!object.containsKey("token")) { + Log.e(TAG, "Cannot get token !!!"); + } + object.put("device_id", "empty_device_id"); // 必填, 推荐填入具有唯一性的id, 方便定位问题 + object.put("url", URL); + object.put("service_mode", Constants.ModeFullCloud); // 必填 + str = object.toString(); + } catch (JSONException e) { + e.printStackTrace(); + } + Log.i(TAG, "InsideUserContext:" + str); + return str; + } + + private String genAsrParams() { + String params = ""; + try { + JSONObject nls_config = new JSONObject(); + + //参数可根据实际业务进行配置 + //接口说明可见https://help.aliyun.com/document_detail/173528.html + //查看 2.开始识别 + + // 是否返回中间识别结果,默认值:False。 + nls_config.put("enable_intermediate_result", true); + // 是否在后处理中添加标点,默认值:False。 + nls_config.put("enable_punctuation_prediction", true); + + nls_config.put("sample_rate", 16000); + nls_config.put("sr_format", "pcm"); +// nls_config.put("enable_inverse_text_normalization", true); +// nls_config.put("max_sentence_silence", 800); +// nls_config.put("enable_words", false); + + /*若文档中不包含某些参数,但是此功能支持这个参数,可以用如下万能接口设置参数*/ +// JSONObject extend_config = new JSONObject(); +// extend_config.put("custom_test", true); +// nls_config.put("extend_config", extend_config); + + JSONObject tmp = new JSONObject(); + tmp.put("nls_config", nls_config); + tmp.put("service_type", Constants.kServiceTypeSpeechTranscriber); // 必填 + +// 如果有HttpDns则可进行设置 +// tmp.put("direct_ip", Utils.getDirectIp()); + + params = tmp.toString(); + } catch (JSONException e) { + e.printStackTrace(); + } + return params; + } + + private String genAsrDialogParams() { + String params = ""; + try { + JSONObject dialog_param = new JSONObject(); + // 运行过程中可以在startDialog时更新临时参数,尤其是更新过期token + // 注意: 若下一轮对话不再设置参数,则继续使用初始化时传入的参数 + long distance_expire_time_30m = 1800; + dialog_param = Auth.refreshTokenIfNeed(dialog_param, distance_expire_time_30m); + + // 注意: 若需要更换appkey和token,可以直接传入参数 +// dialog_param.put("app_key", ""); +// dialog_param.put("token", ""); + params = dialog_param.toString(); + } catch (JSONException e) { + e.printStackTrace(); + } + + Log.i(TAG, "dialog params: " + params); + return params; + } } \ No newline at end of file diff --git a/lib/services/chat_sse_service.dart b/lib/services/chat_sse_service.dart index 8f52997..88d77cd 100644 --- a/lib/services/chat_sse_service.dart +++ b/lib/services/chat_sse_service.dart @@ -4,12 +4,10 @@ import 'dart:io'; import 'dart:math'; import 'package:ai_chat_assistant/utils/tts_util.dart'; -import 'package:flutter/services.dart'; import '../utils/common_util.dart'; class ChatSseService { - static const MethodChannel _channel = MethodChannel('com.example.ai_chat_assistant/tts'); // 缓存用户ID和会话ID String? _cachedUserId; String? _cachedConversationId; diff --git a/lib/services/message_service.dart b/lib/services/message_service.dart index 6ab5331..c5c3dcf 100644 --- a/lib/services/message_service.dart +++ b/lib/services/message_service.dart @@ -1,9 +1,8 @@ -import 'dart:io'; - import 'package:ai_chat_assistant/utils/common_util.dart'; import 'package:ai_chat_assistant/utils/tts_util.dart'; import 'package:basic_intl/intl.dart'; import 'package:flutter/foundation.dart'; +import 'package:flutter/services.dart'; import 'package:permission_handler/permission_handler.dart'; import 'package:uuid/uuid.dart'; import '../enums/vehicle_command_type.dart'; @@ -14,22 +13,35 @@ import '../models/vehicle_cmd.dart'; import '../services/chat_sse_service.dart'; import '../services/classification_service.dart'; import '../services/control_recognition_service.dart'; -import '../services/audio_recorder_service.dart'; -import '../services/voice_recognition_service.dart'; +// import '../services/audio_recorder_service.dart'; +// import '../services/voice_recognition_service.dart'; import 'command_service.dart'; import 'package:fluttertoast/fluttertoast.dart'; class MessageService extends ChangeNotifier { + static const MethodChannel _channel = MethodChannel('com.example.ai_chat_assistant/asr'); + static final MessageService _instance = MessageService._internal(); factory MessageService() => _instance; - MessageService._internal(); + MessageService._internal() { + _channel.setMethodCallHandler((call) async { + switch (call.method) { + case "onAsrResult": + replaceMessage(id: _latestUserMessageId!, text: call.arguments); + break; + case "onAsrStop": + replaceMessage( + id: _latestUserMessageId!, status: MessageStatus.normal); + } + }); + } final ChatSseService _chatSseService = ChatSseService(); // final LocalTtsService _ttsService = LocalTtsService(); - final AudioRecorderService _audioService = AudioRecorderService(); - final VoiceRecognitionService _recognitionService = VoiceRecognitionService(); + // final AudioRecorderService _audioService = AudioRecorderService(); + // final VoiceRecognitionService _recognitionService = VoiceRecognitionService(); final TextClassificationService _classificationService = TextClassificationService(); final VehicleCommandService _vehicleCommandService = VehicleCommandService(); @@ -78,9 +90,9 @@ class MessageService extends ChangeNotifier { _latestAssistantMessageId = null; _isReplyAborted = false; changeState(MessageServiceState.recording); - await _audioService.startRecording(); addMessage("", true, MessageStatus.listening); _latestUserMessageId = messages.last.id; + _channel.invokeMethod("start"); } catch (e) { print('录音开始出错: $e'); } @@ -108,25 +120,27 @@ class MessageService extends ChangeNotifier { } try { changeState(MessageServiceState.recognizing); - final audioData = await _audioService.stopRecording(); - replaceMessage( - id: _latestUserMessageId!, - text: "", - status: MessageStatus.recognizing); - if (audioData == null || audioData.isEmpty) { - removeMessageById(_latestUserMessageId!); - return; - } - final recognizedText = - await _recognitionService.recognizeSpeech(audioData); - if (recognizedText == null || recognizedText.isEmpty) { - removeMessageById(_latestUserMessageId!); - return; - } - replaceMessage( - id: _latestUserMessageId!, - text: recognizedText, - status: MessageStatus.normal); + _channel.invokeMethod("stop"); + final recognizedText = messages.last.text; + // final audioData = await _audioService.stopRecording(); + // replaceMessage( + // id: _latestUserMessageId!, + // text: "", + // status: MessageStatus.recognizing); + // if (audioData == null || audioData.isEmpty) { + // removeMessageById(_latestUserMessageId!); + // return; + // } + // final recognizedText = + // await _recognitionService.recognizeSpeech(audioData); + // if (recognizedText == null || recognizedText.isEmpty) { + // removeMessageById(_latestUserMessageId!); + // return; + // } + // replaceMessage( + // id: _latestUserMessageId!, + // text: recognizedText, + // status: MessageStatus.normal); changeState(MessageServiceState.replying); await reply(recognizedText); } catch (e) {