Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Use DirectStorage with CUDA interop to more efficient load tensors #7796

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

mtavenrath
Copy link
Contributor

@mtavenrath mtavenrath commented Jun 6, 2024

On Windows File I've seen file IO in the range of 3GB/s - 4GB/s using a single IO thread and mmaped files. The newest NVMe drives can do >14GB/s and good raid controllers can read with to ~55GB/s. To get read speeds close to NVMe raid speed without stressing CPU RAM bw DirectStorage can be used.

On Linux CUDA supports the cuFile API, on Windows one currently has to use DirectStorage for DX with CUDA interop as done in this POC. I've seen speedups of 3x over mmap (15s->5s) when streaming from a single NVMe drive. There is a code path which can stream from two NVMe drives at once (lacking a RAID) which improves the speedup even more, but is currently limited by DX/CUDA interop limitations in combination with llama.cpp.

For now I have hijacked the non-mmaped code path and pass a struct passing the file information to the tensor_set function. For a clean solution it'd be good if there was a way to have some abstract way to import tensor data from a ggml file handle which depends on the backend. A special file handle is created because the different DirectStorage APIs all have special ways to open a file for the DirectStorage operation.

The most simple interface one could imagine would be ggml_tensor_set(filename, offset, size).. Passing a filename only would require opening the file on each set operation which is potentially more expensive than the read itself. Thus my proposal is to have two new functions

file = ggml_backend_file_open(filename)
ggml_backend_tensor_set_from_file(tensor, file, offset, size);

Since IO ops are completely asynchronous eventually there must be a way to synchronize all file io, or at least to add an event to the file io queue to ensure that all file io is done. Currently the hack is using a nullptr passed as file to trigger this sync.

@mtavenrath mtavenrath marked this pull request as draft June 6, 2024 08:29
@github-actions github-actions bot added build Compilation issues Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 6, 2024
Copy link
Contributor

github-actions bot commented Jun 6, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 426 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=11080.85ms p(95)=29604.7ms fails=, finish reason: stop=367 truncated=59
  • Prompt processing (pp): avg=122.76tk/s p(95)=532.52tk/s
  • Token generation (tg): avg=23.58tk/s p(95)=36.28tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=direct_storage_cuda commit=18dbe4b8af23765ccc2c824adc13202a25f0afb1

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 426 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717665262 --> 1717665898
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 612.37, 612.37, 612.37, 612.37, 612.37, 380.54, 380.54, 380.54, 380.54, 380.54, 410.56, 410.56, 410.56, 410.56, 410.56, 433.96, 433.96, 433.96, 433.96, 433.96, 483.5, 483.5, 483.5, 483.5, 483.5, 487.65, 487.65, 487.65, 487.65, 487.65, 494.31, 494.31, 494.31, 494.31, 494.31, 517.87, 517.87, 517.87, 517.87, 517.87, 519.88, 519.88, 519.88, 519.88, 519.88, 534.3, 534.3, 534.3, 534.3, 534.3, 537.52, 537.52, 537.52, 537.52, 537.52, 537.92, 537.92, 537.92, 537.92, 537.92, 554.79, 554.79, 554.79, 554.79, 554.79, 559.61, 559.61, 559.61, 559.61, 559.61, 573.02, 573.02, 573.02, 573.02, 573.02, 577.12, 577.12, 577.12, 577.12, 577.12, 541.54, 541.54, 541.54, 541.54, 541.54, 551.17, 551.17, 551.17, 551.17, 551.17, 555.3, 555.3, 555.3, 555.3, 555.3, 555.95, 555.95, 555.95, 555.95, 555.95, 556.02, 556.02, 556.02, 556.02, 556.02, 553.32, 553.32, 553.32, 553.32, 553.32, 553.06, 553.06, 553.06, 553.06, 553.06, 555.42, 555.42, 555.42, 555.42, 555.42, 555.95, 555.95, 555.95, 555.95, 555.95, 562.31, 562.31, 562.31, 562.31, 562.31, 565.75, 565.75, 565.75, 565.75, 565.75, 568.68, 568.68, 568.68, 568.68, 568.68, 566.62, 566.62, 566.62, 566.62, 566.62, 568.55, 568.55, 568.55, 568.55, 568.55, 569.43, 569.43, 569.43, 569.43, 569.43, 571.34, 571.34, 571.34, 571.34, 571.34, 580.16, 580.16, 580.16, 580.16, 580.16, 579.17, 579.17, 579.17, 579.17, 579.17, 579.34, 579.34, 579.34, 579.34, 579.34, 578.87, 578.87, 578.87, 578.87, 578.87, 580.53, 580.53, 580.53, 580.53, 580.53, 583.4, 583.4, 583.4, 583.4, 583.4, 585.03, 585.03, 585.03, 585.03, 585.03, 584.77, 584.77, 584.77, 584.77, 584.77, 588.45, 588.45, 588.45, 588.45, 588.45, 598.86, 598.86, 598.86, 598.86, 598.86, 607.41, 607.41, 607.41, 607.41, 607.41, 608.57, 608.57, 608.57, 608.57, 608.57, 607.32, 607.32, 607.32, 607.32, 607.32, 606.52, 606.52, 606.52, 606.52, 606.52, 606.07, 606.07, 606.07, 606.07, 606.07, 607.37, 607.37, 607.37, 607.37, 607.37, 612.76, 612.76, 612.76, 612.76, 612.76, 614.41, 614.41, 614.41, 614.41, 614.41, 610.26, 610.26, 610.26, 610.26, 610.26, 603.69, 603.69, 603.69, 603.69, 603.69, 604.1, 604.1, 604.1, 604.1, 604.1, 603.4, 603.4, 603.4, 603.4, 603.4, 603.4, 603.4, 603.4, 603.4, 603.4, 602.22, 602.22, 602.22, 602.22, 602.22, 601.75, 601.75, 601.75, 601.75, 601.75, 603.77, 603.77, 603.77, 603.77, 603.77, 607.04, 607.04, 607.04, 607.04, 607.04, 605.24, 605.24, 605.24, 605.24, 605.24, 605.63, 605.63, 605.63, 605.63]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 426 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717665262 --> 1717665898
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.92, 33.92, 33.92, 33.92, 33.92, 29.24, 29.24, 29.24, 29.24, 29.24, 26.09, 26.09, 26.09, 26.09, 26.09, 26.77, 26.77, 26.77, 26.77, 26.77, 27.25, 27.25, 27.25, 27.25, 27.25, 26.6, 26.6, 26.6, 26.6, 26.6, 27.94, 27.94, 27.94, 27.94, 27.94, 28.1, 28.1, 28.1, 28.1, 28.1, 28.03, 28.03, 28.03, 28.03, 28.03, 27.7, 27.7, 27.7, 27.7, 27.7, 26.77, 26.77, 26.77, 26.77, 26.77, 26.26, 26.26, 26.26, 26.26, 26.26, 25.51, 25.51, 25.51, 25.51, 25.51, 24.89, 24.89, 24.89, 24.89, 24.89, 24.28, 24.28, 24.28, 24.28, 24.28, 24.35, 24.35, 24.35, 24.35, 24.35, 23.98, 23.98, 23.98, 23.98, 23.98, 23.42, 23.42, 23.42, 23.42, 23.42, 23.2, 23.2, 23.2, 23.2, 23.2, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 23.23, 22.68, 22.68, 22.68, 22.68, 22.68, 22.54, 22.54, 22.54, 22.54, 22.54, 22.34, 22.34, 22.34, 22.34, 22.34, 22.33, 22.33, 22.33, 22.33, 22.33, 22.25, 22.25, 22.25, 22.25, 22.25, 22.33, 22.33, 22.33, 22.33, 22.33, 22.24, 22.24, 22.24, 22.24, 22.24, 22.42, 22.42, 22.42, 22.42, 22.42, 22.42, 22.42, 22.42, 22.42, 22.42, 22.61, 22.61, 22.61, 22.61, 22.61, 22.61, 22.61, 22.61, 22.61, 22.61, 22.48, 22.48, 22.48, 22.48, 22.48, 22.19, 22.19, 22.19, 22.19, 22.19, 21.96, 21.96, 21.96, 21.96, 21.96, 21.97, 21.97, 21.97, 21.97, 21.97, 22.14, 22.14, 22.14, 22.14, 22.14, 22.21, 22.21, 22.21, 22.21, 22.21, 22.29, 22.29, 22.29, 22.29, 22.29, 22.37, 22.37, 22.37, 22.37, 22.37, 22.42, 22.42, 22.42, 22.42, 22.42, 22.35, 22.35, 22.35, 22.35, 22.35, 22.34, 22.34, 22.34, 22.34, 22.34, 22.32, 22.32, 22.32, 22.32, 22.32, 22.18, 22.18, 22.18, 22.18, 22.18, 22.16, 22.16, 22.16, 22.16, 22.16, 22.24, 22.24, 22.24, 22.24, 22.24, 22.32, 22.32, 22.32, 22.32, 22.32, 22.44, 22.44, 22.44, 22.44, 22.44, 22.53, 22.53, 22.53, 22.53, 22.53, 22.47, 22.47, 22.47, 22.47, 22.47, 22.26, 22.26, 22.26, 22.26, 22.26, 22.1, 22.1, 22.1, 22.1, 22.1, 21.91, 21.91, 21.91, 21.91, 21.91, 21.86, 21.86, 21.86, 21.86, 21.86, 21.22, 21.22, 21.22, 21.22, 21.22, 20.7, 20.7, 20.7, 20.7, 20.7, 20.7, 20.7, 20.7, 20.7, 20.7, 20.76, 20.76, 20.76, 20.76, 20.76, 20.81, 20.81, 20.81, 20.81]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 426 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717665262 --> 1717665898
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18, 0.18, 0.18, 0.18, 0.18, 0.34, 0.34, 0.34, 0.34, 0.34, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.14, 0.14, 0.14, 0.14, 0.14, 0.23, 0.23, 0.23, 0.23, 0.23, 0.22, 0.22, 0.22, 0.22, 0.22, 0.21, 0.21, 0.21, 0.21, 0.21, 0.22, 0.22, 0.22, 0.22, 0.22, 0.25, 0.25, 0.25, 0.25, 0.25, 0.16, 0.16, 0.16, 0.16, 0.16, 0.2, 0.2, 0.2, 0.2, 0.2, 0.27, 0.27, 0.27, 0.27, 0.27, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.3, 0.3, 0.3, 0.3, 0.3, 0.27, 0.27, 0.27, 0.27, 0.27, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, 0.31, 0.35, 0.35, 0.35, 0.35, 0.35, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.31, 0.31, 0.31, 0.31, 0.31, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.09, 0.09, 0.09, 0.09, 0.09, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.28, 0.28, 0.28, 0.28, 0.28, 0.45, 0.45, 0.45, 0.45, 0.45, 0.48, 0.48, 0.48, 0.48, 0.48, 0.55, 0.55, 0.55, 0.55, 0.55, 0.58, 0.58, 0.58, 0.58, 0.58, 0.55, 0.55, 0.55, 0.55, 0.55, 0.4, 0.4, 0.4, 0.4, 0.4, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.21, 0.21, 0.21, 0.21]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 426 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717665262 --> 1717665898
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0]
                    
Loading

@JohannesGaessler
Copy link
Collaborator

When I previously investigated this in this PR #1483 I found that cuFile is not significantly faster but it's always possible that my implementation was just bad.

In any case, be aware that cuFile is incompatible with some filesystems; on my Linux machine I could for instance only load models stored on EXT4 partitions but not models stored on Btrfs partitions.

@mtavenrath
Copy link
Contributor Author

Your PR doesn't state which NVMe devices you used for benchmarking with cuFile.

On a modern x64 system a single CPU core can process ~25gb/s max with specialized memcpy operations and most likely ~12gb/s with libcs memcpy.

NVMe->GPU has quite a few reads/writes through host memory with the default pipeline limiting perf if a single thread is used for IO:

PCI->host kernel space (write), kernel->disk cache (rw), disk cache -> user space (rw), user space -> pinned (CUDA, rw), pinned -> GPU (read)

Assuming the DMA reads directly into the disk cache (or the disk cache is bypassed) the best case we have is 2xDMA + 4xCPU read or write. Assuming DMA is async and a ~12gb/s memcpy implementation is used one would get 3GB/s max perf using a single thread which is close to the 2.66gb/s I've been seeing (40GB in 15s).

Instead of DirectStorage I could imagine another pipeline as well which uses multiple threads copying from mmapped memory to pinned memory and a single thread spawing the cudaMemcpys for the uploads.

Yet using mmapped memory in its current form has another problem: It commits pages and thus increases physical memory utilization for tensors which are actually required only on the GPU. This could be solved easily by closing the mmap handles after uploading the data to the GPU.

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

Instead of DirectStorage I could imagine another pipeline as well which uses multiple threads copying from mmapped memory to pinned memory and a single thread spawing the cudaMemcpys for the uploads.

I think we should give it a try using multiple threads with direct I/O into a pinned buffer. I think that would remove all unnecessary copies. If the overall performance is similar, it would save us a significant amount of complexity from having to implement backend-specific interfaces.

Yet using mmapped memory in its current form has another problem: It commits pages and thus increases physical memory utilization for tensors which are actually required only on the GPU. This could be solved easily by closing the mmap handles after uploading the data to the GPU.

This is done on POSIX systems when possible by calling munmap on the fraction of the model offloaded to the GPU. There is always at least one tensor on the CPU side (the token embeddings), so we cannot unmap the file completely. I couldn't find any way to unmap a file partially under Windows, so it is not supported there.

@mtavenrath
Copy link
Contributor Author

Have you tried how much perf is lost when not mapping the whole file at once, but have one mapping for each tensor? This would allow unmapping tensors which are no longer required on the CPU.

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

I have not tried it, but we absolutely could do something like that. Mapping only the fraction of the file used on the CPU, and using direct I/O or DirectStorage/cuFile/etc for the offloaded fraction could be a good default behavior.

@mtavenrath
Copy link
Contributor Author

Wouldn't partial mapping require the same infrastructure as DirectStorage? Instead of relying on getting a host pointer in tensor_set the backend would get a file reference and be responsible to map/unmap data or use direct io as desired per tensor?

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

I don't think it would be exactly the same. Moving the mmap code into the backends would result in code duplication in systems with unified memory, because in that case mmap can be used for both for the CPU and the GPU backends. Currently this is only Metal on Apple silicon, but we have experimented supporting mmap with CUDA and HIP. Here we tested that it does work on AMD iGPUs. I would also expect it to work on NVIDIA systems like tegra. I would also be wary about creating one mapping per tensor, even small models have hundreds of tensors, and above a thousand tensors for larger models. In most cases we should be able to consolidate the mappings into one or two mappings per backend.

I would also expect to be able to use the same implementation of direct I/O for all the backends, except maybe the pinned buffer allocation.

@JohannesGaessler
Copy link
Collaborator

Your PR doesn't state which NVMe devices you used for benchmarking with cuFile.

I was using a SanDisk Ultra 3D NVMe.

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

That said, I wouldn't be against moving the entire loading logic, including mmap and possible direct I/O, to ggml or ggml-backend. @ggerganov has disagreed in the past, but I think this should be dealt by the library so that all applications can benefit from it (in this context, ggml is the library and llama.cpp is the application).

@mtavenrath
Copy link
Contributor Author

mtavenrath commented Jun 6, 2024

SanDisk Ultra 3D NVMe

This disk can read 3500MB/s which means perf improvements are not expected with DS.

Mapping topic

On Windows it's legal to call MapViewOfFile twice and one gets two different pointers with the same physical backing store. What I do not know yet if unmapping a view of the file will also free all the physical backing store allocated for the view only.

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Jun 6, 2024
@ggerganov
Copy link
Owner

That said, I wouldn't be against moving the entire loading logic, including mmap and possible direct I/O, to ggml or ggml-backend. @ggerganov has disagreed in the past, but I think this should be dealt by the library so that all applications can benefit from it (in this context, ggml is the library and llama.cpp is the application).

We can reconsider. The main reason to disagree in the beginning was because I wasn't familiar with mmap and didn't want to have something that I don't understand in the core library. The functionality has now been exercised extensively and although I'm still not deeply familiar with all the aspects of mmap, I think we can look to merge it in ggml

On the topic of this PR - don't have much to add for now. I agree that the Direct I/O approach suggested earlier should be explored because it seems the implementation would carry less baggage (i.e. dependencies and headers) and it would be more portable

@mtavenrath
Copy link
Contributor Author

I wrote a tiny benchmark today to determine the real IO throughput reading into system memory with the following results:

I ran two iterations on a system with 128GB of memory and a Corsair T705 NVMe drive. The system is large enough to keep one 39GB file completely cached whereas it is not large enough to keep 5 files of this size cached.

I ran two iterations of the benchmark, one where each benchmark read a different file to ensure that the FS cache is not utilized and one where all benchmarks read the same file with a single warmup iteration.

As result, std::fstream is pretty bad, fread is okay once data is in the FS cache, but still only 55% as fast as using the Win32 API to read data which is not yet in the FS cache. mmap is slower than direct IO as well.

Caching improves things a lot, yet unbuffered IO with the WIN32 API on a fast NVMe drive is still the fastest option.

The outcome of this benchmark is that for a non-raid NVMe drive unbuffered file IO is already quite good if file IO into pinned memory and host->device transfers can be pipelined to run overlapped instead of serially.

Filesize: 39503 MB
                        Not in FS Cache | in FS Cache
std::fstream            1777.82 MB/s    | 1638.84 MB/s
fread                   3005.38 MB/s    | 7965.70 MB/s
CreateFile buffered     5501.46 MB/s    | 8155.88 MB/s
CreateFile unbuffered   8538.62 MB/s    | 8383.86 MB/s
mmap                    2193.19 MB/s    | 3602.94 MB/S

@mtavenrath
Copy link
Contributor Author

@slaren I've prototyped a small piece of code which goes the direct file io path directly to a set of pinned memory buffers which are used in round robin style and achieved ~8.5gb/s with 4 pinned memory buffers, each with a size of 1MB. While this is still slower than using DS it is a good intermediate step to faster IO BW and also to prepare an API which allows the use of DirectStorage in the future.

To achieve the bandwidth the file has to be openend with CreateFile(A) on Windows while ggml_fopen currently supports only C-a C-style FILE*.

My proposal for the API changes would be:

Add:
// File API
ggml_file* ggml_open(const char *filename);
size_t ggml_file_read(ggml_file*, size);
size_t ggml_file_get_size(ggml_file*);
ggml_file_close(ggml_file*);

// mapping api -> will this still be required?
ggml_mapping *ggml_file_map(ggml_file, start, size);
ggml_mapping_unmap(ggml_mapping* mapping);

// tensor backend api
ggml_backend_tensor_set_data(ggml_file, dst_offset, src_offset, src_size);

The following symbols can potentially be removed: llama_file, llamamap

For the async upload one would have to add n pinned memory buffers as temporary storage and n cudaEvent_t for synchronization. Given that all upload do happen within the ctx cuda stream no further synchronization would be required.

In the future, in case DS will be implemented, one would potentially have to add ggml_backend_file_open(ggml_file* file); to have the special required for direct storage APIs.

Do you like those API changes or do you have suggestions how to do this more efficient within the ggml framework?

@mtavenrath
Copy link
Contributor Author

Besides being more efficient with regards to perf (measure on Windows) the other benefit of this change will be that the issues with commited memory on windows will be gone as well since data is read into temporary pinned memory buffers only.

@slaren
Copy link
Collaborator

slaren commented Jun 10, 2024

If I understand correctly, currently what you are doing is reading the file from a single thread with a loop similar to this:

cudaEventSynchronize(buffer_event[i]);
read_file(file, buffer[i], ..);
cudaMemcpyAsync(.., buffer[i], .., stream);
cudaEventRecord(buffer_event[i], stream);
i = (i + 1) % n_bufs;

Is this correct? Would there be any advantage to using multiple threads?

@mtavenrath
Copy link
Contributor Author

mtavenrath commented Jun 10, 2024

That is correct. Here is the prototype I used for benchmarking.

I haven't used multiple threads yet. I suggest delaying experimenting with multiple threads until the basic algorithm is implemented. For multiple threads questions arrives like do we want to have one worker thread for each GPU or flush the memcpy in the main thread?

If we have one thread per GPU new questions will arrive like, is it still necessary to call cudaSetDevice before every CUDA call? What is the overhead of putting kernel launches in a queue per thread vs. executing them in a local thread?

Prototype
#include <iostream>
#include <fstream>
#include <vector>
#include <cuda_runtime.h>
#include <Windows.h>
#undef min
#undef max

#include <string>
#include <chrono>
#include <iostream>
#include <sstream>
#include <algorithm>

class File {
public:
    File() {};
    File(File const&rhs) = delete;
    File& operator=(File const&rhs) = delete;
    
    File(File &&rhs) {
        m_handle = rhs.m_handle;
        rhs.m_handle = nullptr;
    }
    File& operator=(File &&rhs) {
        m_handle = rhs.m_handle;
        rhs.m_handle = nullptr;
    }

    static File openFile(const std::string& filename) {
        HANDLE handle = CreateFileA(filename.c_str(),
                                   GENERIC_READ,
                                   0,
                                   NULL,
                                   OPEN_EXISTING,
                                   FILE_ATTRIBUTE_NORMAL,
                                   NULL);

        if (handle == INVALID_HANDLE_VALUE) {
            std::cerr << "Error opening file: " << GetLastError() << std::endl;
            return std::move(File{});
        }

        return File(handle);
    }

    size_t getFileSize() const {
        LARGE_INTEGER fileSize;
        if (!GetFileSizeEx(m_handle, &fileSize)) {
            std::cerr << "Error getting file size: " << GetLastError() << std::endl;
            return 0; // Or throw an exception, depending on your error handling strategy.
        }
        return static_cast<size_t>(fileSize.QuadPart);
    }

    HANDLE getHandle() const {
        return m_handle;
    }

    ~File() 
    {
        if (m_handle != INVALID_HANDLE_VALUE && m_handle != NULL) {
            CloseHandle(m_handle);
            m_handle = NULL;
        }
    }
    
private:
    File(HANDLE handle): m_handle(handle)
    {
        std::cout << "Handle " << handle << std::endl;
    }



    HANDLE m_handle = {};
};


class AsyncFileUploadCUDA {
public:
    explicit AsyncFileUploadCUDA(size_t bufferSize, size_t numBuffers) : m_bufferSize(bufferSize), m_numBuffers(numBuffers) {
        m_buffers.reserve(m_numBuffers);

        for (size_t i = 0; i < m_numBuffers; ++i) {
            void* buffer;
            cudaError_t err = cudaMallocHost(&buffer, m_bufferSize);
            if (err != cudaSuccess) {
                std::cerr << "Failed to allocate pinned CUDA buffer: " << cudaGetErrorString(err) << std::endl;
                // Clean up allocated buffers and rethrow the exception.
                for (size_t j = 0; j < i; ++j) {
                    cudaFreeHost(m_buffers[j]);
                }
                throw std::runtime_error("Failed to allocate pinned CUDA buffer");
            }
            m_buffers.push_back(buffer);

            cudaEvent_t event;
            err = cudaEventCreate(&event);
            if (err != cudaSuccess) {
                std::cerr << "Failed to create CUDA event: " << cudaGetErrorString(err) << std::endl;
                // Clean up allocated buffers and events, then rethrow the exception.
                for (size_t j = 0; j < i; ++j) {
                    cudaFreeHost(m_buffers[j]);
                    if (m_events.size() > j) {
                        cudaEventDestroy(m_events[j]);
                    }
                }
                throw std::runtime_error("Failed to create CUDA event");
            }
            m_events.push_back(event);

        }

        // Create CUDA stream
        cudaError_t err = cudaStreamCreate(&m_stream);
        if (err != cudaSuccess) {
            std::cerr << "Failed to create CUDA stream: " << cudaGetErrorString(err) << std::endl;
            // Clean up allocated buffers and events, then rethrow the exception.
            for (auto& buffer : m_buffers) {
                cudaFreeHost(buffer);
            }
            for (auto& event : m_events) {
                cudaEventDestroy(event);
            }
            throw std::runtime_error("Failed to create CUDA stream");
        }

    }

    ~AsyncFileUploadCUDA() {
        for (auto& buffer : m_buffers) {
            cudaFreeHost(buffer);
        }

        // Destroy CUDA events.
        for (auto& event : m_events) {
            cudaEventDestroy(event);
        }

        cudaError_t err = cudaStreamDestroy(m_stream);
        if (err != cudaSuccess) {
            std::cerr << "Failed to destroy CUDA stream: " << cudaGetErrorString(err) << std::endl;
        }

    }

    // Disable copy and move operations to ensure the class is not copied or moved.
    AsyncFileUploadCUDA(const AsyncFileUploadCUDA&) = delete;
    AsyncFileUploadCUDA& operator=(const AsyncFileUploadCUDA&) = delete;
    AsyncFileUploadCUDA(AsyncFileUploadCUDA&&) = delete;
    AsyncFileUploadCUDA& operator=(AsyncFileUploadCUDA&&) = delete;


    void uploadFile(File const& file, void *cuda_memory) {
        size_t bytesRead = 0;
        size_t fileSize = file.getFileSize();

        std::cout << "filesize: " << fileSize << std::endl;

        size_t i = 0;
        while (bytesRead < fileSize)
        {
            cudaEventSynchronize(m_events[i]); // wait for buffer i to finish.
            DWORD bytesReadIteration = 0;
            bool success = ReadFile(file.getHandle(), m_buffers[i], DWORD(m_bufferSize), &bytesReadIteration, nullptr);
            if (!success) {
                std::cerr << "Failed to read from file: " << std::endl;
                break; // Error reading file
            }
            if (bytesReadIteration == 0) {
                break; // Error reading file or end of file reached
            }

            cudaMemcpyAsync(m_buffers[i], reinterpret_cast<std::byte*>(cuda_memory) + bytesRead, bytesRead, cudaMemcpyHostToDevice, m_stream);
            cudaEventRecord(m_events[i], m_stream);
            bytesRead += bytesReadIteration;
            ++i;
            i %= m_numBuffers;
        }
    }

private:
    size_t m_bufferSize;
    size_t m_numBuffers;
    cudaStream_t m_stream;
    std::vector<void*> m_buffers;
    std::vector<cudaEvent_t> m_events;
};

int main(int argc, char** argv) {
    if (argc != 2) {
        std::cerr << "Usage: " << argv[0] << " <filename>" << std::endl;
        return 1;
    }

    std::string filename = argv[1];

    // Open the file and get its size.
    File&& file = File::openFile(filename);
    if (file.getHandle() == nullptr) {
        std::cerr << "Failed to open file: " << filename << std::endl;
        return 1;
    }
    size_t fileSize = file.getFileSize();

    // Allocate CUDA device memory for the file data.
    void* cudaDeviceMemory;
    cudaError_t err = cudaMalloc(&cudaDeviceMemory, fileSize);
    if (err != cudaSuccess) {
        std::cerr << "Failed to allocate CUDA device memory: " << cudaGetErrorString(err) << std::endl;
        return 1;
    }

    // Create the asynchronous file uploader.
    size_t bufferSize = 1024 * 1024; // 1MB buffers, adjust as needed.
    size_t numBuffers = 2;           // Use 4 buffers for overlapped I/O and compute.
    AsyncFileUploadCUDA asyncUpload(bufferSize, numBuffers);

    try {
        // Upload the file to CUDA device memory.
        auto start = std::chrono::high_resolution_clock::now();
        asyncUpload.uploadFile(file, cudaDeviceMemory);
        auto end = std::chrono::high_resolution_clock::now();

        auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
        double readBandwidthMBps = (fileSize / 1024.0 / 1024.0) / (elapsed / 1e6);
        std::cout << "File read bandwidth: " << readBandwidthMBps << " MB/s" << std::endl;


    // Download CUDA device memory to host memory for processing or verification.
    void* cudaHostMemory = nullptr;
    err = cudaMallocHost(&cudaHostMemory, fileSize);
    if (err != cudaSuccess) {
        std::cerr << "Failed to allocate CUDA host memory: " << cudaGetErrorString(err) << std::endl;
        return 1;
    }

    err = cudaMemcpy(cudaHostMemory, cudaDeviceMemory, fileSize, cudaMemcpyDeviceToHost);
    if (err != cudaSuccess) {
        std::cerr << "Failed to copy data from CUDA device memory to host: " << cudaGetErrorString(err) << std::endl;
        cudaFreeHost(cudaHostMemory);
        return 1;
    }

    // Now you can process or verify the file content in cudaHostMemory.
    // Open the file and read its content into a std::string.
    std::ifstream inputFile(filename, std::ios::binary);
    std::stringstream ss;
    ss << inputFile.rdbuf();
    std::string fileContent = ss.str();

    // Compare the content of cudaHostMemory with fileContent.
    int compareResult = memcmp(cudaHostMemory, fileContent.data(), std::min(fileSize, fileContent.size()));
    if (compareResult == 0) {
        std::cout << "Data in CUDA host memory matches the file content." << std::endl;
    } else {
        std::cout << "Data in CUDA host memory does not match the file content." << std::endl;
    }


    // Don't forget to free the host memory when done:
    cudaFreeHost(cudaHostMemory);

    } catch (const std::exception& e) {
        std::cerr << "Error uploading file: " << e.what() << std::endl;
        cudaFree(cudaDeviceMemory);
        return 1;
    }
    // You can now use the data in CUDA device memory for computations.
    // Don't forget to free it when you're done: cudaFree(cudaDeviceMemory);
    cudaFree(cudaDeviceMemory);


    return 0;
}



@mtavenrath mtavenrath closed this Jun 10, 2024
@mtavenrath mtavenrath reopened this Jun 10, 2024
@slaren
Copy link
Collaborator

slaren commented Jun 11, 2024

I think there are a few problems with the proposed API. It would require backends to implement a function that is not really necessary, since ggml-backend already exposes (almost) all the functionality necessary to implement this. It is also not clear to me how to free the temporary buffers and events used during loading. ggml-backend is missing a generic way to allocate pinned buffers, but it is available in llama.cpp with calls to the specific implementations of each backend.

It also think it would be preferable to have a higher level API in ggml that can load multiple tensors in a single call and hide most of the details of loading tensors from gguf files, but it is not clear how that would look like. My conclusion is that we need more time to figure how to move this functionality to ggml.

For the time being, this funcionality could be implemented in llama.cpp instead. We would need to extend llama_file to support direct (unbuffered) I/O to avoid unnecessary copies, and the buffers and events could be allocated through ggml-backend. Here is a rough overview of how you could implement this using the existing ggml-backend interfaces:

// allocate resources
ggml_backend_buffer_t host_buffer = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), size);
ggml_backend_t backend = ggml_backend_cuda_init(device);
ggml_backend_event_t event = ggml_backend_event_new(backend);

void * host_buffer_ptr = ggml_backend_buffer_get_base(host_buffer);

// copy loop
ggml_backend_event_synchronize(event);
file->read(host_buffer_ptr, ..);
ggml_backend_tensor_set_async(backend, tensor, host_buffer_ptr, ...);
ggml_backend_event_record(event);

// wait for all copies to finish
ggml_backend_synchronize(backend);

//  free resources
ggml_backend_buffer_free(pinned_buffer);
ggml_backend_event_free(event);
ggml_backend_free(backend);

Currently we are missing a function in llama.cpp to initialize a ggml_backend instance in a generic way (hence the call to ggml_backend_cuda_init), but at the moment only the CUDA backend implements the async and event interfaces, so it would be ok to keep it CUDA only for now, and we will figure how to extend it to other backends later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants