Skip to content

Commit

Permalink
Stable diffusion automatic API Server support (#43)
Browse files Browse the repository at this point in the history
* add sd automatic111111 as an option

* add --sd-api option for using external api server

* automatic111111 support.

* choose model for sd automatic11111

* readme updated for stable diffusion

* add ./bin/ with startup scripts for API servers

* add mimic3 server startup script

* add back image scaling for automatic1111111

---------

Co-authored-by: Chris Kennedy <[email protected]>
  • Loading branch information
groovybits and Chris Kennedy committed Mar 27, 2024
1 parent 7da437b commit 2326638
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 54 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ urlencoding = "2.1.3"
clap_builder = "4.5.2"
safetensors = "0.4.2"
ctrlc = "3.4.4"
base64 = "0.22.0"
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ RsLLM is AI pipeline 100% in Rust for Transformer/Tensor code that is leveraging
- **Image Generation and NDI Output**: Supports generating images from text descriptions and outputting through NDI for a wide range of applications, including real-time content creation and broadcasting. (In Beta Testing)
- **TTS Support**: Candle implements TTS using MetaVoice (default, WIP), OpenAI TTS API (high-quality, real-time), and Mimic3 TTS API (local, free). MetaVoice is being optimized for Metal GPUs, while OpenAI TTS API generates premium speech at a cost. Mimic3 TTS API requires running the mimic3-server but offers a good alternative to OpenAI TTS API. [Mimic3 GitHub](https://github.com/MycroftAI/mimic3)
- **Twitch Chat Interactive AI**: Integrated Twitch chat for real-time AI interactions, enabling users to engage with the toolkit through chat commands and receive AI-generated responses.
- **Stable Diffusion Image Generation**: Supports either Candle stable diffusion or the AUTOMATIC111111 API server. <https://github.com/AUTOMATIC1111/stable-diffusion-webui/>

![RSLLM](https://storage.googleapis.com/groovybits/images/rsllm/rsllm.webp)

Expand Down
13 changes: 13 additions & 0 deletions bin/llama_start.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
#
#MODEL=/Volumes/BrahmaSSD/LLM/models/GGUF/dolphin-2.7-mixtral-8x7b.Q5_K_M.gguf
MODEL=/Volumes/BrahmaSSD/LLM/models/GGUF/dolphin-2.7-mixtral-8x7b.Q8_0.gguf

server \
-m $MODEL \
-c 0 \
-np 2 \
--port 8080 \
-ngl 60 \
-t 24 \
--host 0.0.0.0 $@
71 changes: 71 additions & 0 deletions bin/music_player.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
import os
import subprocess
from hashlib import md5

# Configuration
music_dir = "/Volumes/BrahmaSSD/music/AiGen"
output_file = "/tmp/combined_playlist.wav"
playlist_file = "/tmp/ffmpeg_playlist.txt"
checksum_file = "/tmp/playlist_checksum.txt"

def get_files_sorted_by_mtime(directory, extension=".wav"):
files = []
for root, dirs, filenames in os.walk(directory):
for filename in filenames:
if filename.endswith(extension):
full_path = os.path.join(root, filename)
files.append(full_path)
return sorted(files, key=os.path.getmtime)

def generate_playlist(files, playlist_path):
with open(playlist_path, 'w') as playlist:
for file in files:
playlist.write(f"file '{file}'\n")

def calculate_checksum(files):
hash_md5 = md5()
for file in files:
with open(file, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()

def read_previous_checksum(checksum_path):
try:
with open(checksum_path, 'r') as file:
return file.read().strip()
except FileNotFoundError:
return ''

def write_new_checksum(checksum, checksum_path):
with open(checksum_path, 'w') as file:
file.write(checksum)

def concatenate_files(playlist_path, output_path):
cmd = ['ffmpeg', '-y', '-hide_banner', '-f', 'concat', '-safe', '0', '-i', playlist_path, '-c', 'copy', output_path]
print("Running command:", ' '.join(cmd))
subprocess.run(cmd, check=True)

def play_audio(file_path):
subprocess.run(['mpv', '--volume=50', file_path], check=True)

while True:
files = get_files_sorted_by_mtime(music_dir)
if not files:
print("No .wav files found in the directory.")
else:
current_checksum = calculate_checksum(files)
previous_checksum = read_previous_checksum(checksum_file)

if current_checksum != previous_checksum or not os.path.exists(output_file):
print("Changes detected or output file missing, regenerating...")
generate_playlist(files, playlist_file)
concatenate_files(playlist_file, output_file)
write_new_checksum(current_checksum, checksum_file)
else:
print("No changes detected. Using existing combined audio file.")

print("Playing combined playlist...")
play_audio(output_file)

3 changes: 3 additions & 0 deletions bin/start_mimic3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
#
mimic3-server
6 changes: 6 additions & 0 deletions bin/start_sdwebui.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
#
# https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
#
./webui.sh --api --api-log --nowebui --port 7860 --skip-torch-cuda-test --no-half --use-cpu all
#./webui.sh --api --api-log --port 7860 --skip-torch-cuda-test --no-half --use-cpu all
8 changes: 7 additions & 1 deletion scripts/twitch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ TWITCH_LLM_CONCURRENCY=1
TWITCH_CHAT_HISTORY=16
TWITCH_MAX_TOKENS=150
## Stable Diffusion Settings
SD_API=1
SD_MODEL=turbo
SD_INTERMEDIARY_IMAGES=0
SD_INTERMEDIARY_IMAGES=1
SD_N_STEPS=6
ALIGNMENT=right
SUBTITLES=1
Expand All @@ -50,6 +51,10 @@ NO_HISTORY_CMD=
QUANTIZED_CMD=
ASYNC_CONCURRENCY_CMD=
SD_INTERMEDIARY_IMAGES_CMD=
SD_API_CMD=
if [ "$SD_API" == 1 ]; then
SD_API_CMD="--sd-api"
fi
if [ "$SD_INTERMEDIARY_IMAGES" == 1 ]; then
SD_INTERMEDIARY_IMAGES_CMD="--sd-intermediary-images"
fi
Expand Down Expand Up @@ -92,6 +97,7 @@ DYLD_LIBRARY_PATH=`pwd`:/usr/local/lib:$DYLD_LIBRARY_PATH \
--twitch-max-tokens $TWITCH_MAX_TOKENS \
--twitch-prompt "$TWITCH_PROMPT" \
--mimic3-tts \
$SD_API_CMD \
--sd-image \
--sd-model $SD_MODEL \
--sd-n-steps $SD_N_STEPS \
Expand Down
9 changes: 9 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,15 @@ pub struct Args {
)]
pub sd_image: bool,

/// Use SD API - use the stable diffusion server api from automatic111111
#[clap(
long,
env = "SD_API",
default_value_t = false,
help = "SD API - use the stable diffusion server api from automatic111111. Must install it and run on localhost."
)]
pub sd_api: bool,

/// SD Max Length in tokens for SD Image
#[clap(
long,
Expand Down
52 changes: 51 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod network_capture;
pub mod openai_api;
pub mod openai_tts;
pub mod pipeline;
pub mod sd_automatic;
pub mod stable_diffusion;
pub mod stream_data;
pub mod system_stats;
Expand All @@ -28,7 +29,10 @@ use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
pub use system_stats::{get_system_stats, SystemStats};
pub mod candle_gemma;
use image::{ImageBuffer, Rgb, Rgba};
use image::{
imageops::{resize, FilterType},
ImageBuffer, Rgb, Rgba,
};
#[cfg(feature = "fonts")]
use imageproc::drawing::draw_text_mut;
#[cfg(feature = "fonts")]
Expand Down Expand Up @@ -455,3 +459,49 @@ pub async fn clean_tts_input(input: String) -> String {

input
}

pub fn scale_image(
image: ImageBuffer<Rgb<u8>, Vec<u8>>,
new_width: Option<u32>,
new_height: Option<u32>,
image_position: Option<String>,
) -> ImageBuffer<Rgb<u8>, Vec<u8>> {
if let (Some(target_width), Some(target_height)) = (new_width, new_height) {
if target_width == 0 || target_height == 0 {
return image;
}

let (orig_width, orig_height) = image.dimensions();
let scale = (target_width as f32 / orig_width as f32)
.min(target_height as f32 / orig_height as f32);
let scaled_width = (orig_width as f32 * scale).round() as u32;
let scaled_height = (orig_height as f32 * scale).round() as u32;

// Scale the image while preserving the aspect ratio.
let scaled_image = resize(&image, scaled_width, scaled_height, FilterType::Lanczos3);

// Create a new image with the target dimensions filled with black pixels.
let mut new_image = ImageBuffer::from_pixel(target_width, target_height, Rgb([0, 0, 0]));

// Calculate the offsets to position the scaled image based on image_position.
let x_offset = match image_position.as_deref() {
Some("left") => 0,
Some("right") => target_width - scaled_width,
_ => (target_width - scaled_width) / 2, // Default to center if it's not "left" or "right"
};
let y_offset = (target_height - scaled_height) / 2;

// Copy the scaled image onto the new image at the calculated offset.
for (x, y, pixel) in scaled_image.enumerate_pixels() {
// Ensure the pixel is within the bounds of the target image dimensions.
if x + x_offset < target_width && y + y_offset < target_height {
new_image.put_pixel(x + x_offset, y + y_offset, *pixel);
}
}

new_image
} else {
// Return the original image if dimensions are not specified.
image
}
}
11 changes: 9 additions & 2 deletions src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::ndi::send_images_over_ndi;
use crate::openai_tts::tts as oai_tts;
use crate::openai_tts::Request as OAITTSRequest;
use crate::openai_tts::Voice as OAITTSVoice;
use crate::sd_automatic::sd_auto;
use crate::stable_diffusion::{sd, SDConfig};
use crate::ApiError;
use image::ImageBuffer;
Expand All @@ -42,8 +43,14 @@ pub async fn process_image(mut data: MessageData) -> Vec<ImageBuffer<Rgb<u8>, Ve
data.sd_config.prompt = crate::truncate_tokens(&data.sd_config.prompt, data.args.sd_text_min);
if data.args.sd_image {
debug!("Generating images with prompt: {}", data.sd_config.prompt);
let sd_clone = sd.clone();
match sd_clone(data.sd_config).await {

let images = if data.args.sd_api {
sd_auto(data.sd_config).await
} else {
sd(data.sd_config).await
};

match images {
// Ensure `sd` function is async and await its result
Ok(images) => {
// Save images to disk
Expand Down
92 changes: 92 additions & 0 deletions src/sd_automatic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use crate::scale_image;
use crate::stable_diffusion::SDConfig;
use crate::stable_diffusion::StableDiffusionVersion;
use anyhow::Result;
use base64::engine::general_purpose;
use base64::Engine;
use image::ImageBuffer;
use image::Rgb;
use reqwest::Client;
use serde::{Deserialize, Serialize};

pub async fn sd_auto(
config: SDConfig,
) -> Result<Vec<ImageBuffer<Rgb<u8>, Vec<u8>>>, anyhow::Error> {
let client = Client::new();

let model = match config.sd_version {
StableDiffusionVersion::V1_5 => "v1-5-pruned-emaonly.ckpt",
StableDiffusionVersion::V2_1 => "v2-1_768-ema-pruned.ckpt",
StableDiffusionVersion::Xl => "stabilityai/stable-diffusion-xl-1024-1.0.ckpt",
StableDiffusionVersion::Turbo => "madebyollin/turbo-diffusion.ckpt",
};

let payload = AutomaticPayload {
prompt: config.prompt,
negative_prompt: config.uncond_prompt,
steps: config.n_steps.unwrap_or(20),
width: config.width.unwrap_or(512),
height: config.height.unwrap_or(512),
cfg_scale: config.guidance_scale.unwrap_or(7.5),
sampler_index: "Euler".to_string(),
seed: config.seed.unwrap_or_else(rand::random) as u64,
n_iter: config.num_samples,
batch_size: 1,
override_settings: OverrideSettings {
sd_model_checkpoint: model.to_string(),
},
};

let response = client
.post("http://127.0.0.1:7860/sdapi/v1/txt2img")
.json(&payload)
.send()
.await?;

let response_json: serde_json::Value = response.json().await?;
let image_data = response_json["images"].as_array().unwrap();

let mut images = Vec::new();
for image_base64 in image_data {
let image_bytes = general_purpose::STANDARD
.decode(image_base64.as_str().unwrap())
.unwrap();
let image = image::load_from_memory(&image_bytes)?;
let image_rgb8 = image.to_rgb8();
images.push(image_rgb8);
}

let scaled_images: Vec<ImageBuffer<Rgb<u8>, Vec<u8>>> = images
.into_iter()
.map(|image| {
scale_image(
image,
config.scaled_width,
config.scaled_height,
config.image_position.clone(),
)
})
.collect();

Ok(scaled_images)
}

#[derive(Debug, Serialize, Deserialize)]
struct OverrideSettings {
sd_model_checkpoint: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct AutomaticPayload {
prompt: String,
negative_prompt: String,
steps: usize,
width: usize,
height: usize,
cfg_scale: f64,
sampler_index: String,
seed: u64,
n_iter: usize,
batch_size: usize,
override_settings: OverrideSettings,
}
Loading

0 comments on commit 2326638

Please sign in to comment.