From cbe01b6f96cd92f7fed11c480ad78c26905b238d Mon Sep 17 00:00:00 2001 From: yinhew <46698869+yinhew@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:06:29 +0800 Subject: [PATCH] [TTS Avatar][Live][Python] Update sample code to add option for client to communicate with server through WebSocket, and do STT on server side (#2597) * [TalkingAvatar] Add sample code for TTS talking avatar real-time API * sample codes for batch avatar synthesis * Address repository check failure * update * [Avatar] Update real time avatar sample code to support multi-lingual * [avatar] update real time avatar chat sample to receive GPT response streamingly * [Live Avatar] update chat sample to make some refinements * [TTS Avatar] Update real-time sample to support 1. non-continuous recognition mode 2. a button to stop speaking 3. user can type query without speech * [TTS Avatar] Update real time avatar sample to support auto-reconnect * Don't reset message history when re-connecting * [talking avatar] update real time sample to support using cached local video for idle status, to help save customer cost * Update chat.html and README.md * Update batch avatar sample to use mp4 as default format, to avoid defaultly showing slow speed with vp9 * A minor refinement * Some refinement * Some bug fixing * Refine the reponse receiving logic for AOAI streaming mode, to make it more robust * [Talking Avatar] update real-time sample code to log result id (turn id) for ease of debugging * [Talking Avatar] Update avatar live chat sample, to upgrade AOAI API version from 2023-03-15-preview to 2023-12-01-preview * [Talking Avatar][Live Chat] Update AOAI API to be long term support version 2023-06-01-preview * [Talking Avatar] Add real time avatar sample code for server/client hybrid web app, with server code written in python * Some refinements * Add README.md * Fix repo check failure: files that are neither marked as binary nor text, please extend .gitattributes * [Python][TTS Avatar] Add chat sample * [Python][TTS Avatar] Add chat sample - continue * Support multiple clients management * Update README.md * [Python][TTS Avatar] Support customized ICE server * [Talking Avatar][Python] Support stop speaking * Tolerat speech sdk to unsupport sending message with connection * [Python][TTS Avatar] Send local SDP as post body instead of header, to avoid header size over limit * [python][avatar] update requirements.txt to add the missing dependencies * [python][avatar] update real-time sample to make auto-connection more smoothy * [Python][Avatar] Fix some small bugs * [python][avatar] Support AAD authorization on private endpoint * [Java][Android][Avatar] Add Android sample code for real time avatar * Code refinement * More refinement * More refinement * Update README.md * [Java][Android][Avatar] Remove AddStream method, which is not available with Unified Plan SDP semantics, and use AddTrack per suggestion * [Python][Avatar][Live] Get speaking status from WebRTC event, and remove the checkSpeakingStatus API from backend code * [Java][Android][Live Avatar] Update the sample to demonstrate switching audio output device to loud speaker * [Python][Avatar][Live] Switch from REST API to SDK for calling AOAI * [Python][Avatar][Live] Trigger barging at first recognizing event which is earlier * [Python][Avatar][Live] Enable continuous conversation by default * [Python][Avatar][Live] Disable multi-lingual by default for better latency * [Python][Avatar][Live] Configure shorter segmentation silence timeout for quicker SR * [Live Avatar][Python, CSharp] Add logging for latency * [TTS Avatar][Live][Python, CSharp, JS] Fix a bug to correctly clean up audio player * [TTS Avatar][Live][JavaScript] Output display text with a slower rate, to follow the avatar speaking progress * Make the display text / speech alignment able for on/off * [TTS Avatar][Live][CSharp] Output display text with a slower rate, to follow the avatar speaking progress * Create an auto-deploy file * Unlink the containerApp yinhew-avatar-app from this repo * Delete unnecessary file * [talking avatar][python] Update real time sample to add option to connect with server through WebSocket, and do STT on server side --------- Co-authored-by: Yulin Li --- .../avatar/Controllers/AvatarController.cs | 51 +++-- .../web/avatar/Models/ClientSettings.cs | 4 + .../csharp/web/avatar/Views/Home/chat.cshtml | 2 +- samples/js/browser/avatar/js/chat.js | 24 ++- samples/python/web/avatar/app.py | 174 +++++++++++++++++- samples/python/web/avatar/chat.html | 2 + samples/python/web/avatar/requirements.txt | 1 + samples/python/web/avatar/static/js/chat.js | 170 ++++++++++++++++- 8 files changed, 402 insertions(+), 26 deletions(-) diff --git a/samples/csharp/web/avatar/Controllers/AvatarController.cs b/samples/csharp/web/avatar/Controllers/AvatarController.cs index 28e9700db..85d9d8245 100644 --- a/samples/csharp/web/avatar/Controllers/AvatarController.cs +++ b/samples/csharp/web/avatar/Controllers/AvatarController.cs @@ -153,9 +153,17 @@ public async Task ConnectAvatar() speechConfig.EndpointId = customVoiceEndpointId; } - var speechSynthesizer = new SpeechSynthesizer(speechConfig); + var speechSynthesizer = new SpeechSynthesizer(speechConfig, null); clientContext.SpeechSynthesizer = speechSynthesizer; + if (ClientSettings.EnableAudioAudit) + { + speechSynthesizer.Synthesizing += (o, e) => + { + Console.WriteLine($"Audio chunk received: {e.Result.AudioData.Length} bytes."); + }; + } + if (string.IsNullOrEmpty(GlobalVariables.IceToken)) { return BadRequest("IceToken is missing or invalid."); @@ -168,7 +176,7 @@ public async Task ConnectAvatar() { iceTokenObj = new Dictionary { - { "Urls", string.IsNullOrEmpty(_clientSettings.IceServerUrlRemote) ? [_clientSettings.IceServerUrl] : new[] { _clientSettings.IceServerUrlRemote } }, + { "Urls", string.IsNullOrEmpty(_clientSettings.IceServerUrlRemote) ? new JArray(_clientSettings.IceServerUrl) : new JArray(_clientSettings.IceServerUrlRemote) }, { "Username", _clientSettings.IceServerUsername }, { "Password", _clientSettings.IceServerPassword } }; @@ -189,7 +197,7 @@ public async Task ConnectAvatar() var videoCrop = Request.Headers["VideoCrop"].FirstOrDefault() ?? "false"; // Configure avatar settings - var urlsArray = iceTokenObj?.TryGetValue("Urls", out var value) == true ? value as string[] : null; + var urlsArray = iceTokenObj?.TryGetValue("Urls", out var value) == true ? value as JArray : null; var firstUrl = urlsArray?.FirstOrDefault()?.ToString(); @@ -213,7 +221,8 @@ public async Task ConnectAvatar() username = iceTokenObj!["Username"], credential = iceTokenObj["Password"] } - } + }, + auditAudio = ClientSettings.EnableAudioAudit } }, format = new @@ -255,7 +264,7 @@ public async Task ConnectAvatar() connection.SetMessageProperty("speech.config", "context", JsonConvert.SerializeObject(avatarConfig)); var speechSynthesisResult = speechSynthesizer.SpeakTextAsync("").Result; - Console.WriteLine($"Result ID: {speechSynthesisResult.ResultId}"); + Console.WriteLine($"Result ID: {speechSynthesisResult.ResultId}"); if (speechSynthesisResult.Reason == ResultReason.Canceled) { var cancellationDetails = SpeechSynthesisCancellationDetails.FromResult(speechSynthesisResult); @@ -456,7 +465,7 @@ public async Task HandleUserQuery(string userQuery, Guid clientId, HttpResponse // We return some quick reply here before the chat API returns to mitigate. if (ClientSettings.EnableQuickReply) { - await SpeakWithQueue(ClientSettings.QuickReplies[new Random().Next(ClientSettings.QuickReplies.Count)], 2000, clientId); + await SpeakWithQueue(ClientSettings.QuickReplies[new Random().Next(ClientSettings.QuickReplies.Count)], 2000, clientId, httpResponse); } // Process the responseContent as needed @@ -507,9 +516,13 @@ public async Task HandleUserQuery(string userQuery, Guid clientId, HttpResponse responseToken = ClientSettings.OydDocRegex.Replace(responseToken, string.Empty); } - await httpResponse.WriteAsync(responseToken).ConfigureAwait(false); + if (!ClientSettings.EnableDisplayTextAlignmentWithSpeech) + { + await httpResponse.WriteAsync(responseToken).ConfigureAwait(false); + } assistantReply.Append(responseToken); + spokenSentence.Append(responseToken); // build up the spoken sentence if (responseToken == "\n" || responseToken == "\n\n") { if (isFirstSentence) @@ -520,13 +533,12 @@ public async Task HandleUserQuery(string userQuery, Guid clientId, HttpResponse isFirstSentence = false; } - await SpeakWithQueue(spokenSentence.ToString().Trim(), 0, clientId); + await SpeakWithQueue(spokenSentence.ToString(), 0, clientId, httpResponse); spokenSentence.Clear(); } else { responseToken = responseToken.Replace("\n", string.Empty); - spokenSentence.Append(responseToken); // build up the spoken sentence if (responseToken.Length == 1 || responseToken.Length == 2) { foreach (var punctuation in ClientSettings.SentenceLevelPunctuations) @@ -541,7 +553,7 @@ public async Task HandleUserQuery(string userQuery, Guid clientId, HttpResponse isFirstSentence = false; } - await SpeakWithQueue(spokenSentence.ToString().Trim(), 0, clientId); + await SpeakWithQueue(spokenSentence.ToString(), 0, clientId, httpResponse); spokenSentence.Clear(); break; } @@ -553,11 +565,21 @@ public async Task HandleUserQuery(string userQuery, Guid clientId, HttpResponse if (spokenSentence.Length > 0) { - await SpeakWithQueue(spokenSentence.ToString().Trim(), 0, clientId); + await SpeakWithQueue(spokenSentence.ToString(), 0, clientId, httpResponse); } var assistantMessage = new AssistantChatMessage(assistantReply.ToString()); messages.Add(assistantMessage); + + if (ClientSettings.EnableDisplayTextAlignmentWithSpeech) + { + while (clientContext.SpokenTextQueue.Count > 0) + { + await Task.Delay(200); + } + + await Task.Delay(200); + } } public void InitializeChatContext(string systemPrompt, Guid clientId) @@ -572,7 +594,7 @@ public void InitializeChatContext(string systemPrompt, Guid clientId) } // Speak the given text. If there is already a speaking in progress, add the text to the queue. For chat scenario. - public Task SpeakWithQueue(string text, int endingSilenceMs, Guid clientId) + public Task SpeakWithQueue(string text, int endingSilenceMs, Guid clientId, HttpResponse httpResponse) { var clientContext = _clientService.GetClientContext(clientId); @@ -595,6 +617,11 @@ public Task SpeakWithQueue(string text, int endingSilenceMs, Guid clientId) while (spokenTextQueue.Count > 0) { var currentText = spokenTextQueue.Dequeue(); + if (ClientSettings.EnableDisplayTextAlignmentWithSpeech) + { + httpResponse.WriteAsync(currentText); + } + await SpeakText(currentText, ttsVoice!, personalVoiceSpeakerProfileId!, endingSilenceMs, clientId); clientContext.LastSpeakTime = DateTime.UtcNow; } diff --git a/samples/csharp/web/avatar/Models/ClientSettings.cs b/samples/csharp/web/avatar/Models/ClientSettings.cs index b71230ce1..22b6d08d7 100644 --- a/samples/csharp/web/avatar/Models/ClientSettings.cs +++ b/samples/csharp/web/avatar/Models/ClientSettings.cs @@ -19,6 +19,10 @@ public class ClientSettings public static readonly bool EnableQuickReply = false; + public static readonly bool EnableDisplayTextAlignmentWithSpeech = false; + + public static readonly bool EnableAudioAudit = false; + public string? SpeechRegion { get; set; } public string? SpeechKey { get; set; } diff --git a/samples/csharp/web/avatar/Views/Home/chat.cshtml b/samples/csharp/web/avatar/Views/Home/chat.cshtml index 049dfde5e..fbd1ad390 100644 --- a/samples/csharp/web/avatar/Views/Home/chat.cshtml +++ b/samples/csharp/web/avatar/Views/Home/chat.cshtml @@ -36,7 +36,7 @@
-
+

diff --git a/samples/js/browser/avatar/js/chat.js b/samples/js/browser/avatar/js/chat.js index 1968ac7d0..c3aa8c9d1 100644 --- a/samples/js/browser/avatar/js/chat.js +++ b/samples/js/browser/avatar/js/chat.js @@ -9,6 +9,7 @@ var messages = [] var messageInitiated = false var dataSources = [] var sentenceLevelPunctuations = [ '.', '?', '!', ':', ';', '。', '?', '!', ':', ';' ] +var enableDisplayTextAlignmentWithSpeech = true var enableQuickReply = false var quickReplies = [ 'Let me take a look.', 'Let me check.', 'One moment, please.' ] var byodDocRegex = new RegExp(/\[doc(\d+)\]/g) @@ -322,6 +323,12 @@ function speakNext(text, endingSilenceMs = 0) { ssml = `${htmlEncode(text)}` } + if (enableDisplayTextAlignmentWithSpeech) { + let chatHistoryTextArea = document.getElementById('chatHistory') + chatHistoryTextArea.innerHTML += text.replace(/\n/g, '
') + chatHistoryTextArea.scrollTop = chatHistoryTextArea.scrollHeight + } + lastSpeakTime = new Date() isSpeaking = true document.getElementById('stopSpeaking').disabled = false @@ -506,17 +513,18 @@ function handleUserQuery(userQuery, userQueryHTML, imgUrlPath) { // console.log(`Current token: ${responseToken}`) if (responseToken === '\n' || responseToken === '\n\n') { - speak(spokenSentence.trim()) + spokenSentence += responseToken + speak(spokenSentence) spokenSentence = '' } else { - responseToken = responseToken.replace(/\n/g, '') spokenSentence += responseToken // build up the spoken sentence + responseToken = responseToken.replace(/\n/g, '') if (responseToken.length === 1 || responseToken.length === 2) { for (let i = 0; i < sentenceLevelPunctuations.length; ++i) { let sentenceLevelPunctuation = sentenceLevelPunctuations[i] if (responseToken.startsWith(sentenceLevelPunctuation)) { - speak(spokenSentence.trim()) + speak(spokenSentence) spokenSentence = '' break } @@ -531,9 +539,11 @@ function handleUserQuery(userQuery, userQueryHTML, imgUrlPath) { } }) - chatHistoryTextArea.innerHTML += `${displaySentence}` - chatHistoryTextArea.scrollTop = chatHistoryTextArea.scrollHeight - displaySentence = '' + if (!enableDisplayTextAlignmentWithSpeech) { + chatHistoryTextArea.innerHTML += displaySentence.replace(/\n/g, '
') + chatHistoryTextArea.scrollTop = chatHistoryTextArea.scrollHeight + displaySentence = '' + } // Continue reading the next chunk return read() @@ -545,7 +555,7 @@ function handleUserQuery(userQuery, userQueryHTML, imgUrlPath) { }) .then(() => { if (spokenSentence !== '') { - speak(spokenSentence.trim()) + speak(spokenSentence) spokenSentence = '' } diff --git a/samples/python/web/avatar/app.py b/samples/python/web/avatar/app.py index 000714244..20a891390 100644 --- a/samples/python/web/avatar/app.py +++ b/samples/python/web/avatar/app.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import azure.cognitiveservices.speech as speechsdk +import base64 import datetime import html import json @@ -15,12 +16,16 @@ import traceback import uuid from flask import Flask, Response, render_template, request +from flask_socketio import SocketIO, join_room from azure.identity import DefaultAzureCredential from openai import AzureOpenAI # Create the Flask app app = Flask(__name__, template_folder='.') +# Create the SocketIO instance +socketio = SocketIO(app) + # Environment variables # Speech resource (required) speech_region = os.environ.get('SPEECH_REGION') # e.g. westus2 @@ -43,6 +48,8 @@ ice_server_password = os.environ.get('ICE_SERVER_PASSWORD') # The ICE password # Const variables +enable_websockets = False # Enable websockets between client and server for real-time communication optimization +enable_token_auth_for_speech = False # Enable token authentication for speech service default_tts_voice = 'en-US-JennyMultilingualV2Neural' # Default TTS voice sentence_level_punctuations = [ '.', '?', '!', ':', ';', '。', '?', '!', ':', ';' ] # Punctuations that indicate the end of a sentence enable_quick_reply = False # Enable quick reply for certain chat models which take longer time to respond @@ -71,7 +78,7 @@ def basicView(): # The chat route, which shows the chat web page @app.route("/chat") def chatView(): - return render_template("chat.html", methods=["GET"], client_id=initializeClient()) + return render_template("chat.html", methods=["GET"], client_id=initializeClient(), enable_websockets=enable_websockets) # The API route to get the speech token @app.route("/api/getSpeechToken", methods=["GET"]) @@ -115,9 +122,21 @@ def connectAvatar() -> Response: try: if speech_private_endpoint: speech_private_endpoint_wss = speech_private_endpoint.replace('https://', 'wss://') - speech_config = speechsdk.SpeechConfig(subscription=speech_key, endpoint=f'{speech_private_endpoint_wss}/tts/cognitiveservices/websocket/v1?enableTalkingAvatar=true') + if enable_token_auth_for_speech: + while not speech_token: + time.sleep(0.2) + speech_config = speechsdk.SpeechConfig(endpoint=f'{speech_private_endpoint_wss}/tts/cognitiveservices/websocket/v1?enableTalkingAvatar=true') + speech_config.authorization_token = speech_token + else: + speech_config = speechsdk.SpeechConfig(subscription=speech_key, endpoint=f'{speech_private_endpoint_wss}/tts/cognitiveservices/websocket/v1?enableTalkingAvatar=true') else: - speech_config = speechsdk.SpeechConfig(subscription=speech_key, endpoint=f'wss://{speech_region}.tts.speech.microsoft.com/cognitiveservices/websocket/v1?enableTalkingAvatar=true') + if enable_token_auth_for_speech: + while not speech_token: + time.sleep(0.2) + speech_config = speechsdk.SpeechConfig(endpoint=f'wss://{speech_region}.tts.speech.microsoft.com/cognitiveservices/websocket/v1?enableTalkingAvatar=true') + speech_config.authorization_token = speech_token + else: + speech_config = speechsdk.SpeechConfig(subscription=speech_key, endpoint=f'wss://{speech_region}.tts.speech.microsoft.com/cognitiveservices/websocket/v1?enableTalkingAvatar=true') if custom_voice_endpoint_id: speech_config.endpoint_id = custom_voice_endpoint_id @@ -202,6 +221,107 @@ def connectAvatar() -> Response: except Exception as e: return Response(f"Result ID: {speech_sythesis_result.result_id}. Error message: {e}", status=400) +# The API route to connect the STT service +@app.route("/api/connectSTT", methods=["POST"]) +def connectSTT() -> Response: + global client_contexts + client_id = uuid.UUID(request.headers.get('ClientId')) + system_prompt = request.headers.get('SystemPrompt') + client_context = client_contexts[client_id] + try: + if speech_private_endpoint: + speech_private_endpoint_wss = speech_private_endpoint.replace('https://', 'wss://') + if enable_token_auth_for_speech: + while not speech_token: + time.sleep(0.2) + speech_config = speechsdk.SpeechConfig(endpoint=f'{speech_private_endpoint_wss}/stt/speech/universal/v2') + speech_config.authorization_token = speech_token + else: + speech_config = speechsdk.SpeechConfig(subscription=speech_key, endpoint=f'{speech_private_endpoint_wss}/stt/speech/universal/v2') + else: + if enable_token_auth_for_speech: + while not speech_token: + time.sleep(0.2) + speech_config = speechsdk.SpeechConfig(endpoint=f'wss://{speech_region}.stt.speech.microsoft.com/speech/universal/v2') + speech_config.authorization_token = speech_token + else: + speech_config = speechsdk.SpeechConfig(subscription=speech_key, endpoint=f'wss://{speech_region}.stt.speech.microsoft.com/speech/universal/v2') + + audio_input_stream = speechsdk.audio.PushAudioInputStream() + client_context['audio_input_stream'] = audio_input_stream + + audio_config = speechsdk.audio.AudioConfig(stream=audio_input_stream) + speech_recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config) + client_context['speech_recognizer'] = speech_recognizer + + speech_recognizer.session_started.connect(lambda evt: print(f'STT session started - session id: {evt.session_id}')) + speech_recognizer.session_stopped.connect(lambda evt: print(f'STT session stopped.')) + + speech_recognition_start_time = datetime.datetime.now(pytz.UTC) + + def stt_recognized_cb(evt): + if evt.result.reason == speechsdk.ResultReason.RecognizedSpeech: + try: + user_query = evt.result.text.strip() + if user_query == '': + return + + socketio.emit("response", { 'path': 'api.chat', 'chatResponse': '\n\nUser: ' + user_query + '\n\n' }, room=client_id) + recognition_result_received_time = datetime.datetime.now(pytz.UTC) + speech_finished_offset = (evt.result.offset + evt.result.duration) / 10000 + stt_latency = round((recognition_result_received_time - speech_recognition_start_time).total_seconds() * 1000 - speech_finished_offset) + print(f'STT latency: {stt_latency}ms') + socketio.emit("response", { 'path': 'api.chat', 'chatResponse': f"{stt_latency}" }, room=client_id) + chat_initiated = client_context['chat_initiated'] + if not chat_initiated: + initializeChatContext(system_prompt, client_id) + client_context['chat_initiated'] = True + first_response_chunk = True + for chat_response in handleUserQuery(user_query, client_id): + if first_response_chunk: + socketio.emit("response", { 'path': 'api.chat', 'chatResponse': 'Assistant: ' }, room=client_id) + first_response_chunk = False + socketio.emit("response", { 'path': 'api.chat', 'chatResponse': chat_response }, room=client_id) + except Exception as e: + print(f"Error in handling user query: {e}") + speech_recognizer.recognized.connect(stt_recognized_cb) + + def stt_recognizing_cb(evt): + is_speaking = client_context['is_speaking'] + if is_speaking: + stopSpeakingInternal(client_id) + speech_recognizer.recognizing.connect(stt_recognizing_cb) + + def stt_canceled_cb(evt): + cancellation_details = speechsdk.CancellationDetails(evt.result) + print(f'STT connection canceled. Error message: {cancellation_details.error_details}') + speech_recognizer.canceled.connect(stt_canceled_cb) + + speech_recognizer.start_continuous_recognition() + return Response(status=200) + + except Exception as e: + return Response(f"STT connection failed. Error message: {e}", status=400) + +# The API route to disconnect the STT service +@app.route("/api/disconnectSTT", methods=["POST"]) +def disconnectSTT() -> Response: + global client_contexts + client_id = uuid.UUID(request.headers.get('ClientId')) + client_context = client_contexts[client_id] + speech_recognizer = client_context['speech_recognizer'] + audio_input_stream = client_context['audio_input_stream'] + try: + if speech_recognizer: + speech_recognizer.stop_continuous_recognition() + client_context['speech_recognizer'] = None + if audio_input_stream: + audio_input_stream.close() + client_context['audio_input_stream'] = None + return Response('STT Disconnected.', status=200) + except Exception as e: + return Response(f"STT disconnection failed. Error message: {e}", status=400) + # The API route to speak a given SSML @app.route("/api/speak", methods=["POST"]) def speak() -> Response: @@ -262,10 +382,44 @@ def disconnectAvatar() -> Response: except: return Response(traceback.format_exc(), status=400) +@socketio.on("connect") +def handleWsConnection(): + client_id = uuid.UUID(request.args.get('clientId')) + join_room(client_id) + print(f"WebSocket connected for client {client_id}.") + +@socketio.on("message") +def handleWsMessage(message): + global client_contexts + client_id = uuid.UUID(message.get('clientId')) + path = message.get('path') + client_context = client_contexts[client_id] + if path == 'api.audio': + chat_initiated = client_context['chat_initiated'] + audio_chunk = message.get('audioChunk') + audio_chunk_binary = base64.b64decode(audio_chunk) + audio_input_stream = client_context['audio_input_stream'] + if audio_input_stream: + audio_input_stream.write(audio_chunk_binary) + elif path == 'api.chat': + chat_initiated = client_context['chat_initiated'] + if not chat_initiated: + initializeChatContext(message.get('systemPrompt'), client_id) + client_context['chat_initiated'] = True + user_query = message.get('userQuery') + for chat_response in handleUserQuery(user_query, client_id): + socketio.emit("response", { 'path': 'api.chat', 'chatResponse': chat_response }, room=client_id) + elif path == 'api.stopSpeaking': + is_speaking = client_contexts[client_id]['is_speaking'] + if is_speaking: + stopSpeakingInternal(client_id) + # Initialize the client by creating a client id and an initial context def initializeClient() -> uuid.UUID: client_id = uuid.uuid4() client_contexts[client_id] = { + 'audio_input_stream': None, # Audio input stream for speech recognition + 'speech_recognizer': None, # Speech recognizer for user speech 'azure_openai_deployment_name': azure_openai_deployment_name, # Azure OpenAI deployment name 'cognitive_search_index_name': cognitive_search_index_name, # Cognitive search index name 'tts_voice': default_tts_voice, # TTS voice @@ -288,9 +442,19 @@ def initializeClient() -> uuid.UUID: def refreshIceToken() -> None: global ice_token if speech_private_endpoint: - ice_token = requests.get(f'{speech_private_endpoint}/tts/cognitiveservices/avatar/relay/token/v1', headers={'Ocp-Apim-Subscription-Key': speech_key}).text + if enable_token_auth_for_speech: + while not speech_token: + time.sleep(0.2) + ice_token = requests.get(f'{speech_private_endpoint}/tts/cognitiveservices/avatar/relay/token/v1', headers={'Authorization': f'Bearer {speech_token}'}).text + else: + ice_token = requests.get(f'{speech_private_endpoint}/tts/cognitiveservices/avatar/relay/token/v1', headers={'Ocp-Apim-Subscription-Key': speech_key}).text else: - ice_token = requests.get(f'https://{speech_region}.tts.speech.microsoft.com/cognitiveservices/avatar/relay/token/v1', headers={'Ocp-Apim-Subscription-Key': speech_key}).text + if enable_token_auth_for_speech: + while not speech_token: + time.sleep(0.2) + ice_token = requests.get(f'https://{speech_region}.tts.speech.microsoft.com/cognitiveservices/avatar/relay/token/v1', headers={'Authorization': f'Bearer {speech_token}'}).text + else: + ice_token = requests.get(f'https://{speech_region}.tts.speech.microsoft.com/cognitiveservices/avatar/relay/token/v1', headers={'Ocp-Apim-Subscription-Key': speech_key}).text # Refresh the speech token every 9 minutes def refreshSpeechToken() -> None: diff --git a/samples/python/web/avatar/chat.html b/samples/python/web/avatar/chat.html index ff378dcb1..8c76e21fe 100644 --- a/samples/python/web/avatar/chat.html +++ b/samples/python/web/avatar/chat.html @@ -7,11 +7,13 @@ +

Talking Avatar Chat Demo

+

Chat Configuration

diff --git a/samples/python/web/avatar/requirements.txt b/samples/python/web/avatar/requirements.txt index 3a3ef484e..eccd99ee9 100644 --- a/samples/python/web/avatar/requirements.txt +++ b/samples/python/web/avatar/requirements.txt @@ -1,6 +1,7 @@ azure-cognitiveservices-speech azure-identity flask +flask-socketio openai pytz requests diff --git a/samples/python/web/avatar/static/js/chat.js b/samples/python/web/avatar/static/js/chat.js index 262f2fee7..b82f310c6 100644 --- a/samples/python/web/avatar/static/js/chat.js +++ b/samples/python/web/avatar/static/js/chat.js @@ -3,14 +3,20 @@ // Global objects var clientId +var enableWebSockets +var socket +var audioContext +var isFirstResponseChunk var speechRecognizer var peerConnection var isSpeaking = false var sessionActive = false var recognitionStartedTime +var chatRequestSentTime var chatResponseReceivedTime var lastSpeakTime var isFirstRecognizingEvent = true +var sttLatencyRegex = new RegExp(/(\d+)<\/STTL>/) var firstTokenLatencyRegex = new RegExp(/(\d+)<\/FTL>/) var firstSentenceLatencyRegex = new RegExp(/(\d+)<\/FSL>/) @@ -87,6 +93,51 @@ function disconnectAvatar(closeSpeechRecognizer = false) { sessionActive = false } +function setupWebSocket() { + socket = io.connect(`${window.location.origin}?clientId=${clientId}`) + socket.on('connect', function() { + console.log('WebSocket connected.') + }) + + socket.on('response', function(data) { + let path = data.path + if (path === 'api.chat') { + let chatHistoryTextArea = document.getElementById('chatHistory') + let chunkString = data.chatResponse + if (sttLatencyRegex.test(chunkString)) { + let sttLatency = parseInt(sttLatencyRegex.exec(chunkString)[0].replace('', '').replace('', '')) + console.log(`STT latency: ${sttLatency} ms`) + let latencyLogTextArea = document.getElementById('latencyLog') + latencyLogTextArea.innerHTML += `STT latency: ${sttLatency} ms\n` + chunkString = chunkString.replace(sttLatencyRegex, '') + } + + if (firstTokenLatencyRegex.test(chunkString)) { + let aoaiFirstTokenLatency = parseInt(firstTokenLatencyRegex.exec(chunkString)[0].replace('', '').replace('', '')) + // console.log(`AOAI first token latency: ${aoaiFirstTokenLatency} ms`) + chunkString = chunkString.replace(firstTokenLatencyRegex, '') + } + + if (firstSentenceLatencyRegex.test(chunkString)) { + let aoaiFirstSentenceLatency = parseInt(firstSentenceLatencyRegex.exec(chunkString)[0].replace('', '').replace('', '')) + chatResponseReceivedTime = new Date() + console.log(`AOAI latency: ${aoaiFirstSentenceLatency} ms`) + let latencyLogTextArea = document.getElementById('latencyLog') + latencyLogTextArea.innerHTML += `AOAI latency: ${aoaiFirstSentenceLatency} ms\n` + latencyLogTextArea.scrollTop = latencyLogTextArea.scrollHeight + chunkString = chunkString.replace(firstSentenceLatencyRegex, '') + } + + chatHistoryTextArea.innerHTML += `${chunkString}` + if (chatHistoryTextArea.innerHTML.startsWith('\n\n')) { + chatHistoryTextArea.innerHTML = chatHistoryTextArea.innerHTML.substring(2) + } + + chatHistoryTextArea.scrollTop = chatHistoryTextArea.scrollHeight + } + }) +} + // Setup WebRTC function setupWebRTC(iceServerUrl, iceServerUsername, iceServerCredential) { // Create WebRTC peer connection @@ -278,7 +329,13 @@ function connectToAvatarService(peerConnection) { // Handle user query. Send user query to the chat API and display the response. function handleUserQuery(userQuery) { - let chatRequestSentTime = new Date() + chatRequestSentTime = new Date() + if (socket !== undefined) { + socket.emit('message', { clientId: clientId, path: 'api.chat', systemPrompt: document.getElementById('prompt').value, userQuery: userQuery }) + isFirstResponseChunk = true + return + } + fetch('/api/chat', { method: 'POST', headers: { @@ -391,12 +448,17 @@ function checkHung() { window.onload = () => { clientId = document.getElementById('clientId').value + enableWebSockets = document.getElementById('enableWebSockets').value === 'True' setInterval(() => { checkHung() }, 2000) // Check session activity every 2 seconds } window.startSession = () => { + if (enableWebSockets) { + setupWebSocket() + } + createSpeechRecognizer() if (document.getElementById('useLocalVideoForIdle').checked) { document.getElementById('startSession').disabled = true @@ -417,6 +479,11 @@ window.startSession = () => { window.stopSpeaking = () => { document.getElementById('stopSpeaking').disabled = true + if (socket !== undefined) { + socket.emit('message', { clientId: clientId, path: 'api.stopSpeaking' }) + return + } + fetch('/api/stopSpeaking', { method: 'POST', headers: { @@ -471,6 +538,26 @@ window.clearChatHistory = () => { window.microphone = () => { if (document.getElementById('microphone').innerHTML === 'Stop Microphone') { + // Stop microphone for websocket mode + if (socket !== undefined) { + document.getElementById('microphone').disabled = true + fetch('/api/disconnectSTT', { + method: 'POST', + headers: { + 'ClientId': clientId + }, + body: '' + }) + .then(() => { + document.getElementById('microphone').innerHTML = 'Start Microphone' + document.getElementById('microphone').disabled = false + if (audioContext !== undefined) { + audioContext.close() + audioContext = undefined + } + }) + } + // Stop microphone document.getElementById('microphone').disabled = true speechRecognizer.stopContinuousRecognitionAsync( @@ -485,6 +572,87 @@ window.microphone = () => { return } + // Start microphone for websocket mode + if (socket !== undefined) { + document.getElementById('microphone').disabled = true + // Audio worklet script (https://developer.chrome.com/blog/audio-worklet) for recording audio + const audioWorkletScript = `class MicAudioWorkletProcessor extends AudioWorkletProcessor { + constructor(options) { + super(options) + } + + process(inputs, outputs, parameters) { + const input = inputs[0] + const output = [] + for (let channel = 0; channel < input.length; channel += 1) { + output[channel] = input[channel] + } + this.port.postMessage(output[0]) + return true + } + } + + registerProcessor('mic-audio-worklet-processor', MicAudioWorkletProcessor)` + const audioWorkletScriptBlob = new Blob([audioWorkletScript], { type: 'application/javascript; charset=utf-8' }) + const audioWorkletScriptUrl = URL.createObjectURL(audioWorkletScriptBlob) + + fetch('/api/connectSTT', { + method: 'POST', + headers: { + 'ClientId': clientId, + 'SystemPrompt': document.getElementById('prompt').value + }, + body: '' + }) + .then(response => { + document.getElementById('microphone').disabled = false + if (response.ok) { + document.getElementById('microphone').innerHTML = 'Stop Microphone' + + navigator.mediaDevices + .getUserMedia({ + audio: { + echoCancellation: true, + noiseSuppression: true, + sampleRate: 16000 + } + }) + .then((stream) => { + audioContext = new AudioContext({ sampleRate: 16000 }) + const audioSource = audioContext.createMediaStreamSource(stream) + audioContext.audioWorklet + .addModule(audioWorkletScriptUrl) + .then(() => { + const audioWorkletNode = new AudioWorkletNode(audioContext, 'mic-audio-worklet-processor') + audioWorkletNode.port.onmessage = (e) => { + const audioDataFloat32 = e.data + const audioDataInt16 = new Int16Array(audioDataFloat32.length) + for (let i = 0; i < audioDataFloat32.length; i++) { + audioDataInt16[i] = Math.max(-0x8000, Math.min(0x7FFF, audioDataFloat32[i] * 0x7FFF)) + } + const audioDataBytes = new Uint8Array(audioDataInt16.buffer) + const audioDataBase64 = btoa(String.fromCharCode(...audioDataBytes)) + socket.emit('message', { clientId: clientId, path: 'api.audio', audioChunk: audioDataBase64 }) + } + + audioSource.connect(audioWorkletNode) + audioWorkletNode.connect(audioContext.destination) + }) + .catch((err) => { + console.log('Failed to add audio worklet module:', err) + }) + }) + .catch((err) => { + console.log('Failed to get user media:', err) + }) + } else { + throw new Error(`Failed to connect STT service: ${response.status} ${response.statusText}`) + } + }) + + return + } + if (document.getElementById('useLocalVideoForIdle').checked) { if (!sessionActive) { connectAvatar()