Skip to content

Commit

Permalink
allow custom stable diffusion models
Browse files Browse the repository at this point in the history
add --sd-custom-model "string" to allow easier custom model usage
for stable diffusion images. have --sd-model "custom" enable the
custom model to be used, must use both in coordination for now.
  • Loading branch information
Chris Kennedy committed May 11, 2024
1 parent fd994d9 commit 4a4e14a
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 11 deletions.
5 changes: 3 additions & 2 deletions bin/start_sdwebui.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
#
# https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
#
./webui.sh --api --api-log --nowebui --port 7860 --skip-torch-cuda-test --no-half --use-cpu all
#./webui.sh --api --api-log --port 7860 --skip-torch-cuda-test --no-half --use-cpu all
#./webui.sh --api --api-log --nowebui --port 7860 --skip-torch-cuda-test --no-half --use-cpu all
#./webui.sh --api --api-log --listen --port 7860 --skip-torch-cuda-test --no-half --use-cpu all
./webui.sh --api --api-log --port 7860 --skip-torch-cuda-test --no-half --use-cpu all
9 changes: 6 additions & 3 deletions scripts/twitch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@ NDI_TIMEOUT=600
## Twitch Chat Settings
TWITCH_MODEL=mistral
TWITCH_LLM_CONCURRENCY=1
TWITCH_CHAT_HISTORY=3
TWITCH_MAX_TOKENS_CHAT=100
TWITCH_CHAT_HISTORY=32
TWITCH_MAX_TOKENS_CHAT=200
TWITCH_MAX_TOKENS_LLM=$MAX_TOKENS
## Stable Diffusion Settings
SD_TEXT_MIN=70
SD_WIDTH=512
SD_HEIGHT=512
SD_API=1
SD_MODEL=turbo
SD_MODEL=custom
#SD_CUSTOM_MODEL="babes_31.safetensors"
SD_CUSTOM_MODEL="sexyToon3D_v420.safetensors"
SD_INTERMEDIARY_IMAGES=1
SD_N_STEPS=20
ALIGNMENT=center
Expand Down Expand Up @@ -118,6 +120,7 @@ DYLD_LIBRARY_PATH=`pwd`:/usr/local/lib:$DYLD_LIBRARY_PATH \
--sd-height $SD_HEIGHT \
--sd-image \
--sd-model $SD_MODEL \
--sd-custom-model $SD_CUSTOM_MODEL \
--sd-n-steps $SD_N_STEPS \
--image-alignment $ALIGNMENT \
$SUBTITLE_CMD \
Expand Down
9 changes: 9 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,15 @@ pub struct Args {
)]
pub sd_intermediary_images: bool,

/// Stable Diffusion Custom Model Name to load
#[clap(
long,
env = "SD_CUSTOM_MODEL",
default_value = "sd_xl_turbo_1.0.safetensors",
help = "Custom Stable Diffusion Model. for automatic 111111 API usage, the name must exist as a model locally or remotely."
)]
pub sd_custom_model: String,

/// Stable Diffusion Version
#[clap(
long,
Expand Down
18 changes: 18 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ async fn main() {
sd_config.width = Some(args.sd_width);
sd_config.image_position = Some(args.image_alignment.clone());
sd_config.intermediary_images = args.sd_intermediary_images;
sd_config.custom_model = Some(args.sd_custom_model.clone());
// match args.sd_model with on of the strings "1.5", "2.1", "xl", "turbo" and set the sd_version accordingly
sd_config.sd_version = if args.sd_model == "1.5" {
StableDiffusionVersion::V1_5
Expand All @@ -691,6 +692,8 @@ async fn main() {
StableDiffusionVersion::Xl
} else if args.sd_model == "turbo" {
StableDiffusionVersion::Turbo
} else if args.sd_model == "Custom" {
StableDiffusionVersion::Custom
} else {
StableDiffusionVersion::V1_5
};
Expand Down Expand Up @@ -817,6 +820,7 @@ async fn main() {
sd_config.width = Some(args.sd_width);
sd_config.image_position = Some(args.image_alignment.clone());
sd_config.intermediary_images = args.sd_intermediary_images;
sd_config.custom_model = Some(args.sd_custom_model.clone());
if args.sd_scaled_height > 0 {
sd_config.scaled_height = Some(args.sd_scaled_height);
}
Expand All @@ -832,6 +836,8 @@ async fn main() {
StableDiffusionVersion::Xl
} else if args.sd_model == "turbo" {
StableDiffusionVersion::Turbo
} else if args.sd_model == "Custom" {
StableDiffusionVersion::Custom
} else {
StableDiffusionVersion::V1_5
};
Expand Down Expand Up @@ -1215,6 +1221,7 @@ async fn main() {
sd_config.width = Some(args.sd_width);
sd_config.image_position = Some(args.image_alignment.clone());
sd_config.intermediary_images = args.sd_intermediary_images;
sd_config.custom_model = Some(args.sd_custom_model.clone());
// match args.sd_model with on of the strings "1.5", "2.1", "xl", "turbo" and set the sd_version accordingly
sd_config.sd_version = if args.sd_model == "1.5" {
StableDiffusionVersion::V1_5
Expand All @@ -1224,6 +1231,8 @@ async fn main() {
StableDiffusionVersion::Xl
} else if args.sd_model == "turbo" {
StableDiffusionVersion::Turbo
} else if args.sd_model == "Custom" {
StableDiffusionVersion::Custom
} else {
StableDiffusionVersion::V1_5
};
Expand Down Expand Up @@ -1374,6 +1383,7 @@ async fn main() {
sd_config.width = Some(args.sd_width);
sd_config.image_position = Some(image_alignment);
sd_config.intermediary_images = args.sd_intermediary_images;
sd_config.custom_model = Some(args.sd_custom_model.clone());
if args.sd_scaled_height > 0 {
sd_config.scaled_height = Some(args.sd_scaled_height);
}
Expand All @@ -1389,6 +1399,8 @@ async fn main() {
StableDiffusionVersion::Xl
} else if args.sd_model == "turbo" {
StableDiffusionVersion::Turbo
} else if args.sd_model == "babes" {
StableDiffusionVersion::Custom
} else {
StableDiffusionVersion::V1_5
};
Expand Down Expand Up @@ -1473,6 +1485,7 @@ async fn main() {
sd_config.width = Some(args.sd_width);
sd_config.image_position = Some(image_alignment);
sd_config.intermediary_images = args.sd_intermediary_images;
sd_config.custom_model = Some(args.sd_custom_model.clone());
if args.sd_scaled_height > 0 {
sd_config.scaled_height = Some(args.sd_scaled_height);
}
Expand All @@ -1488,6 +1501,8 @@ async fn main() {
StableDiffusionVersion::Xl
} else if args.sd_model == "turbo" {
StableDiffusionVersion::Turbo
} else if args.sd_model == "babes" {
StableDiffusionVersion::Custom
} else {
StableDiffusionVersion::V1_5
};
Expand Down Expand Up @@ -1530,6 +1545,7 @@ async fn main() {
sd_config.width = Some(args.sd_width);
sd_config.image_position = Some(args.image_alignment.clone());
sd_config.intermediary_images = args.sd_intermediary_images;
sd_config.custom_model = Some(args.sd_custom_model.clone());
if args.sd_scaled_height > 0 {
sd_config.scaled_height = Some(args.sd_scaled_height);
}
Expand All @@ -1545,6 +1561,8 @@ async fn main() {
StableDiffusionVersion::Xl
} else if args.sd_model == "turbo" {
StableDiffusionVersion::Turbo
} else if args.sd_model == "babes" {
StableDiffusionVersion::Custom
} else {
StableDiffusionVersion::V1_5
};
Expand Down
1 change: 1 addition & 0 deletions src/sd_automatic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub async fn sd_auto(
let client = Client::new();

let model = match config.sd_version {
StableDiffusionVersion::Custom => config.custom_model.as_deref().unwrap_or("sd_xl_turbo_1.0.safetensors"),
StableDiffusionVersion::V1_5 => "v1-5-pruned-emaonly.ckpt",
StableDiffusionVersion::V2_1 => "v2-1_768-ema-pruned.ckpt",
StableDiffusionVersion::Xl => "stabilityai/stable-diffusion-xl-1024-1.0.ckpt",
Expand Down
24 changes: 18 additions & 6 deletions src/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub enum StableDiffusionVersion {
V2_1,
Xl,
Turbo,
Custom,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand All @@ -39,12 +40,13 @@ impl StableDiffusionVersion {
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
Self::Turbo => "stabilityai/sdxl-turbo",
Self::Custom => "stabilityai/sdxl-turbo",
}
}

fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo | Self::Custom => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
Expand All @@ -56,7 +58,7 @@ impl StableDiffusionVersion {

fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo | Self::Custom => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
Expand All @@ -68,7 +70,7 @@ impl StableDiffusionVersion {

fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo | Self::Custom => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
Expand All @@ -80,7 +82,7 @@ impl StableDiffusionVersion {

fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo | Self::Custom => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
Expand Down Expand Up @@ -108,7 +110,7 @@ impl ModelFile {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo | StableDiffusionVersion::Custom => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
Expand All @@ -127,7 +129,7 @@ impl ModelFile {
// See https://github.com/huggingface/candle/issues/1060
if matches!(
version,
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo | StableDiffusionVersion::Custom,
) && use_f16
{
(
Expand Down Expand Up @@ -260,6 +262,7 @@ pub struct SDConfig {
pub n_steps: Option<usize>,
pub num_samples: usize,
pub sd_version: StableDiffusionVersion,
pub custom_model: Option<String>,
pub intermediary_images: bool,
pub use_flash_attn: bool,
pub use_f16: bool,
Expand Down Expand Up @@ -290,6 +293,7 @@ impl SDConfig {
n_steps: None,
num_samples: 1,
sd_version: StableDiffusionVersion::Turbo,
custom_model: None,
intermediary_images: false,
use_flash_attn: false,
use_f16: false,
Expand Down Expand Up @@ -337,6 +341,7 @@ pub async fn sd(config: SDConfig) -> Result<Vec<ImageBuffer<image::Rgb<u8>, Vec<
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 7.5,
StableDiffusionVersion::Turbo => 0.,
StableDiffusionVersion::Custom => 0.,
},
};
let n_steps = match config.n_steps {
Expand All @@ -346,6 +351,7 @@ pub async fn sd(config: SDConfig) -> Result<Vec<ImageBuffer<image::Rgb<u8>, Vec<
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 20,
StableDiffusionVersion::Turbo => 1,
StableDiffusionVersion::Custom => 1,
},
};
let dtype = if config.use_f16 {
Expand Down Expand Up @@ -374,6 +380,11 @@ pub async fn sd(config: SDConfig) -> Result<Vec<ImageBuffer<image::Rgb<u8>, Vec<
config.height,
config.width,
),
StableDiffusionVersion::Custom => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
config.sliced_attention_size,
config.height,
config.width,
),
};

let scheduler = sd_config.build_scheduler(n_steps)?;
Expand Down Expand Up @@ -437,6 +448,7 @@ pub async fn sd(config: SDConfig) -> Result<Vec<ImageBuffer<image::Rgb<u8>, Vec<
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 0.18215,
StableDiffusionVersion::Turbo => 0.13025,
StableDiffusionVersion::Custom => 0.13025,
};

// array of image buffers to gather the results
Expand Down

0 comments on commit 4a4e14a

Please sign in to comment.