First, import the needed dependencies and define a helper function to convert tokens into audio data.
Copy !pip install torchaudio snac openai soundfile &> /dev/null
from IPython.display import Audio, display
import torch
import openai
import snac
import re
import torchaudio
AUDIO_TOKENS_REGEX = re.compile(r"<custom_token_(\d+)>")
# Convert tokens into audio data
def convert_to_audio(audio_ids, model):
audio_ids = torch.tensor(audio_ids, dtype=torch.int32).reshape(-1, 7)
codes_0 = audio_ids[:, 0].unsqueeze(0)
codes_1 = torch.stack((audio_ids[:, 1], audio_ids[:, 4])).t().flatten().unsqueeze(0)
codes_2 = (
torch.stack((audio_ids[:, 2], audio_ids[:, 3], audio_ids[:, 5], audio_ids[:, 6]))
.t()
.flatten()
.unsqueeze(0)
)
with torch.inference_mode():
audio_hat = model.decode([codes_0, codes_1, codes_2])
return audio_hat[0]
Next, prompt the model to generate audio.
Copy try:
from google.colab import userdata
API_KEY = userdata.get('OPENAI_API_KEY')
except:
import os
API_KEY = os.getenv('OPENAI_API_KEY')
snac_model = snac.SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
oai_client = openai.OpenAI(
base_url="https://api.parasail.io/v1",
api_key=API_KEY)
# Voice options: "tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"
voice = "tara"
# Available emotive tag: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>
text_prompt = """<chuckle> Hello world!"""
prompt = f"<custom_token_3><|begin_of_text|>{voice}: {text_prompt}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
# Tips for prompting [1]:
# - Sampling parameters 'temperature' and 'top_p' work just like regular LLMs.
# - 'repetition_penalty' >= 1.1 is required for stable generations.
# - Increasing 'repetition_penalty' and/or 'temperature' makes the model speak faster.
#
# [1]: https://github.com/canopyai/Orpheus-TTS/tree/main?tab=readme-ov-file#prompting
response = oai_client.completions.create(
model="parasail-orpheus-3b-01-ft",
prompt=prompt,
temperature=0.6,
top_p=0.9,
max_tokens=10240,
extra_body={
"stop_token_ids": [128258],
"repetition_penalty": 1.1,
},
)
text = response.choices[0].text
audio_ids = [
int(token) - 10 - ((index % 7) * 4096)
for index, token in enumerate(AUDIO_TOKENS_REGEX.findall(text))
]
audio = convert_to_audio(audio_ids, snac_model)
torchaudio.save("output.wav", audio, sample_rate=24000)
print(f"Audio duration: {audio.shape[1] / 24000:.2f}s")
display(Audio("output.wav"))