Skip to content

Commit

Permalink
Merge pull request #51 from aim-qmul:open-source-version
Browse files Browse the repository at this point in the history
feat: winning configurations
  • Loading branch information
yoyololicon committed May 18, 2023
2 parents 47a42b4 + f9060ab commit e70b474
Show file tree
Hide file tree
Showing 12 changed files with 969 additions and 41 deletions.
79 changes: 72 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

</div>

AIMLESS (Artificial Intelligence and Music League for Effective Source Separation) is a special interest group in audio source separation at C4DM, consisting of PhD students from the AIM CDT program.
This repository is adapted from [Danna-Sep](https://github.com/yoyololicon/music-demixing-challenge-ismir-2021-entry) and
contains our training code for the [SDX23 Sound Demixing Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023).


## Quick Start

You can build a conda environment using `environment.yml`.
The below commands should be runnable if you're using EECS servers.
The conda environment we used for training is described in `environment.yml`.
The below commands should be runnable if you're using QMUL EECS servers.
If you want to run it on your local machine, change the `root` param in the config to where you downloaded the MUSDB18-HQ dataset.

### Frequency-masking-based model
Expand All @@ -24,6 +29,71 @@ python main.py fit --config cfg/xumx.yaml
python main.py fit --config cfg/demucs.yaml
```

## Install the repository as a package

This step is required if you want to test our submission repositories (see section [Reproduce the winning submission](#reproduce-the-winning-submission)) locally.
```sh
pip install git+https://github.com/aim-qmul/sdx23-aimless
```

## Reproduce the winning submission

### CDX Leaderboard A, submission ID 220319

This section describes how to reproduce the [best perform model](https://gitlab.aicrowd.com/yoyololicon/cdx-submissions/-/issues/90) we used on CDX leaderboard A.
The submission consists of one HDemucs predicting all the targets and one BandSplitRNN predicitng the music from the mixture.

To train the HDemucs:
```commandline
python main.py fit --config cfg/cdx_a/hdemucs.yaml --data.init_args.root /DNR_DATASET_ROOT/dnr_v2/
```
Remember to change `/DNR_DATASET_ROOT/dnr_v2/` to your download location of [Divide and Remaster (DnR) dataset](https://zenodo.org/record/6949108).

To train the BandSplitRNN:
```commandline
python main.py fit --config cfg/cdx_a/bandsplit_rnn.yaml --data.init_args.root /DNR_DATASET_ROOT/dnr_v2/
```

We trained the models with no more than 4 GPUs, depending on the resources we had at the time.

After training, please go to our [submission repository](https://gitlab.aicrowd.com/yoyololicon/cdx-submissions/).
Then, copy the last checkpoint of HDemucs (usually located at `lightning_logs/version_**/checkpoints/last.ckpt`) to `my_submission/lightning_logs/hdemucs-64-sdr/checkpoints/last.ckpt` in the submission repository.
Similarly, copy the last checkpoint of BandSplitRNN to `my_submission/lightning_logs/bandsplitRNN-music/checkpoints/last.ckpt`.
After these steps, you have reproduced our submission!

The inference procedure in our submission repository is a bit complex.
Briefly speaking, the HDemucs predicts the targets independently for each channels of the stereo mixture, plus, the average (the mid) and the difference (the side) of the two channels.
The stereo separated sources are made from a linear combination of these mono predictions.
The separated music from the BandSplitRNN is enhanced by Wiener Filtering, and the final music predictions is the average from the two models.

### MDX Leaderboard A (Label Noise), submission ID 220426

This section describes how to reproduce the [best perform model](https://gitlab.aicrowd.com/yoyololicon/mdx23-submissions/-/issues/76) we used on MDX leaderboard A.

Firstly, we manually inspected the [label noise dataset](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23/dataset_files)(thanks @mhrice for the hard work!) and labeled the clean songs (no label noise).
The labels are recorded in `data/lightning/label_noise.csv`.
Then, a HDemucs was trained only on the clean labels with the following settings:

* negative SDR as the loss function
* Training occurs on random chunks and random stem combinations of the clean songs
* Training batches are augmented and processed using different random effects
* Due to all this randomization, validation is done also on the training dataset (no separate validation set)

To reproduce the training:
```commandline
python main.py fit --config cfg/mdx_a/hdemucs.yaml --data.init_args.root /DATASET_ROOT/
```
Remember to place the label noise data under `/DATASET_ROOT/train/`.

Other details:
* Model is trained for ~800 epochs (approx. 2 weeks on 4 RTX A50000)
* During the last ~200 epochs, the learning rate is reduced to 0.001, gradient accumulation is increased to 64, and the effect randomization chance is increased by a factor of 1.666 (e.g. 30% to 50% etc.)

After training, please go to our [submission repository](https://gitlab.aicrowd.com/yoyololicon/mdx23-submissions/).
Then, copy the checkpoint to `my_submission/acc64_4devices_lr0001_e1213_last.ckpt` in the submission repository.
After these steps, you have reproduced our submission!


## Structure

* `aimless`: package root, which can be imported for submission.
Expand All @@ -50,9 +120,4 @@ Split song in the browser with pretrained Hybrid Demucs.
Then open [http://localhost:8501/](http://localhost:8501/) in your browser.


## Install the repository as a package

```sh
pip install git+https://yoyololicon:[email protected]/yoyololicon/mdx23-aim-playground
```
For the value of `ACCESS_TOKEN` please refer to [#24](https://github.com/yoyololicon/mdx23-aim-playground/issues/24#issuecomment-1420952853).
212 changes: 212 additions & 0 deletions cfg/cdx_a/bandsplit_rnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# pytorch_lightning==1.8.5.post0
seed_everything: 2434
trainer:
logger: true
enable_checkpointing: true
callbacks:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: null
verbose: false
save_last: true
save_top_k: 1
save_weights_only: false
mode: min
auto_insert_metric_name: true
every_n_train_steps: 2000
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: null
verbose: false
save_last: null
save_top_k: -1
save_weights_only: true
mode: min
auto_insert_metric_name: true
every_n_train_steps: 1000
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
- class_path: pytorch_lightning.callbacks.ModelSummary
init_args:
max_depth: 2
default_root_dir: null
gradient_clip_val: null
gradient_clip_algorithm: null
num_nodes: 1
num_processes: null
devices: null
gpus: null
auto_select_gpus: false
tpu_cores: null
ipus: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: null
max_epochs: null
min_epochs: null
max_steps: 99000
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: 0
limit_test_batches: null
limit_predict_batches: null
val_check_interval: null
log_every_n_steps: 1
accelerator: gpu
strategy: ddp
sync_batchnorm: false
precision: 32
enable_model_summary: true
num_sanity_val_steps: 0
resume_from_checkpoint: null
profiler: null
benchmark: null
deterministic: null
reload_dataloaders_every_n_epochs: 0
auto_lr_find: false
replace_sampler_ddp: true
detect_anomaly: false
auto_scale_batch_size: false
plugins: null
amp_backend: native
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
inference_mode: true
ckpt_path: null
model:
class_path: aimless.lightning.freq_mask.MaskPredictor
init_args:
model:
class_path: aimless.models.band_split_rnn.BandSplitRNN
init_args:
n_fft: 4096
split_freqs:
- 100
- 200
- 300
- 400
- 500
- 600
- 700
- 800
- 900
- 1000
- 1250
- 1500
- 1750
- 2000
- 2250
- 2500
- 2750
- 3000
- 3250
- 3500
- 3750
- 4000
- 4500
- 5000
- 5500
- 6000
- 6500
- 7000
- 7500
- 8000
- 9000
- 10000
- 11000
- 12000
- 13000
- 14000
- 15000
- 16000
- 18000
- 20000
- 22000
hidden_size: 128
num_layers: 12
norm_groups: 4
criterion:
class_path: aimless.loss.freq.MDLoss
init_args:
mcoeff: 10
n_fft: 4096
hop_length: 1024
transforms:
- class_path: aimless.augment.SpeedPerturb
init_args:
orig_freq: 44100
speeds:
- 90
- 100
- 110
p: 0.2
- class_path: aimless.augment.RandomPitch
init_args:
semitones:
- -1
- 1
- 0
- 1
- 2
n_fft: 2048
hop_length: 512
p: 0.2
target_track: sdx
targets:
music: null
n_fft: 4096
hop_length: 1024
residual_model: true
softmask: false
alpha: 1.0
n_iter: 1
data:
class_path: data.lightning.DnR
init_args:
root: /import/c4dm-datasets-ext/sdx-2023/dnr_v2/dnr_v2/
seq_duration: 3.0
samples_per_track: 144
random: true
include_val: true
random_track_mix: true
transforms:
- class_path: data.augment.RandomGain
init_args:
low: 0.25
high: 1.25
p: 1.0
- class_path: data.augment.RandomFlipPhase
init_args:
p: 0.5
- class_path: data.augment.RandomSwapLR
init_args:
p: 0.5
batch_size: 16
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.0003
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0
amsgrad: false
foreach: null
maximize: false
capturable: false
differentiable: false
fused: false
Loading

0 comments on commit e70b474

Please sign in to comment.